graviti
PlatformMarketplaceSolutionsResourcesOpen DatasetsCommunityCompany

[GENERAL] A quick guide to OCR with Transformer

Published at2022-01-19

Classification tasks seem to have become the earmark of Transformer applications in computer vision, which makes us wonder what else Transformer is capable of. This article will talk you through how to use Transformer to conduct simple optical character recognition (OCR) tasks with an English word-recognition dataset. We have prepared a complete set a code with detailed discussions on how to design your model and train its architecture. You will learn how Transformer can be applied in CV tasks that are more complex than classification.

An OCR task normally involves the following files:
analysis_recognition_dataset.py (dataset analysis script)
ocr_by_transformer.py (OCR training script)
transformer.py (model file of the Transformer)
train_utils.py (complementary functions for model training, such as loss and optimizer)
Among these files, ocr_by_transformer.py is our main training script, which depends on train_utils.py and transformer.py to build the Transformer and train it for our OCR task.

Dataset

The dataset used in this article is based on Task 4.3: Word Recognition of ICDAR2015 Incidental Scene Text, which is a famous OCR dataset that features texts found in natural scenarios. For learning purposes, we will use a simplified version of the dataset and we have hosted it on Graviti, an open-dataset platform that supports dataset sharing and version control. If you have any questions about the dataset, feel free to leave a comment on Graviti's page.

This dataset includes texts extracted from images that are collected in production environments, with 4326 images in the training set and 1992 in the validation set, all cropped based on prescribed bounding boxes in the original files. Most of the texts are located in the center of the original images.

A typical image in the dataset has the following format:

Labels are stored in CLASSIFICATION, and later in the article we'll show that we can extract a complete list of all labels in text form with a few lines of code.

We will now go over a quick guide to using this dataset:

  • Install TensorBay
pip3 install tensorbay
  • Open this link which will take you to the dataset
  • Fork this dataset to your own account
  • On Graviti, go to Developer Tools --> AccessKey --> Create AccessKey, and then copy the Key you just generated
from tensorbay import GAS
from tensorbay.dataset import Dataset

# GAS authorization
KEY = 'Accesskey-***************80a'  # Replace with your own AccessKey
gas = GAS(KEY)

# Acquire dataset
dataset = Dataset("ICDAR2015", gas)
# dataset.enable_cache('./data')  # Uncomment to enable local cache for the data

# Training set and validation set
train_segment = dataset["train"]valid_segment = dataset['valid']

# Data and labels
for data in train_segment:
    # Image data
    img = data.open()
    # Label data
    label = data.label.classification.category
    break

From the above code we will acquire image-pairs in the following format:

In this way we can get the image data and their label very quickly. However, as the program downloads these data every time it runs, it is recommended that you enable local cache to save time. Once you're done with the data you may delete the cache whenever you'd like.

Data analysis

One more thing before we begin: we should conduct a simple analysis of the data at hand so that we can build a better baseline and reduce unnecessary training work.

Run the following to conduct a simple analysis of the dataset:

python analysis_recognition_dataset.py

What this script does is calculate the number of characters in the data and their occurences, length of the longest label, image size, etc., and will export a text file lbl2id_map.txt that details the mapping of the characters and labels.

Let's go over the script.

Note: all the code in this article are stored in this GitHub repo.

First, the script does its preparations, import the libraries it needs, and configure paths for the directories and files it need.

import os
from PIL import Image
import tqdm

from tensorbay import GAS
from tensorbay.dataset import Dataset

# GAS Authorization
KEY = 'Accesskey-************************480a'  # Replace with your own AccessKey
gas = GAS(KEY)
# Acquire dataset and enable local cache
dataset = Dataset("ICDAR2015", gas)
dataset.enable_cache('./data')  # Directory of cache

# Acquire training and validation set
train_segment = dataset["train"]
valid_segment = dataset['valid']

# path for mapping file that stores the mapping between label characters and their id
base_data_dir = './'
lbl2id_map_path = os.path.join(base_data_dir, 'lbl2id_map.txt')

Maximum length of the labels

This function counts the maximum length of the labels in the entire dataset, that is, including training and validation sets.

def statistics_max_len_label(segment):
    """
    Calculate maximum label length
    """
    max_len = -1
    for data in segment:
        lbl_str = data.label.classification.category  # Get label
        lbl_len = len(lbl_str)
        max_len = max_len if max_len > lbl_len else lbl_len
    return max_len

train_max_label_len = statistics_max_len_label(train_segment)  # Longest label in the training set
valid_max_label_len = statistics_max_len_label(valid_segment)  # Longest label in the validation set
max_label_len = max(train_max_label_len, valid_max_label_len)  # Longest label in the entire dataset
print(f"Maximum label length in the dataset is {max_label_len}")

In our current dataset, the maximum label length is 21, which will be an important reference for configuring the time step when we build the model.

All characters and their number of occurrences

The following code records all the characters in the dataset.

def statistics_label_cnt(segment, lbl_cnt_map):
    """
    Calculate what characters are in the labels and their occurrnces
    lbl_cnt_map : dictionary for occurrence counts
    """
    for data in segment:
        lbl_str = data.label.classification.category  # Get label
        for lbl in lbl_str:
                if lbl not in lbl_cnt_map.keys():
                    lbl_cnt_map[lbl] = 1
                else:
                    lbl_cnt_map[lbl] += 1

