type
status
date
slug
summary
tags
category
icon
password
Code LoRA from Scratch
Coding LoRA from Scratch
例如,在常规微调中,我们计算权重矩阵W的权重更新为ΔW,而在 LoRA 中,我们通过两个较小矩阵 AB 的矩阵乘法来近似ΔW,如下图所示。(如果您熟悉 PCA 或 SVD,可以将其视为将ΔW分解为A和B。)
请注意,上图中的r是一个超参数,我们可以用它来指定用于自适应的低秩矩阵的秩。较小的r会导致更简单的低秩矩阵,从而导致自适应过程中需要学习的参数更少。这可能导致更快的训练速度,并可能降低计算需求。但是,对于较小的r,低秩矩阵捕获特定任务信息的能力会降低。
LoRA 通常应用于神经网络的线性(前馈)层
为了应用 LoRA,我们将神经网络中现有的线性层替换为 LinearWithLoRA 层,该层结合了原始线性层和 LoRALayer。
通过修改现有的 PyTorch 模型来实现 LoRA 时,一种简单的方法是将每个线性层替换为一个
LinearWithLoRA
层,该层将 Linear
层与我们之前 LoRALayer
实现相结合:实际上,要使用 LoRA 装备和微调模型,我们所要做的就是用我们新的
LinearWithLoRA
层替换其预训练的 Linear
层。我们将在下面即将到来的动手部分中看到如何将 LinearWithLoRA
层应用于预训练的语言模型。A Hands-On Example
LoRA 是一种可以应用于各种类型神经网络的方法,而不仅仅是像 GPT 或图像生成模型这样的生成模型。在本动手示例中,我们将训练一个小型 BERT 模型用于文本分类,因为分类准确性比生成的文本更容易评估。(这些实验的完整代码可以在本 Studio 的
02_finetune-with-lora.ipynb
文件中找到。)我们使用来自 transformers 库的预训练 DistilBERT(BERT 的一个较小版本)模型:
由于我们只想训练新的 LoRA 权重,因此我们通过将所有可训练参数的
requires_grad
设置为 False
来冻结所有模型参数:接下来,让我们使用
print(model)
简要检查模型的结构:根据下面的输出,我们可以看到模型包含 6 个包含线性层的 Transformer 层:
此外,该模型有两个
Linear
输出层:我们可以通过定义以下赋值函数和循环,选择性地为这些
Linear
层启用 LoRA:现在我们可以再次检查模型,使用
print(model)
查看其更新后的结构:正如我们上面看到的,
Linear
层已经被 LinearWithLoRA
层成功替换。在IMDb 电影评论分类数据集上的效果如下:
- Train acc: 92.15% 训练准确率:92.15%
- Val acc: 89.98% 验证准确率:89.98%
- Test acc: 89.44% 测试准确率:89.44%
Comparison to Traditional Finetuning
让我们先训练 DistilBERT 模型,但在训练期间只更新最后两层。我们可以先冻结所有模型权重,然后解冻两个线性输出层来实现这一点。
仅训练最后两层后,得到的分类性能如下:
- Train acc: 86.68% 训练准确率:86.68%
- Val acc: 87.26% 验证准确率:87.26%
- Test acc: 86.22% 测试准确率:86.22%
微调所有层后,得到的分类性能如下:
- Train acc: 96.41% 训练准确率:96.41%
- Val acc: 92.80% 验证准确率:92.80%
- Test acc: 92.31% 测试准确率:92.31%
LoRA 的表现优于对最后两层进行的传统微调,尽管它使用的参数少 4 倍。对所有层进行微调需要更新比 LoRA 设置多 450 倍的参数,但也导致测试准确率提高了 2%。
Optimizing the LoRA Configuration
这仅涉及将 LoRA 应用于注意力层的查询和值权重矩阵。可选地,我们也可以为其他层启用 LoRA。此外,我们可以通过修改秩(
lora_r
)来控制每个 LoRA 层中可训练参数的数量。要尝试不同的超参数配置,您可以使用我们简洁的
03_finetune-lora.py 脚本
,它接受超参数选择作为命令行参数:此外,您还可以切换其他超参数设置,例如:
提高 LoRA 性能的一种方法是手动调整这些超参数选择。但是,为了使超参数调整更方便,您也可以使用
03_gridsearch.py 脚本
,它将在所有可用的 GPU 上运行以下超参数网格:- Author:liamtech
- URL:https://liamtech.top/article/example-3
- Copyright:All articles in this blog, except for special statements, adopt BY-NC-SA agreement. Please indicate the source!