Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CUDA OOM during ckpt saving for Llama2-70b #142

Closed
lwmlyy opened this issue Aug 24, 2023 · 17 comments
Closed

CUDA OOM during ckpt saving for Llama2-70b #142

lwmlyy opened this issue Aug 24, 2023 · 17 comments
Assignees
Labels

Comments

@lwmlyy
Copy link

lwmlyy commented Aug 24, 2023

Hi, I am using 8*a100-80gb to lora-finetune Llama2-70b, the training and evaluation during epoch-1 went well, but went OOM when saving the peft model. The nightly version of pytorch is used.

The following command is used:
torchrun --nnodes 1 --nproc_per_node 8 llama_finetuning.py --enable_fsdp --low_cpu_fsdp --model_name ../Llama-2-70b-chat-hf --micro_batch_size 1 --batch_size_training 1 --dist_checkpoint_root_folder ../Llama-2-70b-chat-hf/ --dist_checkpoint_folder fine-tuned --use_peft --peft_method lora --lr 3e-4 --epoch 2 --pure_bf16 --alpaca_dataset --output_dir llama-70b-lorawallsft

"we are about to save the PEFT modules", it went CUDA OOM after this log is printed.

@clechristophe
Copy link

I have the same problem. I think everything is brought back to the first rank before saving and it causes CUDA OOM. I was able to save one .distcp file per GPU but I'm not sure how to get just the LoRA adapter file from there...

@gongy
Copy link

gongy commented Sep 13, 2023

I'm also running into this (albeit with 4 A100 80GB). Wondering if there is a way we can work around it - happy to make a contribution if the direction is clear.

