Bootstrap

科大讯飞-群聊对话角色要素提取:lora微调范式

0 引入

#AI夏令营 #Datawhale #夏令营
主要为参与datawhale夏令营活动:
https://datawhaler.feishu.cn/wiki/VIy8ws47ii2N79kOt9zcXnbXnuS
比赛连接:
https://challenge.xfyun.cn/topic/info?type=role-element-extraction&ch=dw24_y0SCtd
代码参考:
数据集制作:
https://aistudio.baidu.com/projectdetail/8090135
平台微调:
https://training.xfyun.cn/overview

感觉这种平台式的微调,如果平台做的比较好的话,还是非常方便的。 而且微调完之后就可以直接发布为服务,巨方便。这样的话,这个微调的模型就不用自己去部署了,直接调用api就可以使用了,而且可以把自己的模型分享给别人使用。

1 数据集制作:

1.1 环境配置

这里我们需要先对原始群聊数据做初步抽取,我们需要准备一下讯飞3.5的api环境配置。和baseline1的配置一样。

pip uninstall websocket-client
pip install --upgrade spark_ai_python websocket-client

api连通测试:

from sparkai.llm.llm import ChatSparkLLM, ChunkPrintHandler
from sparkai.core.messages import ChatMessage
import numpy as np
from tqdm import tqdm


def chatbot(prompt):
    #星火认知大模型Spark3.5 Max的URL值,其他版本大模型URL值请前往文档(https://www.xfyun.cn/doc/spark/Web.html)查看
    SPARKAI_URL = 'wss://spark-api.xf-yun.com/v3.5/chat'
    #星火认知大模型调用秘钥信息,请前往讯飞开放平台控制台(https://console.xfyun.cn/services/bm35)查看
    SPARKAI_APP_ID = ''
    SPARKAI_API_SECRET = ''
    SPARKAI_API_KEY = ''
    #星火认知大模型Spark3.5 Max的domain值,其他版本大模型domain值请前往文档(https://www.xfyun.cn/doc/spark/Web.html)查看
    SPARKAI_DOMAIN = 'generalv3.5'
    spark = ChatSparkLLM(
        spark_api_url=SPARKAI_URL,
        spark_app_id=SPARKAI_APP_ID,
        spark_api_key=SPARKAI_API_KEY,
        spark_api_secret=SPARKAI_API_SECRET,
        spark_llm_domain=SPARKAI_DOMAIN,
        streaming=False,
    )
    messages = [ChatMessage(
        role="user",
        content=prompt
    )]
    handler = ChunkPrintHandler()
    a = spark.generate([messages], callbacks=[handler])
    return a.generations[0][0].message.content

1.2 数据处理prompt:

这里我们对原群聊对话设计了一个总结Prompt,目的是将原始对话内容进行精简。方便做微调数据。
一方面直接将群聊对话作为数据集的话,会导致上下文过长,超过限制。还有上下文太长会导致抽取效果变差。
过长的上下文也会导致训练时长和费用倍增。(比如我做了一个数据集要花3000多块钱跑完。就算能跑可能也要1-2天……)

好了我们来说说prompt。这个prompt相较于baseline01区别比较明显,对需要抽取的任务做了一次总结。总结了四个方面:

客户基本信息:需要从中区分出客户角色,并得到客户基本信息,其中包括姓名、手机号码、邮箱、地区、详细地址、性别、年龄和生日
客户意向与预算信息: 客户意向与预算信息包括咨询类型、意向产品、购买异议点、预算是否充足、总体预算金额以及预算明细
客户购买准备情况:户购买准备情况包括竞品信息、客户是否有意向、客户是否有卡点以及客户购买阶段
跟进计划信息: 跟进计划信息包括参与人、时间点和具体事项,这些信息用于指导销售团队在未来的跟进工作中与客户互动
通过总结后的数据一方面节约了微调的运算资源,一方面也让数据被清洗后更容易被模型理解,达到更好的抽取效果。

