Skip to content

Commit

Permalink
Add O2 support for RETRO model (#4411)
Browse files Browse the repository at this point in the history
* make sure post-ln works with new coeff

Signed-off-by: Yi Dong <[email protected]>

* add some comments

Signed-off-by: Yi Dong <[email protected]>

* second v

Signed-off-by: Yi Dong <[email protected]>

* fix headscale

Signed-off-by: Yi Dong <[email protected]>

* works both for pre-ln and post-ln

Signed-off-by: Yi Dong <[email protected]>

* fix unittest

Signed-off-by: Yi Dong <[email protected]>

* stop gradient

Signed-off-by: Yi Dong <[email protected]>

* stop gradient

Signed-off-by: Yi Dong <[email protected]>

* use half rotary embedding

Signed-off-by: Yi Dong <[email protected]>

* use default grad clip

Signed-off-by: Yi Dong <[email protected]>

* turn off rotary embedding

Signed-off-by: Yi Dong <[email protected]>

* added o2 support

Signed-off-by: Yi Dong <[email protected]>

* fix style

Signed-off-by: Yi Dong <[email protected]>

* add debugging

Signed-off-by: Yi Dong <[email protected]>

* make cyclic lr work

Signed-off-by: Yi Dong <[email protected]>

* o2 works with cyclic lr

Signed-off-by: Yi Dong <[email protected]>

* remove deepnet

Signed-off-by: Yi Dong <[email protected]>

* fix merge error

Signed-off-by: Yi Dong <[email protected]>

* update the comments

Signed-off-by: Yi Dong <[email protected]>

* added output scaling for stable training

Signed-off-by: Yi Dong <[email protected]>

* improve the debug code

Signed-off-by: Yi Dong <[email protected]>

* fix comment

Signed-off-by: Yi Dong <[email protected]>

* move debug hook above

Signed-off-by: Yi Dong <[email protected]>

* move optimizer config to base class

Signed-off-by: Yi Dong <[email protected]>

* address comments

Signed-off-by: Yi Dong <[email protected]>

Co-authored-by: Eric Harper <[email protected]>
  • Loading branch information
yidong72 and ericharper committed Jun 28, 2022
1 parent d8785e0 commit c9f16fd
Show file tree
Hide file tree
Showing 15 changed files with 375 additions and 114 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ model:
headscale: False # Whether to learn extra parameters that scale the output of the each self-attention head.
transformer_block_type: 'pre_ln' # Options ['pre_ln', 'post_ln', 'normformer']

dump_debug_info: False # dump out the debug information
dump_debug_info_to_file: False # dump out the debug information to files

# retro architecture
chunk_size: 64 # the chunk size used to retrive
enc_num_layers: 4 # total number of encoder layers
Expand All @@ -81,6 +84,9 @@ model:
post_process: True # add pooler
bert_binary_head: True # BERT binary head

megatron_amp_O2: False # use AMP with O2 style mixed precision instead of native amp on-the-fly weight autocasting.
grad_allreduce_chunk_size_mb: 125

tokenizer:
library: 'megatron'
type: 'GPT2BPETokenizer'
Expand Down
10 changes: 7 additions & 3 deletions examples/nlp/language_modeling/megatron_retro_pretraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector

from nemo.collections.nlp.models.language_modeling.megatron_retrieval_model import MegatronRetrievalModel
from nemo.collections.nlp.parts.nlp_overrides import GradScaler, NLPDDPPlugin
from nemo.collections.nlp.parts.nlp_overrides import GradScaler, MegatronHalfPrecisionPlugin, NLPDDPPlugin
from nemo.core.config import hydra_runner
from nemo.utils import logging
from nemo.utils.exp_manager import StatelessTimer, exp_manager
Expand All @@ -31,9 +31,10 @@ def main(cfg) -> None:
logging.info("\n\n************** Experiment configuration ***********")
logging.info(f'\n{OmegaConf.to_yaml(cfg)}')

megatron_amp_o2 = cfg.model.get('megatron_amp_O2', False)
plugins = [
NLPDDPPlugin(
no_ddp_communication_hook=False,
no_ddp_communication_hook=True if megatron_amp_o2 else False,
gradient_as_bucket_view=cfg.model.gradient_as_bucket_view,
find_unused_parameters=False,
)
Expand All @@ -47,7 +48,10 @@ def main(cfg) -> None:
growth_interval=cfg.model.get('native_amp_growth_interval', 1000),
hysteresis=cfg.model.get('hysteresis', 2),
)
plugins.append(NativeMixedPrecisionPlugin(precision=cfg.trainer.precision, device='cuda', scaler=scaler))
if megatron_amp_o2:
plugins.append(MegatronHalfPrecisionPlugin(precision=cfg.trainer.precision, device='cuda', scaler=scaler))
else:
plugins.append(NativeMixedPrecisionPlugin(precision=cfg.trainer.precision, device='cuda', scaler=scaler))

if cfg.get('cluster_type', None) == 'BCP':
plugins.append(TorchElasticEnvironment())
Expand Down
1 change: 1 addition & 0 deletions nemo/collections/asr/metrics/multi_binary_acc.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import logging

import torch
from torchmetrics import Metric

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from nemo.collections.nlp.modules.common.megatron.clip_grads import clip_grad_norm_fp32
from nemo.collections.nlp.modules.common.megatron.megatron_init import initialize_model_parallel_for_nemo
from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer
from nemo.core.optim import MainParamsOptimizerWrapper, prepare_lr_scheduler
from nemo.utils import logging

try:
Expand Down Expand Up @@ -221,3 +222,47 @@ def allreduce_gradients(self):
torch.distributed.all_reduce(coalesced, group=parallel_state.get_data_parallel_group())
for buf, synced in zip(grads, torch._utils._unflatten_dense_tensors(coalesced, grads)):
buf.copy_(synced)

def configure_optimizers(self):
self.setup_optimization()

# Wrap the baseline optimizer with the optimizer class with master parameters
if self.megatron_amp_o2 and self._optimizer is not None:
if self.cfg.precision == 'bf16':
fp32_grad_accum = True
contiguous_grad_bucket = True
elif self.cfg.precision == 16:
fp32_grad_accum = False
# TODO: contiguous grad bucket for fp16 is also planned to be supported
contiguous_grad_bucket = False
raise ValueError(
"fp16 training is not yet supported with O2. Please set megatron_amp_O2 to False in the model config."
)

# if using tensor parallel only, we can use async grad all-reduce
if self.cfg.get('pipeline_model_parallel_size', 1) == 1:
async_grad_allreduce = True
else:
async_grad_allreduce = False

self._optimizer = MainParamsOptimizerWrapper(
self._optimizer,
fp32_grad_accum=fp32_grad_accum,
contiguous_grad_bucket=contiguous_grad_bucket,
async_grad_allreduce=async_grad_allreduce,
grad_div_ar_fusion=self.cfg.get('grad_div_ar_fusion', True),
grad_allreduce_chunk_size_mb=self.cfg.get('grad_allreduce_chunk_size_mb', 125),
)

assert self._trainer.max_steps is not None, "'max_steps' is missing in trainer config."
if hasattr(self._cfg.optim, 'sched'):
sched_config = self._cfg.optim.sched
sched_config['max_steps'] = self._trainer.max_steps
self._scheduler = prepare_lr_scheduler(
optimizer=self._optimizer, scheduler_config=sched_config, train_dataloader=self._train_dl
)

if self._scheduler is None:
return self._optimizer
else:
return [self._optimizer], [self._scheduler]
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@
from nemo.collections.nlp.parts.nlp_overrides import GradScaler
from nemo.collections.nlp.parts.utils_funcs import get_last_rank
from nemo.core.classes.common import PretrainedModelInfo
from nemo.core.optim import MainParamsOptimizerWrapper, prepare_lr_scheduler
from nemo.utils import AppState, logging

try:
Expand Down Expand Up @@ -589,49 +588,6 @@ def setup_test_data(self, cfg):
)
self._test_dl = self.build_pretraining_data_loader(self._test_ds, consumed_samples)

