search_baidu_baike
方法
python
def search_baidu_baike(self, query):
url = f"https://baike.baidu.com/item/{query}"
headers = {
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3'
}
response = requests.get(url, headers=headers)
soup = BeautifulSoup(response.text, 'html.parser')
summary = soup.find('div', class_='lemma-summary')
if summary:
return summary.get_text().strip()
return "没有找到相关信息"
完整的
XihuaChatbotGUI
类
python
class XihuaChatbotGUI:
def __init__(self, root):
self.root = root
self.root.title("羲和聊天机器人")
self.tokenizer = BertTokenizer.from_pretrained('F:/models/bert-base-chinese')
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model = XihuaModel(pretrained_model_name='F:/models/bert-base-chinese').to(self.device)
self.load_model()
self.model.eval()
# 加载训练数据集以便在获取答案时使用
self.data = self.load_data(os.path.join(PROJECT_ROOT, 'data/train_data.jsonl'))
# 历史记录
self.history = []
self.create_widgets()
def create_widgets(self):
# 设置样式
style = ttk.Style()
style.theme_use('clam')
# 顶部框架
top_frame = ttk.Frame(self.root)
top_frame.pack(pady=10)
self.question_label = ttk.Label(top_frame, text="问题:", font=("Arial", 12))
self.question_label.grid(row=0, column=0, padx=10)
self.question_entry = ttk.Entry(top_frame, width=50, font=("Arial", 12))
self.question_entry.grid(row=0, column=1, padx=10)
self.answer_button = ttk.Button(top_frame, text="获取回答", command=self.get_answer, style='TButton')
self.answer_button.grid(row=0, column=2, padx=10)
# 中部框架
middle_frame = ttk.Frame(self.root)
middle_frame.pack(pady=10)
self.chat_text = tk.Text(middle_frame, height=20, width=100, font=("Arial", 12), wrap='word')
self.chat_text.grid(row=0, column=0, padx=10, pady=10)
self.chat_text.tag_configure("user", justify='right', foreground='blue')
self.chat_text.tag_configure("xihua", justify='left', foreground='green')
# 底部框架
bottom_frame = ttk.Frame(self.root)
bottom_frame.pack(pady=10)
self.correct_button = ttk.Button(bottom_frame, text="准确", command=self.mark_correct, style='TButton')
self.correct_button.grid(row=0, column=0, padx=10)
self.incorrect_button = ttk.Button(bottom_frame, text="不准确", command=self.mark_incorrect, style='TButton')
self.incorrect_button.grid(row=0, column=1, padx=10)
self.train_button = ttk.Button(bottom_frame, text="训练模型", command=self.train_model, style='TButton')
self.train_button.grid(row=0, column=2, padx=10)
self.retrain_button = ttk.Button(bottom_frame, text="重新训练模型", command=lambda: self.train_model(retrain=True), style='TButton')
self.retrain_button.grid(row=0, column=3, padx=10)
self.progress_var = tk.DoubleVar()
self.progress_bar = ttk.Progressbar(bottom_frame, variable=self.progress_var, maximum=100, length=200, mode='determinate')
self.progress_bar.grid(row=1, column=0, columnspan=4, pady=10)
self.log_text = tk.Text(bottom_frame, height=10, width=70, font=("Arial", 12))
self.log_text.grid(row=2, column=0, columnspan=4, pady=10)
self.evaluate_button = ttk.Button(bottom_frame, text="评估模型", command=self.evaluate_model, style='TButton')
self.evaluate_button.grid(row=3, column=0, padx=10, pady=10)
self.history_button = ttk.Button(bottom_frame, text="查看历史记录", command=self.view_history, style='TButton')
self.history_button.grid(row=3, column=1, padx=10, pady=10)
self.save_history_button = ttk.Button(bottom_frame, text="保存历史记录", command=self.save_history, style='TButton')
self.save_history_button.grid(row=3, column=2, padx=10, pady=10)
def get_answer(self):
question = self.question_entry.get()
if not question:
messagebox.showwarning("输入错误", "请输入问题")
return
inputs = self.tokenizer(question, return_tensors='pt', padding='max_length', truncation=True, max_length=128)
with torch.no_grad():
input_ids = inputs['input_ids'].to(self.device)
attention_mask = inputs['attention_mask'].to(self.device)
logits = self.model(input_ids, attention_mask)
if logits.item() > 0:
answer_type = "羲和回答"
else:
answer_type = "零回答"
specific_answer = self.get_specific_answer(question, answer_type)
self.chat_text.insert(tk.END, f"用户: {question}\n", "user")
self.chat_text.insert(tk.END, f"羲和: {specific_answer}\n", "xihua")
# 添加到历史记录
self.history.append({
'question': question,
'answer_type': answer_type,
'specific_answer': specific_answer,
'accuracy': None # 初始状态为未评价
})
def get_specific_answer(self, question, answer_type):
# 使用模糊匹配查找最相似的问题
best_match = None
best_ratio = 0.0
for item in self.data:
ratio = SequenceMatcher(None, question, item['question']).ratio()
if ratio > best_ratio:
best_ratio = ratio
best_match = item
if best_match:
if answer_type == "羲和回答":
return best_match['human_answers'][0]
else:
return best_match['chatgpt_answers'][0]
return "这个我也不清楚,你问问零吧"
def load_data(self, file_path):
data = []
if file_path.endswith('.jsonl'):
with jsonlines.open(file_path) as reader:
for i, item in enumerate(reader):
try:
data.append(item)
except jsonlines.jsonlines.InvalidLineError as e:
logging.warning(f"跳过无效行 {i + 1}: {e}")
elif file_path.endswith('.json'):
with open(file_path, 'r') as f:
try:
data = json.load(f)
except json.JSONDecodeError as e:
logging.warning(f"跳过无效文件 {file_path}: {e}")
return data
def load_model(self):
model_path = os.path.join(PROJECT_ROOT, 'models/xihua_model.pth')
if os.path.exists(model_path):
self.model.load_state_dict(torch.load(model_path, map_location=self.device))
logging.info("加载现有模型")
else:
logging.info("没有找到现有模型,将使用预训练模型")
def train_model(self, retrain=False):
file_path = filedialog.askopenfilename(filetypes=[("JSONL files", "*.jsonl"), ("JSON files", "*.json")])
if not file_path:
messagebox.showwarning("文件选择错误", "请选择一个有效的数据文件")
return
try:
dataset = XihuaDataset(file_path, self.tokenizer)
data_loader = DataLoader(dataset, batch_size=8, shuffle=True)
# 加载已训练的模型权重
if retrain:
self.model.load_state_dict(torch.load(os.path.join(PROJECT_ROOT, 'models/xihua_model.pth'), map_location=self.device))
self.model.to(self.device)
self.model.train()
optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-5)
criterion = torch.nn.BCEWithLogitsLoss()
num_epochs = 30
for epoch in range(num_epochs):
train_loss = train(self.model, data_loader, optimizer, criterion, self.device, self.progress_var)
logging.info(f'Epoch [{epoch+1}/{num_epochs}], Loss: {train_loss:.4f}')
self.log_text.insert(tk.END, f'Epoch [{epoch+1}/{num_epochs}], Loss: {train_loss:.4f}\n')
self.log_text.see(tk.END)
torch.save(self.model.state_dict(), os.path.join(PROJECT_ROOT, 'models/xihua_model.pth'))
logging.info("模型训练完成并保存")
self.log_text.insert(tk.END, "模型训练完成并保存\n")
self.log_text.see(tk.END)
messagebox.showinfo("训练完成", "模型训练完成并保存")
except Exception as e:
logging.error(f"模型训练失败: {e}")
self.log_text.insert(tk.END, f"模型训练失败: {e}\n")
self.log_text.see(tk.END)
messagebox.showerror("训练失败", f"模型训练失败: {e}")
def evaluate_model(self):
# 这里可以添加模型评估的逻辑
messagebox.showinfo("评估结果", "模型评估功能暂未实现")
def mark_correct(self):
if self.history:
self.history[-1]['accuracy'] = True
messagebox.showinfo("评价成功", "您认为这次回答是准确的")
def mark_incorrect(self):
if self.history:
self.history[-1]['accuracy'] = False
question = self.history[-1]['question']
baike_answer = self.search_baidu_baike(question)
self.chat_text.insert(tk.END, f"百度百科结果: {baike_answer}\n", "xihua")
messagebox.showinfo("评价成功", "您认为这次回答是不准确的")
def search_baidu_baike(self, query):
url = f"https://baike.baidu.com/item/{query}"
headers = {
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3'
}
response = requests.get(url, headers=headers)
soup = BeautifulSoup(response.text, 'html.parser')
summary = soup.find('div', class_='lemma-summary')
if summary:
return summary.get_text().strip()
return "没有找到相关信息"
def view_history(self):
history_window = tk.Toplevel(self.root)
history_window.title("历史记录")
history_text = tk.Text(history_window, height=20, width=80, font=("Arial", 12))
history_text.pack(padx=10, pady=10)
for entry in self.history:
history_text.insert(tk.END, f"问题: {entry['question']}\n")
history_text.insert(tk.END, f"回答类型: {entry['answer_type']}\n")
history_text.insert(tk.END, f"具体回答: {entry['specific_answer']}\n")
if entry['accuracy'] is None:
history_text.insert(tk.END, "评价: 未评价\n")
elif entry['accuracy']:
history_text.insert(tk.END, "评价: 准确\n")
else:
history_text.insert(tk.END, "评价: 不准确\n")
history_text.insert(tk.END, "-" * 50 + "\n")
def save_history(self):
file_path = filedialog.asksaveasfilename(defaultextension=".json", filetypes=[("JSON files", "*.json")])
if not file_path:
return
with open(file_path, 'w') as f:
json.dump(self.history, f, ensure_ascii=False, indent=4)
messagebox.showinfo("保存成功", "历史记录已保存到文件")
解释
mark_incorrect 方法:当用户点击“不准确”按钮时,会调用 search_baidu_baike 方法来从百度百科获取更详细的信息,并将其显示在回答组件中。
search_baidu_baike 方法:这个方法通过发送 HTTP 请求到百度百科,解析返回的 HTML 内容,提取出摘要信息并返回。
这样,当用户认为回答不准确时,程序会自动从百度百科获取更详细的信息,并显示在聊天窗口中。