content = ''
prompt = f'''
你是一个数据分析大师,你需要从群聊对话中进行分析,里面对话的角色中大部分是客服角色,你需要从中区分出有需求的客户,并得到以下四类数据。

****群聊对话****
{content}

****分析数据****
客户基本信息:需要从中区分出客户角色,并得到客户基本信息,其中包括姓名、手机号码、邮箱、地区、详细地址、性别、年龄和生日
客户意向与预算信息: 客户意向与预算信息包括咨询类型、意向产品、购买异议点、预算是否充足、总体预算金额以及预算明细
客户购买准备情况:户购买准备情况包括竞品信息、客户是否有意向、客户是否有卡点以及客户购买阶段
跟进计划信息: 跟进计划信息包括参与人、时间点和具体事项,这些信息用于指导销售团队在未来的跟进工作中与客户互动

****注意****
1.只输出客户基本信息、客户意向与预算信息、客户购买准备情况、跟进计划信息对应的信息,不要输出无关内容
2.不要输出分析内容
3.输出内容格式为md格式
'''

1.3 训练数据集制作:

jsonl_data 是用来训练的规范单行数据,需要由训练数据组成一个jsonl文件(每行是一个json数据的文件),格式如下:

jsonl_data = {"instruction":"假设你是一个智能交互助手,基于用户的输入文本,解析其中语义,抽取关键信息,以json格式生成结构化的语义内容。","input":"请调小空气净化器的湿度到1","output":"{\"intent\":\"CONTROL\",\"slots\":[{\"name\":\"device\",\"normValue\":\"airCleaner\",\"value\":\"空气净化器\"},{\"name\":\"insType\",\"normValue\":\"set\",\"value\":\"调小\"},{\"name\":\"attr\",\"normValue\":\"humidity\",\"value\":\"湿度\"},{\"name\":\"attrValue\",\"normValue\":\"1\",\"value\":\"1\"}],\"sample\":\"请调小空气净化器的湿度到1\"}"}

{‘instruction’: ‘假设你是一个智能交互助手,基于用户的输入文本,解析其中语义,抽取关键信息,以json格式生成结构化的语义内容。’, ‘input’: ‘请调小空气净化器的湿度到1’, ‘output’: ‘{“intent”:“CONTROL”,“slots”:[{“name”:“device”,“normValue”:“airCleaner”,“value”:“空气净化器”},{“name”:“insType”,“normValue”:“set”,“value”:“调小”},{“name”:“attr”,“normValue”:“humidity”,“value”:“湿度”},{“name”:“attrValue”,“normValue”:“1”,“value”:“1”}],“sample”:“请调小空气净化器的湿度到1”}’}

需要训练的数据文件在官网下载后是train.json。

import json

# 打开并读取JSON文件
with open('train.json', 'r', encoding='utf-8') as file:
    data = json.load(file)

这里我们通过星火3.5api清洗原来的数据,总结后按照刚才看到得单行jsonl存储格式将数据存入traindata.jsonl中。大家可以经过处理后自行查阅traindata.jsonl文件,看看都有啥。
这里的训练时长大概40min左右,请耐心等待。这段等待的时间可以看看后面的内容。

学到了,原来大模型微调的数据集是这样构建的,

# 训练集制作