lbl_cnt_map = dict()  # dictionary for occurrence counts
statistics_label_cnt(train_segment, lbl_cnt_map)  # Occurrence count in the training set
print("All characters in the training set:")
print(lbl_cnt_map)
statistics_label_cnt(valid_segment, lbl_cnt_map)  # Occurrence count in the training and validation sets
print("All characters in the training and validation sets:")
print(lbl_cnt_map)

The output will be:

All characters in the training set:
{'C': 593, 'A': 1189, 'U': 319, 'T': 896, 'I': 861, 'O': 965, 'N': 785, 'D': 383, 'W': 179, 'M': 367, 'E': 1423, 'X': 110, '$': 46, '2': 121, '4': 44, 'L': 745, 'F': 259, 'P': 389, 'R': 836, 'S': 1164, 'a': 843, 'v': 123, 'e': 1057, 'G': 345, "'": 51, 'r': 655, 'k': 96, 's': 557, 'i': 651, 'c': 318, 'V': 158, 'H': 391, '3': 50, '.': 95, '"': 8, '-': 68, ',': 19, 'Y': 229, 't': 563, 'y': 161, 'B': 332, 'u': 293, 'x': 27, 'n': 605, 'g': 171, 'o': 659, 'l': 408, 'd': 258, 'b': 88, 'p': 197, 'K': 163, 'J': 72, '5': 80, '0': 203, '1': 186, 'h': 299, '!': 51, ':': 19, 'f': 133, 'm': 202, '9': 66, '7': 45, 'j': 15, 'z': 12, '´': 3, 'Q': 19, 'Z': 29, '&': 9, ' ': 50, '8': 47, '/': 24, '#': 16, 'w': 97, '?': 5, '6': 40, '[': 2, ']': 2, 'É': 1, 'q': 3, ';': 3, '@': 4, '%': 28, '=': 1, '(': 6, ')': 5, '+': 1}
All characters in the training and validation sets:
{'C': 893, 'A': 1827, 'U': 467, 'T': 1315, 'I': 1241, 'O': 1440, 'N': 1158, 'D': 548, 'W': 288, 'M': 536, 'E': 2215, 'X': 181, '$': 57, '2': 141, '4': 53, 'L': 1120, 'F': 402, 'P': 582, 'R': 1262, 'S': 1752, 'a': 1200, 'v': 169, 'e': 1536, 'G': 521, "'": 70, 'r': 935, 'k': 137, 's': 793, 'i': 924, 'c': 442, 'V': 224, 'H': 593, '3': 69, '.': 132, '"': 8, '-': 87, ',': 25, 'Y': 341, 't': 829, 'y': 231, 'B': 469, 'u': 415, 'x': 38, 'n': 880, 'g': 260, 'o': 955, 'l': 555, 'd': 368, 'b': 129, 'p': 317, 'K': 253, 'J': 100, '5': 105, '0': 258, '1': 231, 'h': 417, '!': 65, ':': 24, 'f': 203, 'm': 278, '9': 76, '7': 62, 'j': 19, 'z': 14, '´': 3, 'Q': 28, 'Z': 36, '&': 15, ' ': 82, '8': 58, '/': 29, '#': 24, 'w': 136, '?': 7, '6': 46, '[': 2, ']': 2, 'É': 2, 'q': 3, ';': 3, '@': 9, '%': 42, '=': 1, '(': 7, ')': 5, '+': 2, 'é': 1}

In the above code, lbl_cnt_map is a dictionary for characters and their number of occurrences, which will be useful to map characters and their id.

From the results, we can see that the validation set contains characters that are not in the training set, such as one occurrence of "é". We may overlook these differences as they are fairly minor, but it's good to keep in mind that a diff check is always necessary.

Building mappings between char and id

An OCR task requires predicting every character in a given image, and for this purpose, we need to establish mappings between characters and their id, so that characters can be transformed to numeric information that the model can read. This process is similar to building a corpora in NLP.

When we build our mappings, we need to initialize three special characters for sentence beginning, end, and padding. Details will be provided later.

After the script is run, the mapping of all characters will be stored in lbl2id_map.txt.

# build char-id mappings
print("Char-id mappings in labels:")

lbl2id_map = dict()
# Initialize 3 special characters
lbl2id_map['☯'] = 0    # padding identifier
lbl2id_map['■'] = 1    # sentence beginning
lbl2id_map['□'] = 2    # sentence end
# generate mappings for the remaining characters
cur_id = 3
for lbl in lbl_cnt_map.keys():
    lbl2id_map[lbl] = cur_id
    cur_id += 1

# save mappings to txt file
with open(lbl2id_map_path, 'w', encoding='utf-8') as writer:  # the encoding is optional, but some devices do not have utf-8 as default
    for lbl in lbl2id_map.keys():
        cur_id = lbl2id_map[lbl]
        print(lbl, cur_id)
        line = lbl + '\t' + str(cur_id) + '\n'
        writer.write(line)

The outputs are:

☯ 0
■ 1
□ 2
C 3
A 4
...
= 85
( 86
) 87
+ 88
é 89

Moreover, analysis_recognition_dataset.py contains another function that reads the mapping txt file and generates dictionaries of char-id and id-char, which will facilitate our model training later.

def load_lbl2id_map(lbl2id_map_path):
    """
    read char-id mapping file and returns lbl->id and id->lbl dicts
    lbl2id_map_path : path of char-id mapping file
    """

    lbl2id_map = dict()
    id2lbl_map = dict()
    with open(lbl2id_map_path, 'r') as reader:
        for line in reader:
            items = line.rstrip().split('\t')
            label = items[0]
            cur_id = int(items[1])
            lbl2id_map[label] = cur_id
            id2lbl_map[cur_id] = label
    return lbl2id_map, id2lbl_map

Image size

Image size is important because we need to figure out an appropriate way to preprocess these images. For example, during target recognition we need to have a calculation over the sizes and scales of the images and the bounding boxes, so that a proper cropping and anchoring strategy can be decided.

We will thus analyze the width, height, and aspect ratio of the images for our future reference.

def read_gas_image(data):
    with data.open() as fp:
        image = Image.open(fp)
    return image

# Get image size
print("Get image size:")

# Initialize parameters
min_h = 1e10
min_w = 1e10
max_h = -1
max_w = -1
min_ratio = 1e10
max_ratio = 0
# Traverse the dataset to collect size info
for data in tqdm.tqdm(train_segment):
    img = read_gas_image(data)  # Read image
    w, h = img.size  # Extract size
    ratio = w / h  # Calculate ratio
    min_h = min_h if min_h <= h else h  # Minimum height
    max_h = max_h if max_h >= h else h  # Maximum height
    min_w = min_w if min_w <= w else w  # Minimum height
    max_w = max_w if max_w >= w else w  # Maximum height
    min_ratio = min_ratio if min_ratio <= ratio else ratio  # Minimum ratio
    max_ratio = max_ratio if max_ratio >= ratio else ratio  # Maximum ratio
# Print info
print('min_h:', min_h)
print('max_h:', max_h)
print('min_w:', min_w)
print('max_w:', max_w)
print('min_ratio:', min_ratio)
print('max_ratio:', max_ratio)

The results for our current dataset are:

min_h: 9
max_h: 295
min_w: 16
max_w: 628
min_ratio: 0.6666666666666666
max_ratio: 8.619047619047619

We can see that most of our images are wide rectangles, and max_ratio exceeding 8 tells us that there are images that are extremely long horizontally.

So far we have conducted a simple analysis of the dataset, and prepared the char2id mapping file for model training. We will now see how to bring Transformer to do CV tasks such as OCR.

Transformer and OCR

Most algorithms are not as difficult as they may appear. What is tricky is that we need to find ways to utilize existing solutions for new problems, and so before we begin to go through the code, let's talk about why Transformer can be motivated to solve OCR tasks.

First, we know that Transformer has been widely used in NLP, and can solve sequence-to-sequence problems like machine translation, as shown below.

NLP machine translation

OCR tasks, on the other hand, can be seen as a sequence-to-sequence task as well. For example, we need to read the following image as "Share", and the only difference with machine translation is that the input sequence is represented as an image.

Share

Therefore, a sequence-to-sequence perspective naturally leads us to use Transformer for OCR. What remains to be resolved is how to construct our image input as what Transformer needs, and we have a solution: word embedding.

Since almost all images in our dataset are horizontal stripes, and the texts within them are horizontally distributed as well, we can consolidate images horizontally and regard each embedding as a feature of a vertical image slice. These embeddings can then be turned into Transformer, and Transformer will use its powerful attention abilities to yield its predictions.

Based on the analysis above, we thus define our model pipeline as illustrated below:

Model Pipeline

The illustration shows that the pipeline is basically the same as machine translation, only that there's now a convoluted network as our backbone to extract image features and yield embeddings. The backbone is the essential part of our pipeline design and will be the focus of the rest of the article.

Implementation

The code for model training is stored in ocr_by_transformer.py, and contains the following components:

  • Dataset construction → image pre-processing, label processing, etc.
  • Model construction → backbone + Transformer;
  • Model training
  • Inference → Greedy Decoder

Preliminaries

We first import the libraries we need.

import os
import time
import copy
from PIL import Image

# dataset-related packages from TensorBay
from tensorbay import GAS
from tensorbay.dataset import Dataset

# torch packages
import torch
import torch.nn as nn
from torch.autograd import Variable
import torchvision.models as models
import torchvision.transforms as transforms

# toolkit packages
from analysis_recognition_dataset import load_lbl2id_map, statistics_max_len_label
from transformer import *
from train_utils import *

Then we configure some basic parameters.

device = torch.device('cuda')  # 'cpu' or 'cuda'
nrof_epochs = 1500  # epochs of interations; can be modified if needed
batch_size = 64     # batch size; can be modified if needed
model_save_path = './log/ex1_ocr_model.pth'

Now we acquire the image data online and read the mapping file for later needs.

# GAS Authoriztion
KEY = 'Accesskey-fd26cc098604c68a99d3bf7f87cd480a'
gas = GAS(KEY)
# Acquire dataset from Graviti
dataset_online = Dataset("ICDAR2015", gas)
dataset_online.enable_cache('./data')  # Enable local cache

# Acquire training and validation sets
train_data_online = dataset_online["train"]
valid_data_online = dataset_online['valid']

