Bootstrap

【从原理到实战】文本识别经典算法CRNN+CTC

论文链接:
http://arxiv.org/abs/1507.05717
github项目:
https://github.com/bgshih/crnn#train-a-new-model
https://github.com/meijieru/crnn.pytorch

原理

(1)网络结构

在这里插入图片描述
在这里插入图片描述
1.首先对图像进行预处理,高度必须为16的倍数(这里为32),将输入图像缩放至:32W3
2.之后利用CNN提取后图像卷积特征,得到的大小为:1W/4512
3.以seq_len=W/4, input_size=512送入LSTM,提取序列特征,得到:W/4*n的后验概率矩阵
4最后利用CTC,使标签和输出无需一一对应,也能进行训练。
网络结构十分简单,就是常用的CNN和BLSTM,无需多言,下面来看一下CTC是什么

(2)CTC

全称是Connectionist Temporal Classification,用来解决语音或字符识别标注困难的问题,引入-(blank)使得多种输出都可以对应同一label。
原理的难点在于ctcloss的计算和训练,这里借鉴了HMM的forward-backward算法,具体可以参考,写的非常详尽。
一文读懂CRNN+CTC文字识别
使用时pytorch有已经写好的ctcloss函数:

'''
初始化
blank:空白标签所在的label值,默认为0,需要根据实际的标签定义进行设定;
reduction:处理output losses的方式,string类型,
可选'none'表示对output losses不做任何处理,
'mean' 则对output losses取平均值处理,
'sum'则是对output losses求和处理,默认为'mean' 。
'''
ctc_loss = nn.CTCLoss(blank=len(CHARS)-1, reduction='mean')
#计算
'''
log_probs:shape为(T, N, C)的模型输出张量,
T表示CTCLoss的输入长度也即输出序列长度,
N表示训练的batch size长度,
C则表示包含有空白标签的所有要预测的字符集总长度.

targets:shape为(N, S) 或(sum(target_lengths))的张量,
其中第一种类型,N表示训练的batch size长度,S则为标签长度,
第二种类型,则为所有标签长度之和.

input_lengths:shape为(N)的张量或元组,
但每一个元素的长度必须等于T即输出序列长度,一般来说模型输出序列固定后则该张量或元组的元素值均相同.

target_lengths:shape为(N)的张量或元组,
其每一个元素指示每个训练输入序列的标签长度,但标签长度是可以变化的.
'''
loss = ctc_loss(log_probs, targets, input_lengths, target_lengths)

实战

代码解析
作者的github网址为https://github.com/meijieru/crnn.pytorch

(1)数据集

我们来看一下训练需要用的中文文本识别数据集。
训练集包括3279606图片,测试集包括364400张图片,图片形式如下:
在这里插入图片描述
字母表txt,内容为数据集内的所有文字,包括汉字,数字,英文字母,标点和blank,共5990个字符,形式如下:
在这里插入图片描述
训练集和测试集的标签txt,内容为这张图片的名字和其包含文字对应字母表的位置。
在这里插入图片描述

对于数据的预处理主要是对于labletxt的处理,作者定义了strLabelConverter类,将两者进行转换,encode是将str转化为lable;decod将lable转化为str,decode中用到了ctc中的对应规则。