# 打开一个文件用于写入,如果文件已存在则会被覆盖
with open('traindata.jsonl', 'w', encoding='utf-8') as file:
    # 训练集行数(130)不符合要求,范围:1500~90000000
    # 遍历数据列表,并将每一行写入文件
    # 这里为了满足微调需求我们重复12次数据集 130*12=1560
    
    for line_data in tqdm(data):
        line_input = line_data["chat_text"] 
        line_output = line_data["infos"]
        content = line_input
        
        prompt = f'''
                你是一个数据分析大师,你需要从群聊对话中进行分析,里面对话的角色中大部分是客服角色,你需要从中区分出有需求的客户,并得到以下四类数据。

                ****群聊对话****
                {content}

                ****分析数据****
                客户基本信息:需要从中区分出客户角色,并得到客户基本信息,其中包括姓名、手机号码、邮箱、地区、详细地址、性别、年龄和生日
                客户意向与预算信息: 客户意向与预算信息包括咨询类型、意向产品、购买异议点、预算是否充足、总体预算金额以及预算明细
                客户购买准备情况:户购买准备情况包括竞品信息、客户是否有意向、客户是否有卡点以及客户购买阶段
                跟进计划信息: 跟进计划信息包括参与人、时间点和具体事项,这些信息用于指导销售团队在未来的跟进工作中与客户互动

                ****注意****
                1.只输出客户基本信息、客户意向与预算信息、客户购买准备情况、跟进计划信息对应的信息,不要输出无关内容
                2.不要输出分析内容
                3.输出内容格式为md格式
                '''
        res = chatbot(prompt=prompt)
        # print(res)
        line_write = {
            "instruction":jsonl_data["instruction"],
            "input":json.dumps(res, ensure_ascii=False),
            "output":json.dumps(line_output, ensure_ascii=False)
        }
        # 因为数据共有130行,为了能满足训练需要的1500条及以上,我们将正常训练数据扩充12倍。
        for time in range(12):
            file.write(json.dumps(line_write, ensure_ascii=False) + '\n')  # '\n' 用于在每行末尾添加换行符

结果:
image.png

1.4 测试集数据制作:

测试数据和训练数据相似,都是通过api清洗后存储。

# 验证集制作(提交版本)
# input,target

import json

# 打开并读取JSON文件
with open('test_data.json', 'r', encoding='utf-8') as file:
    data_test = json.load(file)

这里的验证数据我们以csv文件存储,有input和target两列,由于我们没有这些数据的真实标签,我这里将target列设置为’-'。
测试集text.csv文件大概需要20min能得到,也请大家耐心等待~

import csv

# 打开一个文件用于写入CSV数据
with open('test.csv', 'w', newline='', encoding='utf-8') as csvfile:
    # 创建一个csv writer对象
    csvwriter = csv.writer(csvfile)
    csvwriter.writerow(["input","target"])
    # 遍历数据列表,并将每一行写入CSV文件
    for line_data in tqdm(data_test):
        content = line_data["chat_text"]
        prompt = f'''
                你是一个数据分析大师,你需要从群聊对话中进行分析,里面对话的角色中大部分是客服角色,你需要从中区分出有需求的客户,并得到以下四类数据。

                ****群聊对话****
                {content}

                ****分析数据****
                客户基本信息:需要从中区分出客户角色,并得到客户基本信息,其中包括姓名、手机号码、邮箱、地区、详细地址、性别、年龄和生日
                客户意向与预算信息: 客户意向与预算信息包括咨询类型、意向产品、购买异议点、预算是否充足、总体预算金额以及预算明细
                客户购买准备情况:户购买准备情况包括竞品信息、客户是否有意向、客户是否有卡点以及客户购买阶段
                跟进计划信息: 跟进计划信息包括参与人、时间点和具体事项,这些信息用于指导销售团队在未来的跟进工作中与客户互动

                ****注意****
                1.只输出客户基本信息、客户意向与预算信息、客户购买准备情况、跟进计划信息对应的信息,不要输出无关内容
                2.不要输出分析内容
                3.输出内容格式为md格式
                '''
        res = chatbot(prompt=prompt)
        
        # print(line_data["chat_text"])
        ## 文件内容校验失败: test.jsonl(不含表头起算)第1行的内容不符合规则,限制每组input和target字符数量总和上限为8000,当前行字符数量:10721
        line_list = [res, "-"]   
        csvwriter.writerow(line_list)
        # break

2 模型微调:

再次吐槽一下讯飞的平台,真的经常抽风,感觉根本没想要面向个人使用者使用。就比如我训练3个epoch,能用上一天多:
image.png
科大讯飞的微调平台:https://training.xfyun.cn/overview

2.1 上传数据集:

image.png
image.png
image.png
- 点击确定,等待上传成功,- 运行成功训练表示数据上传结束
image.png

2.2 测试集上传:

image.png
上传我们的test.csv文件即可。
image.png

2.3 平台微调:

这个平台里面,有一些常见的模型;
进入创建微调页面:https://training.xfyun.cn/model/add
基本配置与版本配置如下,我们选择性能比较好的Spark Pro模型~
image.png
数据配置与参数配置,选择代金券,通义协议,提交训练。
image.png
image.png

2.4 结果处理

训练完成后,进入模型页面。
image.png
点击右侧,发布为服务
image.png
点击左侧导航栏的我的模型服务接着拿到resourceId、APPID、APIKey、APISecret
image.png
然后就可以直接调用api进行推理了。
微调推理部分填入APPID、APIKey、APISecret(注意顺序)

3 模型推理:

# 定义写入函数

def write_json(json_file_path, data):
    #"""写入json文件"""
    with open(json_file_path, 'w') as f:
        json.dump(data, f, ensure_ascii=False, indent=4)
import SparkApi
import json
#以下密钥信息从控制台获取
appid = ""     #填写控制台中获取的 APPID 信息
api_secret = ""   #填写控制台中获取的 APISecret 信息
api_key =""    #填写控制台中获取的 APIKey 信息

#调用微调大模型时,设置为“patch”
domain = "patchv3"

#云端环境的服务地址
# Spark_url = "wss://spark-api-n.xf-yun.com/v1.1/chat"  # 微调v1.5环境的地址
Spark_url = "wss://spark-api-n.xf-yun.com/v3.1/chat"  # 微调v3.0环境的地址


text =[]

# length = 0

def getText(role,content):
    jsoncon = {}
    jsoncon["role"] = role
    jsoncon["content"] = content
    text.append(jsoncon)
    return text

def getlength(text):
    length = 0
    for content in text:
        temp = content["content"]
        leng = len(temp)
        length += leng
    return length

def checklen(text):
    while (getlength(text) > 8000):
        del text[0]
    return text
    


def core_run(text,prompt):
    # print('prompt',prompt)
    text.clear
    Input = prompt
    question = checklen(getText("user",Input))
    SparkApi.answer =""
    # print("星火:",end = "")
    SparkApi.main(appid,api_key,api_secret,Spark_url,domain,question)
    getText("assistant",SparkApi.answer)
    # print(text)
    return text[-1]['content']

text = []
res = core_run(text,'你好吗?')
print(res)

在这里出现问题:
在这里插入图片描述
我知道怎么破解连接出现问题的了,可以先用不连接resouceid的进行连通测试;然后,再在108行写入微调好的模型的resouceid。再进行连通测试。这样就可以了。

在SparkApi.py文件的108行,引号中填入你的resourceId
image.png

写入:

import pandas as pd
import re

# 读取Excel文件
df_test = pd.read_csv('test.csv',)

data_dict_empty = {
                "基本信息-姓名": "",
                "基本信息-手机号码": "",
                "基本信息-邮箱": "",
                "基本信息-地区": "",
                "基本信息-详细地址": "",
                "基本信息-性别": "",
                "基本信息-年龄": "",
                "基本信息-生日": "",
                "咨询类型": [],
                "意向产品": [],
                "购买异议点": [],
                "客户预算-预算是否充足": "",
                "客户预算-总体预算金额": "",
                "客户预算-预算明细": "",
                "竞品信息": "",
                "客户是否有意向": "",
                "客户是否有卡点": "",
                "客户购买阶段": "",
                "下一步跟进计划-参与人": [],
                "下一步跟进计划-时间点": "",
                "下一步跟进计划-具体事项": ""
            }

代码比较鲁棒的,使用try, except,的方式,避免了遇到错误后,后面的样例无法运行的情况。

submit_data = []
for id,line_data in tqdm(enumerate(df_test['input'])):
    # print(line_data)
    content = line_data
    text = []
    prompt = json.dumps(content,ensure_ascii=False)
    
    # print(json.dumps(content,ensure_ascii=False))
    res = core_run(text,prompt)
    try:
        data_dict = json.loads(res)
    except json.JSONDecodeError as e:
        data_dict = data_dict_empty
    submit_data.append({"infos":data_dict,"index":id+1})
# 预计执行8min
write_json("submit.json",submit_data)

然后就可以在讯飞的平台上提交结果了。
评分的性能:
26.png

;