1. LoRA模型微调原理
LoRA(Low-Rank Adaptation of LLMs),即LLMs的低秩适应,是参数高效微调最常用的方法。
LoRA微调是用更少的训练参数来接近LLM全参数微调,最后得到一部分训练过的增量参数,从而减少现存占用、降低反向传播计算量的高效微调方式。
LoRA微调原理图如下所示:
训练时冻结蓝色部分的预训练模型参数,训练过程反向传播只变更右侧黄色部分的低秩矩阵A、B的参数。LoRA只能影响模型的线性层参数和卷积层参数。如图所示,模型LoRA训练过程中,推理参数是增加的,但是反向传播需要更新的参数变少了。
LoRA训练后的A、B矩阵参数如何应用到后续模型的推理任务中,将对应线性层旁路的A、B矩阵相乘后加到预训练模型对应层的参数上,作为新模型推理使用。
2. LoRA参数微调过程分析
2.1 模型结构变化
以上篇博客https://www.jianshu.com/p/2839cd5bda14分析
模型LoRA变换前的模型结构打印如下所示:
Qwen2ForCausalLM(
(model): Qwen2Model(
(embed_tokens): Embedding(151936, 1536)
(layers): ModuleList(
(0-27): 28 x Qwen2DecoderLayer(
(self_attn): Qwen2SdpaAttention(
(q_proj): Linear(in_features=1536, out_features=1536, bias=True)
(k_proj): Linear(in_features=1536, out_features=256, bias=True)
(v_proj): Linear(in_features=1536, out_features=256, bias=True)
(o_proj): Linear(in_features=1536, out_features=1536, bias=False)
(rotary_emb): Qwen2RotaryEmbedding()
)
(mlp): Qwen2MLP(
(gate_proj): Linear(in_features=1536, out_features=8960, bias=False)
(up_proj): Linear(in_features=1536, out_features=8960, bias=False)
(down_proj): Linear(in_features=8960, out_features=1536, bias=False)
(act_fn): SiLU()
)
(input_layernorm): Qwen2RMSNorm((1536,), eps=1e-06)
(post_attention_layernorm): Qwen2RMSNorm((1536,), eps=1e-06)
)
)
(norm): Qwen2RMSNorm((1536,), eps=1e-06)
(rotary_emb): Qwen2RotaryEmbedding()
)
(lm_head): Linear(in_features=1536, out_features=151936, bias=False)
)
模型LoRA变换后模型结构打印如下所示:
PeftModelForCausalLM(
(base_model): LoraModel(
(model): Qwen2ForCausalLM(
(model): Qwen2Model(
(embed_tokens): Embedding(151936, 1536)
(layers): ModuleList(
(0-27): 28 x Qwen2DecoderLayer(
(self_attn): Qwen2SdpaAttention(
(q_proj): lora.Linear(
(base_layer): Linear(in_features=1536, out_features=1536, bias=True)
(lora_dropout): ModuleDict(
(default): Dropout(p=0.05, inplace=False)
)
(lora_A): ModuleDict(
(default): Linear(in_features=1536, out_features=8, bias=False)
)
(lora_B): ModuleDict(
(default): Linear(in_features=8, out_features=1536, bias=False)
)
(lora_embedding_A): ParameterDict()
(lora_embedding_B): ParameterDict()
(lora_magnitude_vector): ModuleDict()
)
(k_proj): lora.Linear(
(base_layer): Linear(in_features=1536, out_features=256, bias=True)
(lora_dropout): ModuleDict(
(default): Dropout(p=0.05, inplace=False)
)
(lora_A): ModuleDict(
(default): Linear(in_features=1536, out_features=8, bias=False)
)
(lora_B): ModuleDict(
(default): Linear(in_features=8, out_features=256, bias=False)
)
(lora_embedding_A): ParameterDict()
(lora_embedding_B): ParameterDict()
(lora_magnitude_vector): ModuleDict()
)
(v_proj): lora.Linear(
(base_layer): Linear(in_features=1536, out_features=256, bias=True)
(lora_dropout): ModuleDict(
(default): Dropout(p=0.05, inplace=False)
)
(lora_A): ModuleDict(
(default): Linear(in_features=1536, out_features=8, bias=False)
)
(lora_B): ModuleDict(
(default): Linear(in_features=8, out_features=256, bias=False)
)
(lora_embedding_A): ParameterDict()
(lora_embedding_B): ParameterDict()
(lora_magnitude_vector): ModuleDict()
)
(o_proj): lora.Linear(
(base_layer): Linear(in_features=1536, out_features=1536, bias=False)
(lora_dropout): ModuleDict(
(default): Dropout(p=0.05, inplace=False)
)
(lora_A): ModuleDict(
(default): Linear(in_features=1536, out_features=8, bias=False)
)
(lora_B): ModuleDict(
(default): Linear(in_features=8, out_features=1536, bias=False)
)
(lora_embedding_A): ParameterDict()
(lora_embedding_B): ParameterDict()
(lora_magnitude_vector): ModuleDict()
)
(rotary_emb): Qwen2RotaryEmbedding()
)
(mlp): Qwen2MLP(
(gate_proj): lora.Linear(
(base_layer): Linear(in_features=1536, out_features=8960, bias=False)
(lora_dropout): ModuleDict(
(default): Dropout(p=0.05, inplace=False)
)
(lora_A): ModuleDict(
(default): Linear(in_features=1536, out_features=8, bias=False)
)
(lora_B): ModuleDict(
(default): Linear(in_features=8, out_features=8960, bias=False)
)
(lora_embedding_A): ParameterDict()
(lora_embedding_B): ParameterDict()
(lora_magnitude_vector): ModuleDict()
)
(up_proj): lora.Linear(
(base_layer): Linear(in_features=1536, out_features=8960, bias=False)
(lora_dropout): ModuleDict(
(default): Dropout(p=0.05, inplace=False)
)
(lora_A): ModuleDict(
(default): Linear(in_features=1536, out_features=8, bias=False)
)
(lora_B): ModuleDict(
(default): Linear(in_features=8, out_features=8960, bias=False)
)
(lora_embedding_A): ParameterDict()
(lora_embedding_B): ParameterDict()
(lora_magnitude_vector): ModuleDict()
)
(down_proj): lora.Linear(
(base_layer): Linear(in_features=8960, out_features=1536, bias=False)
(lora_dropout): ModuleDict(
(default): Dropout(p=0.05, inplace=False)
)
(lora_A): ModuleDict(
(default): Linear(in_features=8960, out_features=8, bias=False)
)
(lora_B): ModuleDict(
(default): Linear(in_features=8, out_features=1536, bias=False)
)
(lora_embedding_A): ParameterDict()
(lora_embedding_B): ParameterDict()
(lora_magnitude_vector): ModuleDict()
)
(act_fn): SiLU()
)
(input_layernorm): Qwen2RMSNorm((1536,), eps=1e-06)
(post_attention_layernorm): Qwen2RMSNorm((1536,), eps=1e-06)
)
)
(norm): Qwen2RMSNorm((1536,), eps=1e-06)
(rotary_emb): Qwen2RotaryEmbedding()
)
(lm_head): Linear(in_features=1536, out_features=151936, bias=False)
)
)
)
我们可以看到,base_layer为模型中原始层,lora_开头的层为新添加的层。
2.2 模型变动相关代码实现
swift库中switft/llm/tuner.py中prepare_model函数通过Swift.prepare_model方法最后调用了peft
peft库的PeftModelForCausalLM类构造函数中调用了peft.tunners.lora.model.LoraModel类的构造函数完成了模型结构变动。
这块代码比较复杂,没有详细跟踪,截取关键三段代码
如下是遍历模型,查找要替换的线性层
for key in key_list:
......
self._create_and_replace(peft_config, adapter_name, target, target_name, parent, current_key=key)
如下是替换预训练模型原来的层。这里冻结了非线性层的梯度(LoRA适配层),反向传播不会更新这部分的参数。
......
else:
new_module = self._create_new_module(lora_config, adapter_name, target, **kwargs)
if adapter_name not in self.active_adapters:
# adding an additional adapter: it is not automatically trainable
new_module.requires_grad_(False)
self._replace_module(parent, target_name, new_module, target)
......
如下代码,LoraModel类的set_adapter和_mark_only_adapters_as_trainable两个实例方法会调整冻结参数
# layers will be activated, which we don't want.
self.set_adapter(self.active_adapters)
self._mark_only_adapters_as_trainable(model)
3. 参数微调后模型参数使用
swift库中swift/llm/infer.py的merge_lora函数实现该功能,截取代码如下所示:
加载LoRA模型的checkpoint,然后融合参数,将融合后的参数保存到merged_lora_path中
推理时加载原始模型,然后加载merged_lora_path中的参数进行推理。
.......
ckpt_dir, ckpt_name = os.path.split(args.ckpt_dir)
merged_lora_path = os.path.join(ckpt_dir, f'{ckpt_name}-merged')
logger.info(f'merged_lora_path: `{merged_lora_path}`')
if os.path.exists(merged_lora_path) and not replace_if_exists:
logger.info(f'The weight directory for the merged LoRA already exists in {args.ckpt_dir}, '
'skipping the saving process. '
'you can pass `replace_if_exists=True` to overwrite it.')
else:
if device_map is None:
device_map = args.merge_device_map
logger.info(f'merge_device_map: {device_map}')
model, template = prepare_model_template(args, device_map=device_map, task='export')
logger.info('Merge LoRA...')
Swift.merge_and_unload(model)
model = model.model
logger.info('Saving merged weights...')
save_checkpoint(
model,
template.tokenizer,
model.model_dir,
args.ckpt_dir,
merged_lora_path,
save_safetensors=args.save_safetensors,
sft_args_kwargs={'dtype': args.dtype})
logger.info(f'Successfully merged LoRA and saved in {merged_lora_path}.')
logger.info("Setting args.sft_type: 'full'")
logger.info(f'Setting args.ckpt_dir: {merged_lora_path}')
args.sft_type = 'full'
args.ckpt_dir = merged_lora_path
return merged_lora_path