论文链接:
http://arxiv.org/abs/1507.05717
github项目:
https://github.com/bgshih/crnn#train-a-new-model
https://github.com/meijieru/crnn.pytorch
原理
(1)网络结构
1.首先对图像进行预处理,高度必须为16的倍数(这里为32),将输入图像缩放至:32W3
2.之后利用CNN提取后图像卷积特征,得到的大小为:1W/4512
3.以seq_len=W/4, input_size=512送入LSTM,提取序列特征,得到:W/4*n的后验概率矩阵
4最后利用CTC,使标签和输出无需一一对应,也能进行训练。
网络结构十分简单,就是常用的CNN和BLSTM,无需多言,下面来看一下CTC是什么
(2)CTC
全称是Connectionist Temporal Classification,用来解决语音或字符识别标注困难的问题,引入-(blank)使得多种输出都可以对应同一label。
原理的难点在于ctcloss的计算和训练,这里借鉴了HMM的forward-backward算法,具体可以参考,写的非常详尽。
一文读懂CRNN+CTC文字识别
使用时pytorch有已经写好的ctcloss函数:
'''
初始化
blank:空白标签所在的label值,默认为0,需要根据实际的标签定义进行设定;
reduction:处理output losses的方式,string类型,
可选'none'表示对output losses不做任何处理,
'mean' 则对output losses取平均值处理,
'sum'则是对output losses求和处理,默认为'mean' 。
'''
ctc_loss = nn.CTCLoss(blank=len(CHARS)-1, reduction='mean')
#计算
'''
log_probs:shape为(T, N, C)的模型输出张量,
T表示CTCLoss的输入长度也即输出序列长度,
N表示训练的batch size长度,
C则表示包含有空白标签的所有要预测的字符集总长度.
targets:shape为(N, S) 或(sum(target_lengths))的张量,
其中第一种类型,N表示训练的batch size长度,S则为标签长度,
第二种类型,则为所有标签长度之和.
input_lengths:shape为(N)的张量或元组,
但每一个元素的长度必须等于T即输出序列长度,一般来说模型输出序列固定后则该张量或元组的元素值均相同.
target_lengths:shape为(N)的张量或元组,
其每一个元素指示每个训练输入序列的标签长度,但标签长度是可以变化的.
'''
loss = ctc_loss(log_probs, targets, input_lengths, target_lengths)
实战
代码解析
作者的github网址为https://github.com/meijieru/crnn.pytorch
(1)数据集
我们来看一下训练需要用的中文文本识别数据集。
训练集包括3279606图片,测试集包括364400张图片,图片形式如下:
字母表txt,内容为数据集内的所有文字,包括汉字,数字,英文字母,标点和blank,共5990个字符,形式如下:
训练集和测试集的标签txt,内容为这张图片的名字和其包含文字对应字母表的位置。
对于数据的预处理主要是对于labletxt的处理,作者定义了strLabelConverter类,将两者进行转换,encode是将str转化为lable;decod将lable转化为str,decode中用到了ctc中的对应规则。
class strLabelConverter(object):
"""Convert between str and label.
NOTE:
Insert `blank` to the alphabet for CTC.
Args:
alphabet (str): set of the possible characters.
ignore_case (bool, default=True): whether or not to ignore all of the case.
"""
def __init__(self, alphabet, ignore_case=False):
self._ignore_case = ignore_case
if self._ignore_case:
alphabet = alphabet.lower()
self.alphabet = alphabet + '-' # for `-1` index
self.dict = {}
for i, char in enumerate(alphabet):
# NOTE: 0 is reserved for 'blank' required by wrap_ctc
self.dict[char] = i + 1
def encode(self, text):
"""Support batch or single str.
Args:
text (str or list of str): texts to convert.
Returns:
torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts.
torch.IntTensor [n]: length of each text.
"""
length = []
result = []
decode_flag = True if type(text[0])==bytes else False
for item in text:
if decode_flag:
item = item.decode('utf-8','strict')
length.append(len(item))
for char in item:
index = self.dict[char]
result.append(index)
text = result
return (torch.IntTensor(text), torch.IntTensor(length))
def decode(self, t, length, raw=False):
"""Decode encoded texts back into strs.
Args:
torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts.
torch.IntTensor [n]: length of each text.
Raises:
AssertionError: when the texts and its length does not match.
Returns:
text (str or list of str): texts to convert.
"""
if length.numel() == 1:
length = length[0]
assert t.numel() == length, "text with length: {} does not match declared length: {}".format(t.numel(), length)
if raw:
# ''.join将序列中的元素以指定的字符连接生成一个新的字符串。
return ''.join([self.alphabet[i - 1] for i in t])
else:
char_list = []
for i in range(length):
if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])):
char_list.append(self.alphabet[t[i] - 1])
return ''.join(char_list)
else:
# batch mode
assert t.numel() == length.sum(), "texts with length: {} does not match declared length: {}".format(t.numel(), length.sum())
texts = []
index = 0
for i in range(length.numel()):
l = length[i]
texts.append(
self.decode(
t[index:index + l], torch.IntTensor([l]), raw=raw))
index += l
return texts
(2)Model
之后解析crnn的模型
首先是定义rnn的类,输入为输出,隐层和输出的特征维数,由一个BiLSTM和一个全连接层组成,方便下一步直接调用
class BidirectionalLSTM(nn.Module):
# Inputs hidden units Out
def __init__(self, nIn, nHidden, nOut):
super(BidirectionalLSTM, self).__init__()
self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True)
self.embedding = nn.Linear(nHidden * 2, nOut)
def forward(self, input):
recurrent, _ = self.rnn(input)
#seq_len, batch, hidden_size * num_directions
T, b, h = recurrent.size()
t_rec = recurrent.view(T * b, h)
output = self.embedding(t_rec) # [T * b, nOut]
output = output.view(T, b, -1)
return output
之后定义整体模型的类
输入分别为图片的高,config里为32;输入的channel,这里为1;rnn输出特征的维数,就是字母表的大小。
作者写了一个convRelu的函数,当i等于0时输入通道为送入图片的通道数,否则为上一层的输出通道数,每层的输出通道在nm中,卷积核大小为3,步长为1,padding为1,使用relu为激活函数。
最终的cnn模型与VGG16基本相同,rnn模型为两个bilstm级联。
从cnn得到的特征,以width为seq,batch不变,channel为输入特征维度,来送入rnn,输出为[seq_len, batch, nh]的概率矩阵
class CRNN(nn.Module):
def __init__(self, imgH, nc, nclass, nh, n_rnn=2, leakyRelu=False):
super(CRNN, self).__init__()
assert imgH % 16 == 0, 'imgH has to be a multiple of 16'
ks = [3, 3, 3, 3, 3, 3, 2]
ps = [1, 1, 1, 1, 1, 1, 0]
ss = [1, 1, 1, 1, 1, 1, 1]
nm = [64, 128, 256, 256, 512, 512, 512]
cnn = nn.Sequential()
def convRelu(i, batchNormalization=False):
#i==0成立则nIn = nc,否则nIn = nm[i - 1]
nIn = nc if i == 0 else nm[i - 1]
nOut = nm[i]
cnn.add_module('conv{0}'.format(i),
nn.Conv2d(nIn, nOut, ks[i], ss[i], ps[i]))
if batchNormalization:
cnn.add_module('batchnorm{0}'.format(i), nn.BatchNorm2d(nOut))
if leakyRelu:
cnn.add_module('relu{0}'.format(i),
nn.LeakyReLU(0.2, inplace=True))
else:
cnn.add_module('relu{0}'.format(i), nn.ReLU(True))
convRelu(0)
cnn.add_module('pooling{0}'.format(0), nn.MaxPool2d(2, 2)) # 64x16x64
convRelu(1)
cnn.add_module('pooling{0}'.format(1), nn.MaxPool2d(2, 2)) # 128x8x32
convRelu(2, True)
convRelu(3)
cnn.add_module('pooling{0}'.format(2),
nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 256x4x16
convRelu(4, True)
convRelu(5)
cnn.add_module('pooling{0}'.format(3),
nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 512x2x16
convRelu(6, True) # 512x1x16
self.cnn = cnn
self.rnn = nn.Sequential(
BidirectionalLSTM(512, nh, nh),
BidirectionalLSTM(nh, nh, nclass))
def forward(self, input):
# conv features
conv = self.cnn(input)
b, c, h, w = conv.size()
print(conv.size())
assert h == 1, "the height of conv must be 1"
conv = conv.squeeze(2) # b *512 * width
conv = conv.permute(2, 0, 1) # [w, b, c]
output = F.log_softmax(self.rnn(conv), dim=2)
return output
之后是参数权重初始化和类的实例化
def weights_init(m):
#get class name
classname = m.__class__.__name__
if classname.find('Conv') != -1:
m.weight.data.normal_(0.0, 0.02)
elif classname.find('BatchNorm') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
def get_crnn(config):
model = CRNN(config.MODEL.IMAGE_SIZE.H, 1, config.MODEL.NUM_CLASSES + 1, config.MODEL.NUM_HIDDEN)
model.apply(weights_init)
return model
(3)Train
首先定义了一个读取config的函数:
argparse是用来从命令行传入参数的,其用法可以参考:https://zhuanlan.zhihu.com/p/56922793
def parse_arg():
parser = argparse.ArgumentParser(description="train crnn")
parser.add_argument('--cfg', help='experiment configuration filename', required=True, type=str)
args = parser.parse_args()
with open(args.cfg, 'r') as f:
# config = yaml.load(f, Loader=yaml.FullLoader)
config = yaml.load(f)
config = edict(config)
config.DATASET.ALPHABETS = alphabets.alphabet
config.MODEL.NUM_CLASSES = len(config.DATASET.ALPHABETS)
return config
之后是创建日志文件,这部分在之后解析:[1]
# create output folder
output_dict = utils.create_log_folder(config, phase='train')
# cudnn
cudnn.benchmark = config.CUDNN.BENCHMARK
cudnn.deterministic = config.CUDNN.DETERMINISTIC
cudnn.enabled = config.CUDNN.ENABLED
# writer dict
writer_dict = {
'writer': SummaryWriter(log_dir=output_dict['tb_dir']),
'train_global_steps': 0,
'valid_global_steps': 0,
}
之后是一些初始化工作,包括模型,损失函数和优化器:
有nvidia显卡,模型运行在gpu上;损失函数是上文提到的ctc。
model = crnn.get_crnn(config)
# get device
if torch.cuda.is_available():
device = torch.device("cuda:{}".format(config.GPUID))
else:
device = torch.device("cpu:0")
model = model.to(device)
# define loss function
criterion = torch.nn.CTCLoss()
这里重点讲一下优化器的初始化
关于优化器的知识可以参考:https://blog.csdn.net/weixin_40170902/article/details/80092628
优化器config里设置为adam
并且利用torch.optim.lr_scheduler来调整学习率,在指定epoch后将lr降低指定倍数,可以参考https://blog.csdn.net/qyhaill/article/details/103043637
optimizer = utils.get_optimizer(config, model)
if isinstance(config.TRAIN.LR_STEP, list):
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer, config.TRAIN.LR_STEP,
config.TRAIN.LR_FACTOR, last_epoch-1
)
else:
lr_scheduler = torch.optim.lr_scheduler.StepLR(
optimizer, config.TRAIN.LR_STEP,
config.TRAIN.LR_FACTOR, last_epoch - 1
)
之后是finetune和resume的选择,以及与训练模型的载入。
fintune讲解:https://zhuanlan.zhihu.com/p/35890660,这里的fintune冻结了cnn,其参数不更新。
if config.TRAIN.FINETUNE.IS_FINETUNE:
model_state_file = config.TRAIN.FINETUNE.FINETUNE_CHECKPOINIT
if model_state_file == '':
print(" => no checkpoint found")
checkpoint = torch.load(model_state_file, map_location='cpu')
if 'state_dict' in checkpoint.keys():
checkpoint = checkpoint['state_dict']
from collections import OrderedDict
model_dict = OrderedDict()
for k, v in checkpoint.items():
if 'cnn' in k:
model_dict[k[4:]] = v
model.cnn.load_state_dict(model_dict)
if config.TRAIN.FINETUNE.FREEZE:
for p in model.cnn.parameters():
p.requires_grad = False
elif config.TRAIN.RESUME.IS_RESUME:
model_state_file = config.TRAIN.RESUME.FILE
if model_state_file == '':
print(" => no checkpoint found")
checkpoint = torch.load(model_state_file, map_location='cpu')
if 'state_dict' in checkpoint.keys():
model.load_state_dict(checkpoint['state_dict'])
last_epoch = checkpoint['epoch']
# optimizer.load_state_dict(checkpoint['optimizer'])
# lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
else:
model.load_state_dict(checkpoint)
之后用写了一个函数用来打印模型参数,这部分也放在后边讲:[2]
model_info(model)
之后载入训练集和测试集,参数均从config文件中进行读取,关于数据集之后会更进一步分析。
train_dataset = get_dataset(config)(config, is_train=True)
train_loader = DataLoader(
dataset=train_dataset,
batch_size=config.TRAIN.BATCH_SIZE_PER_GPU,
shuffle=config.TRAIN.SHUFFLE,
num_workers=config.WORKERS,
pin_memory=config.PIN_MEMORY,
)
val_dataset = get_dataset(config)(config, is_train=False)
val_loader = DataLoader(
dataset=val_dataset,
batch_size=config.TEST.BATCH_SIZE_PER_GPU,
shuffle=config.TEST.SHUFFLE,
num_workers=config.WORKERS,
pin_memory=config.PIN_MEMORY,
)
终于到了训练部分.
utils.strLabelConverter是ctc部分,是将数据集中lable和字符lable的相互转化,这个之后解析[3]
之后有两个重要函数function.train和function.validate,分别用来训练和测试
最后保存模型,这里只保存模型的参数。
converter = utils.strLabelConverter(config.DATASET.ALPHABETS)
for epoch in range(last_epoch, config.TRAIN.END_EPOCH):
function.train(config, train_loader, train_dataset, converter, model, criterion, optimizer, device, epoch, writer_dict, output_dict)
lr_scheduler.step()
acc = function.validate(config, val_loader, val_dataset, converter, model, criterion, device, epoch, writer_dict, output_dict)
is_best = acc > best_acc
best_acc = max(acc, best_acc)
print("is best:", is_best)
print("best acc is:", best_acc)
# save checkpoint
torch.save(
{
"state_dict": model.state_dict(),
"epoch": epoch + 1,
# "optimizer": optimizer.state_dict(),
# "lr_scheduler": lr_scheduler.state_dict(),
"best_acc": best_acc,
}, os.path.join(output_dict['chs_dir'], "checkpoint_{}_acc_{:.4f}.pth".format(epoch, acc))
)
下面来解析function.train
enumerate()用于将可遍历的数据对象组合为一个索引序列,同时列出数据和数据下标,inp指输入图片,idx指其标签。
这个函数主要分为计算时间,计算inferernce也就是模型输出,计算loss以及更新参数。
值得注意的是在计算ctcloss时要先计算inferernce的长度(batch*seq)和label的长度(一个batch总的lable长度)
def train(config, train_loader, dataset, converter, model, criterion, optimizer, device, epoch, writer_dict=None, output_dict=None):
batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter()
model.train()
end = time.time()
for i, (inp, idx) in enumerate(train_loader):
# measure data time
data_time.update(time.time() - end)
labels = utils.get_batch_label(dataset, idx)
inp = inp.to(device)
# inference
preds = model(inp).cpu()
# compute loss
batch_size = inp.size(0)
text, length = converter.encode(labels) # length = 一个batch中的总字符长度, text = 一个batch中的字符所对应的下标
preds_size = torch.IntTensor([preds.size(0)] * batch_size) # timestep * batchsize
loss = criterion(preds, text, preds_size, length)
optimizer.zero_grad()
loss.backward()
optimizer.step()
losses.update(loss.item(), inp.size(0))
batch_time.update(time.time()-end)
if i % config.PRINT_FREQ == 0:
msg = 'Epoch: [{0}][{1}/{2}]\t' \
'Time {batch_time.val:.3f}s ({batch_time.avg:.3f}s)\t' \
'Speed {speed:.1f} samples/s\t' \
'Data {data_time.val:.3f}s ({data_time.avg:.3f}s)\t' \
'Loss {loss.val:.5f} ({loss.avg:.5f})\t'.format(
epoch, i, len(train_loader), batch_time=batch_time,
speed=inp.size(0)/batch_time.val,
data_time=data_time, loss=losses)
print(msg)
if writer_dict:
writer = writer_dict['writer']
global_steps = writer_dict['train_global_steps']
writer.add_scalar('train_loss', losses.avg, global_steps)
writer_dict['train_global_steps'] = global_steps + 1
end = time.time()
下面来看function.validate,这个函数用来在一个epoch后,在测试集上测试精确度,计算流程与train基本相同,只是要在前面加上model.eval()和torch.no_grad(),保证测试时参数不变。
这里讲一下accuracy的计算方式,函数的输出为[seq_len, batch, nh],利用torch.max求出nh中最大的索引,作为这一时刻的输出,之后利用converter.decode将输出转化为str,与lable转化成的str对比,正确的字数/总字数,即为精确度。
首先输出的分别为一组示例,有config.TEST.NUM_TEST_DISP个,输出有三个为:不经过合并的原始输出,经过ctc合并和原始输出,标签。
之后输出正确的字符数和总字符数,是否为最佳,和最佳准确度。
def validate(config, val_loader, dataset, converter, model, criterion, device, epoch, writer_dict, output_dict):
losses = AverageMeter()
model.eval()
n_correct = 0
with torch.no_grad():
for i, (inp, idx) in enumerate(val_loader):
labels = utils.get_batch_label(dataset, idx)
inp = inp.to(device)
# inference
preds = model(inp).cpu()
# compute loss
batch_size = inp.size(0)
text, length = converter.encode(labels)
preds_size = torch.IntTensor([preds.size(0)] * batch_size)
loss = criterion(preds, text, preds_size, length)
losses.update(loss.item(), inp.size(0))
#返回每一列中最大值的那个元素,且返回索引
_, preds = preds.max(2)
preds = preds.transpose(1, 0).contiguous().view(-1)
sim_preds = converter.decode(preds.data, preds_size.data, raw=False)
for pred, target in zip(sim_preds, labels):
if pred == target:
n_correct += 1
if (i + 1) % config.PRINT_FREQ == 0:
print('Epoch: [{0}][{1}/{2}]'.format(epoch, i, len(val_loader)))
if i == config.TEST.NUM_TEST_BATCH:
break
raw_preds = converter.decode(preds.data, preds_size.data, raw=True)[:config.TEST.NUM_TEST_DISP]
for raw_pred, pred, gt in zip(raw_preds, sim_preds, labels):
print('%-20s => %-20s, gt: %-20s' % (raw_pred, pred, gt))
num_test_sample = config.TEST.NUM_TEST_BATCH * config.TEST.BATCH_SIZE_PER_GPU
if num_test_sample > len(dataset):
num_test_sample = len(dataset)
print("[#correct:{} / #total:{}]".format(n_correct, num_test_sample))
accuracy = n_correct / float(num_test_sample)
print('Test loss: {:.4f}, accuray: {:.4f}'.format(losses.avg, accuracy))
if writer_dict:
writer = writer_dict['writer']
global_steps = writer_dict['valid_global_steps']
writer.add_scalar('valid_acc', accuracy, global_steps)
writer_dict['valid_global_steps'] = global_steps + 1
return accuracy
(4)demo
这部分是利用训练好的模型进行识别。
首先是利用argparse,将你配置,训练好的模型,和需要识别的图片文件夹输入。
def parse_arg():
parser = argparse.ArgumentParser(description="demo")
parser.add_argument('--cfg', help='experiment configuration filename', type=str, default='lib/config/360CC_config.yaml')
parser.add_argument('--image_path', type=str, default='images/test.png', help='the path to your image')
parser.add_argument('--checkpoint', type=str, default='output/checkpoints/mixed_second_finetune_acc_97P7.pth',
help='the path to your checkpoints')
args = parser.parse_args()
with open(args.cfg, 'r') as f:
config = yaml.load(f)
config = edict(config)
config.DATASET.ALPHABETS = alphabets.alphabet
config.MODEL.NUM_CLASSES = len(config.DATASET.ALPHABETS)
return config, args
下面这个函数是对图像进行预处理,并对处理后的图片进行识别
预处理首先将图片的高度resize为32,并且对长度以同比例进行缩放。之后保持高度不边,对长度进行一个缩放。
之后对图片的像素值进行一个标准化,方法是减去均值除以方差。
def recognition(config, img, model, converter, device):
# github issues: https://github.com/Sierkinhane/CRNN_Chinese_Characters_Rec/issues/211
h, w = img.shape
# fisrt step: resize the height and width of image to (32, x)
img = cv2.resize(img, (0, 0), fx=config.MODEL.IMAGE_SIZE.H / h, fy=config.MODEL.IMAGE_SIZE.H / h, interpolation=cv2.INTER_CUBIC)
# second step: keep the ratio of image's text same with training
h, w = img.shape
w_cur = int(img.shape[1] / (config.MODEL.IMAGE_SIZE.OW / config.MODEL.IMAGE_SIZE.W))
img = cv2.resize(img, (0, 0), fx=w_cur / w, fy=1.0, interpolation=cv2.INTER_CUBIC)
img = np.reshape(img, (config.MODEL.IMAGE_SIZE.H, w_cur, 1))
# normalize
img = img.astype(np.float32)
img = (img / 255. - config.DATASET.MEAN) / config.DATASET.STD
img = img.transpose([2, 0, 1])
img = torch.from_numpy(img)
img = img.to(device)
img = img.view(1, *img.size())
model.eval()
preds = model(img)
print(preds.shape)
_, preds = preds.max(2)
preds = preds.transpose(1, 0).contiguous().view(-1)
preds_size = Variable(torch.IntTensor([preds.size(0)]))
sim_pred = converter.decode(preds.data, preds_size.data, raw=False)
print('results: {0}'.format(sim_pred))
之后主程序就是一些常规的载入模型和读取图片的操作,完成识别
if __name__ == '__main__':
config, args = parse_arg()
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
model = crnn.get_crnn(config).to(device)
print('loading pretrained model from {0}'.format(args.checkpoint))
checkpoint = torch.load(args.checkpoint)
if 'state_dict' in checkpoint.keys():
model.load_state_dict(checkpoint['state_dict'])
else:
model.load_state_dict(checkpoint)
started = time.time()
img = cv2.imread(args.image_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
converter = utils.strLabelConverter(config.DATASET.ALPHABETS)
recognition(config, img, model, converter, device)
finished = time.time()
print('elapsed time: {0}'.format(finished - started))
总结
实验结果
利用300W张图片的训练集(训练显卡为2080ti,batch为128时,显存占用为6GB,每轮用时大概2小时),训练17轮后,在测试集(32*10000张)上准确率为99.22%。这里贴一张25轮后训练的结果,可以看到loss已经收敛。
我训练好的模型:链接:https://pan.baidu.com/s/1a0BKxFEOKXXnL8wgHbbkJg
提取码:qwer
下面是一些例子:
存在的问题
(1)由于训练集中基本为白底黑字的图片,面对下面这些图片时识别效果很差
(2)此外,如果在一张图片内文字的尺度变化,文字间隔变化,也会使效果很差
当然这一点可以在文本检测模块处进行处理,使得分割出的图片内文本尺度接近。
(3)目前模型对仅汉语的识别效果很好,对多语言文本的识别效果也很一般