写在前面
首先说Streamlit已弃用,因为聊天对话框无法实现自动scroll to the bottom (真的很蠢),本身不支持,自定义也没有找到合适的方法。
Demo示例:
代码
这里一些FileLoader和Spider在前几篇都写过了,不再重复。
1. 导包
import tkinter as tk
from tkinter import filedialog
import streamlit as st
from streamlit_chat import message
from llm import MyLlm, FileLoader
from spider import Spider
2. 初始化
if 'clicked' not in st.session_state:
st.session_state.clicked = False
if 'chain' not in st.session_state:
st.session_state.chain = None
st.title("Your own AI chatbot⭐")
3. Side bar 以及点击事件
with st.sidebar:
root = tk.Tk()
root.withdraw()
root.wm_attributes('-topmost', 1)
st.write('Please select a folder:')
browse_clicked = st.button('Browse Folder')
st.write('Load local database:')
local_clicked = st.button('Load local')
website_input = st.text_input(label="Scrap from web:", placeholder="Copy your website link here", key='web_input')
web_clicked = st.button('Scrap web')
def load_chain_from_directory(dirname):
loader = FileLoader(dirname)
data = loader.load_all()
llm = MyLlm()
st.session_state.chain = llm.get_chain(data)
st.session_state.clicked = True
# button click actions
if browse_clicked:
dirname = str(filedialog.askdirectory(master=root))
load_chain_from_directory(dirname)
if local_clicked:
llm = MyLlm()
st.session_state.chain = llm.get_chain_from_local()
st.session_state.clicked = True
if web_clicked:
spider = Spider(website_input)
dirname = spider.scrap_web()
load_chain_from_directory(dirname)
4. 数据导入之后,开启聊天
if st.session_state.clicked:
def conversational_chat(query):
print("conversational_chat")
result = st.session_state.chain.invoke({"question": query, "chat_history": st.session_state['history']})
st.session_state['history'].append((query, result["answer"]))
return result["answer"]
if 'history' not in st.session_state:
st.session_state['history'] = []
if 'messages' not in st.session_state:
st.session_state['messages'] = []
if 'generated' not in st.session_state:
st.session_state['generated'] = ["Hello ! Ask me anything about 🤗"]
if 'past' not in st.session_state:
st.session_state['past'] = ["Hey!"]
chat_box_id = 'my_chat_box'
# container for the chat history
response_container = st.container(height=400)
# container for the user's text input
container = st.container()
with container:
with st.form(key='my_form', clear_on_submit=True):
user_input = st.text_input("Query:", placeholder="Talk to your csv data here (:", key='my_input')
submit_button = st.form_submit_button(label='Send')
if submit_button and user_input:
print("here")
output = conversational_chat(user_input)
st.session_state['past'].append(user_input)
st.session_state['generated'].append(output)
if st.session_state['generated']:
with response_container:
for i in range(len(st.session_state['generated'])):
message(st.session_state["past"][i], is_user=True, key=str(i) + '_user', avatar_style="big-smile")
message(st.session_state["generated"][i], key=str(i), avatar_style="thumbs")
5. MyLlm 类定义
class MyLlm:
# Loading the model
def __init__(self):
self.llm = ChatOpenAI(openai_api_key=OPENAI_API_KEY, model_name="gpt-3.5-turbo")
model_name = "sentence-transformers/all-MiniLM-L6-v2"
model_kwargs = {'device': 'cpu'}
self.embedding = HuggingFaceEmbeddings(model_name=model_name, model_kwargs=model_kwargs)
def get_chain(self, documents):
print("get_chain")
text_splitter = RecursiveCharacterTextSplitter()
documents = text_splitter.split_documents(documents)
# embedding = OpenAIEmbeddings(openai_api_key= OPENAI_API_KEY)
vector = FAISS.from_documents(documents, self.embedding)
vector.save_local(DB_FAISS_PATH)
retriever = vector.as_retriever()
chain = ConversationalRetrievalChain.from_llm(llm=self.llm, retriever=retriever)
print(chain)
return chain
def get_chain_from_local(self):
vector = FAISS.load_local(DB_FAISS_PATH, self.embedding, allow_dangerous_deserialization=True)
print(vector)
retriever = vector.as_retriever()
chain = ConversationalRetrievalChain.from_llm(llm=self.llm, retriever=retriever)
return chain