# Read label-id mapping file
lbl2id_map_path = os.path.join('./', 'lbl2id_map.txt')
lbl2id_map, id2lbl_map = load_lbl2id_map(lbl2id_map_path)

# Calculate maximum label length for ground truth (gt)
train_max_label_len = statistics_max_len_label(train_data_online)
valid_max_label_len = statistics_max_len_label(valid_data_online)
# Maximum label length then becomes the sequence_len for making gt
sequence_len = max(train_max_label_len, valid_max_label_len)  

Dataset creation

Image preprocessing

For dataset creation, we need to first think about how we should pre-process the images.

Suppose the image size is $$ [batch_size, 3, H_i, W_i], $$ and the network-generated feature image has the size of $$ [batch_size, C_f, H_f, W_f]. $$ Based on our previous analysis, almost all images are horizontal stripes with horizontally lined characters. Hence, in each vertical slice there is normally only one character, and so we won't need a very large vertical resolution, which we set as $$ H_f = 1. $$ The horizontal resolution, however, needs to be larger, because we need to have different embeddings to encode character features on the horizontal axis.

Here, we will use the classic ResNet18 as our backbone. As its downsampling multiplier is 32 and the last layer of the feature map channel number is 512, we have: $$ H_i = H_f * 32 = 32
C_f = 512 $$ How then, should we get the width of the input image? We have two solutions:

Two Solutions

Method 1: Set a fixed size and resize the image with its ratio unchanged. Padding is done in the right periphery.

Method 2: Directly resize the original image to a preset size.

Which do you think is the best method?
We prefer Method 1 because the aspect ratio of the image and the number of characters are roughly proportional. If the aspect ratio is maintained in preprocessing, then the range of each pixel on the feature map corresponding to the character area on the original map is basically stable, which may yield better prediction results.

Here is another detail. You will find in the figure above that in each area with width:height = 1:1 there are basically 2-3 characters, so our actual operation also does not strictly keep the aspect ratio unchanged, but increase it by 3 times. That is, stretch the original image width to 3 times the original, maintain the aspect ratio, and then resize the height to 32.

Why do we need this bit of detail?
Our purpose is to let each character on the image have at least one pixel on the feature map corresponding to it rather than a pixel on the wide dimension of the feature map, while encoding multiple characters in the original image, which will reduce unnecessary difficulties in the prediction of the Transformer. Of course, this is only our own opinion and we welcome discussions on this point in the comment section below.

Now that we have decided on a resizing strategy, what values then should we assign? Based on our previous analysis, the maximum label length is 21 and the widest aspect ratio is 8:6, and so we set the final aspect ratio as 24:1. $$ H_i = H_f * 32 = 32
W_i = 24 * H_i = 768
C_f = 512, H_f = 1, W_f = 24 $$ The relevant code is shown below.

# ----------------
# Image preprocessing
# ----------------
# load image
with img_data.open() as fp:
    img = Image.open(fp).convert('RGB')

# Resize images with roughly the same aspect ratio
# Resize height to 32 times the original, and width with roughly the same ratio but it must be divisible by 32
w, h = img.size
ratio = round((w / h) * 3)   # Multiple width by 3 and round it up
if ratio == 0:
    ratio = 1 
if ratio > self.max_ratio:
    ratio = self.max_ratio
h_new = 32
w_new = h_new * ratio
img_resize = img.resize((w_new, h_new), Image.BILINEAR)

# Do padding to the right half of the image so that the ratio remain fixed as self.max_ratio
img_padd = Image.new('RGB', (32*self.max_ratio, 32), (0,0,0))
img_padd.paste(img_resize, (0, 0)) 

Image augmentation

Augmentation is not our focus here, and we merely conduct the regular augmentation methods such as randomized color transformation and normalization.

Complete code

