Bootstrap

Langchain +Streamlit 搭建支持上传文件夹/历史数据库/网页爬取的聊天机器人

写在前面

首先说Streamlit已弃用,因为聊天对话框无法实现自动scroll to the bottom (真的很蠢),本身不支持,自定义也没有找到合适的方法。
Demo示例:
在这里插入图片描述

代码

这里一些FileLoader和Spider在前几篇都写过了,不再重复。

1. 导包
import tkinter as tk
from tkinter import filedialog

import streamlit as st
from streamlit_chat import message

from llm import MyLlm, FileLoader
from spider import Spider

2. 初始化
if 'clicked' not in st.session_state:
    st.session_state.clicked = False

if 'chain' not in st.session_state:
    st.session_state.chain = None

st.title("Your own AI chatbot⭐")
3. Side bar 以及点击事件
with st.sidebar:
    root = tk.Tk()
    root.withdraw()
    root.wm_attributes('-topmost', 1)
    st.write('Please select a folder:')
    browse_clicked = st.button('Browse Folder')
    st.write('Load local database:')
    local_clicked = st.button('Load local')
    website_input = st.text_input(label="Scrap from web:", placeholder="Copy your website link here", key='web_input')
    web_clicked = st.button('Scrap web')


def load_chain_from_directory(dirname):
    loader = FileLoader(dirname)
    data = loader.load_all()
    llm = MyLlm()
    st.session_state.chain = llm.get_chain(data)
    st.session_state.clicked = True


# button click actions
if browse_clicked:
    dirname = str(filedialog.askdirectory(master=root))
    load_chain_from_directory(dirname)

if local_clicked:
    llm = MyLlm()
    st.session_state.chain = llm.get_chain_from_local()
    st.session_state.clicked = True

if web_clicked:
    spider = Spider(website_input)
    dirname = spider.scrap_web()
    load_chain_from_directory(dirname)
4. 数据导入之后,开启聊天
if st.session_state.clicked:

    def conversational_chat(query):
        print("conversational_chat")
        result = st.session_state.chain.invoke({"question": query, "chat_history": st.session_state['history']})
        st.session_state['history'].append((query, result["answer"]))
        return result["answer"]


    if 'history' not in st.session_state:
        st.session_state['history'] = []

    if 'messages' not in st.session_state:
        st.session_state['messages'] = []

    if 'generated' not in st.session_state:
        st.session_state['generated'] = ["Hello ! Ask me anything about 🤗"]

    if 'past' not in st.session_state:
        st.session_state['past'] = ["Hey!"]

    chat_box_id = 'my_chat_box'

    # container for the chat history
    response_container = st.container(height=400)
    # container for the user's text input
    container = st.container()

    with container:

        with st.form(key='my_form', clear_on_submit=True):
            user_input = st.text_input("Query:", placeholder="Talk to your csv data here (:", key='my_input')
            submit_button = st.form_submit_button(label='Send')

        if submit_button and user_input:
            print("here")

            output = conversational_chat(user_input)
            st.session_state['past'].append(user_input)
            st.session_state['generated'].append(output)

    if st.session_state['generated']:
        with response_container:
            for i in range(len(st.session_state['generated'])):
                message(st.session_state["past"][i], is_user=True, key=str(i) + '_user', avatar_style="big-smile")
                message(st.session_state["generated"][i], key=str(i), avatar_style="thumbs")

5. MyLlm 类定义
class MyLlm:

    # Loading the model
    def __init__(self):
        self.llm = ChatOpenAI(openai_api_key=OPENAI_API_KEY, model_name="gpt-3.5-turbo")
        model_name = "sentence-transformers/all-MiniLM-L6-v2"
        model_kwargs = {'device': 'cpu'}
        self.embedding = HuggingFaceEmbeddings(model_name=model_name, model_kwargs=model_kwargs)

    def get_chain(self, documents):
        print("get_chain")
        text_splitter = RecursiveCharacterTextSplitter()
        documents = text_splitter.split_documents(documents)

        # embedding = OpenAIEmbeddings(openai_api_key= OPENAI_API_KEY)
        vector = FAISS.from_documents(documents, self.embedding)
        vector.save_local(DB_FAISS_PATH)
        retriever = vector.as_retriever()

        chain = ConversationalRetrievalChain.from_llm(llm=self.llm, retriever=retriever)
        print(chain)

        return chain

    def get_chain_from_local(self):
        vector = FAISS.load_local(DB_FAISS_PATH, self.embedding, allow_dangerous_deserialization=True)
        print(vector)
        retriever = vector.as_retriever()

        chain = ConversationalRetrievalChain.from_llm(llm=self.llm, retriever=retriever)

        return chain

;