Bootstrap

星火v3+langchain 看这里

import _thread as thread
import base64
import datetime
import hashlib
import hmac
import json
import ssl
import websocket  # 使用websocket_client
import langchain
import logging
from config.settings import SPARK
from urllib.parse import urlparse
from datetime import datetime
from time import mktime
from urllib.parse import urlencode
from wsgiref.handlers import format_date_time
from typing import Optional, List, Dict, Mapping, Any
from langchain.llms.base import LLM
from langchain.cache import InMemoryCache
 
logging.basicConfig(level=logging.INFO)
# 启动llm的缓存
langchain.llm_cache = InMemoryCache()
result_list = []
 
 
def _construct_query(prompt, temperature, max_tokens):
    data = {
        "header": {
            "app_id": SPARK.get("appid"), #appid
            "uid": '12345'
        },
        "parameter": {
            "chat": {
                "domain": SPARK.get("domain_v3"), #generalv3
                "random_threshold": temperature,
                "max_tokens": max_tokens
            }
        },
        "payload": {
            "message": {
                "text": [
                    {"role": "user", "content": prompt}
                ]
            }
        }
    }
    return data
 
 
def _run(ws, *args):
    data = json.dumps(
        _construct_query(prompt=ws.question, temperature=ws.temperature, max_tokens=ws.max_tokens))
    # print (data)
    ws.send(data)
 
 
def on_error(ws, error, *args, **kwargs):
    print("error:", error)
 
 
def on_close(ws,*args, **kwargs):
    print("closed...")
 
 
def on_open(ws,*args, **kwargs):
    thread.start_new_thread(_run, (ws,))
 
 
def on_message(ws, message):
    data = json.loads(message)
    code = data['header']['code']
    # print(data)
    if code != 0:
        print(f'请求错误: {code}, {data}')
        ws.close()
    else:
        choices = data["payload"]["choices"]
        status = choices["status"]
        content = choices["text"][0]["content"]
        result_list.append(content)
        if status == 2:
            ws.close()
            setattr(ws, "content", "".join(result_list))
            print(result_list)
            result_list.clear()
 
 
class Spark(LLM):
    '''
    根据源码解析在通过LLMS包装的时候主要重构两个部分的代码
    _call 模型调用主要逻辑,输入问题,输出模型相应结果
    _identifying_params 返回模型描述信息,通常返回一个字典,字典中包括模型的主要参数
    '''
 
    gpt_url = SPARK.get('spark_url_v3')  # ws://spark-api.xf-yun.com/v3.1/chat
    host = urlparse(gpt_url).netloc  # host目标机器解析
    path = urlparse(gpt_url).path  # 路径目标解析
    max_tokens = 1024
    temperature = 0.5
 
    # ws = websocket.WebSocketApp(url='')
 
    @property
    def _llm_type(self) -> str:
        # 模型简介
        return "Spark"
 
    def _get_url(self):
        # 获取请求路径
        now = datetime.now()
        date = format_date_time(mktime(now.timetuple()))
 
        signature_origin = "host: " + self.host + "\n"
        signature_origin += "date: " + date + "\n"
        signature_origin += "GET " + self.path + " HTTP/1.1"
 
        signature_sha = hmac.new(SPARK.get('api_secret').encode('utf-8'), signature_origin.encode('utf-8'),
                                 digestmod=hashlib.sha256).digest()
 
        signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8')
 
        authorization_origin = f'api_key="{SPARK.get("""api_key""")}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'
    
        authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
 
        v = {
            "authorization": authorization,
            "date": date,
            "host": self.host
        }
        print('v',v)
        url = self.gpt_url + '?' + urlencode(v)
        return url
 
    def _post(self, prompt):
        #模型请求响应
        wsUrl = self._get_url()
        ws = websocket.WebSocketApp(wsUrl, on_message=on_message, on_error=on_error,
                                    on_close=on_close, on_open=on_open)
        ws.question = prompt
        setattr(ws, "temperature", self.temperature)
        setattr(ws, "max_tokens", self.max_tokens)
        ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
        return ws.content if hasattr(ws, "content") else ""
 
    def _call(self, prompt: str,
              stop: Optional[List[str]] = None) -> str:
        # 启动关键的函数
        content = self._post(prompt)
        # content = "这是一个测试"
        return content
 
    @property
    def _identifying_params(self) -> Mapping[str, Any]:
        """
        Get the identifying parameters.
        """
        _param_dict = {
            "url": self.gpt_url
        }
        return _param_dict
 
 
if __name__ == "__main__":
    llm = Spark(temperature=0.9)
    # data =json.dumps(llm._construct_query(prompt="你好啊", temperature=llm.temperature, max_tokens=llm.max_tokens))
    # print (data)
    # print (type(data))
    result = llm("你好啊", stop=["you"])
    print(result)
;