class strLabelConverter(object):
    """Convert between str and label.

    NOTE:
        Insert `blank` to the alphabet for CTC.

    Args:
        alphabet (str): set of the possible characters.
        ignore_case (bool, default=True): whether or not to ignore all of the case.
    """

    def __init__(self, alphabet, ignore_case=False):
        self._ignore_case = ignore_case
        if self._ignore_case:
            alphabet = alphabet.lower()
        self.alphabet = alphabet + '-'  # for `-1` index

        self.dict = {}
        for i, char in enumerate(alphabet):
            # NOTE: 0 is reserved for 'blank' required by wrap_ctc
            self.dict[char] = i + 1

    def encode(self, text):
        """Support batch or single str.

        Args:
            text (str or list of str): texts to convert.

        Returns:
            torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts.
            torch.IntTensor [n]: length of each text.
        """

        length = []
        result = []
        decode_flag = True if type(text[0])==bytes else False

        for item in text:

            if decode_flag:
                item = item.decode('utf-8','strict')
            length.append(len(item))
            for char in item:
                index = self.dict[char]
                result.append(index)
        text = result
        return (torch.IntTensor(text), torch.IntTensor(length))

    def decode(self, t, length, raw=False):
        """Decode encoded texts back into strs.

        Args:
            torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts.
            torch.IntTensor [n]: length of each text.

        Raises:
            AssertionError: when the texts and its length does not match.

        Returns:
            text (str or list of str): texts to convert.
        """
        if length.numel() == 1:
            length = length[0]
            assert t.numel() == length, "text with length: {} does not match declared length: {}".format(t.numel(), length)
            if raw:
            	# ''.join将序列中的元素以指定的字符连接生成一个新的字符串。
                return ''.join([self.alphabet[i - 1] for i in t])
            else:
                char_list = []
                for i in range(length):
                    if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])):
                        char_list.append(self.alphabet[t[i] - 1])
                return ''.join(char_list)
        else:
            # batch mode
            assert t.numel() == length.sum(), "texts with length: {} does not match declared length: {}".format(t.numel(), length.sum())
            texts = []
            index = 0
            for i in range(length.numel()):
                l = length[i]
                texts.append(
                    self.decode(
                        t[index:index + l], torch.IntTensor([l]), raw=raw))
                index += l
            return texts

(2)Model

之后解析crnn的模型
首先是定义rnn的类,输入为输出,隐层和输出的特征维数,由一个BiLSTM和一个全连接层组成,方便下一步直接调用

class BidirectionalLSTM(nn.Module):
    # Inputs hidden units Out
    def __init__(self, nIn, nHidden, nOut):
        super(BidirectionalLSTM, self).__init__()
        self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True)
        self.embedding = nn.Linear(nHidden * 2, nOut)

    def forward(self, input):
        recurrent, _ = self.rnn(input)
        #seq_len, batch, hidden_size * num_directions
        T, b, h = recurrent.size()
        t_rec = recurrent.view(T * b, h)
        output = self.embedding(t_rec)  # [T * b, nOut]
        output = output.view(T, b, -1)
        return output

之后定义整体模型的类
输入分别为图片的高,config里为32;输入的channel,这里为1;rnn输出特征的维数,就是字母表的大小。
作者写了一个convRelu的函数,当i等于0时输入通道为送入图片的通道数,否则为上一层的输出通道数,每层的输出通道在nm中,卷积核大小为3,步长为1,padding为1,使用relu为激活函数。
最终的cnn模型与VGG16基本相同,rnn模型为两个bilstm级联。
从cnn得到的特征,以width为seq,batch不变,channel为输入特征维度,来送入rnn,输出为[seq_len, batch, nh]的概率矩阵

class CRNN(nn.Module):
    def __init__(self, imgH, nc, nclass, nh, n_rnn=2, leakyRelu=False):
        super(CRNN, self).__init__()
        assert imgH % 16 == 0, 'imgH has to be a multiple of 16'

        ks = [3, 3, 3, 3, 3, 3, 2]
        ps = [1, 1, 1, 1, 1, 1, 0]
        ss = [1, 1, 1, 1, 1, 1, 1]
        nm = [64, 128, 256, 256, 512, 512, 512]

        cnn = nn.Sequential()

        def convRelu(i, batchNormalization=False):
        	#i==0成立则nIn = nc,否则nIn = nm[i - 1]
            nIn = nc if i == 0 else nm[i - 1]
            nOut = nm[i]
            cnn.add_module('conv{0}'.format(i),
                           nn.Conv2d(nIn, nOut, ks[i], ss[i], ps[i]))
            if batchNormalization:
                cnn.add_module('batchnorm{0}'.format(i), nn.BatchNorm2d(nOut))
            if leakyRelu:
                cnn.add_module('relu{0}'.format(i),
                               nn.LeakyReLU(0.2, inplace=True))
            else:
                cnn.add_module('relu{0}'.format(i), nn.ReLU(True))

        convRelu(0)
        cnn.add_module('pooling{0}'.format(0), nn.MaxPool2d(2, 2))  # 64x16x64
        convRelu(1)
        cnn.add_module('pooling{0}'.format(1), nn.MaxPool2d(2, 2))  # 128x8x32
        convRelu(2, True)
        convRelu(3)
        cnn.add_module('pooling{0}'.format(2),
                       nn.MaxPool2d((2, 2), (2, 1), (0, 1)))  # 256x4x16
        convRelu(4, True)
        convRelu(5)
        cnn.add_module('pooling{0}'.format(3),
                       nn.MaxPool2d((2, 2), (2, 1), (0, 1)))  # 512x2x16
        convRelu(6, True)  # 512x1x16

        self.cnn = cnn
        self.rnn = nn.Sequential(
            BidirectionalLSTM(512, nh, nh),
            BidirectionalLSTM(nh, nh, nclass))

    def forward(self, input):

        # conv features
        conv = self.cnn(input)
        b, c, h, w = conv.size()
        print(conv.size())
        assert h == 1, "the height of conv must be 1"
        conv = conv.squeeze(2) # b *512 * width
        conv = conv.permute(2, 0, 1)  # [w, b, c]
        output = F.log_softmax(self.rnn(conv), dim=2)

        return output