Seems like a shame to have this bug during save_pretrained when the rest of the training and evaluation works well.

  File "/opt/conda/lib/python3.9/site-packages/llama_recipes/utils/train_utils.py", line 142, in train
    model.save_pretrained(train_config.output_dir)
  File "/opt/conda/lib/python3.9/site-packages/peft/peft_model.py", line 167, in save_pretrained
    output_state_dict = get_peft_model_state_dict(
  File "/opt/conda/lib/python3.9/site-packages/peft/utils/save_and_load.py", line 41, in get_peft_model_state_dict
    state_dict = model.state_dict()
  File "/opt/conda/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1898, in state_dict
    module.state_dict(destination=destination, prefix=prefix + name + '.', keep_vars=keep_vars)
  File "/opt/conda/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1898, in state_dict
    module.state_dict(destination=destination, prefix=prefix + name + '.', keep_vars=keep_vars)
  File "/opt/conda/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1898, in state_dict
    module.state_dict(destination=destination, prefix=prefix + name + '.', keep_vars=keep_vars)
  [Previous line repeated 8 more times]
  File "/opt/conda/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1894, in state_dict
    hook(self, prefix, keep_vars)
  File "/opt/conda/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/opt/conda/lib/python3.9/site-packages/torch/distributed/fsdp/_state_dict_utils.py", line 774, in _pre_state_dict_hook
    _pre_state_dict_hook_fn[fsdp_state._state_dict_type](
  File "/opt/conda/lib/python3.9/site-packages/torch/distributed/fsdp/_state_dict_utils.py", line 293, in _full_pre_state_dict_hook
    _common_unshard_pre_state_dict_hook(
  File "/opt/conda/lib/python3.9/site-packages/torch/distributed/fsdp/_state_dict_utils.py", line 157, in _common_unshard_pre_state_dict_hook
    _enter_unshard_params_ctx(
  File "/opt/conda/lib/python3.9/site-packages/torch/distributed/fsdp/_state_dict_utils.py", line 118, in _enter_unshard_params_ctx
    fsdp_state._unshard_params_ctx[module].__enter__()
  File "/opt/conda/lib/python3.9/contextlib.py", line 119, in __enter__
    return next(self.gen)
  File "/opt/conda/lib/python3.9/site-packages/torch/distributed/fsdp/_unshard_param_utils.py", line 196, in _unshard_fsdp_state_params
    _unshard(state, handle, computation_stream, computation_stream)
  File "/opt/conda/lib/python3.9/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 329, in _unshard
    handle.unshard()
  File "/opt/conda/lib/python3.9/site-packages/torch/distributed/fsdp/flat_param.py", line 1250, in unshard
    unsharded_flat_param = self._alloc_padded_unsharded_flat_param()
  File "/opt/conda/lib/python3.9/site-packages/torch/distributed/fsdp/flat_param.py", line 1276, in _alloc_padded_unsharded_flat_param
    _alloc_storage(unsharded_flat_param, flat_param._padded_unsharded_size)  # type: ignore[attr-defined]
  File "/opt/conda/lib/python3.9/site-packages/torch/distributed/utils.py", line 166, in _alloc_storage
    tensor._typed_storage()._resize_(size.numel())
  File "/opt/conda/lib/python3.9/site-packages/torch/storage.py", line 921, in _resize_
    self._untyped_storage.resize_(size * self._element_size())
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 2.00 MiB. GPU 1 has a total capacty of 79.18 GiB of which 2.31 MiB is free. Process 138368 has 79.18 GiB memory in use. Of the allocated memory 76.72 GiB is allocated by PyTorch, and 245.31 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF```

@jshin49
Copy link

jshin49 commented Sep 15, 2023

Any updates on this issue? I'm encountering the same problem with 8 x A100 (80GB) for Lora 70B

@matthieumeeus
Copy link

I'm facing the same issue. Any workaround?

@anthonyrathe
Copy link

anthonyrathe commented Sep 21, 2023

Also facing the same issue here. I noticed another thread facing a similar issue using LoRa fine-tuning (although with another model): philschmid/deep-learning-pytorch-huggingface#16
Seemed to be caused by peft versions after 0.2.0. Could this be related?

@gongy
Copy link

gongy commented Oct 3, 2023

I found a workaround which involves allowing CPU offloading during the phase of saving the state dict.

I tested that end-to-end 70B training works with checkpointing on this repo.

I will try to find the time to merge in the changes soon, but you can find them here https://github.com/modal-labs/llama-recipes

@JazzTheSaver
Copy link

I'm encountering the same problem with 6 x H800 (80GB) for Lora 70B

@yuanzhedong
Copy link

I'm encountering the same problem with 8 x H100 (80GB) for Lora 70B

@HamidShojanazeri
Copy link
Contributor

HamidShojanazeri commented Feb 27, 2024

sorry for the late reply @yuanzhedong and everyone, is this happening only alpaca? I believe some of the issue from Sep should be resolved. We have the CPU offload now if that can be helpful.

I don't have H100s off of my hand now, but looking to get access and repro the issue.

@HamidShojanazeri
Copy link
Contributor

@yuanzhedong It seems like an issue with transformers, could repro this issue with transformers version of 4.38.1 which is from pip install, installing from src could resolve the issue transformers 4.39.0.dev0

git clone https://github.com/huggingface/transformers.git
cd transformers/

pip install -e .

can you pls give it a try.

@Mugheeera
Copy link

Mugheeera commented Mar 20, 2024

@HamidShojanazeri Thank you very much for your reply. I tried with a new install of transformers 4.39.0.dev0 from the src but still encountered the same issue. I have also attempted to decrease the batch size but to no avail. I am using a custom dataset modelled after the structure of Alpaca dataset. I will now make some modifications to the dataset and proceed with a clean installation llama-recipes and see how it goes.

@HamidShojanazeri
Copy link
Contributor

sure, it shouldn't have anything to do with your batch size, the version conflict was the only way I could repro and by pass it. Pls let know how it went.

@Mugheeera
Copy link

Mugheeera commented Mar 22, 2024

I tried with the following package versions and a fresh install of llama-recipes, but still getting the same error.

accelerate 0.28.0, appdirs 1.4.4, bitsandbytes 0.43.0, black 24.3.0, datasets 2.18.0, fire 0.6.0, gradio 4.22.0, gradio_client 0.13.0 loralib 0.1.2, matplotlib 3.8.3, matplotlib-inline 0.1.6 ,optimum 1.17.1, peft 0.9.0 ,py7zr 0.21.0, scipy 1.12.0, sentencepiece 0.2.0, torch 2.3.0+cu118, transformers 4.39.0.dev0 /home/mugheera/transformers

Below is the command that I used to start the finetuning job

torchrun --nnodes 1 --nproc_per_node 4 finetuning.py --enable_fsdp --use_peft --peft_method lora --model_name /home/mugheera/llama-hf/Llama-2-70b-chat-hf --pure_bf16 --output_dir /home/mugheera/PEFT/model --use_fast_kernels

Training Epoch: 1/3, step 46/47 completed (loss: 0.15002813935279846): 100%|████████████████████████████████████████████████| 47/47 [34:52<00:00, 44.52s/it]
Training Epoch: 1/3, step 46/47 completed (loss: 0.17413900792598724): 100%|████████████████████████████████████████████████| 47/47 [35:40<00:00, 45.54s/it]
Training Epoch: 1/3, step 46/47 completed (loss: 0.1314418613910675): 100%|█████████████████████████████████████████████████| 47/47 [34:16<00:00, 43.75s/it]
Training Epoch: 1/3, step 46/47 completed (loss: 0.18022574484348297): 100%|████████████████████████████████████████████████| 47/47 [35:46<00:00, 45.68s/it]
Max CUDA memory allocated was 31 GB
Max CUDA memory reserved was 39 GB
Peak active CUDA memory was 31 GB
CUDA Malloc retries : 0
CPU Total Peak Memory consumed during the train (max): 153 GB
evaluating Epoch: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 17/17 [01:08<00:00, 4.00s/it]
evaluating Epoch: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 17/17 [01:08<00:00, 4.02s/it]
evaluating Epoch: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 17/17 [01:08<00:00, 4.02s/it]
evaluating Epoch: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 17/17 [01:08<00:00, 4.02s/it]
eval_ppl=tensor(1.1495, device='cuda:0') eval_epoch_loss=tensor(0.1393, device='cuda:0')
we are about to save the PEFT modules
[rank2]: Traceback (most recent call last):
[rank2]: File "/home/mugheera/new/llama-recipes/recipes/finetuning/finetuning.py", line 8, in
[rank2]: fire.Fire(main)
[rank2]: File "/home/mugheera/miniconda3/envs/llama-recipes/lib/python3.12/site-packages/fire/core.py", line 143, in Fire
[rank2]: component_trace = _Fire(component, args, parsed_flag_args, context, name)
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/home/mugheera/miniconda3/envs/llama-recipes/lib/python3.12/site-packages/fire/core.py", line 477, in _Fire
[rank2]: component, remaining_args = _CallAndUpdateTrace(
[rank2]: ^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/home/mugheera/miniconda3/envs/llama-recipes/lib/python3.12/site-packages/fire/core.py", line 693, in _CallAndUpdateTrace
[rank2]: component = fn(*varargs, **kwargs)
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/home/mugheera/new/llama-recipes/src/llama_recipes/finetuning.py", line 265, in main
[rank2]: results = train(
[rank2]: ^^^^^^
[rank2]: File "/home/mugheera/new/llama-recipes/src/llama_recipes/utils/train_utils.py", line 187, in train
[rank2]: model.save_pretrained(train_config.output_dir)
[rank2]: File "/home/mugheera/miniconda3/envs/llama-recipes/lib/python3.12/site-packages/peft/peft_model.py", line 215, in save_pretrained
[rank2]: output_state_dict = get_peft_model_state_dict(
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/home/mugheera/miniconda3/envs/llama-recipes/lib/python3.12/site-packages/peft/utils/save_and_load.py", line 71, in get_peft_model_state_dict
[rank2]: state_dict = model.state_dict()
[rank2]: ^^^^^^^^^^^^^^^^^^
[rank2]: File "/home/mugheera/miniconda3/envs/llama-recipes/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1911, in state_dict
[rank2]: module.state_dict(destination=destination, prefix=prefix + name + '.', keep_vars=keep_vars)
[rank2]: File "/home/mugheera/miniconda3/envs/llama-recipes/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1911, in state_dict
[rank2]: module.state_dict(destination=destination, prefix=prefix + name + '.', keep_vars=keep_vars)
[rank2]: File "/home/mugheera/miniconda3/envs/llama-recipes/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1911, in state_dict
[rank2]: module.state_dict(destination=destination, prefix=prefix + name + '.', keep_vars=keep_vars)
[rank2]: [Previous line repeated 2 more times]
[rank2]: File "/home/mugheera/miniconda3/envs/llama-recipes/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1907, in state_dict
[rank2]: hook(self, prefix, keep_vars)
[rank2]: File "/home/mugheera/miniconda3/envs/llama-recipes/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
[rank2]: return func(*args, **kwargs)
[rank2]: ^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/home/mugheera/miniconda3/envs/llama-recipes/lib/python3.12/site-packages/torch/distributed/fsdp/_state_dict_utils.py", line 786, in _pre_state_dict_hook
[rank2]: _pre_state_dict_hook_fn[fsdp_state._state_dict_type](
[rank2]: File "/home/mugheera/miniconda3/envs/llama-recipes/lib/python3.12/site-packages/torch/distributed/fsdp/_state_dict_utils.py", line 307, in _full_pre_state_dict_hook
[rank2]: _common_unshard_pre_state_dict_hook(
[rank2]: File "/home/mugheera/miniconda3/envs/llama-recipes/lib/python3.12/site-packages/torch/distributed/fsdp/_state_dict_utils.py", line 174, in _common_unshard_pre_state_dict_hook
[rank2]: _enter_unshard_params_ctx(
[rank2]: File "/home/mugheera/miniconda3/envs/llama-recipes/lib/python3.12/site-packages/torch/distributed/fsdp/_state_dict_utils.py", line 138, in _enter_unshard_params_ctx
[rank2]: fsdp_state._unshard_params_ctx[module].enter()
[rank2]: File "/home/mugheera/miniconda3/envs/llama-recipes/lib/python3.12/contextlib.py", line 137, in enter
[rank2]: return next(self.gen)
[rank2]: ^^^^^^^^^^^^^^
[rank2]: File "/home/mugheera/miniconda3/envs/llama-recipes/lib/python3.12/site-packages/torch/distributed/fsdp/_unshard_param_utils.py", line 196, in _unshard_fsdp_state_params
[rank2]: _unshard(state, handle, computation_stream, computation_stream)
[rank2]: File "/home/mugheera/miniconda3/envs/llama-recipes/lib/python3.12/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 299, in _unshard
[rank2]: handle.unshard()
[rank2]: File "/home/mugheera/miniconda3/envs/llama-recipes/lib/python3.12/site-packages/torch/distributed/fsdp/_flat_param.py", line 1307, in unshard
[rank2]: unsharded_flat_param = self._alloc_padded_unsharded_flat_param()
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/home/mugheera/miniconda3/envs/llama-recipes/lib/python3.12/site-packages/torch/distributed/fsdp/_flat_param.py", line 1334, in _alloc_padded_unsharded_flat_param
[rank2]: _alloc_storage(unsharded_flat_param, flat_param._padded_unsharded_size) # type: ignore[attr-defined]
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/home/mugheera/miniconda3/envs/llama-recipes/lib/python3.12/site-packages/torch/distributed/utils.py", line 168, in _alloc_storage
[rank2]: tensor._typed_storage().resize(size.numel())
[rank2]: File "/home/mugheera/miniconda3/envs/llama-recipes/lib/python3.12/site-packages/torch/storage.py", line 972, in resize
[rank2]: self.untyped_storage.resize(size * self._element_size())
[rank2]: torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 1.60 GiB. GPU has a total capacity of 79.15 GiB of which 989.31 MiB is free. Process 4076404 has 4.03 GiB memory in use. Including non-PyTorch memory, this process has 74.00 GiB memory in use. Of the allocated memory 70.11 GiB is allocated by PyTorch, and 1.53 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management

And the following error in another iteration

evaluating Epoch: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 11/11 [00:41<00:00, 3.75s/it]
eval_ppl=tensor(2.6441, device='cuda:0') eval_epoch_loss=tensor(0.9723, device='cuda:0')
we are about to save the PEFT modulested (loss: 1.2340718507766724): 19%|█████████▌ | 9/48 [06:16<28:13, 43.41s/it]
/home/mugheera/miniconda3/envs/llama-recipes/lib/python3.12/site-packages/torch/distributed/fsdp/_state_dict_utils.py:348: UserWarning: Failed to clone() tensor with name base_model.model.model.layers.23.mlp.up_proj.weight on rank 2. This may mean that this state_dict entry could point to invalid memory regions after returning from state_dict() call if this parameter is managed by FSDP. Please check clone implementation of base_model.model.model.layers.23.mlp.up_proj.weight. Error: CUDA out of memory. Tried to allocate 448.00 MiB. GPU has a total capacity of 79.15 GiB of which 313.31 MiB is free. Process 4076404 has 4.03 GiB memory in use. Including non-PyTorch memory, this process has 74.82 GiB memory in use. Of the allocated memory 72.29 GiB is allocated by PyTorch, and 175.13 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management

@HamidShojanazeri
Copy link
Contributor

@Mugheeera I wonder if you are installing llama-recipes from src?

This seems to be working on my end, running H100, but regardless it should work on both A100 and H100. Logs

I could use the latest transfromers as well so not from src anymore,

accelerate  0.28.0, bitsandbytes   0.43.0, transformers   4.38.2, torch   2.3.0+cu121

@HamidShojanazeri
Copy link
Contributor

sounds to be stale issue, will close it for now but feel free to re-open is see same issues.

@yueyugua
Copy link

Hi @HamidShojanazeri , I also got the same OOM error when using 8xH100s.
I'm using the transformer '4.41.0.dev0' and build the llama-recipes from the source.
I observed this error will happen when training with 70B model (in my case, the llama3-70b) and it does not exist for llama 3 8B.
I'm wondering the reason to use use model.save_pretrained when using both PEFT and FSDP instead of using save_model_and_optimizer_sharded as did when not using PEFT? (

if train_config.save_model and eval_epoch_loss < best_val_loss:
if train_config.enable_fsdp:
dist.barrier()
if train_config.use_peft:
if train_config.enable_fsdp:
if rank==0:
print(f"we are about to save the PEFT modules")
else:
print(f"we are about to save the PEFT modules")
model.save_pretrained(train_config.output_dir)
if train_config.enable_fsdp:
if rank==0:
print(f"PEFT modules are saved in {train_config.output_dir} directory")
else:
print(f"PEFT modules are saved in {train_config.output_dir} directory")
else:
if not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT:
save_model_checkpoint(
model, optimizer, rank, train_config, epoch=epoch
)
elif not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.SHARDED_STATE_DICT:
print(" Saving the FSDP model checkpoints using SHARDED_STATE_DICT")
print("=====================================================")
save_model_and_optimizer_sharded(model, rank, train_config)
if train_config.save_optimizer:
save_model_and_optimizer_sharded(model, rank, train_config, optim=optimizer)
print(" Saving the FSDP model checkpoints and optimizer using SHARDED_STATE_DICT")
print("=====================================================")
)
Would it possible the reason of OOM is because the model needs to gather weights across ranks before model.save_pretrained?

@Lidongw
Copy link

Lidongw commented Aug 7, 2024

Hi @HamidShojanazeri I think you were testing with 7B model, most of the people here seeing the issue with 70B model. I also had the same issue with 70B model with alpaca dataset. I have installed llama-recipes from src, but still not working

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests