探索混合专家(MoE)模型预训练:开源项目实操
Mantaverse 来自知乎
目录
收起
MOE模型是什么
实现Moe 模型
实现步骤拆解
1. 初始化和形状调整
2. 计算路由器的logits
3. 初始化和创建专家掩码
4. 循环计算专家层输出
5. 恢复形状并返回结果
预训练效果对比
Deepseek MoE
结语
MOE模型是什么
相比于传统的Dense模型,MoE(Mixture of Experts)模型在结构上进行了优化,特别是在线性投影层方面。MoE模型将单一的全连接层替换成多个专家层(例如,Mixtral使用了8个专家层)。在Switch Transformer的论文中,我们了解到,每次进行token预测时,模型会从这8个专家层中选出两个用于线性推理。这种方法旨在提高模型的性能和效率。
Switch Transformer
这种设计有什么优势呢?首先,它通过引入专家层,能够在每次计算中仅激活部分网络,从而减少计算资源的消耗。具体来说,MoE模型在推理阶段仅需计算两个被选中的专家层,而不是激活所有的专家层或整个网络。这使得计算量显著减少,从而降低了推理成本。
此外,虽然MoE模型整体参数量较大,但由于每次推理只使用部分专家层,实际参与计算的参数量远小于整个模型的参数总量。这意味着,即使MoE模型的参数量比传统Dense模型大得多,其实际计算成本却要低得多。例如,在Mixtral 8X7B中,每个专家层有7亿个参数,但每次推理只使用两个专家层,因此实际计算的参数量是14亿个,而不是所有8个专家层的总参数量。
更重要的是,MoE模型通过选择最适合当前任务的专家层,可以在不同的任务或数据输入下表现出更强的适应性和泛化能力。由于每个专家层可以专注于处理特定类型的数据或任务,整体模型的性能能够显著提升。结果是,MoE模型在推理阶段不仅具有更低的计算成本,还能在许多情况下比参数量更大的Dense模型表现更好。
具体到Mixtral 8X7B的实现上,Mixtral 8X7B使用了8个独立的专家层,每个专家层都有7亿个参数。在模型进行推理时,会动态选择两个最合适的专家层进行计算,从而实现高效的推理过程。这种设计不仅提高了模型的计算效率,还增强了其在处理大规模数据和复杂任务时的能力。
实现Moe 模型
基于我们在上一篇文章(Mantaverse:从零预训练LLAMA3的完整指南:一个文件,探索Scaling Law)中实现的Llama3模型,只需要对模型做些许修改,便能够将Dense模型转化为MoE模型。
实现结果开源在:https://github.com/hengjiUSTC/learn-llm/blob/main/pretrain/train_mixtral.py
实现步骤拆解
在Dense模型中,MLP的实现如下:
我们希望把这个组件替换成8个专家层。首先,我们复用MLP的定义,因为本质上,专家层的结构和MLP是一样的:
在此基础上,我们定义了一个专门处理8个专家层选择逻辑的gate层。在初始化中,我们定义了多个专家层和一个gate。gate的实现非常简单。在每次预测任务中,针对当前的hidden state,通过gate映射到一个(1,x)维度的结果中,结果中较大的值对应我们要选择的专家编号。
向前传播过程定义如下:
我们接下来拆解向前传播的每一步的效果
1. 初始化和形状调整
获取输入 hidden_states
的形状信息,并在训练时添加抖动噪声以增强模型的鲁棒性:
batch_size, sequence_length, hidden_dim = hidden_states.shape
if self.training and self.jitter_noise > 0:
hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise)
hidden_states = hidden_states.view(-1, hidden_dim)
2. 计算路由器的logits
通过一个门控网络(gate)计算出每个输入token对应的专家权重:
router_logits = self.gate(hidden_states)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
routing_weights = routing_weights.to(hidden_states.dtype)
3. 初始化和创建专家掩码
初始化 final_hidden_states
用于存储计算结果,并创建专家掩码以便索引:
final_hidden_states = torch.zeros((batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device)
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
4. 循环计算专家层输出
遍历所有专家层,计算每个专家层的输出,并累加到最终隐藏状态:
for expert_idx in range(self.num_experts):
expert_layer = self.experts[expert_idx]
idx, top_x = torch.where(expert_mask[expert_idx])
current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
5. 恢复形状并返回结果
将 final_hidden_states
重新调整为原始输入的形状,并返回最终隐藏状态和路由logits:
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
return final_hidden_states, router_logits
实现完成最重要的MoE层后,其他的网络结构和llam3基本没有区别,我们就不重新赘述了,感兴趣的朋友可以看我之前的文章: Mantaverse:从零预训练LLAMA3的完整指南:一个文件,探索Scaling Law
模型实现我开源在:
预训练效果对比
本次试验的结果开源在:
我们对比了在不同setup下MoE模型的效果,对比了三个模型在同样的训练配置下的hellaswag评估效果:
- Dense模型,12层,12个attention head,4个kv head,推理时的激活参数180M
- Moe模型A,12层,12个attention head,4个kv head,8个expert 选择两个激活,推理时的激活参数266M
- Moe模型B,10层,8个attention head,4个kv head,8个expert 选择两个激活,推理时的激活参数150M
从结果可以看出,在相同的模型结构下,MoE模型A的效果远远超过了Dense模型。此外,MoE模型B在激活参数更少的情况下,依然能够在训练效果上超过Dense模型。这突显了专家层的优势:MoE模型能够以更少的激活参数实现更好的模型效果。
Deepseek MoE
在MoE模型的最新进展中,DeepSeek MoE模型进一步优化了专家的分配和选择机制。相比于Mixtral的8个专家层,DeepSeek通过引入更多且更小的专家,将8个专家层扩展到了64个专家层,并且每个专家层的大小缩小了4倍。
此外,DeepSeek定义了一个名为share expert的特殊专家层,这个专家层会参与任何输入的计算,从而保证了计算的一致性和效果。在实验中,DeepSeek MoE模型展现出了极其优异的性能。即使使用接近Llama2 7B模型四分之一的激活参数量,通过2B的激活参数,DeepSeek MoE模型依然达到了与Dense模型相当的效果。
结语
通过对比不同配置下的Dense模型和MoE模型,我们清楚地看到了MoE架构在提升性能和优化计算资源方面的巨大潜力。MoE模型不仅在相同参数量下表现优异,更在激活参数减少的情况下依然保持了高效的训练效果。
特别是DeepSeek MoE模型,通过增加专家层数量和引入share expert的创新机制,大幅提升了计算效率和模型效果。DeepSeek MoE在使用更少激活参数的前提下,依然能够达到与大型Dense模型相当的性能,展示了其在处理复杂任务中的独特优势。
本次试验的训练代码和结果全部开源,欢迎大家关注我的github repo,里面会发布更多模型相关的实操代码:https://github.com/hengjiUSTC/learn-llm/tree/main
编辑于 2024-07-15 00:14・IP 属地北京