之后是参数权重初始化和类的实例化

def weights_init(m):
    #get class name
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

def get_crnn(config):

    model = CRNN(config.MODEL.IMAGE_SIZE.H, 1, config.MODEL.NUM_CLASSES + 1, config.MODEL.NUM_HIDDEN)
    model.apply(weights_init)

    return model

(3)Train

首先定义了一个读取config的函数:
argparse是用来从命令行传入参数的,其用法可以参考:https://zhuanlan.zhihu.com/p/56922793

def parse_arg():
    parser = argparse.ArgumentParser(description="train crnn")

    parser.add_argument('--cfg', help='experiment configuration filename', required=True, type=str)

    args = parser.parse_args()

    with open(args.cfg, 'r') as f:
        # config = yaml.load(f, Loader=yaml.FullLoader)
        config = yaml.load(f)
        config = edict(config)

    config.DATASET.ALPHABETS = alphabets.alphabet
    config.MODEL.NUM_CLASSES = len(config.DATASET.ALPHABETS)

    return config

之后是创建日志文件,这部分在之后解析:[1]

    # create output folder
    output_dict = utils.create_log_folder(config, phase='train')

    # cudnn
    cudnn.benchmark = config.CUDNN.BENCHMARK
    cudnn.deterministic = config.CUDNN.DETERMINISTIC
    cudnn.enabled = config.CUDNN.ENABLED

    # writer dict
    writer_dict = {
        'writer': SummaryWriter(log_dir=output_dict['tb_dir']),
        'train_global_steps': 0,
        'valid_global_steps': 0,
    }

之后是一些初始化工作,包括模型,损失函数和优化器:
有nvidia显卡,模型运行在gpu上;损失函数是上文提到的ctc。

	model = crnn.get_crnn(config)

    # get device
    if torch.cuda.is_available():
        device = torch.device("cuda:{}".format(config.GPUID))
    else:
        device = torch.device("cpu:0")

    model = model.to(device)

    # define loss function
    criterion = torch.nn.CTCLoss()

 

这里重点讲一下优化器的初始化
关于优化器的知识可以参考:https://blog.csdn.net/weixin_40170902/article/details/80092628
优化器config里设置为adam
并且利用torch.optim.lr_scheduler来调整学习率,在指定epoch后将lr降低指定倍数,可以参考https://blog.csdn.net/qyhaill/article/details/103043637

optimizer = utils.get_optimizer(config, model)
    if isinstance(config.TRAIN.LR_STEP, list):
        lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer, config.TRAIN.LR_STEP,
            config.TRAIN.LR_FACTOR, last_epoch-1
        )
    else:
        lr_scheduler = torch.optim.lr_scheduler.StepLR(
            optimizer, config.TRAIN.LR_STEP,
            config.TRAIN.LR_FACTOR, last_epoch - 1
        )