class Recognition_Dataset(object):

    def __init__(self, segment, lbl2id_map, sequence_len, max_ratio, pad=0):        self.data = segment
        self.lbl2id_map = lbl2id_map
        self.pad = pad   # padding identifier's id; default is 0
        self.sequence_len = sequence_len    # sequence length
        self.max_ratio = max_ratio * 3      # multiply the width by 3

        # define randomized color transformation
        self.color_trans = transforms.ColorJitter(0.1, 0.1, 0.1)
        # define Normalize
        self.trans_Normalize = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.485,0.456,0.406], [0.229,0.224,0.225]),
        ])

    def __getitem__(self, index):
        """ 
        acquire image with corresponding index and its gt label, and conduct data augmentation if needed
        """
        img_data = self.data[index]
        lbl_str = img_data.label.classification.category  # Image label

        # ----------------
        # Image preprocessing
        # ----------------
        # load image
        with img_data.open() as fp:
            img = Image.open(fp).convert('RGB')

        # Resize images with roughly the same aspect ratio
        # Resize height to 32 times the original, and width with roughly the same ratio but it must be divisible by 32
        w, h = img.size
        ratio = round((w / h) * 3)   # Multiple width by 3 and round it up
        if ratio == 0:
            ratio = 1
        if ratio > self.max_ratio:
            ratio = self.max_ratio
        h_new = 32
        w_new = h_new * ratio
        img_resize = img.resize((w_new, h_new), Image.BILINEAR)

        # Do padding to the right half of the image so that the ratio remain fixed as self.max_ratio
        img_padd = Image.new('RGB', (32*self.max_ratio, 32), (0,0,0))
        img_padd.paste(img_resize, (0, 0))

        # randomized color transformation
        img_input = self.color_trans(img_padd)
        # Normalize
        img_input = self.trans_Normalize(img_input)

        # ----------------
        # label processing
        # ----------------

        # create mask for the encoder
        encode_mask = [1] * ratio + [0] * (self.max_ratio - ratio)
        encode_mask = torch.tensor(encode_mask)
        encode_mask = (encode_mask != 0).unsqueeze(0)

        # create ground truth label
        gt = []
        gt.append(1)    # first add sentence beginning identifier
        for lbl in lbl_str:
            gt.append(self.lbl2id_map[lbl])
        gt.append(2)
        for i in range(len(lbl_str), self.sequence_len):   # set label length (without sentence identifiers) as sequence-len; the rest is left to padding
            gt.append(0)
        # chop as the preset maximum sequence length
        gt = gt[:self.sequence_len]

        # input for decoder
        decode_in = gt[:-1]
        decode_in = torch.tensor(decode_in)
        # output of decoder
        decode_out = gt[1:]
        decode_out = torch.tensor(decode_out)
        # mask  for decoder
        decode_mask = self.make_std_mask(decode_in, self.pad)
        # number of valid tokens
        ntokens = (decode_out != self.pad).data.sum()

        return img_input, encode_mask, decode_in, decode_out, decode_mask, ntokens

    @staticmethod
    def make_std_mask(tgt, pad):
        """
        Create a mask to hide padding and future words, which are represented as 0 in the mask.
        """
        tgt_mask = (tgt != pad)
        tgt_mask = tgt_mask & Variable(subsequent_mask(tgt.size(-1)).type_as(tgt_mask.data))
        tgt_mask = tgt_mask.squeeze(0)   # the subsequent returned value has the shape of (1, N, N)
        return tgt_mask

    def __len__(self):
        return len(self.data)

The above code also involves several details related to label processing, which are part of the Transformer's training logic, and we will briefly mention them here again.

encode_mask

Since we have resized the image and done padding to it (the location of the padding does not contain any valid information), we need to construct the corresponding encode_mask according to the padding ratio, so that the Transformer can ignore this meaningless area during its computation.

Label processing

The predicted labels used in this experiment are basically the same as those used in the training of machine translation models, so the differences in the processing are trivial.

During label processing, the characters in the label are converted into their corresponding ids, and identifiers are added at the beginning and end of the sentence. Padding is performed at the remaining position when the sequence_len length is not satisfied.

decode_mask

Generally in the decoder we will generate a mask in the form of upper triangular array according to the sequence_len of the label. While controlling the current time_step, each line of the mask only allows the decoder to obtain characters before the current step and prohibit access to characters in the future, which prevents the model form cheating.

The decode_mask is generated by a special function make_std_mask().

At the same time, the decoder needs to mask the padding part while making labels, so decode_mask should be written to False at the position where the label is padded.

The generated decode_mask is illustrated below:

The Generated Decode_mask

Above are all the details to create Dataset, and on this basis, we can cerate the DataLoader we need for model training.

# Create dataloader
max_ratio = 8    # Maximum ratio during image pre-processing. Resize if lower than this value but force compress if higher.
train_dataset = Recognition_Dataset(train_data_online, lbl2id_map, sequence_len, max_ratio, pad=0)
valid_dataset = Recognition_Dataset(valid_data_online, lbl2id_map, sequence_len, max_ratio, pad=0)
# loader size info:
# --> img_input: [batch_size, c, h, w] --> [64, 3, 32, 32*8*3]
# --> encode_mask: [batch_size, h/32, w/32] --> [64, 1, 24] The backbone in this article has a 32x sampling, and so here we devide by 32
# --> decode_in: [bs, sequence_len-1] --> [64, 20]
# --> decode_out: [bs, sequence_len-1] --> [64, 20]
# --> decode_mask: [bs, sequence_len-1, sequence_len-1] --> [64, 20, 20]
# --> ntokens: [bs] --> [64]
train_loader = torch.utils.data.DataLoader(train_dataset,
                                        batch_size=batch_size,
                                        shuffle=True,
                                        num_workers=4)
valid_loader = torch.utils.data.DataLoader(valid_dataset,
                                        batch_size=batch_size,
                                        shuffle=False,
                                        num_workers=4)

Model construction

The code completes the model construction with the make_ocr_model and OCR_EncoderDecoder classes.

We can start from the function make_ocr_model, which first calls the pre-trained ResNet-18 in pytorch as the backbone to extract image features. You may use other networks according to your needs, but be sure to keep in mind the downsampling multiplier of the network and the channel_num of the last layer of the feature map. Parameters of related modules need to be adjusted accordingly. Then, the OCR_EncoderDecoder class is called to complete the construction of the Transformer. Finally, model parameters are initialized.

The OCR_EncoderDecoder class is equivalent to an assembly line of each basic component of a Transformer, including the encoder and decoder. Its initial parameters are the basic components that already exist, and its basic component code are in the transformer.py file, which will not be elaborated in this article.

Here we review how the image is constructed as input to the Transformer after passing through the backbone.

