使代码在进行模型训练时界面不卡顿,我们使用多线程或多进程来处理模型训练任务。这样可以避免主线程被阻塞,保持界面的响应性。我们将使用Python的threading模块来实现这一点。
以下是优化后的代码:
import os
import json
import jsonlines
import torch
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import BertModel, BertTokenizer
import tkinter as tk
from tkinter import filedialog, messagebox, ttk, simpledialog
import logging
from difflib import SequenceMatcher
from datetime import datetime
import threading
# 获取项目根目录
PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))
# 配置日志
LOGS_DIR = os.path.join(PROJECT_ROOT, 'logs')
os.makedirs(LOGS_DIR, exist_ok=True)
def setup_logging():
log_file = os.path.join(LOGS_DIR, datetime.now().strftime('%Y-%m-%d_%H-%M-%S_羲和.txt'))
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler(log_file),
logging.StreamHandler()
]
)
setup_logging()
# 数据集类
class XihuaDataset(Dataset):
def __init__(self, file_path, tokenizer, max_length=128):
self.tokenizer = tokenizer
self.max_length = max_length
self.data = self.load_data(file_path)
def load_data(self, file_path):
data = []
if file_path.endswith('.jsonl'):
with jsonlines.open(file_path) as reader:
for i, item in enumerate(reader):
try:
data.append(item)
except jsonlines.jsonlines.InvalidLineError as e:
logging.warning(f"跳过无效行 {i + 1}: {e}")
elif file_path.endswith('.json'):
with open(file_path, 'r') as f:
try:
data = json.load(f)
except json.JSONDecodeError as e:
logging.warning(f"跳过无效文件 {file_path}: {e}")
return data
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
item = self.data[idx]
question = item['question']
human_answer = item['human_answers'][0]
chatgpt_answer = item['chatgpt_answers'][0]
try:
inputs = self.tokenizer(question, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)
human_inputs = self.tokenizer(human_answer, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)
chatgpt_inputs = self.tokenizer(chatgpt_answer, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)
except Exception as e:
logging.warning(f"跳过无效项 {idx}: {e}")
return self.__getitem__((idx + 1) % len(self.data))
return {
'input_ids': inputs['input_ids'].squeeze(),
'attention_mask': inputs['attention_mask'].squeeze(),
'human_input_ids': human_inputs['input_ids'].squeeze(),
'human_attention_mask': human_inputs['attention_mask'].squeeze(),
'chatgpt_input_ids': chatgpt_inputs['input_ids'].squeeze(),
'chatgpt_attention_mask': chatgpt_inputs['attention_mask'].squeeze(),
'human_answer': human_answer,
'chatgpt_answer': chatgpt_answer
}
# 获取数据加载器
def get_data_loader(file_path, tokenizer, batch_size=8, max_length=128):
dataset = XihuaDataset(file_path, tokenizer, max_length)
return DataLoader(dataset, batch_size=batch_size, shuffle=True)
# 模型定义
class XihuaModel(torch.nn.Module):
def __init__(self, pretrained_model_name='F:/models/bert-base-chinese'):
super(XihuaModel, self).__init__()
self.bert = BertModel.from_pretrained(pretrained_model_name)
self.classifier = torch.nn.Linear(self.bert.config.hidden_size, 1)
def forward(self, input_ids, attention_mask):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
pooled_output = outputs.pooler_output
logits = self.classifier(pooled_output)
return logits
# 训练函数
def train(model, data_loader, optimizer, criterion, device, progress_var=None):
model.train()
total_loss = 0.0
num_batches = len(data_loader)
for batch_idx, batch in enumerate(data_loader):
try:
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
human_input_ids = batch['human_input_ids'].to(device)
human_attention_mask = batch['human_attention_mask'].to(device)
chatgpt_input_ids = batch['chatgpt_input_ids'].to(device)
chatgpt_attention_mask = batch['chatgpt_attention_mask'].to(device)
optimizer.zero_grad()
human_logits = model(human_input_ids, human_attention_mask)
chatgpt_logits = model(chatgpt_input_ids, chatgpt_attention_mask)
human_labels = torch.ones(human_logits.size(0), 1).to(device)
chatgpt_labels = torch.zeros(chatgpt_logits.size(0), 1).to(device)
loss = criterion(human_logits, human_labels) + criterion(chatgpt_logits, chatgpt_labels)
loss.backward()
optimizer.step()
total_loss += loss.item()
if progress_var:
progress_var.set((batch_idx + 1) / num_batches * 100)
except Exception as e:
logging.warning(f"跳过无效批次: {e}")
return total_loss / len(data_loader)
# 主训练函数
def main_train(retrain=False, progress_var=None, log_text=None):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logging.info(f'Using device: {device}')
tokenizer = BertTokenizer.from_pretrained('F:/models/bert-base-chinese')
model = XihuaModel(pretrained_model_name='F:/models/bert-base-chinese').to(device)
if retrain:
model_path = os.path.join(PROJECT_ROOT, 'models/xihua_model.pth')
if os.path.exists(model_path):
model.load_state_dict(torch.load(model_path, map_location=device))
logging.info("加载现有模型")
else:
logging.info("没有找到现有模型,将使用预训练模型")
optimizer = optim.Adam(model.parameters(), lr=1e-5)
criterion = torch.nn.BCEWithLogitsLoss()
train_data_loader = get_data_loader(os.path.join(PROJECT_ROOT, 'data/train_data.jsonl'), tokenizer, batch_size=8, max_length=128)
num_epochs = 30
for epoch in range(num_epochs):
train_loss = train(model, train_data_loader, optimizer, criterion, device, progress_var)
logging.info(f'Epoch [{epoch+1}/{num_epochs}], Loss: {train_loss:.8f}')
if log_text:
log_text.insert(tk.END, f'Epoch [{epoch+1}/{num_epochs}], Loss: {train_loss:.8f}\n')
log_text.see(tk.END)
torch.save(model.state_dict(), os.path.join(PROJECT_ROOT, 'models/xihua_model.pth'))
logging.info("模型训练完成并保存")
if log_text:
log_text.insert(tk.END, "模型训练完成并保存\n")
log_text.see(tk.END)
messagebox.showinfo("训练完成", "模型训练完成并保存")
# GUI界面
class XihuaChatbotGUI:
def __init__(self, root):
self.root = root
self.root.title("羲和聊天机器人")
self.tokenizer = BertTokenizer.from_pretrained('F:/models/bert-base-chinese')
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model = XihuaModel(pretrained_model_name='F:/models/bert-base-chinese').to(self.device)
self.load_model()
self.model.eval()
# 加载训练数据集以便在获取答案时使用
self.data = self.load_data(os.path.join(PROJECT_ROOT, 'data/train_data.jsonl'))
# 历史记录
self.history = []
self.create_widgets()
def create_widgets(self):
# 顶部框架
top_frame = tk.Frame(self.root)
top_frame.pack(pady=10)
self.question_label = tk.Label(top_frame, text="问题:", font=("Arial", 12))
self.question_label.grid(row=0, column=0, padx=10)
self.question_entry = tk.Entry(top_frame, width=50, font=("Arial", 12))
self.question_entry.grid(row=0, column=1, padx=10)
self.answer_button = tk.Button(top_frame, text="获取回答", command=self.get_answer, font=("Arial", 12))
self.answer_button.grid(row=0, column=2, padx=10)
# 中部框架
middle_frame = tk.Frame(self.root)
middle_frame.pack(pady=10)
self.answer_label = tk.Label(middle_frame, text="回答:", font=("Arial", 12))
self.answer_label.grid(row=0, column=0, padx=10)
self.answer_text = tk.Text(middle_frame, height=10, width=70, font=("Arial", 12))
self.answer_text.grid(row=1, column=0, padx=10)
# 底部框架
bottom_frame = tk.Frame(self.root)
bottom_frame.pack(pady=10)
self.correct_button = tk.Button(bottom_frame, text="准确", command=self.mark_correct, font=("Arial", 12))
self.correct_button.grid(row=0, column=0, padx=10)
self.incorrect_button = tk.Button(bottom_frame, text="不准确", command=self.mark_incorrect, font=("Arial", 12))
self.incorrect_button.grid(row=0, column=1, padx=10)
self.train_button = tk.Button(bottom_frame, text="训练模型", command=self.train_model, font=("Arial", 12))
self.train_button.grid(row=0, column=2, padx=10)
self.retrain_button = tk.Button(bottom_frame, text="重新训练模型", command=lambda: self.train_model(retrain=True), font=("Arial", 12))
self.retrain_button.grid(row=0, column=3, padx=10)
self.progress_var = tk.DoubleVar()
self.progress_bar = ttk.Progressbar(bottom_frame, variable=self.progress_var, maximum=100, length=200)
self.progress_bar.grid(row=1, column=0, columnspan=4, pady=10)
self.log_text = tk.Text(bottom_frame, height=10, width=70, font=("Arial", 12))
self.log_text.grid(row=2, column=0, columnspan=4, pady=10)
self.evaluate_button = tk.Button(bottom_frame, text="评估模型", command=self.evaluate_model, font=("Arial", 12))
self.evaluate_button.grid(row=3, column=0, padx=10, pady=10)
self.history_button = tk.Button(bottom_frame, text="查看历史记录", command=self.view_history, font=("Arial", 12))
self.history_button.grid(row=3, column=1, padx=10, pady=10)
self.save_history_button = tk.Button(bottom_frame, text="保存历史记录", command=self.save_history, font=("Arial", 12))
self.save_history_button.grid(row=3, column=2, padx=10, pady=10)
self.search_button = tk.Button(bottom_frame, text="搜索历史记录", command=self.search_history, font=("Arial", 12))
self.search_button.grid(row=3, column=3, padx=10, pady=10)
self.export_log_button = tk.Button(bottom_frame, text="导出日志", command=self.export_log, font=("Arial", 12))
self.export_log_button.grid(row=4, column=0, padx=10, pady=10)
self.clear_log_button = tk.Button(bottom_frame, text="清空日志", command=self.clear_log, font=("Arial", 12))
self.clear_log_button.grid(row=4, column=1, padx=10, pady=10)
self.filter_history_button = tk.Button(bottom_frame, text="筛选历史记录", command=self.filter_history, font=("Arial", 12))
self.filter_history_button.grid(row=4, column=2, padx=10, pady=10)
def get_answer(self):
question = self.question_entry.get()
if not question:
messagebox.showwarning("输入错误", "请输入问题")
return
inputs = self.tokenizer(question, return_tensors='pt', padding='max_length', truncation=True, max_length=128)
with torch.no_grad():
input_ids = inputs['input_ids'].to(self.device)
attention_mask = inputs['attention_mask'].to(self.device)
logits = self.model(input_ids, attention_mask)
if logits.item() > 0:
answer_type = "羲和回答"
else:
answer_type = "零回答"
specific_answer = self.get_specific_answer(question, answer_type)
self.answer_text.delete(1.0, tk.END)
self.answer_text.insert(tk.END, f"{answer_type}\n{specific_answer}")
# 添加到历史记录
self.history.append({
'question': question,
'answer_type': answer_type,
'specific_answer': specific_answer,
'accuracy': None # 初始状态为未评价
})
def get_specific_answer(self, question, answer_type):
# 使用模糊匹配查找最相似的问题
best_match = None
best_ratio = 0.0
for item in self.data:
ratio = SequenceMatcher(None, question, item['question']).ratio()
if ratio > best_ratio:
best_ratio = ratio
best_match = item
if best_match:
if answer_type == "羲和回答":
return best_match['human_answers'][0]
else:
return best_match['chatgpt_answers'][0]
return "这个我也不清楚,你问问零吧"
def load_data(self, file_path):
data = []
if file_path.endswith('.jsonl'):
with jsonlines.open(file_path) as reader:
for i, item in enumerate(reader):
try:
data.append(item)
except jsonlines.jsonlines.InvalidLineError as e:
logging.warning(f"跳过无效行 {i + 1}: {e}")
elif file_path.endswith('.json'):
with open(file_path, 'r') as f:
try:
data = json.load(f)
except json.JSONDecodeError as e:
logging.warning(f"跳过无效文件 {file_path}: {e}")
return data
def load_model(self):
model_path = os.path.join(PROJECT_ROOT, 'models/xihua_model.pth')
if os.path.exists(model_path):
self.model.load_state_dict(torch.load(model_path, map_location=self.device))
logging.info("加载现有模型")
else:
logging.info("没有找到现有模型,将使用预训练模型")
def train_model(self, retrain=False):
file_path = filedialog.askopenfilename(filetypes=[("JSONL files", "*.jsonl"), ("JSON files", "*.json")])
if not file_path:
messagebox.showwarning("文件选择错误", "请选择一个有效的数据文件")
return
# 创建一个新的线程来执行模型训练
train_thread = threading.Thread(target=main_train, args=(retrain, self.progress_var, self.log_text))
train_thread.start()
def evaluate_model(self):
# 这里可以添加模型评估的逻辑
messagebox.showinfo("评估结果", "模型评估功能暂未实现")
def mark_correct(self):
if self.history:
self.history[-1]['accuracy'] = True
messagebox.showinfo("评价成功", "您认为这次回答是准确的")
def mark_incorrect(self):
if self.history:
self.history[-1]['accuracy'] = False
messagebox.showinfo("评价成功", "您认为这次回答是不准确的")
def view_history(self):
history_window = tk.Toplevel(self.root)
history_window.title("历史记录")
history_text = tk.Text(history_window, height=20, width=80, font=("Arial", 12))
history_text.pack(padx=10, pady=10)
for entry in self.history:
history_text.insert(tk.END, f"问题: {entry['question']}\n")
history_text.insert(tk.END, f"回答类型: {entry['answer_type']}\n")
history_text.insert(tk.END, f"具体回答: {entry['specific_answer']}\n")
if entry['accuracy'] is None:
history_text.insert(tk.END, "评价: 未评价\n")
elif entry['accuracy']:
history_text.insert(tk.END, "评价: 准确\n")
else:
history_text.insert(tk.END, "评价: 不准确\n")
history_text.insert(tk.END, "-" * 50 + "\n")
def save_history(self):
file_path = filedialog.asksaveasfilename(defaultextension=".json", filetypes=[("JSON files", "*.json")])
if not file_path:
return
with open(file_path, 'w') as f:
json.dump(self.history, f, ensure_ascii=False, indent=4)
messagebox.showinfo("保存成功", "历史记录已保存到文件")
def search_history(self):
search_query = simpledialog.askstring("搜索历史记录", "请输入要搜索的问题或关键词:")
if not search_query:
return
results = [entry for entry in self.history if search_query.lower() in entry['question'].lower()]
if not results:
messagebox.showinfo("搜索结果", "没有找到相关的历史记录")
return
search_window = tk.Toplevel(self.root)
search_window.title("搜索结果")
search_text = tk.Text(search_window, height=20, width=80, font=("Arial", 12))
search_text.pack(padx=10, pady=10)
for entry in results:
search_text.insert(tk.END, f"问题: {entry['question']}\n")
search_text.insert(tk.END, f"回答类型: {entry['answer_type']}\n")
search_text.insert(tk.END, f"具体回答: {entry['specific_answer']}\n")
if entry['accuracy'] is None:
search_text.insert(tk.END, "评价: 未评价\n")
elif entry['accuracy']:
search_text.insert(tk.END, "评价: 准确\n")
else:
search_text.insert(tk.END, "评价: 不准确\n")
search_text.insert(tk.END, "-" * 50 + "\n")
def export_log(self):
file_path = filedialog.asksaveasfilename(defaultextension=".txt", filetypes=[("Text files", "*.txt")])
if not file_path:
return
with open(file_path, 'w') as f:
f.write(self.log_text.get(1.0, tk.END))
messagebox.showinfo("导出成功", "日志已导出到文件")
def clear_log(self):
self.log_text.delete(1.0, tk.END)
messagebox.showinfo("清空日志", "日志已清空")
def filter_history(self):
filter_query = simpledialog.askstring("筛选历史记录", "请输入要筛选的条件 (例如: 准确/不准确):")
if not filter_query:
return
filtered_results = [entry for entry in self.history if entry['accuracy'] == (filter_query.lower() == '准确')]
if not filtered_results:
messagebox.showinfo("筛选结果", "没有找到符合条件的历史记录")
return
filter_window = tk.Toplevel(self.root)
filter_window.title("筛选结果")
filter_text = tk.Text(filter_window, height=20, width=80, font=("Arial", 12))
filter_text.pack(padx=10, pady=10)
for entry in filtered_results:
filter_text.insert(tk.END, f"问题: {entry['question']}\n")
filter_text.insert(tk.END, f"回答类型: {entry['answer_type']}\n")
filter_text.insert(tk.END, f"具体回答: {entry['specific_answer']}\n")
if entry['accuracy'] is None:
filter_text.insert(tk.END, "评价: 未评价\n")
elif entry['accuracy']:
filter_text.insert(tk.END, "评价: 准确\n")
else:
filter_text.insert(tk.END, "评价: 不准确\n")
filter_text.insert(tk.END, "-" * 50 + "\n")
# 主函数
if __name__ == "__main__":
# 启动GUI
root = tk.Tk()
app = XihuaChatbotGUI(root)
root.mainloop()
改进点总结
多线程处理:
使用threading.Thread创建一个新的线程来执行模型训练任务,避免主线程被阻塞,保持界面的响应性。
在train_model方法中,创建并启动一个新的线程来调用main_train函数。
界面布局:
保持原有的界面布局,确保用户界面的美观和易用性。
功能增强:
保留了所有现有的功能,包括日志导出、历史记录搜索和筛选等。
用户体验:
通过多线程处理,确保在模型训练过程中界面不会卡顿,提高用户体验。
希望这些改进能进一步提升你的聊天机器人的用户体验和功能性!