Super Resolution tools with Jax/Flax
pip install flaxsr
You can easily load model/losses and train model using custom train_states.
- Train example
import flaxsr
import jax
import jax.numpy as jnp
import numpy as np
import optax
model_kwargs = {
'n_filters': 64, 'n_blocks': 8, 'scale': 4
}
model = flaxsr.get("models", "vdsr", **model_kwargs) # This equals flaxsr.models.VDSR(**model_kwargs)
losses = [
flaxsr.losses.L1Loss(reduce='sum'),
flaxsr.get('losses', 'vgg', feats_from=(6, 8, 14,), before_act=False, reduce='mean')
]
loss_weights = (.1, 1.)
loss_wrapper = flaxsr.losses.LossWrapper(losses, loss_weights)
params = model.init(jax.random.PRNGKey(0), jnp.ones((1, 8, 8, 3), dtype=jnp.float32))
tx = optax.adam(1e-3)
state = flaxsr.training.TrainState.create(
apply_fn=model.apply, params=params, tx=tx, losses=loss_wrapper
)
hr = jnp.ones((1, 32, 32, 3), dtype=jnp.float32)
lr = jnp.ones((1, 8, 8, 3), dtype=jnp.float32)
batch = (lr, hr)
state_new, loss = flaxsr.training.discriminative_train_step(state, batch) # TODO: Fix This
assert state_new.step == 1
np.not_equal(state_new.params['params']['Conv_0']['kernel'], state.params['params']['Conv_0']['kernel'])
-
models
- SRCNN: srcnn
- FSRCNN: fsrcnn
- ESPCN: espcn
- VDSR: vdsr
- EDSR: edsr
- MDSR: mdsr
- SRResNet: srresnet
- SRGAN: srgan
- NCNet: ncnet
- ESRGAN: esrgan
- MEMNet: memnet
- RDN: rdn
- DRRN: drrn
- RCAN: rcan
- SAFMN: safmn
-
losses
- L1Loss: l1
- L2Loss: l2
- CharbonnierLoss: charbonnier
- VGGLoss: vgg
- MinmaxDriscriminatorLoss: minmax_discriminator
- MinmaxGeneratorLoss: minmax_generator
- LeastSquareDiscriminatorLoss: least_square_discriminator
- LeastSquareGeneratorLoss: least_square_generator
- RelativisticDiscriminatorLoss: relativistic_discriminator
- RelativisticGeneratorLoss: relativistic_generator
- TotalVariationLoss: tv
- FrequencyReconstructionLoss: freq_recon
- EdgeLoss: edge
-
layers
- DropPath: droppath
- DropPathFast: droppath_fast
- PixelShuffle: pixelshuffle
- NearestConv: nearestconv