def configure_optimizers(self):
self.setup_optimization()

# Wrap the baseline optimizer with the optimizer class with master parameters
if self.megatron_amp_o2 and self._optimizer is not None:
if self.cfg.precision == 'bf16':
fp32_grad_accum = True
contiguous_grad_bucket = True
elif self.cfg.precision == 16:
fp32_grad_accum = False
# TODO: contiguous grad bucket for fp16 is also planned to be supported
contiguous_grad_bucket = False
raise ValueError(
"fp16 training is not yet supported with O2. Please set megatron_amp_O2 to False in the model config."
)

# if using tensor parallel only, we can use async grad all-reduce
if self.cfg.get('pipeline_model_parallel_size', 1) == 1:
async_grad_allreduce = True
else:
async_grad_allreduce = False

self._optimizer = MainParamsOptimizerWrapper(
self._optimizer,
fp32_grad_accum=fp32_grad_accum,
contiguous_grad_bucket=contiguous_grad_bucket,
async_grad_allreduce=async_grad_allreduce,
grad_div_ar_fusion=self.cfg.get('grad_div_ar_fusion', True),
grad_allreduce_chunk_size_mb=self.cfg.get('grad_allreduce_chunk_size_mb', 125),
)

assert self._trainer.max_steps is not None, "'max_steps' is missing in trainer config."
sched_config = self._cfg.optim.sched
sched_config['max_steps'] = self._trainer.max_steps
self._scheduler = prepare_lr_scheduler(
optimizer=self._optimizer, scheduler_config=sched_config, train_dataloader=self._train_dl
)

