Bootstrap

p-tuning算法介绍及其pytorch代码实现

P-tuning介绍

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

代码实现

import torch
from transformers import BertTokenizer, BertForSequenceClassification
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker

tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
model = BertForSequenceClassification.from_pretrained('bert-base-chinese', num_labels=2)


def train(tokenize, model, prompt_lenght, prompt, data):
    # 冻结Bert参数
    for param in model.bert.parameters():
        param.requires_grad = False

    # 优化器
    optimizer = torch.optim.Adam([prompt], lr=1e-3)

    # 训练循环
    num_epochs = 8
    losses = []
    for epoch in range(num_epochs):
        total_loss = 0.0
        for text, label in data:
            # 处理输入和标签
            inputs = tokenizer(text, return_tensors='pt')
            labels = torch.tensor([label])  # 标签,形状为 [batch_size]

            # 访问 BERT 的嵌入层
            bert_model = model.bert
            input_ids = inputs['input_ids']

            # 获取输入标记的嵌入表示
            with torch.no_grad():
                input_embeddings = bert_model.embeddings(input_ids)

            # 扩展和拼接提示向量和输入嵌入表示
            prompt_embeddings = prompt.unsqueeze(0).expand(input_ids.size(0), -1, -1)       # unsqueeze(0):新增第一个维度。expand(input_ids.size(0), -1,- 1):对第一个维度按照input_ids[0]的大小进行扩展,-1表示自动计算维度大小。
            prompted_input = torch.cat((prompt_embeddings, input_embeddings), dim=1)

            # 前向传播
            attention_mask = torch.cat((torch.ones(prompt_embeddings.size()[:2], dtype=torch.long), inputs['attention_mask']), dim=1)
            outputs = bert_model(inputs_embeds=prompted_input, attention_mask=attention_mask)
            sequence_output = outputs.last_hidden_state

            # 分类头
            logits = model.classifier(sequence_output[:, prompt_length:, :])  # 跳过提示向量部分

            # 确保logits的形状与labels匹配
            logits = logits[:, 0, :]  # 只取第一个token的logits(即[CLS] token)

            # 计算损失
            loss_fct = torch.nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, model.config.num_labels), labels.view(-1))  # 确保 logits 和 labels 的形状匹配

            # 反向传播和优化
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
        losses.append(total_loss)    
        print(f'Epoch {epoch + 1}/{num_epochs}, Loss: {total_loss/len(data)}')
    torch.save(prompt, 'path_to_trained_prompt.pt')    
    return losses

def plot_loss(losses):
    plt.figure()
    plt.plot(losses)

def predict_classify(tokenize, model, prompt_length, trained_prompt, data):
    predict_list = []
    for input_text in data:
        inputs = tokenizer(input_text, return_tensors='pt')

        # 访问 BERT 的嵌入层
        bert_model = model.bert
        input_ids = inputs['input_ids']

        # 获取输入标记的嵌入表示
        with torch.no_grad():
            input_embeddings = bert_model.embeddings(input_ids)

        # 扩展和拼接提示向量和输入嵌入表示
        prompt_embeddings = trained_prompt.unsqueeze(0).expand(input_ids.size(0), -1, -1)
        prompted_input = torch.cat((prompt_embeddings, input_embeddings), dim=1)

        # 构建新的注意力掩码
        attention_mask = torch.cat((torch.ones(prompt_embeddings.size()[:2], dtype=torch.long), inputs['attention_mask']), dim=1)

        # 前向传播进行推理
        with torch.no_grad():
            outputs = bert_model(inputs_embeds=prompted_input, attention_mask=attention_mask)
            sequence_output = outputs.last_hidden_state

            # 分类头
            logits = model.classifier(sequence_output[:, prompt_length:, :])  # 跳过提示向量部分
            logits = logits[:, 0, :]  # 只取第一个token的logits(即[CLS] token)

        # 获取预测结果
        predicted_label = torch.argmax(logits, dim=-1).item()
        print(f"Input data: {input_text}, Predicted label: {predicted_label}")
        predict_list.append(predicted_label)
    return predict_list

# p-tuning训练
# 定义可学习的提示向量
prompt_length = 5
prompt = torch.nn.Parameter(torch.randn(prompt_length, model.config.hidden_size))
# 训练集
data = [("This movie is great", 1), ("This movie is bad", 0)]
# 训练
losses = train(tokenizer, model, prompt_length, prompt, data)
# 绘制Loss曲线
plot_loss(losses)

# p-tuning预测
prompt_length = 5
trained_prompt = torch.load('path_to_trained_prompt.pt')  # 加载训练好的提示嵌入
input_text = ["This movie is good", "This movie is bad", "This movie is not good"]
predict_list = predict_classify(tokenizer, model, prompt_length, trained_prompt, input_text)

拓展文章:第7章 大模型之Adaptation

;