Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Vae and AutoencoderKL] Final clean of LDM checkpoints #137

Merged
merged 9 commits into from
Jul 28, 2022
Prev Previous commit
Next Next commit
more progress
  • Loading branch information
patrickvonplaten committed Jul 27, 2022
commit 894bc12025989b64694f1cd5183295bf74b8b578
282 changes: 166 additions & 116 deletions src/diffusers/models/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,17 @@ def __init__(
block_out_channels=(64,),
layers_per_block=2,
act_fn="silu",
# To delete
ch=None,
ch_mult=(1, 2, 4, 8),
num_res_blocks=None,
attn_resolutions=None,
dropout=0.0,
resamp_with_conv=True,
resolution=None,
z_channels=None,
double_z=True,
**ignore_kwargs,
# To delete
# ch=None,
# ch_mult=(1, 2, 4, 8),
# num_res_blocks=None,
# attn_resolutions=None,
# dropout=0.0,
# resamp_with_conv=True,
# resolution=None,
# z_channels=None,
# **ignore_kwargs,
):
super().__init__()
# self.ch = ch
Expand All @@ -47,23 +47,28 @@ def __init__(
# self.num_res_blocks = num_res_blocks
# self.resolution = resolution
# self.in_channels = in_channels
self.init_orig(
ch=ch,
ch_mult=ch_mult,
resolution=resolution,
z_channels=z_channels,
dropout=0.0,
attn_resolutions=attn_resolutions,
resamp_with_conv=resamp_with_conv,
num_res_blocks=num_res_blocks,
)
self.weights_is_set = False
self.layers_per_block = layers_per_block

if True:
block_out_channels = [ch * c for c in ch_mult]
down_block_types = [down_block_types[0] for _ in range(len(block_out_channels))]
self.layers_per_block = num_res_blocks
out_channels = z_channels
ch = block_out_channels[0]
ch_mult = [x // ch for x in block_out_channels]
resolution = None
z_channels = out_channels
attn_resolutions = ()
num_res_blocks = layers_per_block
resamp_with_conv = True

self.init_orig(
ch=ch,
ch_mult=ch_mult,
resolution=resolution,
z_channels=z_channels,
dropout=0.0,
attn_resolutions=attn_resolutions,
resamp_with_conv=resamp_with_conv,
num_res_blocks=num_res_blocks,
)
self.weights_is_set = False

self.conv_in = torch.nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1)

Expand Down Expand Up @@ -106,11 +111,14 @@ def __init__(
num_groups_out = 32
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=num_groups_out, eps=1e-6)
self.conv_act = nn.SiLU()
self.conv_out = nn.Conv2d(block_out_channels[-1], 2 * out_channels, 3, padding=1)

conv_out_channels = 2 * out_channels if double_z else out_channels
self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1)

def init_orig(self, ch, ch_mult, resolution, z_channels, dropout, attn_resolutions, resamp_with_conv, num_res_blocks):
# downsampling
curr_res = resolution
# curr_res = resolution
curr_res = 32
in_ch_mult = (1,) + tuple(ch_mult)
self.down = nn.ModuleList()
num_resolutions = len(ch_mult)
Expand Down Expand Up @@ -242,38 +250,42 @@ def __init__(
layers_per_block=2,
act_fn="silu",
# To delete
ch=None,
out_ch=None,
ch_mult=(1, 2, 4, 8),
num_res_blocks=None,
attn_resolutions=None,
dropout=0.0,
resamp_with_conv=True,
resolution=None,
z_channels=None,
give_pre_end=False,
**ignorekwargs,
# ch=None,
# out_ch=None,
# ch_mult=(1, 2, 4, 8),
# num_res_blocks=None,
# attn_resolutions=None,
# dropout=0.0,
# resamp_with_conv=True,
# resolution=None,
# z_channels=None,
# give_pre_end=False,
# **ignorekwargs,
):
super().__init__()

self.init_orig(
ch=ch,
ch_mult=ch_mult,
resolution=resolution,
z_channels=z_channels,
dropout=0.0,
attn_resolutions=attn_resolutions,
resamp_with_conv=resamp_with_conv,
out_ch=out_ch,
num_res_blocks=num_res_blocks,
)
self.weights_is_set = False
self.layers_per_block = layers_per_block

