Bootstrap

Yolov8如何在训练意外中断后接续训练

0. 2023年11月更新

请使用第四节的新方法,不需要修改代码,更加简单。

1.错误尝试

在训练YOLOv8的时候,因为开太多其他程序,导致在100多次的时候崩溃,查询网上相关知识如何接着训练,在yolo5中把resume改成True就可以。
在yolov8中也这样尝试,将ultralytics/yolo/cfg/default.yaml中的resume改成True发现并没有作用,感觉yolov8代码还是有很多bug

2.成功的方法

2.1 ultralytics/yolo/engine/model.py

打开ultralytics/yolo/engine/model.py代码,找到train方法,如下
将self.trainer.model = self.model注释掉

    def train(self, **kwargs):
        """
        Trains the model on a given dataset.

        Args:
            **kwargs (Any): Any number of arguments representing the training configuration.
        """
        overrides = self.overrides.copy()
        overrides.update(kwargs)
        if kwargs.get("cfg"):
            LOGGER.info(f"cfg file passed. Overriding default params with {kwargs['cfg']}.")
            overrides = yaml_load(check_yaml(kwargs["cfg"]), append_filename=False)
        overrides["task"] = self.task
        overrides["mode"] = "train"
        if not overrides.get("data"):
            raise AttributeError("Dataset required but missing, i.e. pass 'data=coco128.yaml'")
        if overrides.get("resume"):
            overrides["resume"] = self.ckpt_path

        self.trainer = self.TrainerClass(overrides=overrides)
        # if not overrides.get("resume"):  # manually set model only if not resuming
        #     self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
        #     self.model = self.trainer.model
        #下面一行代码在正常情况下需要开启
        # self.trainer.model = self.model
        self.trainer.train()
        # update model and cfg after training
        if RANK in {0, -1}:
            self.model, _ = attempt_load_one_weight(str(self.trainer.best))
            self.overrides = self.model.args

2.2 ultralytics/yolo/engine/trainer.py

找到check_resume和resume_training方法
在check_resume方法里面将resume=中断地方的last.pt
在resume_training里面将ckpt=中断地方的last.pt

def check_resume(self):
        # resume = self.args.resume
        resume = 'runs/detect/train49/weights/last.pt';
        if resume:
            try:
                last = Path(
                    check_file(resume) if isinstance(resume, (str,
                                                              Path)) and Path(resume).exists() else get_latest_run())
                self.args = get_cfg(attempt_load_weights(last).args)
                self.args.model, resume = str(last), True  # reinstate
            except Exception as e:
                raise FileNotFoundError("Resume checkpoint not found. Please pass a valid checkpoint to resume from, "
                                        "i.e. 'yolo train resume model=path/to/last.pt'") from e
        self.resume = resume

    def resume_training(self, ckpt):
        ckpt = torch.load('runs/detect/train49/weights/last.pt')
        if ckpt is None:
            return
        best_fitness = 0.0
        start_epoch = ckpt['epoch'] + 1
        if ckpt['optimizer'] is not None:
            self.optimizer.load_state_dict(ckpt['optimizer'])  # optimizer
            best_fitness = ckpt['best_fitness']
        if self.ema and ckpt.get('ema'):
            self.ema.ema.load_state_dict(ckpt['ema'].float().state_dict())  # EMA
            self.ema.updates = ckpt['updates']
        if self.resume:
            assert start_epoch > 0, \
                f'{self.args.model} training to {self.epochs} epochs is finished, nothing to resume.\n' \
                f"Start a new training without --resume, i.e. 'yolo task=... mode=train model={self.args.model}'"
            LOGGER.info(
                f'Resuming training from {self.args.model} from epoch {start_epoch + 1} to {self.epochs} total epochs')
        if self.epochs < start_epoch:
            LOGGER.info(
                f"{self.model} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {self.epochs} more epochs.")
            self.epochs += ckpt['epoch']  # finetune additional epochs
        self.best_fitness = best_fitness
        self.start_epoch = start_epoch

3.运行代码

没有在中断的train49训练,新开了一个文件夹,但是实现了功能
在这里插入图片描述

重要提示

训练完成后请把所有代码复原!!!
训练完成后请把所有代码复原!!!
训练完成后请把所有代码复原!!!

4. 评论区方法

官方给的代码,使用的比较新的代码,老代码没有尝试过,现在很多东西都可以在官网找到解决办法!
使用python代码如下

from ultralytics import YOLO

# Load a model
model = YOLO('path/to/last.pt')  # load a partially trained model

# Resume training
results = model.train(resume=True)

使用命令行代码如下:

# Resume an interrupted training
yolo train resume model=path/to/last.pt

测试如下图,直接就在原文件夹接着训练:
在这里插入图片描述

;