After the image goes through the backbone, a feature map with dimension [batch_size, 512, 1, 24] will be output. Regardless of the batch_size, each image will get a 1×24 feature map with 512 channels as shown below. The red boxes in the figure demonstrate that feature values of different channels at the same location are stitched together to form a new vector and are used as a time-step input. Now the input with dimension [batch_size, 24, 512] is constructed to satisfy the Transformer's input requirements.

Feature Map

Let's take a look at the complete code for model construction:

# Model architecture
class OCR_EncoderDecoder(nn.Module):
    """
    A standard Encoder-Decoder architecture.
    Base for this and many other models.
    """
    def __init__(self, encoder, decoder, src_embed, src_position, tgt_embed, generator):
        super(OCR_EncoderDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_embed = src_embed    # input embedding module
        self.src_position = src_position
        self.tgt_embed = tgt_embed    # ouput embedding module
        self.generator = generator    # output generation module

    def forward(self, src, tgt, src_mask, tgt_mask):
        "Take in and process masked src and target sequences."
        # src --> [bs, 3, 32, 768]  [bs, c, h, w]
        # src_mask --> [bs, 1, 24]  [bs, h/32, w/32]
        memory = self.encode(src, src_mask)
        # memory --> [bs, 24, 512]
        # tgt --> decode_in [bs, 20]  [bs, sequence_len-1]
        # tgt_mask --> decode_mask [bs, 20]  [bs, sequence_len-1]
        res = self.decode(memory, src_mask, tgt, tgt_mask)  # [bs, 20, 512]
        return res

    def encode(self, src, src_mask):
        # feature extract
        # src --> [bs, 3, 32, 768]
        src_embedds = self.src_embed(src)
        # ResNet18 is used as backbone. Output-->[batchsize, c, h, w] --> [bs, 512, 1, 24]
        # Process src_embedds from shape (bs, model_dim, 1, max_ratio) to shape (bs, time_step, model_dim) that the transformer expects
        # [bs, 512, 1, 24] --> [bs, 24, 512]
        src_embedds = src_embedds.squeeze(-2)
        src_embedds = src_embedds.permute(0, 2, 1)

        # position encode
        src_embedds = self.src_position(src_embedds)  # [bs, 24, 512]

        return self.encoder(src_embedds, src_mask)  # [bs, 24, 512]

    def decode(self, memory, src_mask, tgt, tgt_mask):
        target_embedds = self.tgt_embed(tgt)  # [bs, 20, 512]
        return self.decoder(target_embedds, memory, src_mask, tgt_mask)


def make_ocr_model(tgt_vocab, N=6, d_model=512, d_ff=2048, h=8, dropout=0.1):
    """
    Model construction
    params:
        tgt_vocab: size of the output dictionary
        N: Number of encoder and decoder stacking base modules
        d_model: Size of embedding in the model; default is 512
        d_ff: Size of emdedding in the FeedForward Layer; default is 2048
        h: Number of multiheads in MultiHeadAttention; must be divisible by d_model
        dropout: ratio of dropout
    """
    c = copy.deepcopy

    # Use pretrained resnet18 in torch as feature extraction network and our backbone
    backbone = models.resnet18(pretrained=True)
    backbone = nn.Sequential(*list(backbone.children())[:-2])    # dispense with the last two layers (global average pooling and fc layer)

    attn = MultiHeadedAttention(h, d_model)
    ff = PositionwiseFeedForward(d_model, d_ff, dropout)
    position = PositionalEncoding(d_model, dropout)
 # model construction
    model = OCR_EncoderDecoder(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N),
        Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff), dropout), N),
        backbone,
        c(position),
        nn.Sequential(Embeddings(d_model, tgt_vocab), c(position)),
        Generator(d_model, tgt_vocab))  # the generator here is not called in the class

    # Initialize parameters with Glorot / fan_avg.
    for child in model.children():
        if child is backbone:
            # Set the weight of backbone to not calculate the gradient
            for param in child.parameters():
                param.requires_grad = False
            # The pre-trained backbone is not randomly initialized; the rest of the modules are randomly initialized
            continue
        for p in child.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
    return model

A Transformer model can be easily constructed by using the two classes mentioned above.

# build model
# use transformer as ocr recognize model
# The ocr_model constructured here does not have aGenerator
tgt_vocab = len(lbl2id_map.keys()) 
d_model = 512
ocr_model = make_ocr_model(tgt_vocab, N=5, d_model=d_model, d_ff=2048, h=8, dropout=0.1)
ocr_model.to(device)

Model training

Before model training, it is also necessary to define evaluation metrics, iterative optimizer, etc. In this experiment, label smoothing and warmup strategies are used in training. The call functions of the above strategies are in the train_utils.py file, but the principles and code implementation of the above two methods are not covered here.

Label smoothing can convert the original hard labels into soft ones, thus increasing fault tolerance and improving the generalization ability of the model. The LabelSmoothing() function in the code implements the strategy, while a relative entropy function is used within the function to calculate the loss between the predicted and true values.

The warmup strategy can effectively control the learning rate of the optimizer during training. It automatically controls the learning rate from a small increase to a gradual decrease, helping the model to be more stable and achieve rapid convergence of the loss. The NoamOpt() function in the code implements the warmup control. An Adam optimizer is used to achieve automatic adjustment of the learning rate according to the number of iterations.

