Bootstrap

Huggingface-4.8.2自定义训练

Huggingface走到4.8.2这个版本,已经有了很好的封装。训练一个语言网络只需要调用Trainer.train(...)即可完成。如果要根据自己的需求修改训练的过程,比如自定义loss,输出梯度,直接修改huggingface的源码显然是不可取的了。好在huggingface提供了相应的接口,让我们可以深入到训练过程中,加入自定义的内容。根据官方的教程,有两种推荐的方法:

  1. 重载trainer中的方法,将其修改为我们需要的内容。比如trainer.compute_loss()这个函数,它定义了如何计算loss,我们只需要修改其中的逻辑,就可以自定义loss的计算。
  2. 使用callbacks。callbacks可以查看训练过程中一些关键变量的值,并根据其状态做出相应的决策,比如early stop。

关于trainer和callbacks这两个的官方文档分别是这里这里,这两个方法都可以很优雅地修改原有的逻辑。但个人感觉重载trainer的方法是一种更灵活也更强大的方法。callbacks其实只能查看提供的一些变量,并且也只是查看,不能做出修改。而重载方法可以定义任意的全新的函数。接下来给出这两种方法的两个例子。

重载方法

在官方给的教程中是一个重载compute loss的例子,这里给一个不一样的,定义trainging_step的例子,代码如下:

class PrintGradientTrainer(Trainer):

    def training_step(self, model, inputs):
        model.train()
        inputs = self._prepare_inputs(inputs)

        loss = self.compute_loss(model, inputs)

        loss.backward()
        
        # ------------------------new added codes.--------------------------
        for name, param in model.named_parameters():
            if param.requires_grad:
                if param.grad is not None:
                    print("{}, gradient: {}".format(name, param.grad.mean()))
                else:
                    print("{} has not gradient".format(name))
        # ------------------------new added codes.--------------------------
        return loss.detach()

# originally the Trainer() is called
#trainer = Trainer(
#    model=model, args=training_args, train_dataset=small_train_dataset, #eval_dataset=small_eval_dataset,
#    tokenizer=tokenizer, data_collator=data_collator
#)

# Now call the new defined PrintGradientTrainer()
trainer = PrintGradientTrainer(
    model=model, args=training_args, train_dataset=small_train_dataset, eval_dataset=small_eval_dataset,
    tokenizer=tokenizer, data_collator=data_collator
)

trainer.train()

只给出了关键部分的代码,其他的就按照正常写即可。

Callbacks

这个方法也需要定义一个原本的TrainerCallback的子类,然后重载原有的空的callbacks方法。代码实例如下,这个例子打出了现在是第几个epoch。

class MyCallback(TrainerCallback):
    def on_step_begin(self, args, state, control, **kwargs):
        print("train step start")
        control.should_log = False
        control.should_evaluate = False
        control.should_save = False
        print('---------------------------------------',state.epoch)
        # return self.call_event("on_step_begin", args, state, control)
trainer = PrintGradientTrainer(
    model=model, args=training_args, train_dataset=small_train_dataset, eval_dataset=small_eval_dataset,
    tokenizer=tokenizer, data_collator=data_collator,callbacks=[MyCallback()]
)

在定义trainer的时候,给callbacks加入自己定义的类就可以了。

;