LoRA微调2: 模型微调原理

1. LoRA模型微调原理

LoRA(Low-Rank Adaptation of LLMs),即LLMs的低秩适应,是参数高效微调最常用的方法。
LoRA微调是用更少的训练参数来接近LLM全参数微调,最后得到一部分训练过的增量参数,从而减少现存占用、降低反向传播计算量的高效微调方式。
LoRA微调原理图如下所示:


lora原理.png

训练时冻结蓝色部分的预训练模型参数,训练过程反向传播只变更右侧黄色部分的低秩矩阵A、B的参数。LoRA只能影响模型的线性层参数和卷积层参数。如图所示,模型LoRA训练过程中,推理参数是增加的,但是反向传播需要更新的参数变少了。
LoRA训练后的A、B矩阵参数如何应用到后续模型的推理任务中,将对应线性层旁路的A、B矩阵相乘后加到预训练模型对应层的参数上,作为新模型推理使用。

2. LoRA参数微调过程分析

2.1 模型结构变化

以上篇博客https://www.jianshu.com/p/2839cd5bda14分析
模型LoRA变换前的模型结构打印如下所示:

Qwen2ForCausalLM(
  (model): Qwen2Model(
    (embed_tokens): Embedding(151936, 1536)
    (layers): ModuleList(
      (0-27): 28 x Qwen2DecoderLayer(
        (self_attn): Qwen2SdpaAttention(
          (q_proj): Linear(in_features=1536, out_features=1536, bias=True)
          (k_proj): Linear(in_features=1536, out_features=256, bias=True)
          (v_proj): Linear(in_features=1536, out_features=256, bias=True)
          (o_proj): Linear(in_features=1536, out_features=1536, bias=False)
          (rotary_emb): Qwen2RotaryEmbedding()
        )
        (mlp): Qwen2MLP(
          (gate_proj): Linear(in_features=1536, out_features=8960, bias=False)
          (up_proj): Linear(in_features=1536, out_features=8960, bias=False)
          (down_proj): Linear(in_features=8960, out_features=1536, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): Qwen2RMSNorm((1536,), eps=1e-06)
        (post_attention_layernorm): Qwen2RMSNorm((1536,), eps=1e-06)
      )
    )
    (norm): Qwen2RMSNorm((1536,), eps=1e-06)
    (rotary_emb): Qwen2RotaryEmbedding()
  )
  (lm_head): Linear(in_features=1536, out_features=151936, bias=False)
)

模型LoRA变换后模型结构打印如下所示:

PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): Qwen2ForCausalLM(
      (model): Qwen2Model(
        (embed_tokens): Embedding(151936, 1536)
        (layers): ModuleList(
          (0-27): 28 x Qwen2DecoderLayer(
            (self_attn): Qwen2SdpaAttention(
              (q_proj): lora.Linear(
                (base_layer): Linear(in_features=1536, out_features=1536, bias=True)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=1536, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=1536, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
              )
              (k_proj): lora.Linear(
                (base_layer): Linear(in_features=1536, out_features=256, bias=True)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=1536, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=256, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
              )
              (v_proj): lora.Linear(
                (base_layer): Linear(in_features=1536, out_features=256, bias=True)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=1536, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=256, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
              )
              (o_proj): lora.Linear(
                (base_layer): Linear(in_features=1536, out_features=1536, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=1536, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=1536, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
              )
              (rotary_emb): Qwen2RotaryEmbedding()
            )
            (mlp): Qwen2MLP(
              (gate_proj): lora.Linear(
                (base_layer): Linear(in_features=1536, out_features=8960, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=1536, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=8960, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
              )
              (up_proj): lora.Linear(
                (base_layer): Linear(in_features=1536, out_features=8960, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=1536, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=8960, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
              )
              (down_proj): lora.Linear(
                (base_layer): Linear(in_features=8960, out_features=1536, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=8960, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=1536, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
              )
              (act_fn): SiLU()
            )
            (input_layernorm): Qwen2RMSNorm((1536,), eps=1e-06)
            (post_attention_layernorm): Qwen2RMSNorm((1536,), eps=1e-06)
          )
        )
        (norm): Qwen2RMSNorm((1536,), eps=1e-06)
        (rotary_emb): Qwen2RotaryEmbedding()
      )
      (lm_head): Linear(in_features=1536, out_features=151936, bias=False)
    )
  )
)

我们可以看到,base_layer为模型中原始层,lora_开头的层为新添加的层。

2.2 模型变动相关代码实现

swift库中switft/llm/tuner.py中prepare_model函数通过Swift.prepare_model方法最后调用了peft
peft库的PeftModelForCausalLM类构造函数中调用了peft.tunners.lora.model.LoraModel类的构造函数完成了模型结构变动。
这块代码比较复杂,没有详细跟踪,截取关键三段代码
如下是遍历模型,查找要替换的线性层

            for key in key_list:
                ......

                 self._create_and_replace(peft_config, adapter_name, target, target_name, parent, current_key=key)

如下是替换预训练模型原来的层。这里冻结了非线性层的梯度(LoRA适配层),反向传播不会更新这部分的参数。

......
             else:
                 new_module = self._create_new_module(lora_config, adapter_name, target, **kwargs)
                 if adapter_name not in self.active_adapters:
                     # adding an additional adapter: it is not automatically trainable
                     new_module.requires_grad_(False)
                 self._replace_module(parent, target_name, new_module, target)
......

如下代码,LoraModel类的set_adapter和_mark_only_adapters_as_trainable两个实例方法会调整冻结参数

             # layers will be activated, which we don't want.
            self.set_adapter(self.active_adapters)
            self._mark_only_adapters_as_trainable(model)

3. 参数微调后模型参数使用

swift库中swift/llm/infer.py的merge_lora函数实现该功能,截取代码如下所示:
加载LoRA模型的checkpoint,然后融合参数,将融合后的参数保存到merged_lora_path中
推理时加载原始模型,然后加载merged_lora_path中的参数进行推理。

         .......
         ckpt_dir, ckpt_name = os.path.split(args.ckpt_dir)
         merged_lora_path = os.path.join(ckpt_dir, f'{ckpt_name}-merged')
         logger.info(f'merged_lora_path: `{merged_lora_path}`')
         if os.path.exists(merged_lora_path) and not replace_if_exists:
             logger.info(f'The weight directory for the merged LoRA already exists in {args.ckpt_dir}, '
                         'skipping the saving process. '
                         'you can pass `replace_if_exists=True` to overwrite it.')
         else:
             if device_map is None:
                 device_map = args.merge_device_map
             logger.info(f'merge_device_map: {device_map}')
             model, template = prepare_model_template(args, device_map=device_map, task='export')
             logger.info('Merge LoRA...')
             Swift.merge_and_unload(model)
             model = model.model
             logger.info('Saving merged weights...')
             save_checkpoint(
                 model,
                 template.tokenizer,
                 model.model_dir,
                 args.ckpt_dir,
                 merged_lora_path,
                 save_safetensors=args.save_safetensors,
                 sft_args_kwargs={'dtype': args.dtype})
             logger.info(f'Successfully merged LoRA and saved in {merged_lora_path}.')
         logger.info("Setting args.sft_type: 'full'")
         logger.info(f'Setting args.ckpt_dir: {merged_lora_path}')
         args.sft_type = 'full'
         args.ckpt_dir = merged_lora_path
         return merged_lora_path
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容