XTuner 微调 Llama3 图片理解多模态
本次实验基于 Llama3-8B-Instruct 和 XTuner 团队预训练好的 Image Projector 微调自己的多模态图文理解模型 LLaVA。实验平台为InternStudio,实验所用的显存为24G。
环境、模型、数据准备
配置环境
先来配置相关环境。使用如下指令便可以安装好一个 python=3.10 pytorch=2.1.2+cu121 的基础环境了。
conda create -n llama3 python=3.10
conda activate llama3
conda install pytorch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 pytorch-cuda=12.1 -c pytorch -c nvidia
接下来安装 XTuner。
cd ~
git clone -b v0.1.18 https://github.com/InternLM/XTuner
cd XTuner
pip install -e .[all]
最后 clone 本Llama3-Tutorial仓库。
cd ~
git clone https://github.com/SmartFlowAI/Llama3-Tutorial
模型准备
准备Llama3权重
在微调开始前,首先来准备 Llama3-8B-Instruct 模型权重。在InternStudio内部的share文件夹里边有模型权重。
mkdir -p ~/model
cd ~/model
ln -s /root/share/new_models/meta-llama/Meta-Llama-3-8B-Instruct .
接下来准备 Llava 所需要的 openai/clip-vit-large-patch14-336,权重,即 Visual Encoder 权重。
mkdir -p ~/model
cd ~/model
ln -s /root/share/new_models/openai/clip-vit-large-patch14-336 .
准备 Image Projector 权重。
mkdir -p ~/model
cd ~/model
ln -s /root/share/new_models/xtuner/llama3-llava-iter_2181.pth .
数据准备
使用过拟合的方式快速实现。可以执行以下代码:
cd ~
git clone https://github.com/InternLM/tutorial -b camp2
python ~/tutorial/xtuner/llava/llava_data/repeat.py \
-i ~/tutorial/xtuner/llava/llava_data/unique_data.json \
-o ~/tutorial/xtuner/llava/llava_data/repeated_data.json \
-n 200
如果自己想构建数据集参考此处。
微调过程
使用如下指令以启动训练:
xtuner train /root/Llama3-Tutorial/configs/llama3-llava/llava_llama3_8b_instruct_qlora_clip_vit_large_p14_336_lora_e1_finetune.py --work-dir /root/model/llama3_llava_pth --deepspeed deepspeed_zero2_offload
训练过程所需显存约为14170左右 MiB,训练所需时间为4小时48分钟。
05/09 11:24:43 - mmengine - INFO - Iter(train) [ 10/1200] lr: 5.1430e-05 eta: 4:48:31 time: 14.5475 data_time: 0.0383 memory: 14172 loss: 1.4195
05/09 11:26:59 - mmengine - INFO - Iter(train) [ 20/1200] lr: 1.0857e-04 eta: 4:36:40 time: 13.5897 data_time: 0.0298 memory: 14170 loss: 0.5519
05/09 11:29:13 - mmengine - INFO - Iter(train) [ 30/1200] lr: 1.6571e-04 eta: 4:30:23 time: 13.4613 data_time: 0.0273 memory: 14170 loss: 0.5242
05/09 11:31:29 - mmengine - INFO - Iter(train) [ 40/1200] lr: 2.0000e-04 eta: 4:26:41 time: 13.5791 data_time: 0.0361 memory: 14170 loss: 0.4559
05/09 11:33:52 - mmengine - INFO - Iter(train) [ 50/1200] lr: 1.9994e-04 eta: 4:26:19 time: 14.3003 data_time: 0.0397 memory: 14169 loss: 0.3685
05/09 11:36:12 - mmengine - INFO - Iter(train) [ 60/1200] lr: 1.9981e-04 eta: 4:24:22 time: 14.0092 data_time: 0.0411 memory: 14172 loss: 0.4149
05/09 11:38:31 - mmengine - INFO - Iter(train) [ 70/1200] lr: 1.9960e-04 eta: 4:22:07 time: 13.9411 data_time: 0.0617 memory: 14170 loss: 0.2311
05/09 11:40:50 - mmengine - INFO - Iter(train) [ 80/1200] lr: 1.9933e-04 eta: 4:19:34 time: 13.8193 data_time: 0.0326 memory: 14170 loss: 0.1266
05/09 11:43:11 - mmengine - INFO - Iter(train) [ 90/1200] lr: 1.9898e-04 eta: 4:17:38 time: 14.0893 data_time: 0.0308 memory: 14170 loss: 0.1708
05/09 11:45:30 - mmengine - INFO - Iter(train) [ 100/1200] lr: 1.9856e-04 eta: 4:15:22 time: 13.9608 data_time: 0.0320 memory: 14169 loss: 0.1202
05/09 11:47:51 - mmengine - INFO - Iter(train) [ 110/1200] lr: 1.9807e-04 eta: 4:13:21 time: 14.1097 data_time: 0.0408 memory: 14172 loss: 0.0798
05/09 11:50:13 - mmengine - INFO - Iter(train) [ 120/1200] lr: 1.9750e-04 eta: 4:11:26 time: 14.2204 data_time: 0.0433 memory: 14170 loss: 0.2290
05/09 11:52:32 - mmengine - INFO - Iter(train) [ 130/1200] lr: 1.9687e-04 eta: 4:08:59 time: 13.8805 data_time: 0.0462 memory: 14170 loss: 0.1079
05/09 11:54:49 - mmengine - INFO - Iter(train) [ 140/1200] lr: 1.9616e-04 eta: 4:06:14 time: 13.6250 data_time: 0.0364 memory: 14169 loss: 0.0395
05/09 11:56:58 - mmengine - INFO - Iter(train) [ 150/1200] lr: 1.9539e-04 eta: 4:02:43 time: 12.9145 data_time: 0.0451 memory: 14168 loss: 0.1383
05/09 11:59:17 - mmengine - INFO - Iter(train) [ 160/1200] lr: 1.9454e-04 eta: 4:00:29 time: 13.9496 data_time: 0.0380 memory: 14172 loss: 0.1375
05/09 12:01:35 - mmengine - INFO - Iter(train) [ 170/1200] lr: 1.9363e-04 eta: 3:58:05 time: 13.7808 data_time: 0.0380 memory: 14170 loss: 0.0263
05/09 12:03:51 - mmengine - INFO - Iter(train) [ 180/1200] lr: 1.9264e-04 eta: 3:55:30 time: 13.5891 data_time: 0.0426 memory: 14170 loss: 0.2581
05/09 12:06:10 - mmengine - INFO - Iter(train) [ 190/1200] lr: 1.9159e-04 eta: 3:53:16 time: 13.9399 data_time: 0.0498 memory: 14170 loss: 0.1230
05/09 12:08:27 - mmengine - INFO - Iter(train) [ 200/1200] lr: 1.9048e-04 eta: 3:50:49 time: 13.6799 data_time: 0.0493 memory: 14169 loss: 0.0601
05/09 12:10:48 - mmengine - INFO - Iter(train) [ 210/1200] lr: 1.8930e-04 eta: 3:48:41 time: 14.0703 data_time: 0.0408 memory: 14172 loss: 0.1540
05/09 12:13:11 - mmengine - INFO - Iter(train) [ 220/1200] lr: 1.8805e-04 eta: 3:46:41 time: 14.2892 data_time: 0.0405 memory: 14170 loss: 0.1472
05/09 12:15:32 - mmengine - INFO - Iter(train) [ 230/1200] lr: 1.8674e-04 eta: 3:44:35 time: 14.1704 data_time: 0.0360 memory: 14170 loss: 0.0342
05/09 12:17:53 - mmengine - INFO - Iter(train) [ 240/1200] lr: 1.8536e-04 eta: 3:42:24 time: 14.1000 data_time: 0.0391 memory: 14170 loss: 0.0607
05/09 12:20:13 - mmengine - INFO - Iter(train) [ 250/1200] lr: 1.8393e-04 eta: 3:40:06 time: 13.9302 data_time: 0.0382 memory: 14169 loss: 0.0634
05/09 12:22:35 - mmengine - INFO - Iter(train) [ 260/1200] lr: 1.8243e-04 eta: 3:38:01 time: 14.2698 data_time: 0.0495 memory: 14172 loss: 0.0763
05/09 12:24:57 - mmengine - INFO - Iter(train) [ 270/1200] lr: 1.8087e-04 eta: 3:35:48 time: 14.1202 data_time: 0.0336 memory: 14170 loss: 0.0760
05/09 12:27:27 - mmengine - INFO - Iter(train) [ 280/1200] lr: 1.7925e-04 eta: 3:34:07 time: 15.0706 data_time: 0.0356 memory: 14170 loss: 0.0674
05/09 12:29:56 - mmengine - INFO - Iter(train) [ 290/1200] lr: 1.7758e-04 eta: 3:32:15 time: 14.8495 data_time: 0.0328 memory: 14170 loss: 0.0619
05/09 12:32:17 - mmengine - INFO - Iter(train) [ 300/1200] lr: 1.7585e-04 eta: 3:29:59 time: 14.1217 data_time: 0.0418 memory: 14169 loss: 0.0213
05/09 12:34:36 - mmengine - INFO - Iter(train) [ 310/1200] lr: 1.7406e-04 eta: 3:27:36 time: 13.8896 data_time: 0.0345 memory: 14172 loss: 0.0649
05/09 12:37:04 - mmengine - INFO - Iter(train) [ 320/1200] lr: 1.7222e-04 eta: 3:25:39 time: 14.8585 data_time: 0.0411 memory: 14172 loss: 0.0383
05/09 12:39:27 - mmengine - INFO - Iter(train) [ 330/1200] lr: 1.7033e-04 eta: 3:23:26 time: 14.2799 data_time: 0.0332 memory: 14170 loss: 0.0712
05/09 12:41:52 - mmengine - INFO - Iter(train) [ 340/1200] lr: 1.6838e-04 eta: 3:21:18 time: 14.5005 data_time: 0.0291 memory: 14170 loss: 0.0852
05/09 12:44:11 - mmengine - INFO - Iter(train) [ 350/1200] lr: 1.6639e-04 eta: 3:18:54 time: 13.9209 data_time: 0.0301 memory: 14169 loss: 0.0176
05/09 12:46:40 - mmengine - INFO - Iter(train) [ 360/1200] lr: 1.6435e-04 eta: 3:16:53 time: 14.8487 data_time: 0.0294 memory: 14172 loss: 0.0503
05/09 12:49:04 - mmengine - INFO - Iter(train) [ 370/1200] lr: 1.6226e-04 eta: 3:14:39 time: 14.3601 data_time: 0.0289 memory: 14170 loss: 0.0551
05/09 12:51:29 - mmengine - INFO - Iter(train) [ 380/1200] lr: 1.6012e-04 eta: 3:12:28 time: 14.5297 data_time: 0.0293 memory: 14170 loss: 0.0452
05/09 12:53:50 - mmengine - INFO - Iter(train) [ 390/1200] lr: 1.5795e-04 eta: 3:10:08 time: 14.1204 data_time: 0.0276 memory: 14169 loss: 0.0306
05/09 12:56:10 - mmengine - INFO - Iter(train) [ 400/1200] lr: 1.5573e-04 eta: 3:07:46 time: 14.0397 data_time: 0.0302 memory: 14169 loss: 0.0028
05/09 12:58:31 - mmengine - INFO - Iter(train) [ 410/1200] lr: 1.5346e-04 eta: 3:05:24 time: 14.0401 data_time: 0.0541 memory: 14172 loss: 0.1111
05/09 13:00:56 - mmengine - INFO - Iter(train) [ 420/1200] lr: 1.5116e-04 eta: 3:03:12 time: 14.5506 data_time: 0.0393 memory: 14170 loss: 0.0110
05/09 13:03:21 - mmengine - INFO - Iter(train) [ 430/1200] lr: 1.4883e-04 eta: 3:00:58 time: 14.4790 data_time: 0.0370 memory: 14170 loss: 0.8442
05/09 13:05:41 - mmengine - INFO - Iter(train) [ 440/1200] lr: 1.4645e-04 eta: 2:58:35 time: 13.9606 data_time: 0.0434 memory: 14170 loss: 0.0277
05/09 13:08:01 - mmengine - INFO - Iter(train) [ 450/1200] lr: 1.4405e-04 eta: 2:56:12 time: 14.0194 data_time: 0.0314 memory: 14168 loss: 0.0015
05/09 13:10:24 - mmengine - INFO - Iter(train) [ 460/1200] lr: 1.4161e-04 eta: 2:53:55 time: 14.3201 data_time: 0.0301 memory: 14172 loss: 0.0922
05/09 13:12:47 - mmengine - INFO - Iter(train) [ 470/1200] lr: 1.3914e-04 eta: 2:51:36 time: 14.2605 data_time: 0.0315 memory: 14170 loss: 0.0422
05/09 13:15:10 - mmengine - INFO - Iter(train) [ 480/1200] lr: 1.3664e-04 eta: 2:49:18 time: 14.2796 data_time: 0.0325 memory: 14170 loss: 0.0175
05/09 13:17:34 - mmengine - INFO - Iter(train) [ 490/1200] lr: 1.3412e-04 eta: 2:47:02 time: 14.4799 data_time: 0.0321 memory: 14170 loss: 0.0168
05/09 13:20:01 - mmengine - INFO - Iter(train) [ 500/1200] lr: 1.3157e-04 eta: 2:44:48 time: 14.6202 data_time: 0.0311 memory: 14169 loss: 0.0106
05/09 13:20:01 - mmengine - INFO - after_train_iter in EvaluateChatHook.
在训练好之后,将原始 image projector 和 微调得到的 image projector 都转换为 HuggingFace 格式,为了下面的效果体验做准备。
xtuner convert pth_to_hf ~/Llama3-Tutorial/configs/llama3-llava/llava_llama3_8b_instruct_qlora_clip_vit_large_p14_336_lora_e1_finetune.py \
~/model/llama3-llava-iter_2181.pth \
~/llama3_llava_pth/pretrain_iter_2181_hf
原始 image projector转换成HuggingFace格式,如下图所示:
微调得到的 image projector 转换为 HuggingFace 格式:
xtuner convert pth_to_hf ~/Llama3-Tutorial/configs/llama3-llava/llava_llama3_8b_instruct_qlora_clip_vit_large_p14_336_lora_e1_finetune.py \
~/llama3_llava_pth/iter_1200.pth \
~/llama3_llava_pth/iter_1200_hf
在转换完成后,我们就可以在命令行简单体验一下微调后模型的效果了。
问题1:Describe this image.
问题2:What is the equipment in the image?
Pretrain 模型
export MKL_SERVICE_FORCE_INTEL=1
xtuner chat /root/model/Meta-Llama-3-8B-Instruct \
--visual-encoder /root/model/clip-vit-large-patch14-336 \
--llava /root/llama3_llava_pth/pretrain_iter_2181_hf \
--prompt-template llama3_chat \
--image /root/tutorial/xtuner/llava/llava_data/test_img/oph.jpg
此时可以看到,Pretrain 模型只会为图片打标签,并不能回答问题。
Finetune 后 模型
export MKL_SERVICE_FORCE_INTEL=1
xtuner chat /root/model/Meta-Llama-3-8B-Instruct \
--visual-encoder /root/model/clip-vit-large-patch14-336 \
--llava /root/llama3_llava_pth/iter_1200_hf \
--prompt-template llama3_chat \
--image /root/tutorial/xtuner/llava/llava_data/test_img/oph.jpg
经过 Finetune 后,我们可以发现,模型已经可以根据图片回答我们的问题了。本次实验参考Llama3-Tutorial这个教程,有兴趣者可以访问了解一下。