之后是finetune和resume的选择,以及与训练模型的载入。
fintune讲解:https://zhuanlan.zhihu.com/p/35890660,这里的fintune冻结了cnn,其参数不更新。

    if config.TRAIN.FINETUNE.IS_FINETUNE:
        model_state_file = config.TRAIN.FINETUNE.FINETUNE_CHECKPOINIT
        if model_state_file == '':
            print(" => no checkpoint found")
        checkpoint = torch.load(model_state_file, map_location='cpu')
        if 'state_dict' in checkpoint.keys():
            checkpoint = checkpoint['state_dict']

        from collections import OrderedDict
        model_dict = OrderedDict()
        for k, v in checkpoint.items():
            if 'cnn' in k:
                model_dict[k[4:]] = v
        model.cnn.load_state_dict(model_dict)
        if config.TRAIN.FINETUNE.FREEZE:
            for p in model.cnn.parameters():
                p.requires_grad = False

    elif config.TRAIN.RESUME.IS_RESUME:
        model_state_file = config.TRAIN.RESUME.FILE
        if model_state_file == '':
            print(" => no checkpoint found")
        checkpoint = torch.load(model_state_file, map_location='cpu')
        if 'state_dict' in checkpoint.keys():
            model.load_state_dict(checkpoint['state_dict'])
            last_epoch = checkpoint['epoch']
            # optimizer.load_state_dict(checkpoint['optimizer'])
            # lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
        else:
            model.load_state_dict(checkpoint)

之后用写了一个函数用来打印模型参数,这部分也放在后边讲:[2]

model_info(model)

之后载入训练集和测试集,参数均从config文件中进行读取,关于数据集之后会更进一步分析。

	train_dataset = get_dataset(config)(config, is_train=True)
    train_loader = DataLoader(
        dataset=train_dataset,
        batch_size=config.TRAIN.BATCH_SIZE_PER_GPU,
        shuffle=config.TRAIN.SHUFFLE,
        num_workers=config.WORKERS,
        pin_memory=config.PIN_MEMORY,
    )

    val_dataset = get_dataset(config)(config, is_train=False)
    val_loader = DataLoader(
        dataset=val_dataset,
        batch_size=config.TEST.BATCH_SIZE_PER_GPU,
        shuffle=config.TEST.SHUFFLE,
        num_workers=config.WORKERS,
        pin_memory=config.PIN_MEMORY,
    )

终于到了训练部分.
utils.strLabelConverter是ctc部分,是将数据集中lable和字符lable的相互转化,这个之后解析[3]
之后有两个重要函数function.train和function.validate,分别用来训练和测试
最后保存模型,这里只保存模型的参数。

	converter = utils.strLabelConverter(config.DATASET.ALPHABETS)
    for epoch in range(last_epoch, config.TRAIN.END_EPOCH):

        function.train(config, train_loader, train_dataset, converter, model, criterion, optimizer, device, epoch, writer_dict, output_dict)
        lr_scheduler.step()

        acc = function.validate(config, val_loader, val_dataset, converter, model, criterion, device, epoch, writer_dict, output_dict)

        is_best = acc > best_acc
        best_acc = max(acc, best_acc)

        print("is best:", is_best)
        print("best acc is:", best_acc)
        # save checkpoint

        torch.save(
            {
                "state_dict": model.state_dict(),
                "epoch": epoch + 1,
                # "optimizer": optimizer.state_dict(),
                # "lr_scheduler": lr_scheduler.state_dict(),
                "best_acc": best_acc,
            },  os.path.join(output_dict['chs_dir'], "checkpoint_{}_acc_{:.4f}.pth".format(epoch, acc))
        )

下面来解析function.train
enumerate()用于将可遍历的数据对象组合为一个索引序列,同时列出数据和数据下标,inp指输入图片,idx指其标签。
这个函数主要分为计算时间,计算inferernce也就是模型输出,计算loss以及更新参数。
值得注意的是在计算ctcloss时要先计算inferernce的长度(batch*seq)和label的长度(一个batch总的lable长度)

def train(config, train_loader, dataset, converter, model, criterion, optimizer, device, epoch, writer_dict=None, output_dict=None):

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()

    model.train()

    end = time.time()
    for i, (inp, idx) in enumerate(train_loader):
        # measure data time
        data_time.update(time.time() - end)

        labels = utils.get_batch_label(dataset, idx)
        inp = inp.to(device)

        # inference
        preds = model(inp).cpu()

        # compute loss
        batch_size = inp.size(0)
        text, length = converter.encode(labels)                    # length = 一个batch中的总字符长度, text = 一个batch中的字符所对应的下标
        preds_size = torch.IntTensor([preds.size(0)] * batch_size) # timestep * batchsize
        loss = criterion(preds, text, preds_size, length)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.update(loss.item(), inp.size(0))

        batch_time.update(time.time()-end)
        if i % config.PRINT_FREQ == 0:
            msg = 'Epoch: [{0}][{1}/{2}]\t' \
                  'Time {batch_time.val:.3f}s ({batch_time.avg:.3f}s)\t' \
                  'Speed {speed:.1f} samples/s\t' \
                  'Data {data_time.val:.3f}s ({data_time.avg:.3f}s)\t' \
                  'Loss {loss.val:.5f} ({loss.avg:.5f})\t'.format(
                      epoch, i, len(train_loader), batch_time=batch_time,
                      speed=inp.size(0)/batch_time.val,
                      data_time=data_time, loss=losses)
            print(msg)

            if writer_dict:
                writer = writer_dict['writer']
                global_steps = writer_dict['train_global_steps']
                writer.add_scalar('train_loss', losses.avg, global_steps)
                writer_dict['train_global_steps'] = global_steps + 1

        end = time.time()

下面来看function.validate,这个函数用来在一个epoch后,在测试集上测试精确度,计算流程与train基本相同,只是要在前面加上model.eval()和torch.no_grad(),保证测试时参数不变。
这里讲一下accuracy的计算方式,函数的输出为[seq_len, batch, nh],利用torch.max求出nh中最大的索引,作为这一时刻的输出,之后利用converter.decode将输出转化为str,与lable转化成的str对比,正确的字数/总字数,即为精确度。
首先输出的分别为一组示例,有config.TEST.NUM_TEST_DISP个,输出有三个为:不经过合并的原始输出,经过ctc合并和原始输出,标签。
在这里插入图片描述
之后输出正确的字符数和总字符数,是否为最佳,和最佳准确度。
在这里插入图片描述

def validate(config, val_loader, dataset, converter, model, criterion, device, epoch, writer_dict, output_dict):

    losses = AverageMeter()
    model.eval()

    n_correct = 0
    with torch.no_grad():
        for i, (inp, idx) in enumerate(val_loader):

            labels = utils.get_batch_label(dataset, idx)
            inp = inp.to(device)

            # inference
            preds = model(inp).cpu()

            # compute loss
            batch_size = inp.size(0)
            text, length = converter.encode(labels)
            preds_size = torch.IntTensor([preds.size(0)] * batch_size)
            loss = criterion(preds, text, preds_size, length)

            losses.update(loss.item(), inp.size(0))
			#返回每一列中最大值的那个元素,且返回索引
            _, preds = preds.max(2)
            preds = preds.transpose(1, 0).contiguous().view(-1)
            sim_preds = converter.decode(preds.data, preds_size.data, raw=False)
            for pred, target in zip(sim_preds, labels):
                if pred == target:
                    n_correct += 1

            if (i + 1) % config.PRINT_FREQ == 0:
                print('Epoch: [{0}][{1}/{2}]'.format(epoch, i, len(val_loader)))

            if i == config.TEST.NUM_TEST_BATCH:
                break

    raw_preds = converter.decode(preds.data, preds_size.data, raw=True)[:config.TEST.NUM_TEST_DISP]
    for raw_pred, pred, gt in zip(raw_preds, sim_preds, labels):
        print('%-20s => %-20s, gt: %-20s' % (raw_pred, pred, gt))

    num_test_sample = config.TEST.NUM_TEST_BATCH * config.TEST.BATCH_SIZE_PER_GPU
    if num_test_sample > len(dataset):
        num_test_sample = len(dataset)

    print("[#correct:{} / #total:{}]".format(n_correct, num_test_sample))
    accuracy = n_correct / float(num_test_sample)
    print('Test loss: {:.4f}, accuray: {:.4f}'.format(losses.avg, accuracy))

    if writer_dict:
        writer = writer_dict['writer']
        global_steps = writer_dict['valid_global_steps']
        writer.add_scalar('valid_acc', accuracy, global_steps)
        writer_dict['valid_global_steps'] = global_steps + 1

    return accuracy

(4)demo

这部分是利用训练好的模型进行识别。
首先是利用argparse,将你配置,训练好的模型,和需要识别的图片文件夹输入。