if self._scheduler is None:
return self._optimizer
else:
return [self._optimizer], [self._scheduler]

def compute_consumed_samples(self, steps_since_resume=0):
app_state = AppState()
consumed_samples = (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
from nemo.collections.nlp.modules.common.transformer.text_generation import LengthParam, SamplingParam, TextGeneration
from nemo.collections.nlp.parts.nlp_overrides import NLPSaveRestoreConnector
from nemo.collections.nlp.parts.utils_funcs import get_last_rank
from nemo.core.optim import MainParamsOptimizerWrapper, prepare_lr_scheduler
from nemo.utils import logging

try:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
)
from nemo.collections.nlp.parts.nlp_overrides import GradScaler
from nemo.collections.nlp.parts.utils_funcs import get_last_rank
from nemo.core.optim import MainParamsOptimizerWrapper, prepare_lr_scheduler
from nemo.utils import AppState, logging

try:
Expand Down Expand Up @@ -730,48 +729,6 @@ def setup_test_data(self, cfg):
consumed_samples = 0
self._test_dl = self.build_pretraining_data_loader(self._test_ds, consumed_samples)

def configure_optimizers(self):
self.setup_optimization()

# Wrap the baseline optimizer with the optimizer class with master parameters
if self.megatron_amp_o2 and self._optimizer is not None:
if self.cfg.precision == 'bf16':
fp32_grad_accum = True
contiguous_grad_bucket = True

elif self.cfg.precision == 16:
fp32_grad_accum = False
# TODO: contiguous grad bucket for fp16 is also planned to be supported
contiguous_grad_bucket = False

# if using tensor parallel only, we can use async grad all-reduce
if self.cfg.get('pipeline_model_parallel_size', 1) == 1:
async_grad_allreduce = True
else:
async_grad_allreduce = False