# train prepare
criterion = LabelSmoothing(size=tgt_vocab, padding_idx=0, smoothing=0.0)  # label smoothing
optimizer = torch.optim.Adam(ocr_model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)
model_opt = NoamOpt(d_model, 1, 400, optimizer)  # warmup

The code for model training is shown below, with validation performed every 10 epochs, and the calculation of individual epochs is encapsulated in the run_epoch() function.

# train & valid ...
for epoch in range(nrof_epochs):
    print(f"\nepoch {epoch}")

    print("train...")  # training
    ocr_model.train()
    loss_compute = SimpleLossCompute(ocr_model.generator, criterion, model_opt)
    train_mean_loss = run_epoch(train_loader, ocr_model, loss_compute, device)

    if epoch % 10 == 0:
        print("valid...")  # validation
        ocr_model.eval()
        valid_loss_compute = SimpleLossCompute(ocr_model.generator, criterion, None)
        valid_mean_loss = run_epoch(valid_loader, ocr_model, valid_loss_compute, device)
        print(f"valid loss: {valid_mean_loss}")

        # save model
        torch.save(ocr_model.state_dict(), './trained_model/ocr_model.pt')

The SimpleLossCompute() class implements loss calculation of the Transformer's output. When using this class for direct calculation, the class needs to receive three parameters: (x, y, norm), with x as the result of the decoder output, y as the label data, and norm as the normalization factor of the loss, which uses the number of all valid tokens in the batch. We have only now finished the construction of all networks of the Transformer and realized the flow of data computation.

run_epoch() function internally completes all the work of an epoch training, including data loading, model inference, loss calculation and direction propagation, while printing the training process information in the meantime.

def run_epoch(data_loader, model, loss_compute, device=None):
    "Standard Training and Logging Function"
    start = time.time()
    total_tokens = 0
    total_loss = 0
    tokens = 0

    for i, batch in enumerate(data_loader):
        img_input, encode_mask, decode_in, decode_out, decode_mask, ntokens = batch
        img_input = img_input.to(device)
        encode_mask = encode_mask.to(device)
        decode_in = decode_in.to(device)
        decode_out = decode_out.to(device)
        decode_mask = decode_mask.to(device)
        ntokens = torch.sum(ntokens).to(device)

        out = model.forward(img_input, decode_in, encode_mask, decode_mask)
        # out --> [bs, 20, 512]  prediction
        # decode_out --> [bs, 20]  gt
        # ntokens --> number of valid characters in the label

        loss = loss_compute(out, decode_out, ntokens)  # loss computation
        total_loss += loss
        total_tokens += ntokens
        tokens += ntokens
        if i % 50 == 1:
            elapsed = time.time() - start
            print("Epoch Step: %d Loss: %f Tokens per Sec: %f" %
                    (i, loss / ntokens, tokens / elapsed))
            start = time.time()
            tokens = 0
    return total_loss / total_tokens


class SimpleLossCompute:
    "A simple loss compute and train function."
    def __init__(self, generator, criterion, opt=None):
        self.generator = generator
        self.criterion = criterion
        self.opt = opt

    def __call__(self, x, y, norm):
        """
        norm: the normalization factor of the loss, which uses the number of all valid tokens in the batch
        """
        # x --> out --> [bs, 20, 512]  prediction
        # y --> decode_out --> [bs, 20]  gt
        # norm --> ntokens --> number of valid characters in the label
        x = self.generator(x)
        # label smoothing needs to correspond to dimension changes
        x_ = x.contiguous().view(-1, x.size(-1))  # [20bs, 512]
        y_ = y.contiguous().view(-1)  # [20bs]
        loss = self.criterion(x_, y_)
        loss /= norm
        loss.backward()
        if self.opt is not None:
            self.opt.step()
            self.opt.optimizer.zero_grad()
        #return loss.data[0] * norm 
        return loss.item() * norm

Greedy Decoder

For learning purposes, we will use the simplest greedy decoder for OCR result prediction. Since the model only produces one output at a time, we choose the character with the highest probability in the probability distribution of the output as the result of this prediction, and then predict the next character. Such is the so called greedy decoding. See the greedy_decode() function in the code.

In our experiment, each image is used as the input for the model respectively, and the correct rate of the greedy decoder is counted one by one. The prediction accuracy of each of the training and validation sets is given at the end.

# When the training is done, use Greedy Decoder to predict the training and validation sets, and then calculate accuracy
ocr_model.eval()

print("\n------------------------------------------------")
print("greedy decode trainset")
total_img_num = 0
total_correct_num = 0
for batch_idx, batch in enumerate(train_loader):
    img_input, encode_mask, decode_in, decode_out, decode_mask, ntokens = batch
    img_input = img_input.to(device)
    encode_mask = encode_mask.to(device)

    # Get single image info
    bs = img_input.shape[0]
    for i in range(bs):
        cur_img_input = img_input[i].unsqueeze(0)
        cur_encode_mask = encode_mask[i].unsqueeze(0)
        cur_decode_out = decode_out[i]
  # Greedy decoder
        pred_result = greedy_decode(ocr_model, cur_img_input, cur_encode_mask, max_len=sequence_len, start_symbol=1, end_symbol=2)
        pred_result = pred_result.cpu()
  # Judge if the prediction is correct
        is_correct = judge_is_correct(pred_result, cur_decode_out)
        total_correct_num += is_correct
        total_img_num += 1
        if not is_correct:
            # Print wrong cases
            print("----")
            print(cur_decode_out)
            print(pred_result)
        total_correct_rate = total_correct_num / total_img_num * 100
        print(f"total correct rate of trainset: {total_correct_rate}%")