def parse_arg():
    parser = argparse.ArgumentParser(description="demo")

    parser.add_argument('--cfg', help='experiment configuration filename', type=str, default='lib/config/360CC_config.yaml')
    parser.add_argument('--image_path', type=str, default='images/test.png', help='the path to your image')
    parser.add_argument('--checkpoint', type=str, default='output/checkpoints/mixed_second_finetune_acc_97P7.pth',
                        help='the path to your checkpoints')

    args = parser.parse_args()

    with open(args.cfg, 'r') as f:
        config = yaml.load(f)
        config = edict(config)

    config.DATASET.ALPHABETS = alphabets.alphabet
    config.MODEL.NUM_CLASSES = len(config.DATASET.ALPHABETS)

    return config, args

下面这个函数是对图像进行预处理,并对处理后的图片进行识别
预处理首先将图片的高度resize为32,并且对长度以同比例进行缩放。之后保持高度不边,对长度进行一个缩放。
之后对图片的像素值进行一个标准化,方法是减去均值除以方差。


def recognition(config, img, model, converter, device):

    # github issues: https://github.com/Sierkinhane/CRNN_Chinese_Characters_Rec/issues/211
    h, w = img.shape
    # fisrt step: resize the height and width of image to (32, x)
    img = cv2.resize(img, (0, 0), fx=config.MODEL.IMAGE_SIZE.H / h, fy=config.MODEL.IMAGE_SIZE.H / h, interpolation=cv2.INTER_CUBIC)

    # second step: keep the ratio of image's text same with training
    h, w = img.shape
    w_cur = int(img.shape[1] / (config.MODEL.IMAGE_SIZE.OW / config.MODEL.IMAGE_SIZE.W))
    img = cv2.resize(img, (0, 0), fx=w_cur / w, fy=1.0, interpolation=cv2.INTER_CUBIC)
    img = np.reshape(img, (config.MODEL.IMAGE_SIZE.H, w_cur, 1))

    # normalize
    img = img.astype(np.float32)
    img = (img / 255. - config.DATASET.MEAN) / config.DATASET.STD
    img = img.transpose([2, 0, 1])
    img = torch.from_numpy(img)

    img = img.to(device)
    img = img.view(1, *img.size())
    model.eval()
    preds = model(img)
    print(preds.shape)
    _, preds = preds.max(2)
    preds = preds.transpose(1, 0).contiguous().view(-1)

    preds_size = Variable(torch.IntTensor([preds.size(0)]))
    sim_pred = converter.decode(preds.data, preds_size.data, raw=False)

    print('results: {0}'.format(sim_pred))



之后主程序就是一些常规的载入模型和读取图片的操作,完成识别

if __name__ == '__main__':

    config, args = parse_arg()
    device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')

    model = crnn.get_crnn(config).to(device)
    print('loading pretrained model from {0}'.format(args.checkpoint))
    checkpoint = torch.load(args.checkpoint)
    if 'state_dict' in checkpoint.keys():
        model.load_state_dict(checkpoint['state_dict'])
    else:
        model.load_state_dict(checkpoint)

    started = time.time()

    img = cv2.imread(args.image_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    converter = utils.strLabelConverter(config.DATASET.ALPHABETS)

    recognition(config, img, model, converter, device)

    finished = time.time()
    print('elapsed time: {0}'.format(finished - started))

总结

实验结果

利用300W张图片的训练集(训练显卡为2080ti,batch为128时,显存占用为6GB,每轮用时大概2小时),训练17轮后,在测试集(32*10000张)上准确率为99.22%。这里贴一张25轮后训练的结果,可以看到loss已经收敛。在这里插入图片描述
我训练好的模型:链接:https://pan.baidu.com/s/1a0BKxFEOKXXnL8wgHbbkJg
提取码:qwer

下面是一些例子:
在这里插入图片描述

存在的问题

(1)由于训练集中基本为白底黑字的图片,面对下面这些图片时识别效果很差
在这里插入图片描述
在这里插入图片描述
(2)此外,如果在一张图片内文字的尺度变化,文字间隔变化,也会使效果很差在这里插入图片描述
当然这一点可以在文本检测模块处进行处理,使得分割出的图片内文本尺度接近。
(3)目前模型对仅汉语的识别效果很好,对多语言文本的识别效果也很一般

;