if True:
in_channels = z_channels
block_out_channels = [ch * c for c in ch_mult]
up_block_types = [up_block_types[0] for _ in range(len(block_out_channels))]
self.layers_per_block = num_res_blocks
ch = block_out_channels[0]
ch_mult = [x // ch for x in block_out_channels]
resolution = None
z_channels = in_channels
attn_resolutions = ()
num_res_blocks = layers_per_block
resamp_with_conv = True

self.init_orig(
ch=ch,
ch_mult=ch_mult,
resolution=resolution,
z_channels=z_channels,
dropout=0.0,
attn_resolutions=attn_resolutions,
resamp_with_conv=resamp_with_conv,
out_ch=out_channels,
num_res_blocks=num_res_blocks,
)
self.weights_is_set = False

self.conv_in = nn.Conv2d(in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1)

Expand Down Expand Up @@ -326,6 +338,7 @@ def init_orig(
self, ch, ch_mult, resolution, z_channels, dropout, attn_resolutions, resamp_with_conv, out_ch, num_res_blocks
):
# compute in_ch_mult, block_in and curr_res at lowest res
resolution = 32
block_in = ch * ch_mult[len(ch_mult) - 1]
curr_res = resolution // 2 ** (len(ch_mult) - 1)
self.z_shape = (1, z_channels, curr_res, curr_res)
Expand Down Expand Up @@ -616,15 +629,22 @@ class VQModel(ModelMixin, ConfigMixin):
@register_to_config
def __init__(
self,
ch,
out_ch,
num_res_blocks,
attn_resolutions,
in_channels,
resolution,
z_channels,
n_embed,
embed_dim,
in_channels=3,
out_channels=3,
down_block_types=("DownEncoderBlock2D",),
up_block_types=("UpDecoderBlock2D",),
block_out_channels=(64,),
layers_per_block=2,
act_fn="silu",
# to delete
ch=None,
out_ch=None,
num_res_blocks=None,
attn_resolutions=None,
resolution=None,
z_channels=None,
n_embed=None,
embed_dim=None,
remap=None,
sane_index_shape=False, # tell vector quantizer to return indices as bhw
ch_mult=(1, 2, 4, 8),
Expand All @@ -635,19 +655,22 @@ def __init__(
):
super().__init__()

if True:
block_out_channels = [ch * c for c in ch_mult]
down_block_types = [down_block_types[0] for _ in range(len(block_out_channels))]
up_block_types = [up_block_types[0] for _ in range(len(block_out_channels))]
layers_per_block = num_res_blocks
latent_channels = z_channels

# pass init params to Encoder
self.encoder = Encoder(
ch=ch,
num_res_blocks=num_res_blocks,
attn_resolutions=attn_resolutions,
in_channels=in_channels,
resolution=resolution,
z_channels=z_channels,
ch_mult=ch_mult,
dropout=dropout,
resamp_with_conv=resamp_with_conv,
double_z=double_z,
give_pre_end=give_pre_end,
out_channels=latent_channels,
down_block_types=down_block_types,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
act_fn=act_fn,
double_z=False,
)

self.quant_conv = torch.nn.Conv2d(z_channels, embed_dim, 1)
Expand All @@ -656,17 +679,12 @@ def __init__(

# pass init params to Decoder
self.decoder = Decoder(
ch=ch,
out_ch=out_ch,
num_res_blocks=num_res_blocks,
attn_resolutions=attn_resolutions,
in_channels=in_channels,
resolution=resolution,
z_channels=z_channels,
ch_mult=ch_mult,
dropout=dropout,
resamp_with_conv=resamp_with_conv,
give_pre_end=give_pre_end,
in_channels=latent_channels,
out_channels=out_channels,
up_block_types=up_block_types,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
act_fn=act_fn,
)

def encode(self, x):
Expand Down Expand Up @@ -695,14 +713,21 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
@register_to_config
def __init__(
self,
ch,
out_ch,
num_res_blocks,
attn_resolutions,
in_channels,
resolution,
z_channels,
embed_dim,
in_channels=3,
out_channels=3,
down_block_types=("DownEncoderBlock2D",),
up_block_types=("UpDecoderBlock2D",),
block_out_channels=(64,),
layers_per_block=2,
act_fn="silu",
# to delete
ch=None,
out_ch=None,
num_res_blocks=None,
attn_resolutions=None,
resolution=None,
z_channels=None,
embed_dim=None,
remap=None,
sane_index_shape=False, # tell vector quantizer to return indices as bhw
ch_mult=(1, 2, 4, 8),
Expand All @@ -713,36 +738,61 @@ def __init__(
):
super().__init__()

if True:
block_out_channels = [ch * c for c in ch_mult]
down_block_types = [down_block_types[0] for _ in range(len(block_out_channels))]
up_block_types = [up_block_types[0] for _ in range(len(block_out_channels))]
layers_per_block = num_res_blocks
latent_channels = z_channels

# pass init params to Encoder
self.encoder = Encoder(
ch=ch,
out_ch=out_ch,
num_res_blocks=num_res_blocks,
attn_resolutions=attn_resolutions,
in_channels=in_channels,
resolution=resolution,
z_channels=z_channels,
ch_mult=ch_mult,
dropout=dropout,
resamp_with_conv=resamp_with_conv,
double_z=double_z,
give_pre_end=give_pre_end,
out_channels=latent_channels,
down_block_types=down_block_types,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
act_fn=act_fn,
double_z=True,
)

# self.encoder = Encoder(
# ch=ch,
# out_ch=out_ch,
# num_res_blocks=num_res_blocks,
# attn_resolutions=attn_resolutions,
# in_channels=in_channels,
# resolution=resolution,
# z_channels=z_channels,
# ch_mult=ch_mult,
# dropout=dropout,
# resamp_with_conv=resamp_with_conv,
# give_pre_end=give_pre_end,
# double_z=True,
# )
#
# pass init params to Decoder
self.decoder = Decoder(
ch=ch,
out_ch=out_ch,
num_res_blocks=num_res_blocks,
attn_resolutions=attn_resolutions,
in_channels=in_channels,
resolution=resolution,
z_channels=z_channels,
ch_mult=ch_mult,
dropout=dropout,
resamp_with_conv=resamp_with_conv,
give_pre_end=give_pre_end,
in_channels=latent_channels,
out_channels=out_channels,
up_block_types=up_block_types,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
act_fn=act_fn,
)
# pass init params to Decoder
# self.decoder = Decoder(
# ch=ch,
# out_ch=out_ch,
# num_res_blocks=num_res_blocks,
# attn_resolutions=attn_resolutions,
# in_channels=in_channels,
# resolution=resolution,
# z_channels=z_channels,
# ch_mult=ch_mult,
# dropout=dropout,
# resamp_with_conv=resamp_with_conv,
# give_pre_end=give_pre_end,
# )

self.quant_conv = torch.nn.Conv2d(2 * z_channels, 2 * embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, z_channels, 1)
Expand Down