Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

lora_r is double when converting olora to lora. #2075

Closed
2 of 4 tasks
JaheimLee opened this issue Sep 18, 2024 · 4 comments · Fixed by #2077
Closed
2 of 4 tasks

lora_r is double when converting olora to lora. #2075

JaheimLee opened this issue Sep 18, 2024 · 4 comments · Fixed by #2077

Comments

@JaheimLee
Copy link

JaheimLee commented Sep 18, 2024

System Info

  • transformers version: 4.44.2
  • Platform: Linux-5.13.0-30-generic-x86_64-with-glibc2.31
  • Python version: 3.12.4
  • Huggingface_hub version: 0.24.5
  • Safetensors version: 0.4.3
  • Accelerate version: 0.34.0
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.4.0+cu121 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?:
  • Using GPU in script?:
  • GPU type: NVIDIA GeForce RTX 3090
  • Peft: 0.12.0

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder
  • My own task or dataset (give details below)

Reproduction

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
from transformers import AutoModel
from peft import get_peft_model, LoraConfig

base_model = AutoModel.from_pretrained("facebook/opt-350m")
olora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    target_modules='all-linear',
    init_lora_weights='olora',
)
olora_model = get_peft_model(base_model, olora_config)
init_path = './tmp/init'
olora_model.save_pretrained(init_path) # Save the model *before* performing any training

# Train the model
# train(olora_model) # Your training loop

#Save the model after training
olora_model.save_pretrained('./tmp/lora', path_initial_model_for_weight_conversion=init_path) 

Expected behavior

The lora_r of init adapter is 16.

{
  "alpha_pattern": {},
  "auto_mapping": null,
  "base_model_name_or_path": "facebook/opt-350m",
  "bias": "none",
  "fan_in_fan_out": false,
  "inference_mode": false,
  "init_lora_weights": true,
  "layer_replication": null,
  "layers_pattern": null,
  "layers_to_transform": null,
  "loftq_config": {},
  "lora_alpha": 32,
  "lora_dropout": 0.05,
  "megatron_config": null,
  "megatron_core": "megatron.core",
  "modules_to_save": null,
  "peft_type": "LORA",
  "r": 16,
  "rank_pattern": {},
  "revision": null,
  "target_modules": [
    "k_proj",
    "q_proj",
    "fc1",
    "out_proj",
    "project_out",
    "project_in",
    "v_proj",
    "fc2"
  ],
  "task_type": null,
  "use_dora": false,
  "use_rslora": false
}

But the converted one is 32.

{
  "alpha_pattern": {},
  "auto_mapping": {
    "base_model_class": "OPTModel",
    "parent_library": "transformers.models.opt.modeling_opt"
  },
  "base_model_name_or_path": "facebook/opt-350m",
  "bias": "none",
  "fan_in_fan_out": false,
  "inference_mode": true,
  "init_lora_weights": true,
  "layer_replication": null,
  "layers_pattern": null,
  "layers_to_transform": null,
  "loftq_config": {},
  "lora_alpha": 64,
  "lora_dropout": 0.05,
  "megatron_config": null,
  "megatron_core": "megatron.core",
  "modules_to_save": null,
  "peft_type": "LORA",
  "r": 32,
  "rank_pattern": {},
  "revision": null,
  "target_modules": [
    "k_proj",
    "q_proj",
    "fc1",
    "out_proj",
    "project_out",
    "project_in",
    "v_proj",
    "fc2"
  ],
  "task_type": null,
  "use_dora": false,
  "use_rslora": false
}

Model size is also double.
Is it as expected?

@JaheimLee JaheimLee changed the title lora r is double when converting olora to lora. lora_r is double when converting olora to lora. Sep 18, 2024
@BenjaminBossan
Copy link
Member

Yes, this is expected. Methods like OLoRA modify the base weights too. When you want to convert the OLoRA weights to LoRA weights, it needs to be ensured that the original base weights can be used. This is only possible by performing some changes on the OLoRA weights, which involves doubling their size. The reason is not quite straightforward to understand but it's explained here (this is for LoftQ but the same idea applies to OLoRA).

Ping @tokenizer-decode for info.

@JaheimLee
Copy link
Author

Yes, this is expected. Methods like OLoRA modify the base weights too. When you want to convert the OLoRA weights to LoRA weights, it needs to be ensured that the original base weights can be used. This is only possible by performing some changes on the OLoRA weights, which involves doubling their size. The reason is not quite straightforward to understand but it's explained here (this is for LoftQ but the same idea applies to OLoRA).

Ping @tokenizer-decode for info.

Got it, thanks for your reply

@JaheimLee
Copy link
Author

Yes, this is expected. Methods like OLoRA modify the base weights too. When you want to convert the OLoRA weights to LoRA weights, it needs to be ensured that the original base weights can be used. This is only possible by performing some changes on the OLoRA weights, which involves doubling their size. The reason is not quite straightforward to understand but it's explained here (this is for LoftQ but the same idea applies to OLoRA).

Ping @tokenizer-decode for info.

Found a new problem. After converting to lora model, r and alpha of the base model will be 2r and 2alpha. So maybe it's better to reset them to r and alpha after saving is finished.

@JaheimLee JaheimLee reopened this Sep 19, 2024
BenjaminBossan added a commit to BenjaminBossan/peft that referenced this issue Sep 19, 2024
Resolves huggingface#2075

When saving PiSSA or OLoRA with the option to convert to normal LoRA,
the LoRA weight shapes change, which means that some values like r and
alpha need to be adjusted in the saved PEFT config. However, these
modifications should be limited to the saved config, while the loaded
config should stay the same.

This PR implements this change by creating a copy of the config before
modifying it.
@BenjaminBossan
Copy link
Member

Good point @JaheimLee, I created a PR to address that: #2077.

BenjaminBossan added a commit that referenced this issue Sep 20, 2024
Resolves #2075

When saving PiSSA or OLoRA with the option to convert to normal LoRA,
the LoRA weight shapes change, which means that some values like r and
alpha need to be adjusted in the saved PEFT config. However, these
modifications should be limited to the saved config, while the loaded
config should stay the same.

This PR implements this change by creating a copy of the config before
modifying it.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants