前言
本文第二部分,则展示下我司正在做的论文审稿GPT的部分工作 (由于我司每周都有好几个或为申博、或为评职称、或为硕/博毕业而报名论文1V1发表辅导的,比如中文期刊、EI会议、ei期刊/SCI等等,所以对这个方向一直都是高度关注),侧重阐述如何从零实现一个论文审稿GPT,该部分由我和我司第二项目组的阿荀共创
第一部分 论文审稿的项目背景与数据处理
1.1 项目背景:API做论文摘要/对话/翻译可以,但做论文审稿不行
自从去年11月,ChatGPT火爆全球之后,大模型技术正在赋能千行百业,而身处当下的大模型时代,如果不利用大模型做点事情,则深感有负于时代,所以我司七月在线
- 一方面,谋划了35个大模型课程(由我远程带领北京的教育团队研发),帮助各行各业通过大模型技术提升各自的业务
- 二方面,则开始围绕“论文、文档、代码”做一系列LLM项目(由我司的长沙LLM项目团队负责,我目前base长沙兼管该项目团队,目前正在扩人,有兴趣者欢迎私我了解或加入)
对于论文,如本文前两个部分所述,市面上已有几个学术论文GPT了,但实话说,对于论文的摘要/总结、对话、翻译、语法检查而言,市面上的学术论文GPT的效果虽暂未有多好,可至少还过得去,而如果涉及到论文的修订/审稿,则市面上已有的学术论文GPT的效果则大打折扣。
原因在哪呢?本质原因在于无论什么功能,它们基本都是基于OpenAI的API实现的,而关键是API毕竟不是万能的,API做翻译/总结/对话还行,但如果要对论文提出审稿意见,则API就捉襟见肘了。比如当让基于GPT3.5的ChatGPT初版,为经典论文《Attention Is All You Need》提出审稿意见,API(gpt-3.5-turbo,4,097的上下文)最终提出了三点建议(测试时间:23年8月27日),如下图所示:
- 是否可提供更多训练参数细节?
- 是否进行足够的消融实验?
- 是否提供可复现代码?
然而实际情况是,《Attention Is All You Need》中已经给出了模型参数、甚至学习率设置等具体的训练细节,消融实验也是与当时的SOTA进行比较,更是在文末提供了可用的训练、推理代码
so,为实现更好的review效果,需要使用特定的对齐数据集进行微调来获得具备优秀review能力的模型
1.2 数据处理:爬取、PDF解析、清洗、组织
做大模型工作的第一步永远是需要先解决数据的问题
一开始,我们本打算直接用GitHub上相关项目代码及其review数据,但已有的项目存在诸多问题
- 都仅支持爬取单会议单年的数据,数据规模严重不足
- 且部分还是基于Selenium(一个python自动化框架,通过解析网页页面元素,模拟人工点击的操作从网页中取数据)实现的爬虫,该方法效率过低,需要实际打开网页,等待页面元素加载完毕才能进行解析爬取
- 时效性无法保证,项目最近更新时间至今已有些年份,期间review数据难免出现变化,代码是否可用仍存疑
既然GitHub上已有的review数据没法用,那没办法,我们只能从零开始爬取我们需要的数据,那我们需要爬取的数据具体长什么样呢?
1.2.1 数据爬取:论文审稿数据是什么样子的(涵盖paper和review数据)
该例取自:Natural Language Descriptions of Deep Visual Features,具体可结合下方的数据字段释义进行对照查看。
上图中各个数据字段的释义(仅展示关键字段)如下:
字段类别 | 字段名称 | 字段释义 |
---|---|---|
basic(基础信息) | b_forum | 论文讨论页的id |
b_title | 论文的标题 | |
b_url | 论文讨论页的链接 | |
b_abstract | 论文的摘要 | |
b_TL;DR | 论文的极简描述 | |
b_authors | 论文的作者 | |
b_keywords | 论文的关键词 | |
b_venue | 论文所属会议 | |
b_venue_id | 论文所属会议的id | |
b_pdf_url | 论文pdf文件页的链接 | |
b_venue_id | 论文所属会议的id | |
review(review信息,部分论文没有review时此处则均为nan) | r_id | review的id |
r_replyto | review所指向的论文页id | |
r_invitation | review提出者的所属类别(通常为Decision或Official) | |
r_signatures | review提出者的签名(可以理解为提出者在当前paper讨论中的id) | |
content(review具体内容,部分论文没有review时此处则均为nan) | c_content | 完整的review内容(下述字段内容均由此处内容拆分得到) |
c_title | review内容的标题 | |
c_rating | 评级 | |
c_review | 概览性review内容 | |
c_confidence | 可信程度 | |
c_decision | 由主席提出的采纳意见 | |
c_comment | 评论 | |
⋯⋯ | 篇幅所限不再赘述,其余字段可根据字段名称知悉释义 |
论文是有了,但论文这么多篇,怎么批量下载到或爬取下来呢,毕竟我们不可能一篇篇去点击下载
好在论文审稿网站的单篇论文页中,提供了相应PDF文件的跳转链接(如https://julyreview.com/pdf?id=09QFnDWPF8),分析PDF页可知其链接构成与该篇论文在网站中的ID(即字段“b_forum”,上例中即为“09QFnDWPF8”)有关
从而可以通过论文ID,然后去拼出它的PDF所在的网页链接,之后用requests库爬下对应网页的二进制内容,再使用python的文件写入方法将PDF写入本地文件即可,具体如下所示
- 爬取审稿数据:utils/julyreview_crawler.py
通过这份代码来获取review及其paper信息import julyreview import time import requests import jsonlines class JulyreviewCrawler: def __init__(self, baseurl='https://api.julyreview.net'): """后台需挂载代理""" self.client = julyreview.Client(baseurl=baseurl) self.venues = self.client.get_group(id='venues').members def get_and_save_venue(self, venue_id): results_list = self._get_venue_papers(venue_id) if results_list: self._save_results(results_list) return results_list def get_and_save_total(self): total_results_list = [] for idx, venue_id in enumerate(self.venues): print('{}/{}: {}, total_results_list_length: {}'.format(idx + 1, len(self.venues), venue_id, len(total_results_list))) results_list = self._get_venue_papers(venue_id) total_results_list += results_list time.sleep(1) self._save_results(total_results_list, spec_name='total_notes') print('The number of papers is {}.'.format(len(total_results_list))) return total_results_list def _get_venue_papers(self, venue_id): """ 从venues(venues=client.get_group(id='venues').members)中获取指定venue的id来传入, 该函数将返回对应venue_id的论文信息并存储 """ # assert self._existence_check(venue_id), \ # 'This item "{}" is not available in julyviewer.net!'.format(venue_id) # 获取当前venue_id对应的提交论文(双盲) submissions = self.client.get_all_notes(invitation='{}/-/Blind_Submission'.format(venue_id), details='directReplies') # 获取当前venue_id下的论文id specified_forum_ids = self._get_all_forum_ids(submissions) # dict list results_list = [self._format_note(note, venue_id) for note in submissions if note.forum in specified_forum_ids] # if results_list: # for i in range(3): # print(results_list[i]['basic_dict']['forum']) return results_list def _get_specified_forum_ids(self, submissions): forum_ids = set() for note in submissions: for reply in note.details["directReplies"]: forum_ids.add(reply['forum']) return forum_ids def _get_all_forum_ids(self, submissions): """获取所有论文页id,无论是否有reply""" forum_ids = set() for note in submissions: forum_ids.add(note.forum) return forum_ids def _format_note(self, note, venue_id): """单条note的处理方法:提取note中的指定信息""" basic_dict = {} reviews_msg = [] authors_string = ','.join(note.content.get('authors', '--')) keywords_string = ','.join(note.content.get('keywords', '--')) localtime_string = time.strftime('%Y-%m-%d', time.localtime(note.pdate / 1000)) if note.pdate else '--' # basic message basic_dict['forum'] = note.forum if note.forum else '--' basic_dict['title'] = note.content.get('title', '--') basic_dict['url'] = 'https://julyreview.net/forum?id=' + note.forum basic_dict['pub_date'] = localtime_string basic_dict['abstract'] = note.content.get('abstract', '--') basic_dict['TL;DR'] = note.content.get('TL;DR', '--') basic_dict['authors'] = authors_string basic_dict['keywords'] = keywords_string basic_dict['venue'] = note.content.get('venue', '--') basic_dict['venue_id'] = note.content.get('venueid', '--') basic_dict['number'] = note.number if note.number else '--' basic_dict['pdf_url'] = 'https://julyreview.net/pdf?id=' + note.forum basic_dict['signatures'] = note.signatures if note.signatures else '--' basic_dict['bibtex'] = note.content.get('_bibtex', '--') basic_dict['from_venue_id'] = venue_id # reviews message reviews_msg = note.details["directReplies"] result_dict = {'basic_dict': basic_dict, 'reviews_msg': reviews_msg} return result_dict def _existence_check(self, item_id): if requests.get("https://julyreview.net/group?id={}".format(item_id)).status_code == 200: return True else: return False def _save_results(self, results_list, spec_name=None): if spec_name: venue_id = spec_name jsonl_file_name = '{}.jsonl'.format(spec_name) else: venue_id = results_list[0]['basic_dict']['venue_id'] jsonl_file_name = '{}.jsonl'.format(venue_id.replace(r'/', '--').replace(r'.', '__')) for result in results_list: with jsonlines.open(jsonl_file_name, mode='a') as file: file.write(result) print('The item "{}" saved successfully!'.format(venue_id)) return if __name__ == '__main__': orc = JulyreviewCrawler() results_list = orc.get_and_save_venue('ICLR.cc/2023/Workshop/TSRL4H') print(results_list[:3])
- 爬取论文PDF:download_pdfs
具体是通过上步获取到的paper信息里取出对应的论文id,拼成pdf_url,然后爬论文pdf
以下是核心代码,完整代码暂只放在我司针对B端客户的线下公司内训,或我司七月的大模型线上营中import requests import time # 函数用于从给定的URL下载PDF,并以特定论坛名称格式保存 def get_paper_pdf(forum, pdf_url): # 向给定的PDF URL发送请求 response = requests.get(pdf_url) # 打开一个文件用于写入PDF内容,文件名格式为'papers_pdf/{论坛名}.pdf' with open('papers_pdf/{}.pdf'.format(forum), 'wb') as f: # 将请求到的内容写入文件 f.write(response.content) # 函数结束,没有返回值 return # 初始化一个空字典用于存放PDF信息(这部分代码中未使用此字典) pdf_dict = {} # 设定开始索引 start_idx = 5501 # 设定结束索引 end_idx = 5555555 # 获取论坛数据的行数 df_dup_length = df_dup_forum.shape[0] # 遍历论坛数据 for idx, row in df_dup_forum.iterrows(): # 如果当前索引小于开始索引,则跳过当前循环 if idx < start_idx: continue # 每10个索引打印一次进度信息 if idx % 10 == 0: # time.sleep(1.5) # 可以取消注释来减缓请求速度 print('{}/{}'.format(idx, df_dup_length)) try: # 尝试下载PDF get_paper_pdf(row['b_forum'], row['b_pdf_url']) except: # 如果遇到错误,则等待5秒后重试 time.sleep(5) get_paper_pdf(row['b_forum'], row['b_pdf_url']) # 如果达到结束索引,则终止循环 if idx == end_idx: break
- 读取并整理审稿数据: utils/openreview_processor.py
import jsonlines import pandas as pd class JulyreviewProccessor: def __init__(self, jsonl_path): self.df = self._load_jsonl_to_dataframe(jsonl_path) self.df_sub = pd.DataFrame() def _load_jsonl_to_dataframe(self, jsonl_path): msg_list = [] with open(jsonl_path, 'r', encoding='utf-8') as file: for line_dict in jsonlines.Reader(file): msg_dict = {} for k, v in line_dict['basic_dict'].items(): msg_dict['b_' + k] = v msg_list.append(msg_dict) for review_msg in line_dict["reviews_msg"]: msg_dict_copy = msg_dict.copy() pure_review_msg = { 'r_id': review_msg.get('id', None), 'r_number': review_msg.get('number', None), 'r_replyto': review_msg.get('replyto', None), 'r_invitation': review_msg.get('invitation', None), 'r_signatures': ','.join(review_msg['signatures']) if review_msg.get('signatures', None) else None, 'r_readers': review_msg.get('readers', None), 'r_nonreaders': review_msg.get('nonreaders', None), 'r_writers': review_msg.get('writers', None) } pure_content_msg = {} pure_content_msg['c_content'] = review_msg['content'] for k, v in review_msg['content'].items(): pure_content_msg['c_' + k] = v pure_review_msg.update(pure_content_msg) msg_dict_copy.update(pure_review_msg) msg_list.append(msg_dict_copy) dataframe = pd.DataFrame(msg_list) dataframe['c_final_decision'] = self._fill_decision(dataframe) return dataframe def _fill_decision(self, dataframe): return dataframe['c_decision'].map(lambda x: x if pd.isnull(x) else 'Accepted' if 'accept' in x.lower() else 'Rejected' if 'reject' in x.lower() else "Unknown") def get_sub(self, mode=None): # 仅带有review的df df_sub = self.df.dropna(subset=self.df.filter(regex='^(?!b_*)').columns, how='all') if mode == 'decision': # review类型仅为decision的df df_sub = df_sub[df_sub['r_invitation'].str.contains('Decision')] elif mode == 'other': # review类型仅为非decision的df df_sub = df_sub[~df_sub['r_invitation'].str.contains('Decision')] elif mode == 'accepted': # decision中被采纳的df df_sub = df_sub[df_sub['c_final_decision'].isin(['Accepted'])] elif mode == 'rejected': # decision中未被采纳的df df_sub = df_sub[df_sub['c_final_decision'].isin(['Rejected'])] self.df_sub = df_sub return def get_total_shape(self): return self.df.shape def get_sub_shape(self): return self.df_sub.shape if __name__ == '__main__': orp = JulyreviewProccessor('../total_notes.jsonl') orp.get_sub() print(orp.df_sub.iloc[0])
1.2.2 对论文PDF的解析
考虑到论文是PDF形式的,所以爬取完全部论文PDF之后,下一步就涉及到论文PDF的解析了
从头开始编写PDF解析器是一个耗时且需要反复测试的复杂工作,因此在项目周期较为紧凑的情况下倾向于采用开源的解析器来完成PDF解析工作
关于PDF解析器的选型主要考虑有两点:
- 一是PDF发展时至今日仍有效的解析器;
- 二是期望解析器对解析论文类PDF能有所特化
最终参考了ChatPaper中提及的SciPDF Parser以及ChatPaper项目自身实现的ChatPaper Parser。两种解析器各有优劣
- SciPDF切分的粒度更细,甚至独属于某篇论文的小标题都可以识别出来并且以列表的形式进行返回,内容稍显混乱复杂,但保留了小标题间的顺序关系
- ChatPaper根据文章的title、experiment等重要节点关键词来识别并切分正文,切分的粒度更粗,内容更为统一,但提取出的节点内容没有顺序信息
同时两种解析器也都有没法完美识别的地方(比如PDF的title、abstract会因识别不出而为空)。考虑到文本顺序对模型具有指导意义,最终使用上文分析过的SciPDF Parser进行解析
具体代码如下(scipdf_parser.py)
import scipdf
import argparse
from pathlib import Path
import json
from tqdm import tqdm
import warnings
warnings.filterwarnings("ignore")
parser = argparse.ArgumentParser()
parser.add_argument("--dir_path", type=str, default=None, help="The path of the folder about paper pdf.")
args = parser.parse_args()
org_dir_path = Path(args.dir_path).resolve()
trg_dir_path = org_dir_path.with_name("scipdf_parser_results")
error_log = {}
if not trg_dir_path.exists():
trg_dir_path.mkdir()
for pdf_file in tqdm(org_dir_path.glob("*.pdf")):
trg_path = trg_dir_path.joinpath(pdf_file.name).with_suffix(".json")
if trg_path.exists():
continue
try:
article_dict = scipdf.parse_pdf_to_dict(str(pdf_file)) # return dictionary
with open(trg_path, "w") as f:
json.dump(article_dict, f)
except Exception as e:
error_log[str(pdf_file.name)] = str(e)
continue
error_log_path = trg_dir_path.with_name("error_log_scipdf.json")
with open(error_log_path, "w") as fe:
json.dump(error_log, fe)
举个例子,针对下面这篇论文
解析后的论文数据如下
最终,下图是解析后的数据集情况
相当于数据形式分为input和output,其中input为paper数据,output为review数据,其中
- paper数据
原共有30380条paper数据,去除有损文件后解析得到30176条paper数据,涉及各个顶会 - review数据
原共有122892条review数据
对于这个数据集而言,“paper-review”是天然的QA对形式数据,无需借助其他工具构成QA对; 内容专业倾向强,属于领域优质数据,无需采用self-instruct等方法进行专家角色扩写、续写等额外操作; 数据清洗的角度更多在于具体的文本内部,如剔除无效信息等
1.2.3 数据处理:去重、去除无关项/长尾内容/极端项、剔除无效信息
之后做了一系列数据处理,如下图所示,最终得到的paper数从30176变为22966,review数从122892变为106271,数据量虽然变少了,但质量提高了许多
至于上图中各种数据处理如何写代码实现,以及各种细节问题,暂在七月的「大模型项目开发线上营」中见
1.2.4 组织训练格式:单轮与多轮
1.2.4.1 单轮数据组织
当前已设计的数据组织格式如下,需将文本根据该数据格式进行处理:
User: please reivew this paper or give some sugguestion
Assistant: ok, please provide detailed infomation or provide paper to review
User: this is paper/content :\n{paper}
Assistant: this is review/suggestion:\n{reivew}
将paper和review内容填入相应的部分,其中paper的文本内容还可进一步细分为“title: xxxx, abstract: xxxxx, keyword: xxxxx, main: xxxxxxx”,需将paper文本进一步处理成相关的subtitle格式,使得模型更容易辨析相关部分
1.2.4.2 多轮数据组织
当前已设计的数据组织格式如下,需将文本根据该数据格式进行处理:
User: please reivew this paper or give some sugguestion.
Assistant: ok, please provide detailed infomation or provide paper to review.
User: this is paper:\n{paper}
Assistant: this is review/suggestion:\n{reivew1}
User: Any more?
Assistant: this is some more review/suggestion:\n{reivew2}
User: Any more?
Assistant: this is some more review/suggestion:\n{reivew3}
...
现有的数据是多是单paper对应多review的情况「如{paperA-reviewA1, paperA-reviewA2, paperA-reviewA3, ...}, {paperB-reviewB1, paperB-reviewB2, ..}, ...」,考虑能否设计成使用类似“Any more suggestions?”表达希求更多的句子引出另一篇review的多轮场景,其中“希求更多的问句”可以考虑使用ChatGPT来进行同义问句扩充。
至于训练数据的存储,可以是以 jsonl 格式存储组织好的数据
第二部分 Q3第1版之模型的选型
2.1 模型的选型:RWKV PK LLaMA2
在我们得到处理好的数据之后,有3类模型 选择
- LLaMA2
Llama2 虽于23年7月份便已推出,但其上下文长度不够(仅4K)
当然,第二版会尝试LLaMA2-long,LongAlpaca - RWKV
之所以第一版选用这个RWKV,原因在于23年Q3时的长上下文解决方案比较罕见,经典Transformer对16k的长度支持需要耗费很大的资源,而RNN的结构训练和推理占用相对比较便宜(或者说线性Transformer结构占用恒定)
关于什么是RWKV,参见下文的介绍,或RWKV GitHub、RWKV Wiki
关于如何基于RWKV微调,可以用这个RWKV微调库:RWKV-infctx-trainer (for training arbitary context sizes, to 10k and beyond)
但缺点是对于论文这种带有密集知识点的对象而言,遗忘机制比较严重,故最终效果不达预期 - ChatGPT的微调接口,不过其开放的微调接口的上下文长度,截止到10月底暂只有4K
(当然,2023年11.6日,OpenAI在其举办的首届开发者大会上,宣布开放GPT3.5 16K的微调接口)
2.2 (选读)从线性Transformer到RWKV
2.2.1 什么是线性transformer:Transformers are RNNs与cosformer
我们已知,Dot-product attention与softmax归一化是transformer捕捉长程依赖关系的基石。然而,其关于序列长度的二次空间和时间复杂性使其计算开销令人望而却步,特别是对于长输入。为了解决这个问题,最近提出了许多方法,如稀疏注意力矩阵(sparse attention matrix),低秩表示(lowrank representations)或基于核的方法(kernel-based methods)等,让这些方法皆有其各自的局限性
以上之外,另一个重要的方法便是线性Transformer(Linear Transformer),其将transformer的复杂度从O(N^2)降低为O(N),这对加快Transformer整体的加速非常重要
关于线性Transformer,可以看下这两篇论文:《Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention》、以及友人钟博士团队的《COSFORMER : RETHINKING SOFTMAX IN ATTENTION》
线性Transformer的核⼼思想是通过Kernel trick的⽅式,如下图右侧所示,将QKV的左乘变成右乘,从⽽将理论计算复杂度降为线性
我们已知
- Transformer中self-attention的典型计算如下:
其中矩阵Q、K、V是由输入 x 经线性变化得到的query、key、value - 如果暂不考虑缩放因子,则自注意力的计算可以分解为向量运算
其中,上式的分母是一个归一化因子,确保所有的注意力得分加起来等于1
这一步怎么做到的呢,援引HeptaAI的一个说明图如下
接下来,便有以下一系列推导
- 如果用下标来表示矩阵的第行(如 表示矩阵 的第行),那么可以将上述公式中的计算用如下形式抽象出来:
其中为抽象出的计算Query和Key相似度的函数 - Linear Transformer采用了kernel来定义sim():
其中 是一个特征映射函数,可根据情况自行设计
考虑到矩阵乘法有结合律,softmax只能左乘,linear可以右乘,而右乘更快,正因为矩阵乘积的这个属性可以实现注意力操作的线性复杂度:
相当于不是显式地计算注意力矩阵,而是先计算,然后乘以,从而最终的时间复杂度为
考虑到,在一般的NLP任务中,一个头的特征维度总是比输入序列长度 ()小得多,因此可以忽略,实现的计算复杂度 - 因此,self-attention可以从
转化为:
原始Transformer的计算复杂度之所以随序列长呈二次方增长,这是因为attention的计算包含两层for循环
外层是对于每一个Query,我们需要计算它对应token的新表征
内层for循环是为了计算每一个Query对应的新表征,需要让该Query与每一个Key进行计算
所以外层是 for q in Queries,内层是 for k in Keys,Queries数量和Keys数量都是N,从而复杂度是
好比军训时,甲乙丙丁4个人列成一队,计算注意力机制的过程相当于
首先把甲站到队伍的前面,算“其”与“自己在内所有人”的相似度,即计算这些的内积值:
甲q甲k、甲q乙k、甲q丙k、甲q丁k
接着,再乙站到队伍的前面,算“其”与“自己在内所有人”的相似度,即计算这些的内积值:
乙q甲k、乙q乙k、乙q丙k、乙q丁k
丙、丁以此类推,即分别计算这两批内积值:
丙q甲k、丙q乙k、丙q丁k、丙q丙k
丁q甲k、丁q乙k、丁q丙k、丁q丁k
而Linear Transformer,它只有外层for q in Queries这个循环了,因为求和项的计算与 无关,所以所有的 可以共享求和项的值。换言之,求和项的值可以只计算一次,然后存在内存中供所有 去使用,所以Linear Transformer的计算复杂度是O(N) - 引入以下两个新符号:
稍作变换,可以将Si 和Zi 写作递归形式:
因此,在inference阶段,当需要计算第i时刻的输出时,Linear Transformer可以复用之前的状态 Si−1 和 Zi−1 ,再额外加上一个与当前时刻相关的计算量即可。而Transformer在计算第i时刻的输出时,它在第i-1个时刻的所有计算都无法被i时刻所复用。因此,Linear Transformer更加高效
总结一下:
- Linear Transformer的计算复杂度为 O(N) (不考虑embedding的维度的情况下)
- 因为Si可由Si−1计算得到(Zi同理),所以它可实现Sequential Decoding(先算S1,由S1算S2,以此类推)。能Sequential Decoding是让这类Transformer看起来像RNN的核心原因
2.2.2 TransnormerLLM
友人钟博士曾评论,不带点积注意力机制的开源模型中,有希望超越带注意力机制的Llama架构的,一个是mamba,一个便是TransnormerLLM
如qinzhen所说,transnomerLLM相比cosformer,最本质的区别是其位置编码的不同,剩下就是结构细微的优化以及工程
// 待更
2.2.3 AFT(Attention Free Transformer)
Attention Free Transformer (AFT) 是Apple公司提出的一种新型的神经网络模型,它在传统的 Transformer 模型的基础上,通过使用像Residual Connection之类的技术来消除注意力机制,从而减少计算量和提升性能
AFT在不同的资料中有不同的表达形式
- 比如有的资料会写成
其中是sigmoid函数;⊙是逐元素相乘(element-wise product), 是待训练的参数
AFT采用的形式和上面的Linear Transformer不一样
首先是attention score,Linear Transformer仍然是同Transformer一样,为每一个Value赋予一个weight,而AFT会为每个dimension赋予weight
换言之,在Linear Transformer中,同一个Value中不同dimension的weight是一致的,而AFT同一Value中不同dimension的weight不同
此外,attention score的计算也变得格外简单,用K去加一个可训练的bias。Q的用法很像一个gate
可以很容易把AFT也写成递归形式,这样容易看出,AFT也可以像Linear Transformer,在inference阶段复用前面时刻的计算结果,表现如RNN形式,从而相比于Transformer变得更加高效 - 还有的资料比如RWKV论文会写成(和上式一个意思)
其中,其中 是学习的pair-wise位置偏差,每个 是一个标量
下图是对该式的解释说明 其实从式子上看,AFT无非是将矩阵乘改成了矩阵加,加上模型只能看到前面的token。注意这里的 是一个二维矩阵,和attention中的positional encoding作用相似,都是为了给模型输入位置信息
2.2.4 RWKV:试图在Transformer时代重塑RNN
RWKV其实是我司论文审稿GPT第一版一开始就考虑的模型,虽然当时第一版用RWKV的效果没符合预期,但在有些任务上的表现还是不错的,加之因为写mamba模型而再次关注到有点类似的RWKV,故本文也顺带讲一下
据RWKV论文可知,RWKV 架构的名称源自timemixing和channel-mixing模块中使用的四个主要模型元素(defined by four fundamental elements that are intrinsic to the timemixing and channel-mixing blocks):
- R:表示过去的信息,用的sigmoid激活函数
- W:权重是位置权重衰减向量,是可训练的模型参数(后面还会再出来个U,是对当前位置信号的补偿)
- K:Key 是类似于传统注意力中的K 的向量
- V :value 是类似于传统注意力中的V 的向量
每个时间步的主要元素之间的相互作用都是乘法的,如下图所示
在RWKV的结构中,其中的递归被表述为当前输入和前一个时间步的输入之间的线性插值(我们将这种技术称为time-shift mixing或token shift,如下图中的对角线所示)
- 可以表示为针对输入嵌入的每个线性投影(例如,timemixing中的 R、K、V,以及channel-mixing中的 R、K)进行独立调整,并作为 WKV 的时间相关更新
- WKV 计算与 AFT 类似,但 W 现在是“通道向量”乘以“相对位置”(下文详述),而不是 AFT 中的pairwise position matrix。我们还引入了一个向量 U 来单独关注当前token,以补偿 W 的潜在退化
一看有点懵,没事,因为其中有不少细节,咱们来逐一阐述
2.2.4.1 RWKV的时间混合(time mix)模块与通道混合(channel mix)模块
如下图所示,假设输入sequence是My name is,目前 ,则这里 是上一个输入token(My), 是这个输入token(name)
是遗忘因子,越大对上个token(My)就忘的越多,也就是对这个token(name)更专注,黄色(μ)表示token shift「至于红色(1)表示分母,蓝色(2)表示分子,粉色(3)表示16种分数计算,h代表了分子和分母的元组」
可有以下五个公式
先解释前三个公式
- 在传统Transformer中, ,, 本质上都是 的线性变换,可以用来动态调整表示的子空间维度且增大参数量
- 在RWKV中, ,, 本质上都是 , 线性组合的变换,且作为计算RKV的输入的:不再是当前token的embedding,而是当前token与上一个token embedding的加权和
接下来 重点解释下其中最难的部分第4个公式:
- 原始的attention是这样的:
- AFT的attention
- RWKV的attention
怎么理解这个RWKV attention的这个表达式呢?
受 AFT 的启发,RWKV 中的每个 都代表一个「通道时间衰减向量」,该向量乘以相对位置,并且在衰减时从当前时间开始向后追踪(Each wt,i in RWKV is a channelwise time decay vector multiplied by the relative position and traced backward from current time as it decays):
其中 , 是通道数,RWKV要求 为非负数,以确保 并且确保每个通道的权重在时间上向后衰减(ensure that e wt,i ≤ 1 and the per-channel weights decay backwards in time)
这个操作与后面的 都是用来建模序列的time decay的
以上可能解释的比较绕,不够通俗,其实说白了,相比AFT,原来的依靠绝对位置的偏置没有了,改成了相对位置,并且只有一个参数向量需要训练
其次,对当前位置单独处理,增加了参数
最后,再解释第5个公式
- 其中 计算, , 在 Transformers 中扮演 的角色,而不会产生quadratic成本,因为计算的都是标量,这就是上面的第5个公式
- 直观上,随着时间 的增加,向量 取决于较长的历史,由越来越多的项的总和表示。对于目标位置 ,RWKV在 的位置区间进行加权求和,然后乘以接受度
因此,交互作用在给定的时间步长内是乘法的,并在不同的时间步长上求和
最后,通道混合块(channel mix block)根据time-mixing block的输出,然后使用下述三个公式的前两个公式计算一组心的R、K,最后根据下面第三个公式计算最终输出
2.2.4.2 RWKV的训练阶段与推理阶段
训练阶段:时间并行模式
在训练复杂度上,我们对比下标准注意力与RWKV
- 对于标准注意力而言,假设是个最大token,因为RWKV只需要上一时刻的state vector和这一时刻的输入。因此,生成的每一个token只要考虑常数个变量,所以复杂度为
如果是个通道,则每个 需要进行 次求和,每次求和都涉及一维向量分别点乘,复杂度为,因此对于整个序列的复杂度为
当然,如果是个序列,则复杂度为 - 对于RWKV而言
针对, 不是向量下标,意味着对每个 ,我们知道 , 是复用的,因此, → 时复杂度为
针对 , 不是向量下标,意味着对每个 ,我们知道 , 是复用的,因此,时间复杂度为
也就是说,在内层循环,算出的可以直接存起来供外层循环使用。即,RWKV的内外层循环是解耦的
当然,如果是个序列,则复杂度为
推理阶段:时间顺序模式
在循环网络中,使用状态 的输出作为状态 的输入是很常见的。这在语言模型的自回归解码推理中尤其明显,要求每个标记在输入下一步之前进行计算,从而使得RWKV 利用其类似 RNN 的结构,称为时间顺序模式(time-sequence mode),如下图所示(来自小冬瓜AIGC)
- 在这种情况下,可以方便地递归地制定 RWKV 以便在推理过程中进行解码,它利用了每个输出token仅依赖于最新状态的优点,该状态具有恒定的大小,而与序列长度无关
- 然后,它充当 RNN 解码器,根据序列长度产生恒定的速度和内存占用,从而能够更有效地处理较长的序列。相比之下,自注意力通常需要 KV 缓存相对于序列长度线性增长,从而导致效率下降,并且随着序列变长而增加内存占用和时间
第三部分 RWKV的具体训练与推理
以下是训练的一些细节
- 所用GPU:用了8块A800
- 训练时间:4天左右
接下来,我们来看如何针对推理数据的处理与最终推理:给定paper,让训练好的模型输出审稿意见
3.1 推理数据处理-主要针对paper
Paper内容主要被明确划分为了3部分:
- Title:论文标题
- Abstract:论文摘要
- Main:论文正文,包括Introduction、Methodology、Conclusion等内容
故可依赖3种途径接收用户传入的Paper内容:
- 纯解析:预留上传框支持用户上传论文的PDF,使用SciPDF解析出Title、Abstract以及Main(其他部分)
- 纯输入:预留输入框支持用户手动输入论文的Title、Abstract以及Main(其他部分)
- 输入+解析(推荐):预留上述两者,鼓励用户手动输入Title和Abstract,并同时上传论文PDF文件,这样设计是考虑到解析器可能无法准确解析出Title和Abstract,通过用户手动输入来获取Title和Abstract即可,故最终Paper文本的Title和Abstract以用户输入为准、Main以解析为准
3.2 RWKV-light推理
相关代码的具体实现,暂在七月的大模型项目开发线上营中见
3.3 后续的第二版:微调llama2最终反超GPT4
总之,我们在第一版中,做了以下三件事
- 爬取了3万多篇paper、十几万的review数据,并对3万多篇PDF形式的paper做解析(review数据爬下来之后就是文本数据,不用做解析)
当然,paper中有被接收的、也有被拒绝的 - 为提高数据质量,针对paper和review做了一系列数据处理
当然,主要是针对review数据做处理 - 基于RWKV进行微调,然因其遗忘机制比较严重,故最终效果不达预期
所以我们后续马上开始做论文审稿GPT第二版:《七月论文审稿GPT第2版:用一万多条paper-review数据集微调LLaMA2最终反超GPT4》,再更多则暂在七月的「大模型项目开发线上营」中见