Bootstrap

mindspore打卡21天之 基于MindSpore的GPT2文本摘要

基于MindSpore的GPT2文本摘要

%%capture captured_output
# 实验环境已经预装了mindspore==2.2.14,如需更换mindspore版本,可更改下面mindspore的版本号
!pip uninstall mindspore -y
!pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.2.14
!pip install tokenizers==0.15.0 -i https://pypi.tuna.tsinghua.edu.cn/simple
# 该案例在 mindnlp 0.3.1 版本完成适配,如果发现案例跑不通,可以指定mindnlp版本,执行`!pip install mindnlp==0.3.1`
!pip install mindnlp
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Collecting tokenizers==0.15.0
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/14/cf/883acc48862589f9d54c239a9108728db5b75cd6c0949b92c72aae8e044c/tokenizers-0.15.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl (3.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.8/3.8 MB[0m [31m16.7 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m

数据集加载与处理

  1. 数据集加载

    本次实验使用的是nlpcc2017摘要数据,内容为新闻正文及其摘要,总计50000个样本。

from mindnlp.utils import http_get

# download dataset
url = 'https://download.mindspore.cn/toolkits/mindnlp/dataset/text_generation/nlpcc2017/train_with_summ.txt'
path = http_get(url, './')
Building prefix dict from the default dictionary ...
Loading model from cache /tmp/jieba.cache
Loading model cost 0.995 seconds.
Prefix dict has been built successfully.
from mindspore.dataset import TextFileDataset

# load dataset
dataset = TextFileDataset(str(path), shuffle=False)
dataset.get_dataset_size()
50000
# dataset1 = TextFileDataset(str(path), shuffle=True,num_samples=1000)
# dataset1.get_dataset_size()

# from mindspore.dataset import TextFileDataset

# # 载入数据集
# dataset = TextFileDataset(str(path), shuffle=False)

# # 取前1000条数据
# dataset = dataset.take(1000)

# # 打印数据集大小
# print(dataset.get_dataset_size())

# from mindspore.dataset import TextFileDataset
# import numpy as np
# # 载入数据集
# dataset = TextFileDataset(str(path), shuffle=False)

# # 获取数据集大小
# dataset_size = dataset.get_dataset_size()

# # 随机选择1000条数据
# dataset = dataset.skip(np.random.randint(0, dataset_size - 1000)).take(1000)

# # 或者,使用shuffle然后take
# dataset = dataset.shuffle(buffer_size=dataset_size).take(1000)

# # 打印数据集大小
# print(dataset.get_dataset_size())
dataset = TextFileDataset(str(path), shuffle=True,num_samples=1000)
dataset.get_dataset_size()
1000
for idx, value in enumerate(dataset.create_tuple_iterator()):
    print(value)
    print("--------")
    if idx==5:
        break
[Tensor(shape=[], dtype=String, value= '{"summarization": "惠州医护人员抽签进入ICU特护MERS患者,未婚者为第一批,有人向亲友道别(图)", "article": "最新:广东惠州已追踪MERS密切接触者61人5月31日凌晨,惠州市卫生和计划生育局发布最新消息,截至昨晚22时,惠州市追踪到的MERS密切接触者累计61人,其中在惠57人居家或医学观察,均未有异常报告。截至5月30日晚9时30分,确诊病例金某血压、心率正常,生命体征平稳,神志清晰。30日,曾途经香港的韩国男子在广东惠州确诊中东呼吸综合征(MERS),相关密接者抵达麦理浩夫人度假村接受隔离。图/CFP已有9人与惠州疾控中心取得联系,目前正隔离观察;香港已锁定200位密切接触者新京报讯<Paragraph>(记者李丹丹)据新华社报道,广东省卫计委30日晚通报,我国首例韩国MERS患者目前症状仍以发热为主。密切接触者已由29日的38人增至47人。自广东省卫计委5月29日发布寻找首例输入性中东呼吸综合征病例同乘乘客公告后,已有9名乘客与疾控中心取得联系,目前正隔离观察。截至记者发稿时止,已追踪密切接触者47人,暂未出现不适。MERS患者密切接触者主动现身前日晚间,广东省卫生计生委呼吁,曾在5月26日乘坐韩亚航空OZ723航班;下午三时从香港国际机场至沙头角的巴士(香港车牌号码:PJ2595);下午四时四十六分从沙头角至惠州的大巴(广东车牌:粤ZCH70港/香港车牌号码:HN5211)的乘客主动与广东省疾控中心或惠州市疾控中心取得联系。记者昨日从广东省疾控中心和惠州市疾控中心获悉,自消息发布以来,已经有一些密切接触者与其取得联系。惠州疾控中心的邱医生对记者表示,下一步将按照规定对密切接触者启动相应的程序并且进行追踪管理。昨日中午,广东省疾控中心、广州市第八人民医院专家再次对惠州市首例中东呼吸综合征确诊病例及其密切接触者采集咽试子、血液等样本,专车送往广州检测。此外,昨日国家卫计委专家组对已经确诊的韩国籍MERS患者进行临床会诊,目前该患者体温38.5左右,精神尚可,正采取抗病毒治疗。网传“惠州救护车司机被传染”不实昨天以来,有人通过网络散布“惠州中心医院已经把ICU封科。接诊的救护车司机已发热,疑似被传染”等消息。针对以上信息,国家、广东省、惠州市三级联合防控工作组表示该消息失实。据介绍,目前该确诊患者生命体征暂时稳定。惠州市中心人民医院医疗秩序正常,ICU正常运作。当天接诊的救护车司机没有出现任何异常。登记在册的密切接触者也没有出现异常。截至目前,惠州市疾控部门对确诊病例活动场所进行了全面彻底的清洁消毒。香港已锁定200名密接者据媒体报道,韩国患者26日抵达香港。与其同机的乘客共158人,其中80人跟患者坐在同一个机舱,患者附近共29名乘客。在这29名密切接触者中,18人并没有出现症状,正在进行为期十四天的隔离。另外11人已离开香港。香港卫生防护中心的总监梁挺雄对媒体表示,跟患者接触过的人当中有三个人出现了不适的症状,已经送院隔离。梁挺雄介绍,有两位乘客坐在距离韩国人五行的位置,他们并不是密切接触者。这两人都有轻微的上呼吸道症状,另外一辆巴士的票务员也有轻微症状,上述人员均已安排送去玛嘉烈医院接受检查。截至29日夜,香港方面已经锁定了跟患者同机同车的将近200人。梁挺雄认为,患者本身是从患有中东呼吸综合征的亲属身上遭到了感染,本身就属于“第二代传染”,再由他去传染别人的概率会比较小。他认为,暂时来说,看不到疾病传播能力会上升。该病主要是由“第一代”传给“第二代”。再传给“第三代”,几乎很少。现在病毒没有持续的人传人的能力。相关新闻医护人员抽签进MERS患者病房惠州中心人民医院的医护人员已抽签进入ICU特护输入性中东呼吸综合征患者。未婚的医护人员作为第一梯队参与抽签先上战场,<Paragraph>已婚的医护人员作为第二梯队。有人在朋友圈和亲友道别……小编看哭了。致敬!祝安好!(广州日报)韩国网民:强行出境MERS患者令韩国人蒙羞新华网首尔5月30日电(记者姚琪琳)26日进入广东省惠州市的一名韩国人已被确诊为中国首例输入性中东呼吸综合征病例。这位曾在韩国密切接触中东呼吸综合征患者并已出现发烧症状者,是怎样脱离韩方监控来华的呢?这一疑问在韩国社会引发强烈关注。为何能出境?当这名患者在广东惠州被隔离治疗后,便有韩国媒体指责这一男子强行出国将病情扩散到中国极不负责任。这名患者的妻子在接受媒体采访时解释说,她丈夫的工作十分繁忙,是在迫不得已的情况下才出差。韩国保健福祉部发言人在5月28日的通报中承认,在初期流行病学调查中并未将该男子列入确诊患者的“密切接触者”并进行隔离观察。但韩方同时认为,该男子刻意隐瞒“密切接触”经过,导致事态扩大。这名发言人解释说,在初期流行病学调查中,这名男子并未说明其在5月16日探望过后来被证实患有中东呼吸综合征的父亲。他在5月19日开始发烧后和22日在某医院急诊室接受治疗时,并未向医生说明自己曾密切接触过确诊患者,并且是确诊患者的家属。事实上,其父亲在20日就已被确诊为韩国第三例中东呼吸综合征患者。5月25日这名男子第二次接受治疗时仍然否认其父亲为确诊患者,而且也未听从医生劝其取消出差计划的建议。与此同时,当事医生25日了解到该男子有中东呼吸综合征密切接触史后也未立即向有关部门报告,一直拖延到该男子出国后的27日才向其所在地区保健部门报告。这些消极应对的做法,使得这名男患者摆脱了韩国医疗和防疫部门的监管。韩国保健福祉部决定今后将对违反“配合义务”的医患双方采取严厉处罚措施。若医护人员不及时举报,或者与病人密切接触者拒绝接受检查,可被处以200万韩元(约合人民币1.12万元)罚款。如果与病人密切接触者拒绝自行隔离,可被罚款300万韩元(约合人民币1.68万元)。韩国社会怎么看?韩国保健福祉部发言人表示,这是韩国首次发现输入性中东呼吸综合征病例,因此遇到了许多困难。该部门承认其防疫系统出现漏洞,责任全在防疫当局。疫情迅速蔓延使韩国政府的疫情应急处置能力备受韩国社会指责。5月29日,韩国保健福祉部长官文亨杓就政府部门未能充分准备好应急对策导致事态扩大公开道歉。韩国《中央日报》评论说,这名赴华韩国男子曾去医院探望了身为确诊患者的父亲,却未被列入隔离对象,负有报告义务的医疗人员也没有及时报告疫情,这些管理上的漏洞导致疫情不断扩散。不仅如此,韩国保健机构一直表示中东呼吸综合征“传染性不强”,未能切实做好应对,可见相关机构在应对疫情管理方面有很大漏洞。还有韩国网友称,这名韩国男子不顾劝告,强行出国还将疾病传染到中国,令韩国人蒙羞。"}')]
--------

--------
[Tensor(shape=[], dtype=String, value= '{"summarization": "瑞信预计中国中车在全球市占率达15%,在国内具市场领导地位,给予其跑赢大盘评级,目标价18元。 ", "article": "瑞信发表报告,就中国中车(01766.HK)今日挂牌上市,成为全球最大铁路设备制造业,在内地具领导市场地位,料其在全球市占率亦达15%,予其H股「跑赢大市」评级,及目标价18元,指其今早复牌可追回同业表现。该行称,对内地电力动车组(EMU)需求前景乐观,同时料公司扩张海外策略上,可享强劲的协同效应,包括中央采购节省成本、规模及研发能力优势等。瑞信估计,中国中车於2015年至2017年每股盈利预测各为0.48、0.74及0.86元人民币,按年各增长9.9%、52.6%及16.5%。"}')]
--------
# split into training and testing dataset
train_dataset, test_dataset = dataset.split([0.9, 0.1], randomize=False)
[WARNING] ME(19698:281473738053936,MainProcess):2024-07-11-03:36:55.227.373 [mindspore/dataset/engine/datasets.py:1203] Dataset is shuffled before split.
  1. 数据预处理

    原始数据格式:

    article: [CLS] article_context [SEP]
    summary: [CLS] summary_context [SEP]
    

    预处理后的数据格式:

    [CLS] article_context [SEP] summary_context [SEP]
    
import json
import numpy as np

# preprocess dataset
def process_dataset(dataset, tokenizer, batch_size=6, max_seq_len=1024, shuffle=False):
    def read_map(text):
        data = json.loads(text.tobytes())
        return np.array(data['article']), np.array(data['summarization'])

    def merge_and_pad(article, summary):
        # tokenization
        # pad to max_seq_length, only truncate the article
        tokenized = tokenizer(text=article, text_pair=summary,
                              padding='max_length', truncation='only_first', max_length=max_seq_len)
        return tokenized['input_ids'], tokenized['input_ids']
    
    dataset = dataset.map(read_map, 'text', ['article', 'summary'])
    # change column names to input_ids and labels for the following training
    dataset = dataset.map(merge_and_pad, ['article', 'summary'], ['input_ids', 'labels'])

    dataset = dataset.batch(batch_size)
    if shuffle:
        dataset = dataset.shuffle(batch_size)

    return dataset

因GPT2无中文的tokenizer,我们使用BertTokenizer替代。

from mindnlp.transformers import BertTokenizer

# We use BertTokenizer for tokenizing chinese context.
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
len(tokenizer)
21128
train_dataset = process_dataset(train_dataset, tokenizer, batch_size=4)
next(train_dataset.create_tuple_iterator())
[Tensor(shape=[4, 1024], dtype=Int64, value=
 [[ 101, 8127, 2399 ... 7370,  511,  102],
  [ 101, 8119,  118 ...    0,    0,    0],
  [ 101, 8119, 2399 ...    0,    0,    0],
  [ 101, 6598, 3160 ...  100,  511,  102]]),
 Tensor(shape=[4, 1024], dtype=Int64, value=
 [[ 101, 8127, 2399 ... 7370,  511,  102],
  [ 101, 8119,  118 ...    0,    0,    0],
  [ 101, 8119, 2399 ...    0,    0,    0],
  [ 101, 6598, 3160 ...  100,  511,  102]])]

模型构建

  1. 构建GPT2ForSummarization模型,注意shift right的操作。
from mindspore import ops
from mindnlp.transformers import GPT2LMHeadModel

class GPT2ForSummarization(GPT2LMHeadModel):
    def construct(
        self,
        input_ids = None,
        attention_mask = None,
        labels = None,
    ):
        outputs = super().construct(input_ids=input_ids, attention_mask=attention_mask)
        shift_logits = outputs.logits[..., :-1, :]
        shift_labels = labels[..., 1:]
        # Flatten the tokens
        loss = ops.cross_entropy(shift_logits.view(-1, shift_logits.shape[-1]), shift_labels.view(-1), ignore_index=tokenizer.pad_token_id)
        return loss
  1. 动态学习率
from mindspore import ops
from mindspore.nn.learning_rate_schedule import LearningRateSchedule

class LinearWithWarmUp(LearningRateSchedule):
    """
    Warmup-decay learning rate.
    """
    def __init__(self, learning_rate, num_warmup_steps, num_training_steps):
        super().__init__()
        self.learning_rate = learning_rate
        self.num_warmup_steps = num_warmup_steps
        self.num_training_steps = num_training_steps

    def construct(self, global_step):
        if global_step < self.num_warmup_steps:
            return global_step / float(max(1, self.num_warmup_steps)) * self.learning_rate
        return ops.maximum(
            0.0, (self.num_training_steps - global_step) / (max(1, self.num_training_steps - self.num_warmup_steps))
        ) * self.learning_rate

模型训练

num_epochs = 1
warmup_steps = 200
learning_rate = 1.5e-4

num_training_steps = num_epochs * train_dataset.get_dataset_size()
from mindspore import nn
from mindnlp.transformers import GPT2Config, GPT2LMHeadModel

config = GPT2Config(vocab_size=len(tokenizer))
model = GPT2ForSummarization(config)

lr_scheduler = LinearWithWarmUp(learning_rate=learning_rate, num_warmup_steps=warmup_steps, num_training_steps=num_training_steps)
optimizer = nn.AdamWeightDecay(model.trainable_params(), learning_rate=lr_scheduler)
# 记录模型参数数量
print('number of model parameters: {}'.format(model.num_parameters()))
number of model parameters: 102068736
from mindnlp._legacy.engine import Trainer
from mindnlp._legacy.engine.callbacks import CheckpointCallback

ckpoint_cb = CheckpointCallback(save_path='checkpoint', ckpt_name='gpt2_summarization',
                                epochs=1, keep_checkpoint_max=2)

trainer = Trainer(network=model, train_dataset=train_dataset,
                  epochs=1, optimizer=optimizer, callbacks=ckpoint_cb)
trainer.set_amp(level='O1')  # 开启混合精度
Trainer will use 'StaticLossScaler' with `scale_value=2 ** 10` when `loss_scaler` is None.

注:建议使用较高规格的算力,训练时间较长

trainer.run(tgt_columns="labels")
The train will start from the checkpoint saved in 'checkpoint'.



  0%|          | 0/225 [00:00<?, ?it/s]


|

[ERROR] CORE(19698,ffffb62b9930,python):2024-07-11-03:38:35.667.067 [mindspore/core/utils/file_utils.cc:253] GetRealPath] Get realpath failed, path[/tmp/ipykernel_19698/729418715.py]
[ERROR] CORE(19698,ffffb62b9930,python):2024-07-11-03:38:35.667.166 [mindspore/core/utils/file_utils.cc:253] GetRealPath] Get realpath failed, path[/tmp/ipykernel_19698/729418715.py]
[ERROR] CORE(19698,ffffb62b9930,python):2024-07-11-03:38:35.667.705 [mindspore/core/utils/file_utils.cc:253] GetRealPath] Get realpath failed, path[/tmp/ipykernel_19698/729418715.py]
[ERROR] CORE(19698,ffffb62b9930,python):2024-07-11-03:38:35.667.760 [mindspore/core/utils/file_utils.cc:253] GetRealPath] Get realpath failed, path[/tmp/ipykernel_19698/729418715.py]


Checkpoint: 'gpt2_summarization_epoch_0.ckpt' has been saved in epoch: 0.

模型推理

数据处理,将向量数据变为中文数据

def process_test_dataset(dataset, tokenizer, batch_size=1, max_seq_len=1024, max_summary_len=100):
    def read_map(text):
        data = json.loads(text.tobytes())
        return np.array(data['article']), np.array(data['summarization'])

    def pad(article):
        tokenized = tokenizer(text=article, truncation=True, max_length=max_seq_len-max_summary_len)
        return tokenized['input_ids']

    dataset = dataset.map(read_map, 'text', ['article', 'summary'])
    dataset = dataset.map(pad, 'article', ['input_ids'])
    
    dataset = dataset.batch(batch_size)

    return dataset
test_dataset = process_test_dataset(test_dataset, tokenizer, batch_size=1)
print(next(test_dataset.create_tuple_iterator(output_numpy=True)))
[array([[  101,  3219,  1921,   678,  1286,  8024,  3959,  2336,  2356,
         3959,  2336,  1920,  1196,  7368,  7027,  8024,  6225,   830,
         2375,   677,  3198,   679,  3198,  4255,  1355,  1139,  4178,
         4164,  4638,  2958,  1898,  8024,  1762,  1378,   677,  2970,
         1358,  6821,  3416,  4851,  6614,  4638,  8024,  3221,   671,
         5408,  6132,  4708,  6574,  3320,  4638,  3249,  6858,   782,
          511,   800,   812,  3300,  5125,  4868,  4752,  4161,  4638,
         4635,  1355,  5439,  5442,  8024,  3300,  7481,  2425,  7471,
         3886,  4638,  2110,  4495,  8024,  3300,  2692,  3698,  7599,
         1355,  4638,   704,  2399,  4511,  2094,  8024,  3300,  5010,
         1898,  4272,  3306,  4638,  1920,  1995,  8024,   800,   812,
         3221,  3295,  5307,  5815,  2533,  6814,  1059,  1744,  6887,
         2548,  3563,  5745,  3654,  5783,  1469,   100,  6716,  6804,
         1962,   782,   100,  4917,  1384,  4638,  6887,  2548,  3563,
         5745,   807,  6134,   812,  8024,   800,   812,  7970,  5471,
         1762,   704,  1744,  1962,   782,  3528,  1724,  3299,  1057,
         6848,  1399,  1296,  1355,  2357,   811,  2466,  4385,  1767,
         8024,  4867,  6590,  4708,  1057,  6848,   704,  1744,  1962,
          782,  3528,  4638,  3173,  3301,  1351,  8024,   738,  1469,
         1168,  1767,  4638,  6225,   830,   671,  6629,   769,  3837,
         1146,   775,   749,  5632,  2346,  4638,  3633,  5543,  7030,
         3125,   752,   511,  3851,  3736,   758,   782,  1057,  6848,
          704,  1744,  1962,   782,  1724,  3299,  3528,  1296,  4507,
          704,  1925,  3152,  3209,  1215,   712,  1215,  4638,   100,
         2769,  2972,  5773,  2769,  6397,  6379,  6716,  6804,  1962,
          782,   100,  3833,  1220,  8024,  5632,  8182,  2399,   126,
         3299,  6629,  5635,   791,  1762,  1059,  1744,  2898,  5330,
         2458,  2245,   673,  2399,   749,  8024,  1920,  2157,  6858,
         6814,  4510,  6413,   510,  4764,   928,   510,  4510,  2094,
         6934,   816,   510,  5381,  5317,  6389,  1781,  5023,  3175,
         2466,   715,  5773,  6716,  6804,  4638,  1127,   782,  1587,
          715,   510,  5770,  3418,  5739,  7413,   511,   791,  2399,
          125,  3299,  4638,   704,  1744,  1962,   782,  3528,  6397,
         3683,   898,  3191,  4080,  4164,  8024,  1059,  1744,  2519,
         7415,  1168,  1962,   782,  1962,   752,  5296,  5164,  3144,
         1283,  3340,  8024,  8273,  8159,   855,   100,  6716,  6804,
         1962,   782,   100,   868,   711,   704,  1744,  1962,   782,
         3528,   952,  6848,   782,   511,  1762,  5307,  6814,   671,
          702,  3299,  4638,  7415,   704,  2245,  4850,  2146,   837,
         1469,  4178,  2552,  5408,   830,  4638,  2832,  4873,  6397,
         6379,  1400,  8024,  8965,   855,  1221,   782,   711,   727,
          510,  6411,  2141,  2127,   928,   510,  6224,   721,  1235,
          711,   510,  2105,  5439,  4263,   779,  4638,  2398,  1127,
         1962,   782,  5783,  4633,   100,   704,  1744,  1962,   782,
         3528,   100,   511,  5445,  6821,   671,  3613,  8024,  3851,
         3736,  1348,  3300,   126,   855,  1962,   782,   677,  3528,
         8024,  2779,  5635,  3315,  3299,  8024,  2769,  4689,  1066,
         3300, 11056,   782,  1057,  6848,   704,  1744,  1962,   782,
         3528,   511,  6821,   671,  3613,  5815,  2533,   100,   704,
         1744,  1962,   782,   100,  4917,  1384,  4638,  3851,  3736,
         1962,   782,  8024,  3221,  3341,  5632,  3959,  2336,  2128,
         1395,   711, 10220,  1399,   856,   924,  2787,  1048,  6589,
         6981,  5790,  4638,  5790,  2421,  5439,  3352,  2528,  6871,
          756,   510,  3341,  5632,  1378,  2336,  1921,  1378,  4638,
         3297,  5401,   721,  2339,  6145,  3236,  2617,   510,  3341,
         5632,  5305,  1069,  3392,  3441,  5815,  2533,   100,  6411,
         2141,  2127,   928,   100,  4638,   100,   704,  1744,  1962,
          782,   100,   960,  2300,  3239,   510,  3341,  5632,  3959,
         2336,  1298,  3850,  4638,   100,  3143,   689,  1938,  4346,
          100,  1962,   782,  7178,  2207,  5739,   510,  3341,  5632,
         1378,  2336,  3492,  3736,  4638,   100,  2105,  5439,  4263,
          779,   100,  1962,   782,   960,  3424,  1235,   511,   129,
         2259,  1036,  2094,  3136,  3678,  6371,  2099,  6432,  6413,
         3219,  1921,  1762,  6887,  2548,  3563,  5745,   812,  1146,
          775,   769,  3837,  3198,  8024,   671,   819,  4294,  3654,
         4638,  4851,  4289,   837,  6853,  1168,   800,   812,  2797,
          704,   511,  6929,  2218,  3221,  3959,  2336,  4294,  3300,
         4638,   692,  5339,  1469,  3959,  5011,   741,  3791,  5445,
         2768,  4638,   100,  6887,  2548,   837,  2157,  2140,   100,
         8024,   671,  1066,  3300,  1063,   819,  8024,  1146,  1166,
         1091,   677,   100,   785,   510,   721,   510,  6411,   510,
         1249,   510,  2105,   510,  1587,   100,  1469,  2190,  2418,
         4638,  8121,  2099,  2157,  6378,   511,  2970,  6814,   100,
         2105,   100,  2099,   100,  6887,  2548,   837,  2157,  2140,
          100,  4638,  2528,  4207,  2263,  8024,   800,  4638,  3125,
          752,  8024,  6375,  1762,  1767,   679,  2208,   782,  3837,
          678,   749,  4706,  3801,   511,  2528,  4207,  2263,  8024,
          791,  2399,  8111,  2259,  4638,  3959,  2336,  2548,  3926,
         4511,  2111,  8024,  2595,  3419,  1079,  1403,   511,  3680,
         1921,  3123,  2110,  1400,  8024,   800,   794,   679,  1762,
         1912,  4381,  5446,  8024,  5445,  3221,  2593,  2593,  1765,
         6628,  1726,  2157,  8024,  1728,   711,  2157,  7027,  8024,
          800,  4638,  1968,  1968,  2001,  2456,  6004,  3633,  5023,
         4708,   800,  1726,  2157,  3136,  1961,  2573,  2099,   511,
         2001,  2456,  6004,  2400,   679,  3221,  3152,  4683,  8024,
         4685,  1353,  8024,  8151,  2399,  1184,  1961,  6820,  3221,
          671,  1399,  1139,  5682,  4638,  1278,  4495,   511,  1377,
         3221,   124,  2399,  1184,  4638,  6929,  6629,  2692,  1912,
         3121,  1359,   749,  6821,  3678,  2094,   930,  4638,  1462,
         6817,   511,   124,  2399,  1184,  8024,  2001,  2456,  6004,
         1762,   677,  4408,  1139,  2157,  7305,  3198,   679,  2708,
          794,  3517,  3461,   677,  3035,   749,   678,  3341,  8024,
         1920,  5554,  6427,  6241,  1216,  5543,  1358,  2938,  8024,
         2496,  3198,  1278,  4495,  6402,  3171,  8024,  6206,  2612,
         1908,  6427,  6241,  5543,  1213,  1126,   725,   679,  1377,
         5543,   511,  1377,  3221,  6821,  3416,  4638,  6402,  3171,
         5310,  3362,  8024,  2400,  3766,  3300,  6375,  2528,  4207,
         2263,   700,  1927,   928,  2552,  8024,   800,  1353,  5445,
          678,   749,  1104,  2552,  8038,   100,  1968,  1968,   679,
         3221,   679,   833,  6432,  6413,  8024,  2769,  3341,  3136,
         1968,  1968,  6432,  6413,   511,   100,  6821,   702,  2496,
         3198,   129,  2259,  4638,  2207,  2207,  4511,  2094,  3727,
         6432,  1168,   976,  1168,   511,   102]], dtype=int64), array(['湖州:医生摔伤后语言受损,8岁儿子教其认字说话,每天放学后,便直奔家里照顾母亲'], dtype='<U39')]
