Skip to content

Commit

Permalink
update gradient ckpt
Browse files Browse the repository at this point in the history
  • Loading branch information
oahzxl committed Feb 18, 2024
1 parent dc19bb0 commit 674f54b
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 2 deletions.
20 changes: 19 additions & 1 deletion dit/models/dit.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import numpy as np
import torch
import torch.nn as nn
import torch.utils.checkpoint
from timm.models.vision_transformer import Attention, Mlp, PatchEmbed


Expand Down Expand Up @@ -182,6 +183,11 @@ def __init__(
self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
self.initialize_weights()

self.gradient_checkpointing = False

def gradient_checkpointing_enable(self):
self.gradient_checkpointing = True

def initialize_weights(self):
# Initialize transformer layers:
def _basic_init(module):
Expand Down Expand Up @@ -234,6 +240,13 @@ def unpatchify(self, x):
imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))
return imgs

@staticmethod
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)

return custom_forward

def forward(self, x, t, y):
"""
Forward pass of DiT.
Expand All @@ -245,8 +258,13 @@ def forward(self, x, t, y):
t = self.t_embedder(t, dtype=x.dtype) # (N, D)
y = self.y_embedder(y, self.training) # (N, D)
c = t + y # (N, D)

for block in self.blocks:
x = block(x, c) # (N, T, D)
if self.gradient_checkpointing:
x = torch.utils.checkpoint.checkpoint(self.create_custom_forward(block), x, c)
else:
x = block(x, c) # (N, T, D)

x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels)
x = self.unpatchify(x) # (N, out_channels, H, W)
return x
Expand Down
3 changes: 3 additions & 0 deletions dit/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,8 @@ def main(args):
model = DiT_models[args.model](input_size=latent_size, num_classes=args.num_classes).to(device).to(dtype)
model_numel = get_model_numel(model)
logger.info(f"Model params: {format_numel_str(model_numel)}")
if args.grad_checkpoint:
model.gradient_checkpointing_enable()

# Create ema and vae model
# Note that parameter initialization is done within the DiT constructor
Expand Down Expand Up @@ -375,5 +377,6 @@ def main(args):
parser.add_argument("--mixed_precision", type=str, default="bf16", choices=["bf16", "fp16"])
parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value")
parser.add_argument("--lr", type=float, default=1e-4, help="Gradient clipping value")
parser.add_argument("--grad_checkpoint", action="store_true", help="Use gradient checkpointing")
args = parser.parse_args()
main(args)
2 changes: 1 addition & 1 deletion dit/train.sh
Original file line number Diff line number Diff line change
@@ -1 +1 @@
torchrun --nnodes=1 --nproc_per_node=1 train.py --model DiT-XL/2
torchrun --nnodes=1 --nproc_per_node=1 train.py --model DiT-XL/2 --grad_checkpoint

0 comments on commit 674f54b

Please sign in to comment.