- 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
- 🍖 原作者:K同学啊
本周任务:
- 结合Word2Vec文本内容预测文本标签
加载数据
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms, datasets
import os,PIL,pathlib,warnings
import pandas as pd
warnings.filterwarnings('ignore')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
# 从本地CSV文件中读取文本内容和标签
train_data = pd.read_csv("D:/桌面/365/train.csv", sep='\t', header=None)
train_data.head()
def coustom_data_iter(texts, labels):
for x, y in zip(texts, labels):
yield x, y
x = train_data[0].values[:]
y = train_data[1].values[:]
构建词典
from gensim.models.word2vec import Word2Vec
import numpy as np
w2v = Word2Vec(vector_size=100, min_count=3)
w2v.build_vocab(x)
w2v.train(x, total_examples=w2v.corpus_count, epochs=20)
def average_vec(text):
vec = np.zeros(100).reshape((1,100))
for word in text:
try:
vec += w2v.wv[word].reshape((1,100))
except KeyError:
continue
return vec
x_vec = np.concatenate([average_vec(z) for z in x])
w2v.save('./w2c_model.pkl')
生成数据批次和迭代器
text_pipeline = lambda x: average_vec(x)
label_pipeline = lambda x: label_name.index(x)
text_pipeline('你在干嘛')
array([[ 0.78121352, 1.93111382, 0.96291968, 0.39362412, -1.67714586,
-0.55152619, 1.7284598 , 0.69204517, 1.1396839 , -0.9755076 ,
-0.55864345, -3.68676656, 1.41707338, -0.44626126, 0.2580443 ,
1.09325009, 2.28043211, -2.26334408, 3.32311766, -1.24760717,
2.2325974 , -0.48408172, -0.55063696, 0.36853465, -1.32127168,
-0.53377433, -1.48909409, -0.5050023 , 1.42371842, -0.4252875 ,
2.52355766, 0.60818394, -1.68924798, -0.16912293, 1.26915893,
-0.4575564 , 0.02507078, 3.33139969, -2.1995108 , 0.44307417,
-0.41596803, 1.39861814, -0.58643346, 0.91654699, -0.08089826,
0.08773175, 1.51611513, -0.22212304, -3.55333737, 1.93851076,
0.42497785, -1.47862379, -0.96684674, 1.20408788, -0.86870126,
-1.12228102, 1.67186388, -1.11024326, -0.18936946, 1.0811481 ,
1.82965288, -0.78202841, 2.17574303, -1.03871018, -0.51042572,
0.40746585, -1.70572275, 1.3409467 , 1.38298857, 1.11757374,
-0.8333215 , 0.04856796, 1.43110101, -0.02333559, 0.82732772,
-0.9469737 , -4.43783602, -0.20290428, 1.04759257, -1.21757071,
-1.30356295, 0.50049417, -1.87846385, 2.47995635, -2.41918275,
-1.72291106, 2.65663178, -0.96948189, -1.30033612, -0.37353188,
0.53420451, -1.99955091, 0.12223354, 1.74861516, 0.99491888,
-1.43117569, 0.063243 , 0.84598846, -2.79536995, 0.02697589]])
from torch.utils.data import DataLoader
def collate_batch(batch):
label_list, text_list = [],[]
for (_text, _label) in batch:
label_list.append(label_pipeline(_label))
processed_text = torch.tensor(text_pipeline(_text), dtype=torch.float32)
text_list.append(processed_text)
label_list = torch.tensor(label_list, dtype=torch.int64)
text_list = torch.cat(text_list)
return text_list.to(device), label_list.to(device)
datalodaer = DataLoader(train_iter, batch_size=8, shuffle=False, collate_fn=collate_batch)
模型构建
from torch import nn
class TextClassificationModel(nn.Module):
def __init__(self, num_class):
super(TextClassificationModel, self).__init__()
self.fc = nn.Linear(100, num_class)
def forward(self, text):
text = text.float()
return self.fc(text)
在这组词汇中不匹配的词汇:书
初始化模型
num_class = len(label_name)
vocab_size = 100000
em_size = 12
model = TextClassificationModel(num_class).to(device)
定义训练及评估函数
import time
def train(dataloader):
model.train()
total_acc, train_loss, total_count = 0,0,0
log_interval = 50
start_time = time.time()
for idx, (text,label) in enumerate(dataloader): # text, label的顺序不能反,否则会报错
predicted_label = model(text)
optimizer.zero_grad()
loss = criterion(predicted_label, label)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
optimizer.step()
total_acc += (predicted_label.argmax(1) == label).sum().item()
train_loss += loss.item()
total_count += label.size(0)
if idx % log_interval == 0 and idx > 0:
elapsed = time.time() - start_time
print('| epoch {:1d} | {:4d}/{:4d} batches'
'| train_acc {:4.3f} train_loss {:4.5f}'.format(epoch, idx, len(dataloader),
total_acc/total_count, train_loss/total_count))
total_acc, train_loss, total_count = 0,0,0
start_time = time.time()
def evaluate(dataloader):
model.eval()
total_acc,train_loss, total_count = 0,0,0
with torch.no_grad():
for idx, (text,label) in enumerate(dataloader):
predicted_label = model(text)
loss = criterion(predicted_label, label)
total_acc += (predicted_label.argmax(1) == label).sum().item()
train_loss += loss.item()
total_count += label.size(0)
return total_acc/total_count, train_loss/total_count
训练模型
from torch.utils.data.dataset import random_split
from torchtext.data.functional import to_map_style_dataset
EPOCHS = 10
LR = 5
BATCH_SIZE = 64
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=LR)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)
total_accu = None
train_iter = coustom_data_iter(train_data[0].values[:],train_data[1].values[:])
train_dataset = to_map_style_dataset(train_iter)
split_train_, split_valid_ = random_split(train_dataset, [int(len(train_dataset) * 0.8), int(len(train_dataset) * 0.2)])
train_dataloader = DataLoader(split_train_, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)
valid_dataloader = DataLoader(split_valid_, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)
for epoch in range(1,EPOCHS + 1):
epoch_start_time = time.time()
train(train_dataloader)
val_acc, val_loss = evaluate(valid_dataloader)
lr = optimizer.state_dict()['param_groups'][0]['lr']
if total_accu is not None and total_accu > val_acc:
scheduler.step()
else:
total_accu = val_acc
print('-' * 69)
print('|epoch {:1d} | time: {:4.2f}s |'
'valid_acc {:4.3f} valid_loss {:4.3f} | lr {:4.6f}'.format(epoch, time.time() - epoch_start_time, val_acc, val_loss,lr))
print('-' * 69)
| epoch 1 | 50/ 152 batches| train_acc 0.752 train_loss 0.02433
| epoch 1 | 100/ 152 batches| train_acc 0.836 train_loss 0.01740
| epoch 1 | 150/ 152 batches| train_acc 0.831 train_loss 0.01821
---------------------------------------------------------------------
|epoch 1 | time: 2.83s |valid_acc 0.847 valid_loss 0.016 | lr 5.000000
---------------------------------------------------------------------
| epoch 2 | 50/ 152 batches| train_acc 0.843 train_loss 0.01709
| epoch 2 | 100/ 152 batches| train_acc 0.835 train_loss 0.01863
| epoch 2 | 150/ 152 batches| train_acc 0.854 train_loss 0.01577
---------------------------------------------------------------------
|epoch 2 | time: 1.28s |valid_acc 0.852 valid_loss 0.017 | lr 5.000000
---------------------------------------------------------------------
| epoch 3 | 50/ 152 batches| train_acc 0.854 train_loss 0.01663
| epoch 3 | 100/ 152 batches| train_acc 0.855 train_loss 0.01743
| epoch 3 | 150/ 152 batches| train_acc 0.846 train_loss 0.01738
---------------------------------------------------------------------
|epoch 3 | time: 1.34s |valid_acc 0.862 valid_loss 0.017 | lr 5.000000
---------------------------------------------------------------------
| epoch 4 | 50/ 152 batches| train_acc 0.862 train_loss 0.01514
| epoch 4 | 100/ 152 batches| train_acc 0.854 train_loss 0.01638
| epoch 4 | 150/ 152 batches| train_acc 0.854 train_loss 0.01920
---------------------------------------------------------------------
|epoch 4 | time: 1.18s |valid_acc 0.847 valid_loss 0.018 | lr 5.000000
---------------------------------------------------------------------
| epoch 5 | 50/ 152 batches| train_acc 0.898 train_loss 0.00902
| epoch 5 | 100/ 152 batches| train_acc 0.897 train_loss 0.00885
| epoch 5 | 150/ 152 batches| train_acc 0.900 train_loss 0.00893
---------------------------------------------------------------------
|epoch 5 | time: 1.37s |valid_acc 0.879 valid_loss 0.011 | lr 0.500000
---------------------------------------------------------------------
| epoch 6 | 50/ 152 batches| train_acc 0.900 train_loss 0.00788
| epoch 6 | 100/ 152 batches| train_acc 0.904 train_loss 0.00703
| epoch 6 | 150/ 152 batches| train_acc 0.901 train_loss 0.00681
---------------------------------------------------------------------
|epoch 6 | time: 1.33s |valid_acc 0.883 valid_loss 0.010 | lr 0.500000
---------------------------------------------------------------------
| epoch 7 | 50/ 152 batches| train_acc 0.922 train_loss 0.00573
| epoch 7 | 100/ 152 batches| train_acc 0.901 train_loss 0.00728
| epoch 7 | 150/ 152 batches| train_acc 0.894 train_loss 0.00702
---------------------------------------------------------------------
|epoch 7 | time: 1.12s |valid_acc 0.879 valid_loss 0.009 | lr 0.500000
---------------------------------------------------------------------
| epoch 8 | 50/ 152 batches| train_acc 0.908 train_loss 0.00630
| epoch 8 | 100/ 152 batches| train_acc 0.905 train_loss 0.00593
| epoch 8 | 150/ 152 batches| train_acc 0.911 train_loss 0.00526
---------------------------------------------------------------------
|epoch 8 | time: 1.11s |valid_acc 0.881 valid_loss 0.009 | lr 0.050000
---------------------------------------------------------------------
| epoch 9 | 50/ 152 batches| train_acc 0.911 train_loss 0.00580
| epoch 9 | 100/ 152 batches| train_acc 0.905 train_loss 0.00611
| epoch 9 | 150/ 152 batches| train_acc 0.917 train_loss 0.00516
---------------------------------------------------------------------
|epoch 9 | time: 1.12s |valid_acc 0.881 valid_loss 0.009 | lr 0.005000
---------------------------------------------------------------------
| epoch 10 | 50/ 152 batches| train_acc 0.912 train_loss 0.00564
| epoch 10 | 100/ 152 batches| train_acc 0.905 train_loss 0.00575
| epoch 10 | 150/ 152 batches| train_acc 0.916 train_loss 0.00565
---------------------------------------------------------------------
|epoch 10 | time: 1.12s |valid_acc 0.881 valid_loss 0.009 | lr 0.000500
---------------------------------------------------------------------
测试指定数据
def predict(text, text_pipeline):
with torch.no_grad():
text = torch.tensor(text_pipeline(text), dtype=torch.float32)
print(text.shape)
output = model(text)
return output.argmax(1).item()
ex_text_str = '还有双鸭山到淮阴的汽车票吗13号的'
model = model.to('cpu')
print('该文本的类别是: %s' %label_name[predict(ex_text_str, text_pipeline)])
torch.Size([1, 100])
该文本的类别是: Travel-Query
总结
- 本周是结合前几周的内容,使用Word2Vec进行词嵌入之后,再实现中文文本分类
- 本次自己的错误:将for idx, (text,label) in enumerate(dataloader): 中的text、label搞反了,导致输入和模型的输出无法匹配,因此花费了很多时间