# Same as decoding training set
print("\n------------------------------------------------")
print("greedy decode validset")
total_img_num = 0
total_correct_num = 0
for batch_idx, batch in enumerate(valid_loader):
    img_input, encode_mask, decode_in, decode_out, decode_mask, ntokens = batch
    img_input = img_input.to(device)
    encode_mask = encode_mask.to(device)

    bs = img_input.shape[0]
    for i in range(bs):
        cur_img_input = img_input[i].unsqueeze(0)
        cur_encode_mask = encode_mask[i].unsqueeze(0)
        cur_decode_out = decode_out[i]

        pred_result = greedy_decode(ocr_model, cur_img_input, cur_encode_mask, max_len=sequence_len, start_symbol=1, end_symbol=2)
        pred_result = pred_result.cpu()

        is_correct = judge_is_correct(pred_result, cur_decode_out)
        total_correct_num += is_correct
        total_img_num += 1
        if not is_correct:
            # Print the wrong cases
            print("----")
            print(cur_decode_out)
            print(pred_result)
        total_correct_rate = total_correct_num / total_img_num * 100
        print(f"total correct rate of validset: {total_correct_rate}%")

greedy_decode() is realized as below.

# greedy decode
def greedy_decode(model, src, src_mask, max_len, start_symbol, end_symbol):
    memory = model.encode(src, src_mask)
    # ys represents the sequence that has been generated so far, initially as a sequence containing only one start character, and the prediction is continuously appended to the end of the sequence.
    ys = torch.ones(1, 1).fill_(start_symbol).type_as(src.data).long()
    for i in range(max_len-1):
        out = model.decode(memory, src_mask,
                           Variable(ys),
                           Variable(subsequent_mask(ys.size(1)).type_as(src.data)))
        prob = model.generator(out[:, -1])
        _, next_word = torch.max(prob, dim = 1)
        next_word = next_word.data[0]
        next_word = torch.ones(1, 1).type_as(src.data).fill_(next_word).long()
        ys = torch.cat([ys, next_word], dim=1)

        next_word = int(next_word)
        if next_word == end_symbol:
            break
        #ys = torch.cat([ys, torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=1)
    ys = ys[0, 1:]
    return ys


def judge_is_correct(pred, label):
    # Judge if the prediction is the same as the label
    pred_len = pred.shape[0]
    label = label[:pred_len]
    is_correct = 1 if label.equal(pred) else 0
    return is_correct

The following command initiates the training process instantly.

python ocr_by_transformer.py

The training log is shown as below:

epoch 0
train...
Epoch Step: 1 Loss: 5.142612 Tokens per Sec: 852.649109
Epoch Step: 51 Loss: 3.064528 Tokens per Sec: 2709.471436
valid...
Epoch Step: 1 Loss: 3.018526 Tokens per Sec: 1413.900391
valid loss: 2.7769546508789062

epoch 1
train...
Epoch Step: 1 Loss: 3.440590 Tokens per Sec: 1303.567993
Epoch Step: 51 Loss: 2.711708 Tokens per Sec: 2743.414307

...

epoch 1499
train...
Epoch Step: 1 Loss: 0.005739 Tokens per Sec: 1232.602783
Epoch Step: 51 Loss: 0.013249 Tokens per Sec: 2765.866211

------------------------------------------------
greedy decode trainset
----
tensor([17, 32, 18, 19, 31, 50, 30, 10, 30, 10, 17, 32, 41, 55, 55, 55,  2,  0,
         0,  0])
tensor([17, 32, 18, 19, 31, 50, 30, 10, 30, 10, 17, 32, 41, 55, 55, 55, 55, 55,
        55, 55])
----
tensor([17, 32, 18, 19, 31, 50, 30, 10, 17, 32, 41, 55, 55,  2,  0,  0,  0,  0,
         0,  0])
tensor([17, 32, 18, 19, 31, 50, 30, 10, 17, 32, 41, 55, 55, 55, 55,  2])
total correct rate of trainset: 99.95376791493297%

------------------------------------------------
greedy decode validset
----
tensor([10, 11, 28, 27, 25, 11, 47, 45,  2,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0])
tensor([10, 11, 28, 27, 25, 11, 62,  2])

...

tensor([20, 12, 24, 35,  2,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0])
tensor([20, 12, 21, 12, 22, 23, 34,  2])
total correct rate of validset: 92.72088353413655%

Wrap-up

That's all there is to this blog. Congratulations on making this far!

In this blog, we first introduced a word-recognition task dataset in ICDAR2015, then briefly analyzed the characteristics of the data and constructed a character mapping relationship table for recognition. After that, we focused on the motivations of introducing Transformer to OCR tasks. We then demonstrate details of the code. Finally, we roughly went through some training-related logic as well as their code.

The main purpose of this blog is to help you understand what other applications there are for Transformer in CV other than using it as a backbone. Credits of the Transformer's implementation code go to The Annotated Transformer, and the part about how to apply it to OCR is completely combined with the author's personal understanding of the model. It is not guaranteed that this method can be applied to more complex engineering problems.