心得
这个课程的训练时间很长
训练剩余时间,8小时49分。
经过咨询老师,把cell 4中的一行代码改为
dataset = TextFileDataset(str(path), shuffle=True,num_samples=10000)
参考下文的实际执行结果
Epoch 0: 100% 2250/2250 [1:47:33<00:00, 2.83s/it, loss=5.5219264]
1小时47分钟。
学习的时候,一定要注意。
老师说,还有方法是分次执行,可以通过ModelCheckpoint 保存中途训练的ckpt ,下次将ckpt 导入之后再训练。看看什么时候有机会试试。
打卡截图
基于MindSpore的GPT2文本摘要
[1]:
%%capture captured_output
# 实验环境已经预装了mindspore==2.2.14,如需更换mindspore版本,可更改下面mindspore的版本号
!pip uninstall mindspore -y
!pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.2.14
[2]:
!pip install tokenizers==0.15.0 -i https://pypi.tuna.tsinghua.edu.cn/simple
# 该案例在 mindnlp 0.3.1 版本完成适配,如果发现案例跑不通,可以指定mindnlp版本,执行`!pip install mindnlp==0.3.1`
!pip install mindnlp
Looking in indexes: Simple Index Requirement already satisfied: tokenizers==0.15.0 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (0.15.0) Requirement already satisfied: huggingface_hub<1.0,>=0.16.4 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from tokenizers==0.15.0) (0.23.4) Requirement already satisfied: filelock in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from huggingface_hub<1.0,>=0.16.4->tokenizers==0.15.0) (3.15.3) Requirement already satisfied: fsspec>=2023.5.0 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from huggingface_hub<1.0,>=0.16.4->tokenizers==0.15.0) (2024.5.0) Requirement already satisfied: packaging>=20.9 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from huggingface_hub<1.0,>=0.16.4->tokenizers==0.15.0) (23.2) Requirement already satisfied: pyyaml>=5.1 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from huggingface_hub<1.0,>=0.16.4->tokenizers==0.15.0) (6.0.1) Requirement already satisfied: requests in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from huggingface_hub<1.0,>=0.16.4->tokenizers==0.15.0) (2.32.3) Requirement already satisfied: tqdm>=4.42.1 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from huggingface_hub<1.0,>=0.16.4->tokenizers==0.15.0) (4.66.4) Requirement already satisfied: typing-extensions>=3.7.4.3 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from huggingface_hub<1.0,>=0.16.4->tokenizers==0.15.0) (4.11.0) Requirement already satisfied: charset-normalizer<4,>=2 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from requests->huggingface_hub<1.0,>=0.16.4->tokenizers==0.15.0) (3.3.2) Requirement already satisfied: idna<4,>=2.5 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from requests->huggingface_hub<1.0,>=0.16.4->tokenizers==0.15.0) (3.7) Requirement already satisfied: urllib3<3,>=1.21.1 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from requests->huggingface_hub<1.0,>=0.16.4->tokenizers==0.15.0) (2.2.2) Requirement already satisfied: certifi>=2017.4.17 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from requests->huggingface_hub<1.0,>=0.16.4->tokenizers==0.15.0) (2024.6.2) [notice] A new release of pip is available: 24.1 -> 24.1.2 [notice] To update, run: python -m pip install --upgrade pip Looking in indexes: Simple Index Requirement already satisfied: mindnlp in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (0.3.1) Requirement already satisfied: mindspore in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from mindnlp) (2.2.14) Requirement already satisfied: tqdm in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from mindnlp) (4.66.4) Requirement already satisfied: requests in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from mindnlp) (2.32.3) Requirement already satisfied: datasets in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from mindnlp) (2.20.0) Requirement already satisfied: evaluate in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from mindnlp) (0.4.2) Requirement already satisfied: tokenizers in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from mindnlp) (0.15.0) Requirement already satisfied: safetensors in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from mindnlp) (0.4.3) Requirement already satisfied: sentencepiece in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from mindnlp) (0.2.0) Requirement already satisfied: regex in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from mindnlp) (2024.5.15) Requirement already satisfied: addict in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from mindnlp) (2.4.0) Requirement already satisfied: ml-dtypes in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from mindnlp) (0.4.0) Requirement already satisfied: pyctcdecode in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from mindnlp) (0.5.0) Requirement already satisfied: jieba in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from mindnlp) (0.42.1) Requirement already satisfied: pytest==7.2.0 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from mindnlp) (7.2.0) Requirement already satisfied: attrs>=19.2.0 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from pytest==7.2.0->mindnlp) (23.2.0) Requirement already satisfied: iniconfig in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from pytest==7.2.0->mindnlp) (2.0.0) Requirement already satisfied: packaging in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from pytest==7.2.0->mindnlp) (23.2) Requirement already satisfied: pluggy<2.0,>=0.12 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from pytest==7.2.0->mindnlp) (1.5.0) Requirement already satisfied: exceptiongroup>=1.0.0rc8 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from pytest==7.2.0->mindnlp) (1.2.0) Requirement already satisfied: tomli>=1.0.0 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from pytest==7.2.0->mindnlp) (2.0.1) Requirement already satisfied: filelock in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from datasets->mindnlp) (3.15.3) Requirement already satisfied: numpy>=1.17 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from datasets->mindnlp) (1.26.4) Requirement already satisfied: pyarrow>=15.0.0 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from datasets->mindnlp) (17.0.0) Requirement already satisfied: pyarrow-hotfix in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from datasets->mindnlp) (0.6) Requirement already satisfied: dill<0.3.9,>=0.3.0 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from datasets->mindnlp) (0.3.8) Requirement already satisfied: pandas in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from datasets->mindnlp) (2.2.2) Requirement already satisfied: xxhash in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from datasets->mindnlp) (3.4.1) Requirement already satisfied: multiprocess in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from datasets->mindnlp) (0.70.16) Requirement already satisfied: fsspec<=2024.5.0,>=2023.1.0 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from fsspec[http]<=2024.5.0,>=2023.1.0->datasets->mindnlp) (2024.5.0) Requirement already satisfied: aiohttp in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from datasets->mindnlp) (3.9.5) Requirement already satisfied: huggingface-hub>=0.21.2 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from datasets->mindnlp) (0.23.4) Requirement already satisfied: pyyaml>=5.1 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from datasets->mindnlp) (6.0.1) Requirement already satisfied: charset-normalizer<4,>=2 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from requests->mindnlp) (3.3.2) Requirement already satisfied: idna<4,>=2.5 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from requests->mindnlp) (3.7) Requirement already satisfied: urllib3<3,>=1.21.1 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from requests->mindnlp) (2.2.2) Requirement already satisfied: certifi>=2017.4.17 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from requests->mindnlp) (2024.6.2) Requirement already satisfied: protobuf>=3.13.0 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from mindspore->mindnlp) (5.27.1) Requirement already satisfied: asttokens>=2.0.4 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from mindspore->mindnlp) (2.0.5) Requirement already satisfied: pillow>=6.2.0 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from mindspore->mindnlp) (10.3.0) Requirement already satisfied: scipy>=1.5.4 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from mindspore->mindnlp) (1.13.1) Requirement already satisfied: psutil>=5.6.1 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from mindspore->mindnlp) (5.9.0) Requirement already satisfied: astunparse>=1.6.3 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from mindspore->mindnlp) (1.6.3) Requirement already satisfied: pygtrie<3.0,>=2.1 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from pyctcdecode->mindnlp) (2.5.0) Requirement already satisfied: hypothesis<7,>=6.14 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from pyctcdecode->mindnlp) (6.108.2) Requirement already satisfied: six in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from asttokens>=2.0.4->mindspore->mindnlp) (1.16.0) Requirement already satisfied: wheel<1.0,>=0.23.0 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from astunparse>=1.6.3->mindspore->mindnlp) (0.43.0) Requirement already satisfied: aiosignal>=1.1.2 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from aiohttp->datasets->mindnlp) (1.3.1) Requirement already satisfied: frozenlist>=1.1.1 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from aiohttp->datasets->mindnlp) (1.4.1) Requirement already satisfied: multidict<7.0,>=4.5 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from aiohttp->datasets->mindnlp) (6.0.5) Requirement already satisfied: yarl<2.0,>=1.0 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from aiohttp->datasets->mindnlp) (1.9.4) Requirement already satisfied: async-timeout<5.0,>=4.0 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from aiohttp->datasets->mindnlp) (4.0.3) Requirement already satisfied: typing-extensions>=3.7.4.3 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from huggingface-hub>=0.21.2->datasets->mindnlp) (4.11.0) Requirement already satisfied: sortedcontainers<3.0.0,>=2.1.0 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from hypothesis<7,>=6.14->pyctcdecode->mindnlp) (2.4.0) Requirement already satisfied: python-dateutil>=2.8.2 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from pandas->datasets->mindnlp) (2.9.0.post0) Requirement already satisfied: pytz>=2020.1 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from pandas->datasets->mindnlp) (2024.1) Requirement already satisfied: tzdata>=2022.7 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from pandas->datasets->mindnlp) (2024.1) [notice] A new release of pip is available: 24.1 -> 24.1.2 [notice] To update, run: python -m pip install --upgrade pip
数据集加载与处理
-
数据集加载
本次实验使用的是nlpcc2017摘要数据,内容为新闻正文及其摘要,总计50000个样本。
[3]:
from mindnlp.utils import http_get
# download dataset
url = 'https://download.mindspore.cn/toolkits/mindnlp/dataset/text_generation/nlpcc2017/train_with_summ.txt'
path = http_get(url, './')
Building prefix dict from the default dictionary ... Loading model from cache /tmp/jieba.cache Loading model cost 1.078 seconds. Prefix dict has been built successfully.
[4]:
from mindspore.dataset import TextFileDataset
# load dataset
dataset = TextFileDataset(str(path), shuffle=True,num_samples=10000)
dataset.get_dataset_size()
[4]:
10000
[5]:
# split into training and testing dataset
train_dataset, test_dataset = dataset.split([0.9, 0.1], randomize=False)
[WARNING] ME(10913:281472969922864,MainProcess):2024-07-18-03:37:03.859.612 [mindspore/dataset/engine/datasets.py:1203] Dataset is shuffled before split.
-
数据预处理
原始数据格式:
article: [CLS] article_context [SEP] summary: [CLS] summary_context [SEP]
预处理后的数据格式:
[CLS] article_context [SEP] summary_context [SEP]
[6]:
import json
import numpy as np
# preprocess dataset
def process_dataset(dataset, tokenizer, batch_size=6, max_seq_len=1024, shuffle=False):
def read_map(text):
data = json.loads(text.tobytes())
return np.array(data['article']), np.array(data['summarization'])
def merge_and_pad(article, summary):
# tokenization
# pad to max_seq_length, only truncate the article
tokenized = tokenizer(text=article, text_pair=summary,
padding='max_length', truncation='only_first', max_length=max_seq_len)
return tokenized['input_ids'], tokenized['input_ids']
dataset = dataset.map(read_map, 'text', ['article', 'summary'])
# change column names to input_ids and labels for the following training
dataset = dataset.map(merge_and_pad, ['article', 'summary'], ['input_ids', 'labels'])
dataset = dataset.batch(batch_size)
if shuffle:
dataset = dataset.shuffle(batch_size)
return dataset
因GPT2无中文的tokenizer,我们使用BertTokenizer替代。
[7]:
from mindnlp.transformers import BertTokenizer
# We use BertTokenizer for tokenizing chinese context.
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
len(tokenizer)
[7]:
21128
[8]:
train_dataset = process_dataset(train_dataset, tokenizer, batch_size=4)
[9]:
next(train_dataset.create_tuple_iterator())
[9]:
[Tensor(shape=[4, 1024], dtype=Int64, value= [[ 101, 2769, 4689 ... 749, 1408, 102], [ 101, 704, 1744 ... 2658, 1105, 102], [ 101, 3173, 1290 ... 3624, 511, 102], [ 101, 4294, 1166 ... 0, 0, 0]]), Tensor(shape=[4, 1024], dtype=Int64, value= [[ 101, 2769, 4689 ... 749, 1408, 102], [ 101, 704, 1744 ... 2658, 1105, 102], [ 101, 3173, 1290 ... 3624, 511, 102], [ 101, 4294, 1166 ... 0, 0, 0]])]
模型构建
- 构建GPT2ForSummarization模型,注意shift right的操作。
[10]:
from mindspore import ops
from mindnlp.transformers import GPT2LMHeadModel
class GPT2ForSummarization(GPT2LMHeadModel):
def construct(
self,
input_ids = None,
attention_mask = None,
labels = None,
):
outputs = super().construct(input_ids=input_ids, attention_mask=attention_mask)
shift_logits = outputs.logits[..., :-1, :]
shift_labels = labels[..., 1:]
# Flatten the tokens
loss = ops.cross_entropy(shift_logits.view(-1, shift_logits.shape[-1]), shift_labels.view(-1), ignore_index=tokenizer.pad_token_id)
return loss
- 动态学习率
[11]:
from mindspore import ops
from mindspore.nn.learning_rate_schedule import LearningRateSchedule
class LinearWithWarmUp(LearningRateSchedule):
"""
Warmup-decay learning rate.
"""
def __init__(self, learning_rate, num_warmup_steps, num_training_steps):
super().__init__()
self.learning_rate = learning_rate
self.num_warmup_steps = num_warmup_steps
self.num_training_steps = num_training_steps
def construct(self, global_step):
if global_step < self.num_warmup_steps:
return global_step / float(max(1, self.num_warmup_steps)) * self.learning_rate
return ops.maximum(
0.0, (self.num_training_steps - global_step) / (max(1, self.num_training_steps - self.num_warmup_steps))
) * self.learning_rate
模型训练
[12]:
num_epochs = 1
warmup_steps = 2000
learning_rate = 1.5e-4
num_training_steps = num_epochs * train_dataset.get_dataset_size()
[13]:
from mindspore import nn
from mindnlp.transformers import GPT2Config, GPT2LMHeadModel
config = GPT2Config(vocab_size=len(tokenizer))
model = GPT2ForSummarization(config)
lr_scheduler = LinearWithWarmUp(learning_rate=learning_rate, num_warmup_steps=warmup_steps, num_training_steps=num_training_steps)
optimizer = nn.AdamWeightDecay(model.trainable_params(), learning_rate=lr_scheduler)
[14]:
# 记录模型参数数量
print('number of model parameters: {}'.format(model.num_parameters()))
number of model parameters: 102068736
[15]:
from mindnlp._legacy.engine import Trainer
from mindnlp._legacy.engine.callbacks import CheckpointCallback
ckpoint_cb = CheckpointCallback(save_path='checkpoint', ckpt_name='gpt2_summarization',
epochs=1, keep_checkpoint_max=2)
trainer = Trainer(network=model, train_dataset=train_dataset,
epochs=1, optimizer=optimizer, callbacks=ckpoint_cb)
trainer.set_amp(level='O1') # 开启混合精度
Trainer will use 'StaticLossScaler' with `scale_value=2 ** 10` when `loss_scaler` is None.
注:建议使用较高规格的算力,训练时间较长
[16]:
trainer.run(tgt_columns="labels")
The train will start from the checkpoint saved in 'checkpoint'.
Epoch 0: 100%
2250/2250 [1:47:33<00:00, 2.83s/it, loss=5.5219264]
|
[ERROR] CORE(10913,ffff8862d930,python):2024-07-18-03:38:02.598.500 [mindspore/core/utils/file_utils.cc:253] GetRealPath] Get realpath failed, path[/tmp/ipykernel_10913/729418715.py] [ERROR] CORE(10913,ffff8862d930,python):2024-07-18-03:38:02.598.595 [mindspore/core/utils/file_utils.cc:253] GetRealPath] Get realpath failed, path[/tmp/ipykernel_10913/729418715.py] [ERROR] CORE(10913,ffff8862d930,python):2024-07-18-03:38:02.599.141 [mindspore/core/utils/file_utils.cc:253] GetRealPath] Get realpath failed, path[/tmp/ipykernel_10913/729418715.py] [ERROR] CORE(10913,ffff8862d930,python):2024-07-18-03:38:02.599.196 [mindspore/core/utils/file_utils.cc:253] GetRealPath] Get realpath failed, path[/tmp/ipykernel_10913/729418715.py]
Checkpoint: 'gpt2_summarization_epoch_0.ckpt' has been saved in epoch: 0.
模型推理
数据处理,将向量数据变为中文数据
[17]:
def process_test_dataset(dataset, tokenizer, batch_size=1, max_seq_len=1024, max_summary_len=100):
def read_map(text):
data = json.loads(text.tobytes())
return np.array(data['article']), np.array(data['summarization'])
def pad(article):
tokenized = tokenizer(text=article, truncation=True, max_length=max_seq_len-max_summary_len)
return tokenized['input_ids']
dataset = dataset.map(read_map, 'text', ['article', 'summary'])
dataset = dataset.map(pad, 'article', ['input_ids'])
dataset = dataset.batch(batch_size)
return dataset
[18]:
test_dataset = process_test_dataset(test_dataset, tokenizer, batch_size=1)
[19]:
print(next(test_dataset.create_tuple_iterator(output_numpy=True)))
[array([[ 101, 704, 1744, 2255, 691, 5381, 126, 3299, 8153, 3189, 6380, 113, 6381, 5442, 133, 100, 135, 2101, 3345, 114, 6381, 5442, 791, 1921, 794, 2255, 691, 4689, 2141, 7741, 704, 2110, 704, 5401, 7770, 704, 6440, 4923, 2141, 7741, 4408, 3684, 689, 1073, 4851, 677, 749, 6237, 1168, 8024, 6421, 3413, 5018, 753, 2237, 704, 5401, 7770, 704, 6440, 4923, 2141, 7741, 4408, 8259, 855, 2110, 4495, 1059, 6956, 6158, 5401, 1744, 510, 1217, 2897, 1920, 510, 5739, 1744, 510, 4078, 1920, 1164, 762, 510, 3173, 1217, 1786, 510, 7676, 3949, 5023, 1744, 2157, 1469, 1765, 1277, 686, 4518, 1399, 3413, 2497, 1357, 8024, 2398, 1772, 3680, 855, 2110, 4495, 3119, 1168, 126, 702, 1920, 2110, 2497, 1357, 6858, 4761, 741, 511, 6381, 5442, 749, 6237, 1168, 8024, 6421, 4408, 2110, 4495, 8135, 110, 3119, 1168, 5401, 1744, 2961, 1399, 1184, 8188, 855, 4638, 1920, 2110, 8024, 8394, 119, 122, 110, 2110, 4495, 6158, 5401, 1744, 2961, 1399, 1184, 8145, 855, 1920, 2110, 2497, 1357, 8024, 122, 120, 124, 4638, 2110, 4495, 3119, 1168, 749, 5401, 1744, 2961, 1399, 1184, 8114, 855, 4638, 1920, 2110, 2497, 1357, 6858, 4761, 741, 511, 6821, 1071, 704, 1259, 1419, 5401, 1744, 3336, 1046, 1920, 2110, 510, 1506, 867, 4886, 2548, 2110, 7368, 510, 3727, 2166, 2209, 7561, 2110, 7368, 510, 5745, 2548, 1836, 1920, 2110, 510, 7931, 1395, 2209, 1920, 2110, 5023, 686, 4518, 7553, 2211, 1920, 2110, 511, 2255, 691, 4689, 2141, 7741, 704, 2110, 6566, 6569, 782, 1440, 6401, 6381, 5442, 8024, 1184, 679, 719, 1157, 5815, 2533, 809, 5632, 2346, 1399, 2099, 1462, 1399, 2207, 6121, 3215, 4638, 100, 3215, 3215, 4511, 100, 6950, 5735, 2212, 738, 3221, 6421, 4408, 2110, 4495, 722, 671, 511, 1398, 3198, 8024, 6950, 5735, 2212, 5632, 2346, 738, 1357, 2533, 1259, 2886, 3336, 1046, 1920, 2110, 1762, 1079, 4638, 130, 2792, 686, 4518, 1399, 3413, 4638, 10038, 511, 102]], dtype=int64), array(['山东实验中学一毕业班56位学生全被世界名校录取,每个学生平均收到5个大学录取通知书,包括杜克大学、汉密尔顿学院等。'], dtype='<U57')]
[20]:
model = GPT2LMHeadModel.from_pretrained('./checkpoint/gpt2_summarization_epoch_0.ckpt', config=config)
[21]:
model.set_train(False)
model.config.eos_token_id = model.config.sep_token_id
i = 0
for (input_ids, raw_summary) in test_dataset.create_tuple_iterator():
output_ids = model.generate(input_ids, max_new_tokens=50, num_beams=5, no_repeat_ngram_size=2)
output_text = tokenizer.decode(output_ids[0].tolist())
print(output_text)
i += 1
if i == 1:
break
[CLS] 人 民 网 合 肥 5 月 15 日 电 ( 韩 畅 < [UNK] > 实 习 生 李 彬 ) 5 月 14 日 晚, 家 住 合 肥 葛 大 店 附 近 的 吴 师 傅 关 节 炎 发 作, 到 包 河 大 道 滨 江 花 月 小 区 附 近 的 康 利 诊 所 接 受 注 射 治 疗, 随 后 出 现 不 良 反 应, 经 抢 救 无 效 死 亡 。 目 前, 当 地 卫 生 部 门 已 介 入 调 查 。 康 利 诊 所 位 于 滨 江 花 月 小 区 二 期, 门 帘 写 着 [UNK] 包 河 区 卫 生 局 核 发 [UNK] 字 样, 主 治 内 科 、 牙 科 。 5 月 13 日 上 午, 人 民 网 安 徽 频 道 记 者 赶 到 时, 诊 所 大 门 紧 闭, 玻 璃 门 上 写 着 [UNK] 家 里 有 事 [UNK] 。 吴 师 傅 家 人 说, 当 晚 9 时 许, 53 岁 的 吴 师 傅 关 节 炎 发 作, 到 康 利 门 诊 看 病, 然 而 接 受 注 射 治 疗 后, [UNK] 整 个 人 就 不 行 了 [UNK] 。 合 肥 市 第 三 人 民 医 院 急 诊 科 接 诊 记 录 上 记 者 看 到, 吴 师 傅 送 来 时 全 身 抽 搐, 经 抢 救 无 效 死 亡 。 事 发 后, [UNK] 河 派 出 所 民 警 很 快 赶 到, 将 康 利 诊 所 主 治 医 生 及 一 名 护 士 带 回 派 出 所 调 查, 并 将 门 诊 钥 匙 暂 时 保 管 。 包 河 区 卫 生 局 办 公 室 负 责 人 周 女 士 介 绍, 当 晚 包 河 区 卫 生 局 对 此 事 进 行 了 初 步 调 查, 康 利 诊 所 证 照 齐 全, 是 合 法 经 营 。 至 于 给 吴 师 傅 注 射 的 针 剂 是 否 合 规, 是 否 过 量, 周 女 士 称 要 依 鉴 定 结 果 而 定, 待 调 查 结 果 出 来 将 第 一 时 间 向 社 会 公 布 。 [SEP] 广 州 市 公 安 局 长 王 先 生 命 危 险 品 。 ( 图 ) [SEP] 河 北 京 报 警 方 微 信 息 称, 其 中, 一 个 月 28 日 凌 晨 1 日 下 午 5 名 女 儿 子 被
[22]:
import time
print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),'guojun0718')
2024-07-18 06:39:11 guojun0718
[ ]: