Skip to content

Commit

Permalink
fix for pipleline parallel
Browse files Browse the repository at this point in the history
Signed-off-by: Yi Dong <[email protected]>
  • Loading branch information
doyend committed Oct 6, 2022
1 parent 323bca7 commit eee5d38
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@

try:
from apex.transformer import parallel_state, tensor_parallel
from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_without_interleaving import (
forward_backward_pipelining_without_interleaving,
)
from apex.transformer.pipeline_parallel.schedules.fwd_bwd_no_pipelining import forward_backward_no_pipelining

HAVE_APEX = True

Expand Down Expand Up @@ -94,7 +98,8 @@ def __init__(self, cfg: DictConfig, trainer: Trainer):
self._inference_config = None
self.hidden_size = self.frozen_model.cfg.hidden_size
self.padded_vocab_size = self.frozen_model.padded_vocab_size
self.word_embeddings = self.frozen_model.model.language_model.embedding.word_embeddings
if self.frozen_model.model.pre_process:
self.word_embeddings = self.frozen_model.model.language_model.embedding.word_embeddings
self._prompt_encoder_key = 'prompt_encoder'

def add_virtual_prompt_params_to_param_group(self):
Expand Down Expand Up @@ -282,8 +287,47 @@ def fwd_bwd_step(self, batch, forward_only):
Dataloader produces a global batch which is turned into a list of microbatches.
The list of microbatches is then piped through the pipeline using Apex fwd/bwd functions.
"""
sequence_parallel_enabled = (self.cfg.get("sequence_parallel", False),)
return super().fwd_bwd_step(batch, forward_only, sequence_parallel_enabled=sequence_parallel_enabled)
disable_autocast = False
sequence_parallel_enabled = self.cfg.get("sequence_parallel", False)
# Get seq length of batch
_, seq_length = batch[0].shape
tensor_shape = [seq_length + self.cfg.perceiver.hidden_steps, self.cfg.micro_batch_size, self.hidden_size]

if self.pipeline_parallel:
losses_reduced_per_micro_batch = forward_backward_pipelining_without_interleaving(
forward_step_func=self.get_forward_output_and_loss_func(),
batch=batch,
model=self,
forward_only=forward_only,
tensor_shape=tensor_shape,
dtype=self.autocast_dtype,
disable_autocast=disable_autocast,
grad_scaler=self.trainer.precision_plugin.scaler if self.cfg.precision == 16 else None,
sequence_parallel_enabled=sequence_parallel_enabled,
)
else:
losses_reduced_per_micro_batch = forward_backward_no_pipelining(
forward_step_func=self.get_forward_output_and_loss_func(),
batch=batch,
model=self,
forward_only=forward_only,
tensor_shape=tensor_shape,
dtype=self.autocast_dtype,
disable_autocast=disable_autocast,
grad_scaler=self.trainer.precision_plugin.scaler if self.cfg.precision == 16 else None,
)

# only the last stages of the pipeline return losses
if losses_reduced_per_micro_batch:
# average loss across micro batches
loss_tensors_list = [loss_reduced['avg'] for loss_reduced in losses_reduced_per_micro_batch]
loss_tensor = torch.concat(loss_tensors_list)
loss_mean = loss_tensor.mean()
else:
# we're not on the last pipeline stage so no losses
loss_mean = torch.tensor(0.0).cuda()

return loss_mean

def get_forward_output_and_loss_func(self):
def fwd_output_and_loss_func(batch, model):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1833,7 +1833,7 @@ def get_num_layers(self, num_layers):
num_layers = num_layers // num_ranks_in_encoder
else:
num_layers = num_layers // num_ranks_in_decoder
else:
elif self.model_type == ModelType.encoder_or_decoder:
assert (
num_layers % parallel_state.get_pipeline_model_parallel_world_size() == 0
), 'num_layers must be divisible by pipeline_model_parallel_size'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __init__(
"""
"""
super().__init__()
self.encoder = MegatronPerceiverEncoderModule(**cfg)
self.encoder = MegatronPerceiverEncoderModule(**cfg, parent_model_type=None)
self.hidden = self.encoder.hidden_size
self.input_linear = nn.Linear(output_dim, self.hidden)
self.output_linear = nn.Linear(self.hidden, output_dim)
Expand Down

0 comments on commit eee5d38

Please sign in to comment.