微调范式
0 引入
#AI夏令营 #Datawhale #夏令营
主要为参与datawhale夏令营活动:
https://datawhaler.feishu.cn/wiki/VIy8ws47ii2N79kOt9zcXnbXnuS
比赛连接:
https://challenge.xfyun.cn/topic/info?type=role-element-extraction&ch=dw24_y0SCtd
代码参考:
数据集制作:
https://aistudio.baidu.com/projectdetail/8090135
平台微调:
https://training.xfyun.cn/overview
感觉这种平台式的微调,如果平台做的比较好的话,还是非常方便的。 而且微调完之后就可以直接发布为服务,巨方便。这样的话,这个微调的模型就不用自己去部署了,直接调用api就可以使用了,而且可以把自己的模型分享给别人使用。
1 数据集制作:
1.1 环境配置
这里我们需要先对原始群聊数据做初步抽取,我们需要准备一下讯飞3.5的api环境配置。和baseline1的配置一样。
pip uninstall websocket-client
pip install --upgrade spark_ai_python websocket-client
api连通测试:
from sparkai.llm.llm import ChatSparkLLM, ChunkPrintHandler
from sparkai.core.messages import ChatMessage
import numpy as np
from tqdm import tqdm
def chatbot(prompt):
#星火认知大模型Spark3.5 Max的URL值,其他版本大模型URL值请前往文档(https://www.xfyun.cn/doc/spark/Web.html)查看
SPARKAI_URL = 'wss://spark-api.xf-yun.com/v3.5/chat'
#星火认知大模型调用秘钥信息,请前往讯飞开放平台控制台(https://console.xfyun.cn/services/bm35)查看
SPARKAI_APP_ID = ''
SPARKAI_API_SECRET = ''
SPARKAI_API_KEY = ''
#星火认知大模型Spark3.5 Max的domain值,其他版本大模型domain值请前往文档(https://www.xfyun.cn/doc/spark/Web.html)查看
SPARKAI_DOMAIN = 'generalv3.5'
spark = ChatSparkLLM(
spark_api_url=SPARKAI_URL,
spark_app_id=SPARKAI_APP_ID,
spark_api_key=SPARKAI_API_KEY,
spark_api_secret=SPARKAI_API_SECRET,
spark_llm_domain=SPARKAI_DOMAIN,
streaming=False,
)
messages = [ChatMessage(
role="user",
content=prompt
)]
handler = ChunkPrintHandler()
a = spark.generate([messages], callbacks=[handler])
return a.generations[0][0].message.content
1.2 数据处理prompt:
这里我们对原群聊对话设计了一个总结Prompt,目的是将原始对话内容进行精简。方便做微调数据。
一方面直接将群聊对话作为数据集的话,会导致上下文过长,超过限制。还有上下文太长会导致抽取效果变差。
过长的上下文也会导致训练时长和费用倍增。(比如我做了一个数据集要花3000多块钱跑完。就算能跑可能也要1-2天……)
好了我们来说说prompt。这个prompt相较于baseline01区别比较明显,对需要抽取的任务做了一次总结。总结了四个方面:
客户基本信息:需要从中区分出客户角色,并得到客户基本信息,其中包括姓名、手机号码、邮箱、地区、详细地址、性别、年龄和生日
客户意向与预算信息: 客户意向与预算信息包括咨询类型、意向产品、购买异议点、预算是否充足、总体预算金额以及预算明细
客户购买准备情况:户购买准备情况包括竞品信息、客户是否有意向、客户是否有卡点以及客户购买阶段
跟进计划信息: 跟进计划信息包括参与人、时间点和具体事项,这些信息用于指导销售团队在未来的跟进工作中与客户互动
通过总结后的数据一方面节约了微调的运算资源,一方面也让数据被清洗后更容易被模型理解,达到更好的抽取效果。
content = ''
prompt = f'''
你是一个数据分析大师,你需要从群聊对话中进行分析,里面对话的角色中大部分是客服角色,你需要从中区分出有需求的客户,并得到以下四类数据。
****群聊对话****
{content}
****分析数据****
客户基本信息:需要从中区分出客户角色,并得到客户基本信息,其中包括姓名、手机号码、邮箱、地区、详细地址、性别、年龄和生日
客户意向与预算信息: 客户意向与预算信息包括咨询类型、意向产品、购买异议点、预算是否充足、总体预算金额以及预算明细
客户购买准备情况:户购买准备情况包括竞品信息、客户是否有意向、客户是否有卡点以及客户购买阶段
跟进计划信息: 跟进计划信息包括参与人、时间点和具体事项,这些信息用于指导销售团队在未来的跟进工作中与客户互动
****注意****
1.只输出客户基本信息、客户意向与预算信息、客户购买准备情况、跟进计划信息对应的信息,不要输出无关内容
2.不要输出分析内容
3.输出内容格式为md格式
'''
1.3 训练数据集制作:
jsonl_data 是用来训练的规范单行数据,需要由训练数据组成一个jsonl文件(每行是一个json数据的文件),格式如下:
jsonl_data = {"instruction":"假设你是一个智能交互助手,基于用户的输入文本,解析其中语义,抽取关键信息,以json格式生成结构化的语义内容。","input":"请调小空气净化器的湿度到1","output":"{\"intent\":\"CONTROL\",\"slots\":[{\"name\":\"device\",\"normValue\":\"airCleaner\",\"value\":\"空气净化器\"},{\"name\":\"insType\",\"normValue\":\"set\",\"value\":\"调小\"},{\"name\":\"attr\",\"normValue\":\"humidity\",\"value\":\"湿度\"},{\"name\":\"attrValue\",\"normValue\":\"1\",\"value\":\"1\"}],\"sample\":\"请调小空气净化器的湿度到1\"}"}
{‘instruction’: ‘假设你是一个智能交互助手,基于用户的输入文本,解析其中语义,抽取关键信息,以json格式生成结构化的语义内容。’, ‘input’: ‘请调小空气净化器的湿度到1’, ‘output’: ‘{“intent”:“CONTROL”,“slots”:[{“name”:“device”,“normValue”:“airCleaner”,“value”:“空气净化器”},{“name”:“insType”,“normValue”:“set”,“value”:“调小”},{“name”:“attr”,“normValue”:“humidity”,“value”:“湿度”},{“name”:“attrValue”,“normValue”:“1”,“value”:“1”}],“sample”:“请调小空气净化器的湿度到1”}’}
需要训练的数据文件在官网下载后是train.json。
import json
# 打开并读取JSON文件
with open('train.json', 'r', encoding='utf-8') as file:
data = json.load(file)
这里我们通过星火3.5api清洗原来的数据,总结后按照刚才看到得单行jsonl存储格式将数据存入traindata.jsonl中。大家可以经过处理后自行查阅traindata.jsonl文件,看看都有啥。
这里的训练时长大概40min左右,请耐心等待。这段等待的时间可以看看后面的内容。
学到了,原来大模型微调的数据集是这样构建的,
# 训练集制作
# 打开一个文件用于写入,如果文件已存在则会被覆盖
with open('traindata.jsonl', 'w', encoding='utf-8') as file:
# 训练集行数(130)不符合要求,范围:1500~90000000
# 遍历数据列表,并将每一行写入文件
# 这里为了满足微调需求我们重复12次数据集 130*12=1560
for line_data in tqdm(data):
line_input = line_data["chat_text"]
line_output = line_data["infos"]
content = line_input
prompt = f'''
你是一个数据分析大师,你需要从群聊对话中进行分析,里面对话的角色中大部分是客服角色,你需要从中区分出有需求的客户,并得到以下四类数据。
****群聊对话****
{content}
****分析数据****
客户基本信息:需要从中区分出客户角色,并得到客户基本信息,其中包括姓名、手机号码、邮箱、地区、详细地址、性别、年龄和生日
客户意向与预算信息: 客户意向与预算信息包括咨询类型、意向产品、购买异议点、预算是否充足、总体预算金额以及预算明细
客户购买准备情况:户购买准备情况包括竞品信息、客户是否有意向、客户是否有卡点以及客户购买阶段
跟进计划信息: 跟进计划信息包括参与人、时间点和具体事项,这些信息用于指导销售团队在未来的跟进工作中与客户互动
****注意****
1.只输出客户基本信息、客户意向与预算信息、客户购买准备情况、跟进计划信息对应的信息,不要输出无关内容
2.不要输出分析内容
3.输出内容格式为md格式
'''
res = chatbot(prompt=prompt)
# print(res)
line_write = {
"instruction":jsonl_data["instruction"],
"input":json.dumps(res, ensure_ascii=False),
"output":json.dumps(line_output, ensure_ascii=False)
}
# 因为数据共有130行,为了能满足训练需要的1500条及以上,我们将正常训练数据扩充12倍。
for time in range(12):
file.write(json.dumps(line_write, ensure_ascii=False) + '\n') # '\n' 用于在每行末尾添加换行符
结果:
1.4 测试集数据制作:
测试数据和训练数据相似,都是通过api清洗后存储。
# 验证集制作(提交版本)
# input,target
import json
# 打开并读取JSON文件
with open('test_data.json', 'r', encoding='utf-8') as file:
data_test = json.load(file)
这里的验证数据我们以csv文件存储,有input和target两列,由于我们没有这些数据的真实标签,我这里将target列设置为’-'。
测试集text.csv文件大概需要20min能得到,也请大家耐心等待~
import csv
# 打开一个文件用于写入CSV数据
with open('test.csv', 'w', newline='', encoding='utf-8') as csvfile:
# 创建一个csv writer对象
csvwriter = csv.writer(csvfile)
csvwriter.writerow(["input","target"])
# 遍历数据列表,并将每一行写入CSV文件
for line_data in tqdm(data_test):
content = line_data["chat_text"]
prompt = f'''
你是一个数据分析大师,你需要从群聊对话中进行分析,里面对话的角色中大部分是客服角色,你需要从中区分出有需求的客户,并得到以下四类数据。
****群聊对话****
{content}
****分析数据****
客户基本信息:需要从中区分出客户角色,并得到客户基本信息,其中包括姓名、手机号码、邮箱、地区、详细地址、性别、年龄和生日
客户意向与预算信息: 客户意向与预算信息包括咨询类型、意向产品、购买异议点、预算是否充足、总体预算金额以及预算明细
客户购买准备情况:户购买准备情况包括竞品信息、客户是否有意向、客户是否有卡点以及客户购买阶段
跟进计划信息: 跟进计划信息包括参与人、时间点和具体事项,这些信息用于指导销售团队在未来的跟进工作中与客户互动
****注意****
1.只输出客户基本信息、客户意向与预算信息、客户购买准备情况、跟进计划信息对应的信息,不要输出无关内容
2.不要输出分析内容
3.输出内容格式为md格式
'''
res = chatbot(prompt=prompt)
# print(line_data["chat_text"])
## 文件内容校验失败: test.jsonl(不含表头起算)第1行的内容不符合规则,限制每组input和target字符数量总和上限为8000,当前行字符数量:10721
line_list = [res, "-"]
csvwriter.writerow(line_list)
# break
2 模型微调:
再次吐槽一下讯飞的平台,真的经常抽风,感觉根本没想要面向个人使用者使用。就比如我训练3个epoch,能用上一天多:
科大讯飞的微调平台:https://training.xfyun.cn/overview
2.1 上传数据集:
- 点击确定,等待上传成功,- 运行成功训练表示数据上传结束
2.2 测试集上传:
上传我们的test.csv文件即可。
2.3 平台微调:
这个平台里面,有一些常见的模型;
进入创建微调页面:https://training.xfyun.cn/model/add
基本配置与版本配置如下,我们选择性能比较好的Spark Pro模型~
数据配置与参数配置,选择代金券,通义协议,提交训练。
2.4 结果处理
训练完成后,进入模型页面。
点击右侧,发布为服务
点击左侧导航栏的我的模型服务接着拿到resourceId、APPID、APIKey、APISecret
然后就可以直接调用api进行推理了。
微调推理部分填入APPID、APIKey、APISecret(注意顺序)
3 模型推理:
# 定义写入函数
def write_json(json_file_path, data):
#"""写入json文件"""
with open(json_file_path, 'w') as f:
json.dump(data, f, ensure_ascii=False, indent=4)
import SparkApi
import json
#以下密钥信息从控制台获取
appid = "" #填写控制台中获取的 APPID 信息
api_secret = "" #填写控制台中获取的 APISecret 信息
api_key ="" #填写控制台中获取的 APIKey 信息
#调用微调大模型时,设置为“patch”
domain = "patchv3"
#云端环境的服务地址
# Spark_url = "wss://spark-api-n.xf-yun.com/v1.1/chat" # 微调v1.5环境的地址
Spark_url = "wss://spark-api-n.xf-yun.com/v3.1/chat" # 微调v3.0环境的地址
text =[]
# length = 0
def getText(role,content):
jsoncon = {}
jsoncon["role"] = role
jsoncon["content"] = content
text.append(jsoncon)
return text
def getlength(text):
length = 0
for content in text:
temp = content["content"]
leng = len(temp)
length += leng
return length
def checklen(text):
while (getlength(text) > 8000):
del text[0]
return text
def core_run(text,prompt):
# print('prompt',prompt)
text.clear
Input = prompt
question = checklen(getText("user",Input))
SparkApi.answer =""
# print("星火:",end = "")
SparkApi.main(appid,api_key,api_secret,Spark_url,domain,question)
getText("assistant",SparkApi.answer)
# print(text)
return text[-1]['content']
text = []
res = core_run(text,'你好吗?')
print(res)
在这里出现问题:
我知道怎么破解连接出现问题的了,可以先用不连接resouceid的进行连通测试;然后,再在108行写入微调好的模型的resouceid。再进行连通测试。这样就可以了。
在SparkApi.py文件的108行,引号中填入你的resourceId
写入:
import pandas as pd
import re
# 读取Excel文件
df_test = pd.read_csv('test.csv',)
data_dict_empty = {
"基本信息-姓名": "",
"基本信息-手机号码": "",
"基本信息-邮箱": "",
"基本信息-地区": "",
"基本信息-详细地址": "",
"基本信息-性别": "",
"基本信息-年龄": "",
"基本信息-生日": "",
"咨询类型": [],
"意向产品": [],
"购买异议点": [],
"客户预算-预算是否充足": "",
"客户预算-总体预算金额": "",
"客户预算-预算明细": "",
"竞品信息": "",
"客户是否有意向": "",
"客户是否有卡点": "",
"客户购买阶段": "",
"下一步跟进计划-参与人": [],
"下一步跟进计划-时间点": "",
"下一步跟进计划-具体事项": ""
}
代码比较鲁棒的,使用try, except,的方式,避免了遇到错误后,后面的样例无法运行的情况。
submit_data = []
for id,line_data in tqdm(enumerate(df_test['input'])):
# print(line_data)
content = line_data
text = []
prompt = json.dumps(content,ensure_ascii=False)
# print(json.dumps(content,ensure_ascii=False))
res = core_run(text,prompt)
try:
data_dict = json.loads(res)
except json.JSONDecodeError as e:
data_dict = data_dict_empty
submit_data.append({"infos":data_dict,"index":id+1})
# 预计执行8min
write_json("submit.json",submit_data)
然后就可以在讯飞的平台上提交结果了。
评分的性能: