Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuntimeError: Error(s) in loading state_dict for PeftModelForCausalLM: size mismatch when I load using adapter path but not checkpoint #2071

Closed
4 tasks
manitadayon opened this issue Sep 17, 2024 · 9 comments

Comments

@manitadayon
Copy link

System Info

MultiGPU setting with 2 A100 GPU.
GPU memory is 80GB each.
Training is done using accelerate and Deep Speed.

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder
  • My own task or dataset (give details below)

Reproduction

I have pretrained my LLM model using Deep speed and accelerate on 2 GPUs. I have the following on my LoraConfig:

r = 16,
alpha = 32
target_module = ['v_proj','q_proj','k_proj','o_proj']

when I use the following code to load my model with Peft config I get error:

model = AutoModelForCausalLM.from_pretraind(base_path, config =config)
model.resize_token_embedding(len(tokenizer))
model = PeftModel.from_pretrained(model, adapter_path)

RuntimeError: Error(s) in loading state_dict for PeftModelForCausalLM: size mismatch

However when I do this everything works but why???

model = PeftModel.from_pretrained(model, 'path to checkpoint)

To make it more strange the above code works under the following setting with no error.

r = 8,
alpha = 32
target_module = ['v_proj','q_proj']

Can anyone explains what is going on? literally everything is the same, only I added two more layers to target_module.

Expected behavior

Expected behavior is that this command runs with no problem since neither tokenizer nor the base model is changed since the pretraining phase:
I expect this code works with no problem:

model = AutoModelForCausalLM.from_pretraind(base_path., config =config)
model.resize_token_embedding(len(tokenizer))
model= PeftModel.from_pretrained(model, adapter_path)
@BenjaminBossan
Copy link
Member

when I use the following code to load my model with Peft config I get error:

Could you please post the full error message? Also, the more code you can share the likelier I can help. Are you loading on the same setup with multi GPU?

However when I do this everything works but why???

I don't quite see what the difference is supposed to be. In one case you pass a variable that is a string and in the other string directly, is that it?

To make it more strange the above code works under the following setting with no error.

Yes, that is really strange, you mean that if you train a model with these settings, the same code works for loading that model's adapter?

@manitadayon
Copy link
Author

I would share the complete error message later on. But these are not the same scenario:
Under the case that I receive error:
The adapter path: path to saved model and tokenizer which are created as follows:

Tokenizer.save_pretrained(adapter path) Model.save_pretrained(adapter_path)

So in this case after the training is over the model and tokenizer are saved.

Under the second case where there is no error:

I am loading the PEFT model not from adapter path, it is from checkpoint location, which is saved during training by setting save_step and output_dir.

I am seeing similar issues where people complain it is caused by deep speed.
When my Lora config was q_proj and v_proj as target module I did not have this problem but after adding k_proj and o_proj, this problem started.

During the deep speed training I get:

Stage3_gather_16bit_weights_on_model_save = false. Saving the full checkpoint instead, use zero_to_fp32.py to recover weights.
Which is neither an error nor a warning.
Please let me know if this makes sense.

@manitadayon
Copy link
Author

Thanks for your response.
Here is the error message

RuntimeError: Error in loading state dict for PeftModelForCausalLM
  size mismatch for base_model.model.model.layers.0.self_attn.q_proj.lora_A.weight: copying a param with shape torch.Size([12, 8192]) from checkpoint, the shape in current model is torch.Size([8192,12]).
        size mismatch for base_model.model.model.layers.0.self_attn.q_proj.lora_B.weight: copying a param with shape torch.Size([12, 8192]) from checkpoint, the shape in current model is torch.Size([8192,12]).
        size mismatch for base_model.model.model.layers.0.self_attn.k_proj.lora_A.weight: copying a param with shape torch.Size([12, 8192]) from checkpoint, the shape in current model is torch.Size([8192,12]).
        size mismatch for base_model.model.model.layers.0.self_attn.k_proj.lora_B.weight: copying a param with shape torch.Size([12, 8192]) from checkpoint, the shape in current model is torch.Size([8192,12]).
        size mismatch for base_model.model.model.layers.0.self_attn.v_proj.lora_A.weight: copying a param with shape torch.Size([12, 8192]) from checkpoint, the shape in current model is torch.Size([8192,12]).
        size mismatch for base_model.model.model.layers.0.self_attn.v_proj.lora_B.weight: copying a param with shape torch.Size([12, 8192]) from checkpoint, the shape in current model is torch.Size([8192,12]).
        size mismatch for base_model.model.model.layers.0.self_attn.o_proj.lora_A.weight: copying a param with shape torch.Size([12, 8192]) from checkpoint, the shape in current model is torch.Size([8192,12]).
        size mismatch for base_model.model.model.layers.0.self_attn.o_proj.lora_B.weight: copying a param with shape torch.Size([12, 8192]) from checkpoint, the shape in current model is torch.Size([8192,12]).
        size mismatch for base_model.model.model.layers.1.self_attn.q_proj.lora_A.weight: copying a param with shape torch.Size([12, 8192]) from checkpoint, the shape in current model is torch.Size([8192,12]).
        size mismatch for base_model.model.model.layers.1.self_attn.q_proj.lora_B.weight: copying a param with shape torch.Size([12, 8192]) from checkpoint, the shape in current model is torch.Size([8192,12]).
        size mismatch for base_model.model.model.layers.1.self_attn.k_proj.lora_A.weight: copying a param with shape torch.Size([12, 8192]) from checkpoint, the shape in current model is torch.Size([8192,12]).
        size mismatch for base_model.model.model.layers.1.self_attn.k_proj.lora_B.weight: copying a param with shape torch.Size([12, 8192]) from checkpoint, the shape in current model is torch.Size([8192,12]).
        size mismatch for base_model.model.model.layers.1.self_attn.v_proj.lora_A.weight: copying a param with shape torch.Size([12, 8192]) from checkpoint, the shape in current model is torch.Size([8192,12]).
        size mismatch for base_model.model.model.layers.1.self_attn.v_proj.lora_B.weight: copying a param with shape torch.Size([12, 8192]) from checkpoint, the shape in current model is torch.Size([8192,12]).
        size mismatch for base_model.model.model.layers.1.self_attn.o_proj.lora_A.weight: copying a param with shape torch.Size([12, 8192]) from checkpoint, the shape in current model is torch.Size([8192,12]).

@BenjaminBossan
Copy link
Member

Okay, just to ensure that I understand correctly:

When you load from the checkpoint that the trainer automatically created for you, loading works, but not if you try to load the file that you saved using model.save_pretrained, is that right?

Could you please compare the sizes of the files created by both methods? It could be the case that the automatic checkpoint saves the full model, not only the PEFT adapter, in which case the checkpoint would be much larger.

One thing you could try to see if it fixes the PEFT checkpoint for you is to gather the parameters before calling save_pretrained. So the code would be something like this:

with deepspeed.zero.GatheredParameters(trainer.model.parameters()):
    trainer.model.save_pretrained(<path>)

Here is the error message

RuntimeError: Error in loading state dict for PeftModelForCausalLM
  size mismatch for base_model.model.model.layers.0.self_attn.q_proj.lora_A.weight: copying a param with shape torch.Size([12, 8192]) from checkpoint, the shape in current model is torch.Size([8192,12]).
        size mismatch for base_model.model.model.layers.0.self_attn.q_proj.lora_B.weight: copying a param with shape torch.Size([12, 8192]) from checkpoint, the shape in current model is torch.Size([8192,12]).
        ...

This is really strange, since it does not appear that parameters are missing, but instead that they have the wrong shape (transposed). I have never seen this.

@manitadayon
Copy link
Author

When you load from the checkpoint that the trainer automatically created for you, loading works, but not if you try to load the file that you saved using model.save_pretrained, is that right?

Yes, that is correct.

Could you please compare the sizes of the files created by both methods? It could be the case that the automatic checkpoint saves the full model, not only the PEFT adapter, in which case the checkpoint would be much larger.

Indeed this is a case and checkpoint model is larger. I wonder why this happens under the case that target module includes o_proj and k_proj?

Apparently lots of people are experiencing the same problem, so I think there should be a permanent fix for this:

#272
#211
#293

@BenjaminBossan
Copy link
Member

BenjaminBossan commented Sep 18, 2024

I would like to investigate this further but with the given information, I can't. So far, when I tried to reproduce this, the checkpoint that was created by model.save_pretrained was working as expected, so I must be missing something. If possible, please share your full DeepSpeed config, accelerate config, and training script. Could you please also run ls -la on the directories that contain the individual checkpoints and the final saved model?

Did you try what I suggested above using deepspeed.zero.GatheredParameters?

Update

I managed to create a situation that resulted in the checkpoint from model.save_pretrained being essentially empty (all tensors have shape 0) by increasing the rank r sufficiently, which probably triggered a different sharding behavior.

Using this context manager allowed me to save an intact checkpoint:

import deepspeed

with deepspeed.zero.GatheredParameters((p for n, p in trainer.model.named_parameters() if "lora" in n)):
    if trainer.accelerator.is_main_process:
        model.save_pretrained(<checkpoint_path>)

@manitadayon
Copy link
Author

I have not tried that, but I will try this next time I am doing training and update you here. I see you are changing the command a little bit: initially it was:

with deepspeed.zero.GatheredParameters(trainer.model.parameters()): trainer.model.save_pretrained(<path>)

Now you are saving only the parameters that have 'lora' in their name.

import deepspeed

with deepspeed.zero.GatheredParameters((p for n, p in trainer.model.named_parameters() if "lora" in n)): if trainer.accelerator.is_main_process: model.save_pretrained(<checkpoint_path>)

It is interesting since under the case where the model was loaded correctly without using the checkpoint the r was 8 and under r= 16, it was the issue, which is strange since 16 also is not that big to cause much issues.

Can you elaborate why this happens? Is it because of deepspeed?

@BenjaminBossan
Copy link
Member

Now you are saving only the parameters that have 'lora' in their name.

Yes, when calling save_pretrained on a PEFT model, we only save the adapter weights, not the base weights, since those are frozen anyway and can be recovered from the base model. This way, saving is faster and the checkpoint is small. That's why I added this optimization to only gather the LoRA parameters, as only those are saved.

It is interesting since under the case where the model was loaded correctly without using the checkpoint the r was 8 and under r= 16, it was the issue, which is strange since 16 also is not that big to cause much issues.

Can you elaborate why this happens? Is it because of deepspeed?

I don't know enough about what DeepSpeed does under the hood to determine how models are sharded. But from what we can observe, this appears to be enough to make a difference. Same with what you said earlier:

When my Lora config was q_proj and v_proj as target module I did not have this problem but after adding k_proj and o_proj, this problem started.

This also should not make a big difference in the grand scheme of things but apparently does.

I have not tried that, but I will try this next time I am doing training and update you here.

Thanks, hopefully it helps.

@manitadayon
Copy link
Author

Thanks @BenjaminBossan, using this:

import deepspeed

with deepspeed.zero.GatheredParameters((p for n, p in trainer.model.named_parameters() if "lora" in n)):
    if trainer.accelerator.is_main_process:
        model.save_pretrained(<checkpoint_path>)

The saved model works just fine as before.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants