Skip to content

Commit

Permalink
support cfg parallel (#158)
Browse files Browse the repository at this point in the history
* Enabling conditional parallel for classifier-free guidance

* Seeding bug fix.

* update cp

* update

---------

Co-authored-by: ExtremeViscent <[email protected]>
  • Loading branch information
oahzxl and ExtremeViscent committed Aug 23, 2024
1 parent 135631f commit 4335469
Show file tree
Hide file tree
Showing 5 changed files with 136 additions and 16 deletions.
52 changes: 39 additions & 13 deletions opendit/core/parallel_mgr.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,35 +6,46 @@
from colossalai.cluster.process_group_mesh import ProcessGroupMesh
from torch.distributed import ProcessGroup

from opendit.utils.logging import init_dist_logger
from opendit.utils.logging import init_dist_logger, logger
from opendit.utils.utils import set_seed

PARALLEL_MANAGER = None


class ParallelManager(ProcessGroupMesh):
def __init__(self, dp_size, sp_size, dp_axis, sp_axis):
super().__init__(dp_size, sp_size)
self.dp_axis = dp_axis
self.dp_group: ProcessGroup = self.get_group_along_axis(self.dp_axis)
def __init__(self, dp_size, cp_size, sp_size):
super().__init__(dp_size, cp_size, sp_size)
dp_axis, cp_axis, sp_axis = 0, 1, 2

self.dp_size = dp_size
self.dp_group: ProcessGroup = self.get_group_along_axis(dp_axis)
self.dp_rank = dist.get_rank(self.dp_group)

self.cp_size = cp_size
self.cp_group: ProcessGroup = self.get_group_along_axis(cp_axis)
self.cp_rank = dist.get_rank(self.cp_group)

self.sp_size = sp_size
self.sp_axis = sp_axis
self.sp_group: ProcessGroup = self.get_group_along_axis(self.sp_axis)
self.sp_group: ProcessGroup = self.get_group_along_axis(sp_axis)
self.sp_rank = dist.get_rank(self.sp_group)
self.enable_sp = sp_size > 1

logger.info(f"Init parallel manager with dp_size: {dp_size}, cp_size: {cp_size}, sp_size: {sp_size}")


def set_parallel_manager(dp_size, sp_size, dp_axis=0, sp_axis=1):
def set_parallel_manager(dp_size, cp_size, sp_size):
global PARALLEL_MANAGER
PARALLEL_MANAGER = ParallelManager(dp_size, sp_size, dp_axis, sp_axis)
PARALLEL_MANAGER = ParallelManager(dp_size, cp_size, sp_size)


def get_data_parallel_group():
return PARALLEL_MANAGER.dp_group


def get_data_parallel_size():
return PARALLEL_MANAGER.dp_size


def get_data_parallel_rank():
return PARALLEL_MANAGER.dp_rank

Expand All @@ -51,6 +62,14 @@ def get_sequence_parallel_rank():
return PARALLEL_MANAGER.sp_rank


def get_cfg_parallel_group():
return PARALLEL_MANAGER.cp_group


def get_cfg_parallel_size():
return PARALLEL_MANAGER.cp_size


def enable_sequence_parallel():
if PARALLEL_MANAGER is None:
return False
Expand All @@ -61,22 +80,29 @@ def get_parallel_manager():
return PARALLEL_MANAGER


def initialize(seed: Optional[int] = None, sp_size: Optional[int] = None):
def initialize(seed: Optional[int] = None, sp_size: Optional[int] = None, enable_cp=True):
if not dist.is_initialized():
colossalai.launch_from_torch({})
init_dist_logger()
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

# init sequence parallel
if sp_size is None:
sp_size = dist.get_world_size()
dp_size = 1
else:
assert dist.get_world_size() % sp_size == 0, f"world_size {dist.get_world_size()} must be divisible by sp_size"
dp_size = dist.get_world_size() // sp_size

set_parallel_manager(dp_size, sp_size)
# update cfg parallel
if enable_cp and sp_size % 2 == 0:
sp_size = sp_size // 2
cp_size = 2
else:
cp_size = 1

set_parallel_manager(dp_size, cp_size, sp_size)

if seed is not None:
local_seed = seed + get_data_parallel_rank()
set_seed(local_seed)
set_seed(seed + get_data_parallel_rank())
35 changes: 34 additions & 1 deletion opendit/models/latte/latte_t2v.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@


from dataclasses import dataclass
from functools import partial
from typing import Any, Dict, Optional, Tuple

import torch
Expand Down Expand Up @@ -50,7 +51,13 @@
if_broadcast_temporal,
save_mlp_output,
)
from opendit.core.parallel_mgr import enable_sequence_parallel, get_sequence_parallel_group
from opendit.core.parallel_mgr import (
enable_sequence_parallel,
get_cfg_parallel_group,
get_cfg_parallel_size,
get_sequence_parallel_group,
)
from opendit.utils.utils import batch_func


@maybe_allow_in_graph
Expand Down Expand Up @@ -1182,6 +1189,28 @@ def forward(
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""

# 0. Split batch for data parallelism
if get_cfg_parallel_size() > 1:
(
hidden_states,
timestep,
encoder_hidden_states,
added_cond_kwargs,
class_labels,
attention_mask,
encoder_attention_mask,
) = batch_func(
partial(split_sequence, process_group=get_cfg_parallel_group(), dim=0),
hidden_states,
timestep,
encoder_hidden_states,
added_cond_kwargs,
class_labels,
attention_mask,
encoder_attention_mask,
)

input_batch_size, c, frame, h, w = hidden_states.shape
frame = frame - use_image_num
hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w").contiguous()
Expand Down Expand Up @@ -1422,6 +1451,10 @@ def forward(
)
output = rearrange(output, "(b f) c h w -> b c f h w", b=input_batch_size).contiguous()

# 3. Gather batch for data parallelism
if get_cfg_parallel_size() > 1:
output = gather_sequence(output, get_cfg_parallel_group(), dim=0)

if not return_dict:
return (output,)

Expand Down
20 changes: 19 additions & 1 deletion opendit/models/opensora/stdit3.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@


import os
from functools import partial

import numpy as np
import torch
Expand Down Expand Up @@ -36,7 +37,13 @@
if_broadcast_temporal,
save_mlp_output,
)
from opendit.core.parallel_mgr import enable_sequence_parallel, get_sequence_parallel_group
from opendit.core.parallel_mgr import (
enable_sequence_parallel,
get_cfg_parallel_size,
get_data_parallel_group,
get_sequence_parallel_group,
)
from opendit.utils.utils import batch_func

from .modules import (
Attention,
Expand Down Expand Up @@ -449,6 +456,12 @@ def encode_text(self, y, mask=None):
def forward(
self, x, timestep, all_timesteps, y, mask=None, x_mask=None, fps=None, height=None, width=None, **kwargs
):
# === Split batch ===
if get_cfg_parallel_size() > 1:
x, timestep, y, x_mask, mask = batch_func(
partial(split_sequence, process_group=get_data_parallel_group(), dim=0), x, timestep, y, x_mask, mask
)

dtype = self.x_embedder.proj.weight.dtype
B = x.size(0)
x = x.to(dtype)
Expand Down Expand Up @@ -545,6 +558,11 @@ def forward(

# cast to float32 for better accuracy
x = x.to(torch.float32)

# === Gather Output ===
if get_cfg_parallel_size() > 1:
x = gather_sequence(x, get_data_parallel_group(), dim=0)

return x

def unpatchify(self, x, N_t, N_h, N_w, R_t, R_h, R_w):
Expand Down
31 changes: 30 additions & 1 deletion opendit/models/opensora_plan/latte.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import json
import os
from dataclasses import dataclass
from functools import partial
from importlib import import_module
from typing import Any, Callable, Dict, Optional, Tuple

Expand Down Expand Up @@ -63,8 +64,14 @@
if_broadcast_temporal,
save_mlp_output,
)
from opendit.core.parallel_mgr import enable_sequence_parallel, get_sequence_parallel_group
from opendit.core.parallel_mgr import (
enable_sequence_parallel,
get_cfg_parallel_group,
get_cfg_parallel_size,
get_sequence_parallel_group,
)
from opendit.utils.logging import logger
from opendit.utils.utils import batch_func

if is_xformers_available():
import xformers
Expand Down Expand Up @@ -2458,6 +2465,24 @@ def forward(
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
# 0. Split batch
if get_cfg_parallel_size() > 1:
(
hidden_states,
timestep,
encoder_hidden_states,
class_labels,
attention_mask,
encoder_attention_mask,
) = batch_func(
partial(split_sequence, process_group=get_cfg_parallel_group(), dim=0),
hidden_states,
timestep,
encoder_hidden_states,
class_labels,
attention_mask,
encoder_attention_mask,
)
input_batch_size, c, frame, h, w = hidden_states.shape
frame = frame - use_image_num # 20-4=16
hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w").contiguous()
Expand Down Expand Up @@ -2739,6 +2764,10 @@ def forward(
)
output = rearrange(output, "(b f) c h w -> b c f h w", b=input_batch_size).contiguous()

# 3. Gather batch for data parallelism
if get_cfg_parallel_size() > 1:
output = gather_sequence(output, get_cfg_parallel_group(), dim=0)

if not return_dict:
return (output,)

Expand Down
14 changes: 14 additions & 0 deletions opendit/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,20 @@ def str_to_dtype(x: str):
raise RuntimeError(f"Only fp32, fp16 and bf16 are supported, but got {x}")


def batch_func(func, *args):
"""
Apply a function to each element of a batch.
"""
batch = []
for arg in args:
if isinstance(arg, torch.Tensor) and arg.shape[0] == 2:
batch.append(func(arg))
else:
batch.append(arg)

return batch


def merge_args(args1, args2):
"""
Merge two argparse Namespace objects.
Expand Down

0 comments on commit 4335469

Please sign in to comment.