self._optimizer = MainParamsOptimizerWrapper(
self._optimizer,
fp32_grad_accum=fp32_grad_accum,
contiguous_grad_bucket=contiguous_grad_bucket,
async_grad_allreduce=async_grad_allreduce,
grad_div_ar_fusion=self.cfg.get('grad_div_ar_fusion', True),
grad_allreduce_chunk_size_mb=self.cfg.get('grad_allreduce_chunk_size_mb', 125),
)

assert self._trainer.max_steps is not None, "'max_steps' is missing in trainer config."
if hasattr(self._cfg.optim, 'sched'):
sched_config = self._cfg.optim.sched
sched_config['max_steps'] = self._trainer.max_steps
self._scheduler = prepare_lr_scheduler(
optimizer=self._optimizer, scheduler_config=sched_config, train_dataloader=self._train_dl
)

if self._scheduler is None:
return self._optimizer
else:
return [self._optimizer], [self._scheduler]

def compute_consumed_samples(self, steps_since_resume=0):
app_state = AppState()
consumed_samples = (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
build_train_valid_test_datasets,
)
from nemo.collections.nlp.models.language_modeling.megatron_base_model import MegatronBaseModel
from nemo.collections.nlp.modules.common.megatron.module import Float16Module
from nemo.collections.nlp.modules.common.megatron.retrieval_token_level_encoder_decoder import (
MegatronRetrievalTokenLevelEncoderDecoderModule,
)
Expand Down Expand Up @@ -64,6 +65,17 @@ def __init__(self, cfg: DictConfig, trainer: Trainer):

# TODO does not support PP yet
self.model = self.model_provider_func(pre_process=True, post_process=True, add_encoder=True, add_decoder=True)

self.megatron_amp_o2 = cfg.get('megatron_amp_O2', False)

if self.megatron_amp_o2:

# Pre-allocate the model on GPU to have master parameters allocated on the same device with matching data type
self.model.cuda(torch.cuda.current_device())

# Model wrapper to convert both model and inputs to half precision
self.model = Float16Module(module=self.model, precision=self.cfg.precision)

# self.setup_optimizer_param_groups()
if self.cfg.precision == 32:
self.autocast_dtype = torch.float
Expand All @@ -74,8 +86,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer):
else:
raise ValueError('precision must be in [32, 16, "bf16"]')
self.model.model_type = ModelType.encoder_and_decoder
# not using amp o2
self.megatron_amp_o2 = False
# self.grad_clip_pl_default = True

def _build_tokenizer(self):
self.tokenizer = get_nmt_tokenizer(
Expand Down Expand Up @@ -185,6 +196,23 @@ def training_step(self, batch, batch_idx):
reduced_loss = average_losses_across_data_parallel_group([lm_loss])
self._reduced_loss_buffer.append(reduced_loss[0])

# while async grad allreduce is enabled, bprop will keep moving forward without waiting for
# the finish of async grad AR works. Hence, to guarantee the correctness of grads reduction,
# we cannot start weight update until all async grad AR works are done.
if self.megatron_amp_o2 and self.cfg.get('pipeline_model_parallel_size', 1) == 1:
torch.cuda.synchronize()

if self.megatron_amp_o2:
# when using pipeline parallelism grads must be reduced after the pipeline (not asynchronously)
if self.cfg.get('pipeline_model_parallel_size', 1) > 1:
# main grads are stored in the MainParamsOptimizer wrapper
self._optimizer.allreduce_main_grads()
else:
# async grad allreduce is not currently implemented for O1/autocasting mixed precision training
# no pipeline, so use the default pytorch lightning way of doing all_reduce
# self.allreduce_gradients() # @sangkug we think this is causing memory to blow up (hurts perf)
pass

if (batch_idx + 1) % self.trainer.accumulate_grad_batches == 0:
# Reduced loss for logging.
average_reduced_loss = sum(self._reduced_loss_buffer) / len(self._reduced_loss_buffer)
Expand Down
Loading

0 comments on commit c9f16fd

Please sign in to comment.