Bootstrap

BERT的中文问答系统42

我们将对现有的代码进行扩展,以支持360百科的功能。这包括修改XihuaChatbotGUI类中的相关方法,以及添加一个新的搜索360百科的函数。此外,我们还需要更新历史记录的保存格式,以包含360百科的结果。

项目结构
code
project_root/

├── data/
│ └── train_data.jsonl

├── logs/
│ └── [log_files]

├── models/
│ └── xihua_model.pth

├── main.py
└── README.md
README.md
markdown

羲和聊天机器人

项目介绍

羲和聊天机器人是一个基于BERT模型的问答系统。它可以从训练数据中学习,并能够回答用户提出的问题。此外,用户可以通过界面评价机器人的回答是否准确,并提供百度百科和360百科的参考答案。

目录结构

project_root/

├── data/
│ └── train_data.jsonl

├── logs/
│ └── [log_files]

├── models/
│ └── xihua_model.pth

├── main.py
└── README.md

code

依赖

  • Python 3.7+
  • PyTorch
  • Transformers
  • Tkinter
  • Requests
  • BeautifulSoup

安装

pip install torch transformers requests beautifulsoup4

运行

python main.py

功能
用户输入问题,机器人给出回答。
用户可以评价回答是否准确。
如果回答不准确,可以选择查看百度百科或360百科的结果。
训练和重新训练模型。
查看和保存历史记录。

main.py

import os
import json
import jsonlines
import torch
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import BertModel, BertTokenizer
import tkinter as tk
from tkinter import filedialog, messagebox, ttk
import logging
from difflib import SequenceMatcher
from datetime import datetime
import requests
from bs4 import BeautifulSoup

# 获取项目根目录
PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))

# 配置日志
LOGS_DIR = os.path.join(PROJECT_ROOT, 'logs')
os.makedirs(LOGS_DIR, exist_ok=True)

def setup_logging():
    log_file = os.path.join(LOGS_DIR, datetime.now().strftime('%Y-%m-%d_%H-%M-%S_羲和.txt'))
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s',
        handlers=[
            logging.FileHandler(log_file),
            logging.StreamHandler()
        ]
    )

setup_logging()

# 数据集类
class XihuaDataset(Dataset):
    def __init__(self, file_path, tokenizer, max_length=128):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.data = self.load_data(file_path)

    def load_data(self, file_path):
        data = []
        if file_path.endswith('.jsonl'):
            with jsonlines.open(file_path) as reader:
                for i, item in enumerate(reader):
                    try:
                        data.append(item)
                    except jsonlines.jsonlines.InvalidLineError as e:
                        logging.warning(f"跳过无效行 {
     i + 1}: {
     e}")
        elif file_path.endswith('.json'):
            with open(file_path, 'r') as f:
                try:
                    data = json.load(f)
                except json.JSONDecodeError as e:
                    logging.warning(f"跳过无效文件 {
     file_path}: {
     e}")
        return data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        question = item['question']
        human_answer = item['human_answers'][0]
        chatgpt_answer = item['chatgpt_answers'][0]

        try:
            inputs = self.tokenizer(question, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)
            human_inputs = self.tokenizer(human_answer, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)
            chatgpt_inputs = self.tokenizer(chatgpt_answer, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)
        except Exception as e:
            logging.warning(f"跳过无效项 {
     idx}: {
     e}")
            return self.__getitem__((idx + 1) % len(self.data))

        return {
   
            'input_ids': inputs['input_ids'].squeeze(),
            'attention_mask': inputs['attention_mask'].squeeze(),
            'human_input_ids': human_inputs['input_ids'].squeeze(),
            'human_attention_mask': human_inputs['attention_mask'].squeeze(),
            'chatgpt_input_ids': chatgpt_inputs['input_ids'].squeeze(),
            'chatgpt_attention_mask': chatgpt_inputs['attention_mask'].squeeze(),
            'human_answer': human_answer,
            'chatgpt_answer': chatgpt_answer
        }

# 获取数据加载器
def get_data_loader(file_path, tokenizer, batch_size=8, max_length=128):
    dataset = XihuaDataset(file_path, tokenizer, max_length)
    return DataLoader(dataset, batch_size=batch_size, shuffle=True)

# 模型定义
class XihuaModel(torch.nn.Module):
    def __init__(self, pretrained_model_name='F:/models/bert-base-chinese'):
        super(XihuaModel, self).__init__()
        self.bert = BertModel.from_pretrained(pretrained_model_name)
        self.classifier = torch.nn.Linear(self.bert.config.hidden_size, 1)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.pooler_output
        logits = self.classifier(pooled_output)
        return logits

# 训练函数
def 

悦读

道可道,非常道;名可名,非常名。 无名,天地之始,有名,万物之母。 故常无欲,以观其妙,常有欲,以观其徼。 此两者,同出而异名,同谓之玄,玄之又玄,众妙之门。

;