Bootstrap

search_baidu_baike 方法

search_baidu_baike
方法
python

def search_baidu_baike(self, query):
    url = f"https://baike.baidu.com/item/{query}"
    headers = {
        'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3'
    }
    response = requests.get(url, headers=headers)
    soup = BeautifulSoup(response.text, 'html.parser')
    summary = soup.find('div', class_='lemma-summary')
    if summary:
        return summary.get_text().strip()
    return "没有找到相关信息"
完整的
XihuaChatbotGUI
类
python
class XihuaChatbotGUI:
    def __init__(self, root):
        self.root = root
        self.root.title("羲和聊天机器人")

        self.tokenizer = BertTokenizer.from_pretrained('F:/models/bert-base-chinese')
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model = XihuaModel(pretrained_model_name='F:/models/bert-base-chinese').to(self.device)
        self.load_model()
        self.model.eval()

        # 加载训练数据集以便在获取答案时使用
        self.data = self.load_data(os.path.join(PROJECT_ROOT, 'data/train_data.jsonl'))

        # 历史记录
        self.history = []

        self.create_widgets()

    def create_widgets(self):
        # 设置样式
        style = ttk.Style()
        style.theme_use('clam')

        # 顶部框架
        top_frame = ttk.Frame(self.root)
        top_frame.pack(pady=10)

        self.question_label = ttk.Label(top_frame, text="问题:", font=("Arial", 12))
        self.question_label.grid(row=0, column=0, padx=10)

        self.question_entry = ttk.Entry(top_frame, width=50, font=("Arial", 12))
        self.question_entry.grid(row=0, column=1, padx=10)

        self.answer_button = ttk.Button(top_frame, text="获取回答", command=self.get_answer, style='TButton')
        self.answer_button.grid(row=0, column=2, padx=10)

        # 中部框架
        middle_frame = ttk.Frame(self.root)
        middle_frame.pack(pady=10)

        self.chat_text = tk.Text(middle_frame, height=20, width=100, font=("Arial", 12), wrap='word')
        self.chat_text.grid(row=0, column=0, padx=10, pady=10)
        self.chat_text.tag_configure("user", justify='right', foreground='blue')
        self.chat_text.tag_configure("xihua", justify='left', foreground='green')

        # 底部框架
        bottom_frame = ttk.Frame(self.root)
        bottom_frame.pack(pady=10)

        self.correct_button = ttk.Button(bottom_frame, text="准确", command=self.mark_correct, style='TButton')
        self.correct_button.grid(row=0, column=0, padx=10)

        self.incorrect_button = ttk.Button(bottom_frame, text="不准确", command=self.mark_incorrect, style='TButton')
        self.incorrect_button.grid(row=0, column=1, padx=10)

        self.train_button = ttk.Button(bottom_frame, text="训练模型", command=self.train_model, style='TButton')
        self.train_button.grid(row=0, column=2, padx=10)

        self.retrain_button = ttk.Button(bottom_frame, text="重新训练模型", command=lambda: self.train_model(retrain=True), style='TButton')
        self.retrain_button.grid(row=0, column=3, padx=10)

        self.progress_var = tk.DoubleVar()
        self.progress_bar = ttk.Progressbar(bottom_frame, variable=self.progress_var, maximum=100, length=200, mode='determinate')
        self.progress_bar.grid(row=1, column=0, columnspan=4, pady=10)

        self.log_text = tk.Text(bottom_frame, height=10, width=70, font=("Arial", 12))
        self.log_text.grid(row=2, column=0, columnspan=4, pady=10)

        self.evaluate_button = ttk.Button(bottom_frame, text="评估模型", command=self.evaluate_model, style='TButton')
        self.evaluate_button.grid(row=3, column=0, padx=10, pady=10)

        self.history_button = ttk.Button(bottom_frame, text="查看历史记录", command=self.view_history, style='TButton')
        self.history_button.grid(row=3, column=1, padx=10, pady=10)

        self.save_history_button = ttk.Button(bottom_frame, text="保存历史记录", command=self.save_history, style='TButton')
        self.save_history_button.grid(row=3, column=2, padx=10, pady=10)

    def get_answer(self):
        question = self.question_entry.get()
        if not question:
            messagebox.showwarning("输入错误", "请输入问题")
            return

        inputs = self.tokenizer(question, return_tensors='pt', padding='max_length', truncation=True, max_length=128)
        with torch.no_grad():
            input_ids = inputs['input_ids'].to(self.device)
            attention_mask = inputs['attention_mask'].to(self.device)
            logits = self.model(input_ids, attention_mask)
        
        if logits.item() > 0:
            answer_type = "羲和回答"
        else:
            answer_type = "零回答"

        specific_answer = self.get_specific_answer(question, answer_type)

        self.chat_text.insert(tk.END, f"用户: {question}\n", "user")
        self.chat_text.insert(tk.END, f"羲和: {specific_answer}\n", "xihua")

        # 添加到历史记录
        self.history.append({
            'question': question,
            'answer_type': answer_type,
            'specific_answer': specific_answer,
            'accuracy': None  # 初始状态为未评价
        })

    def get_specific_answer(self, question, answer_type):
        # 使用模糊匹配查找最相似的问题
        best_match = None
        best_ratio = 0.0
        for item in self.data:
            ratio = SequenceMatcher(None, question, item['question']).ratio()
            if ratio > best_ratio:
                best_ratio = ratio
                best_match = item

        if best_match:
            if answer_type == "羲和回答":
                return best_match['human_answers'][0]
            else:
                return best_match['chatgpt_answers'][0]
        return "这个我也不清楚,你问问零吧"

    def load_data(self, file_path):
        data = []
        if file_path.endswith('.jsonl'):
            with jsonlines.open(file_path) as reader:
                for i, item in enumerate(reader):
                    try:
                        data.append(item)
                    except jsonlines.jsonlines.InvalidLineError as e:
                        logging.warning(f"跳过无效行 {i + 1}: {e}")
        elif file_path.endswith('.json'):
            with open(file_path, 'r') as f:
                try:
                    data = json.load(f)
                except json.JSONDecodeError as e:
                    logging.warning(f"跳过无效文件 {file_path}: {e}")
        return data

    def load_model(self):
        model_path = os.path.join(PROJECT_ROOT, 'models/xihua_model.pth')
        if os.path.exists(model_path):
            self.model.load_state_dict(torch.load(model_path, map_location=self.device))
            logging.info("加载现有模型")
        else:
            logging.info("没有找到现有模型,将使用预训练模型")

    def train_model(self, retrain=False):
        file_path = filedialog.askopenfilename(filetypes=[("JSONL files", "*.jsonl"), ("JSON files", "*.json")])
        if not file_path:
            messagebox.showwarning("文件选择错误", "请选择一个有效的数据文件")
            return

        try:
            dataset = XihuaDataset(file_path, self.tokenizer)
            data_loader = DataLoader(dataset, batch_size=8, shuffle=True)
            
            # 加载已训练的模型权重
            if retrain:
                self.model.load_state_dict(torch.load(os.path.join(PROJECT_ROOT, 'models/xihua_model.pth'), map_location=self.device))
                self.model.to(self.device)
                self.model.train()

            optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-5)
            criterion = torch.nn.BCEWithLogitsLoss()
            num_epochs = 30
            for epoch in range(num_epochs):
                train_loss = train(self.model, data_loader, optimizer, criterion, self.device, self.progress_var)
                logging.info(f'Epoch [{epoch+1}/{num_epochs}], Loss: {train_loss:.4f}')
                self.log_text.insert(tk.END, f'Epoch [{epoch+1}/{num_epochs}], Loss: {train_loss:.4f}\n')
                self.log_text.see(tk.END)
            torch.save(self.model.state_dict(), os.path.join(PROJECT_ROOT, 'models/xihua_model.pth'))
            logging.info("模型训练完成并保存")
            self.log_text.insert(tk.END, "模型训练完成并保存\n")
            self.log_text.see(tk.END)
            messagebox.showinfo("训练完成", "模型训练完成并保存")
        except Exception as e:
            logging.error(f"模型训练失败: {e}")
            self.log_text.insert(tk.END, f"模型训练失败: {e}\n")
            self.log_text.see(tk.END)
            messagebox.showerror("训练失败", f"模型训练失败: {e}")

    def evaluate_model(self):
        # 这里可以添加模型评估的逻辑
        messagebox.showinfo("评估结果", "模型评估功能暂未实现")

    def mark_correct(self):
        if self.history:
            self.history[-1]['accuracy'] = True
            messagebox.showinfo("评价成功", "您认为这次回答是准确的")

    def mark_incorrect(self):
        if self.history:
            self.history[-1]['accuracy'] = False
            question = self.history[-1]['question']
            baike_answer = self.search_baidu_baike(question)
            self.chat_text.insert(tk.END, f"百度百科结果: {baike_answer}\n", "xihua")
            messagebox.showinfo("评价成功", "您认为这次回答是不准确的")

    def search_baidu_baike(self, query):
        url = f"https://baike.baidu.com/item/{query}"
        headers = {
            'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3'
        }
        response = requests.get(url, headers=headers)
        soup = BeautifulSoup(response.text, 'html.parser')
        summary = soup.find('div', class_='lemma-summary')
        if summary:
            return summary.get_text().strip()
        return "没有找到相关信息"

    def view_history(self):
        history_window = tk.Toplevel(self.root)
        history_window.title("历史记录")

        history_text = tk.Text(history_window, height=20, width=80, font=("Arial", 12))
        history_text.pack(padx=10, pady=10)

        for entry in self.history:
            history_text.insert(tk.END, f"问题: {entry['question']}\n")
            history_text.insert(tk.END, f"回答类型: {entry['answer_type']}\n")
            history_text.insert(tk.END, f"具体回答: {entry['specific_answer']}\n")
            if entry['accuracy'] is None:
                history_text.insert(tk.END, "评价: 未评价\n")
            elif entry['accuracy']:
                history_text.insert(tk.END, "评价: 准确\n")
            else:
                history_text.insert(tk.END, "评价: 不准确\n")
            history_text.insert(tk.END, "-" * 50 + "\n")

    def save_history(self):
        file_path = filedialog.asksaveasfilename(defaultextension=".json", filetypes=[("JSON files", "*.json")])
        if not file_path:
            return

        with open(file_path, 'w') as f:
            json.dump(self.history, f, ensure_ascii=False, indent=4)

        messagebox.showinfo("保存成功", "历史记录已保存到文件")

解释
mark_incorrect 方法:当用户点击“不准确”按钮时,会调用 search_baidu_baike 方法来从百度百科获取更详细的信息,并将其显示在回答组件中。
search_baidu_baike 方法:这个方法通过发送 HTTP 请求到百度百科,解析返回的 HTML 内容,提取出摘要信息并返回。
这样,当用户认为回答不准确时,程序会自动从百度百科获取更详细的信息,并显示在聊天窗口中。

;