model = GPT2LMHeadModel.from_pretrained('./checkpoint/gpt2_summarization_epoch_0.ckpt', config=config)
model.set_train(False)
model.config.eos_token_id = model.config.sep_token_id
i = 0
for (input_ids, raw_summary) in test_dataset.create_tuple_iterator():
    output_ids = model.generate(input_ids, max_new_tokens=50, num_beams=5, no_repeat_ngram_size=2)
    output_text = tokenizer.decode(output_ids[0].tolist())
    print(output_text)
    i += 1
    if i == 1:
        break
[CLS] 民 警 告 诉 记 者 , 男  字 < [UNK] > 不 、 核 查 , 无 论 你 走 到 什 么 地 方 , 都 能 被 公 安 机 关 核 查 到 , 所 以 你 及 时 到 当 地 公 安 机 关 投 案 自 首 , 争 取 宽 大 处 理 。 [UNK] ( 来 源 : 齐 鲁 网 ) [SEP] [UNK] , 司 , 在 一 的 , 中 的 的 在 [UNK] < > [UNK] [SEP] 在 在 中 , 是 在 是 [UNK] 在 司 的 是 中 。 他 。 在 有 , 经 的 有 。 [SEP] 一 , 国 的 。 但 在 他 ,
print("yanggemindspore打卡21天之基于MindSpore的GPT2文本摘要   2024  07 11")
yangge mindspore打卡21天之基于MindSpore的GPT2文本摘要  2024  07 11


;