Extract Information using ‘Donut’ 🍩. Will it be the OCR Killer? 🔪
Part #01: Introduction
When I still confuse about how to improve the performance of tesseract OCR, I found an article that explained about Donut. Yes, you’re not wrong. We will talk about Donut 🍩, but it’s not about the food. Donut 🍩is an abbreviation of Document Understanding Transformers. Based on what I read from the official site, Donut 🍩 is a new method of document understanding that utilizes an OCR-free end-to-end Transformer model. Donut does not require off-the-shelf OCR engines/APIs, yet it shows state-of-the-art performances on various visual document understanding tasks, such as visual document classification or information extraction (a.k.a. document parsing). it mentions that Donut 🍩 is OCR-free. So, next question in my mind, is the performance of Donut can outperform tesseract OCR? Specially for extract information from ID card.
To answer that question, I do some research using Indonesian ID card (KTP) as my dataset. Inspired by this article, I collect around 100++ different KTP to fine-tune the Donut 🍩 model. You can get the model from huggingface. But first, let we prepare the dataset.
Part #02: Data Preprocessing
As you can see on the picture above, KTP exists with some information like:
NIK: Unique Identity Number
Nama: Name
Tempat/Tgl Lahir: place/date of birth
Jenis Kelamin: gender
Alamat, RT/RW, Kel/Desa, Kecamatan: address
Status perkawinan: marital status
Pekerjaan: job
Kewarganegaraan: nationality
berlaku Hingga: valid until
For my case, I just use NIK, Nama, Tempat/Tgl Lahir, Jenis Kelamin, and Alamat, RT/RW, Kel/Desa, Kecamatan. Then, let’s create a spreadsheet and fill it with the information on the KTP.
Repeat the steps until you completely fill all of the data (in my case, i fill 150 data).
FYI, I collect both clear and blur images coz in reality the KTP not always in the good quality. I hope this scheme make the Donut see many references to understand the image.
Then, open your jupyter notebook (or jupyter lab) and create new notebook and type this script to create a function for split the dataset.
import json
import os
import shutil
from tqdm import tqdm
import pandas as pd
def create_sets(df, train=0.7, val=0.2, test=0.1):
"""
train, val and test are the proportions of the images
that go in each split.
"""
if round(train + val + test) != 1:
raise ValueError("train + val + test != 1")
# print(f"Train: {train}, Val: {val}, Test: {test}, Sum: {train + val + test}")
# Create the folders
train_folder = "data/train"
val_folder = "data/val"
test_folder = "data/test"
os.makedirs(train_folder, exist_ok=True)
os.makedirs(val_folder, exist_ok=True)
os.makedirs(test_folder, exist_ok=True)
# Shuffle your data
samples = df.sample(frac=1.).reset_index(drop=True)
# Compute the number of image for each split
n = len(samples)
n_train = train * n
n_val = val * n
n_test = test * n
for idx, row in tqdm(samples.iterrows(), total=samples.shape[0]):
data = {
"nik": row["nik"],
"nama": row["nama"],
"Tempat Lahir": row["tempat"],
"tgl_lahir": row["tanggal_lahir"],
"Jenis Kelamin": row["jenis_kelamin"],
"alamat": row["alamat"],
"RT/RW": row["rt_rw"],
"kel_desa": row["kel/desa"],
"kecamatan": row["kecamatan"]
}
file_name = row["filename"]
gt_parse = {"gt_parse": data}
line = {
"file_name": file_name,
"ground_truth": json.dumps(gt_parse)
}
# We assume that your images are in
# a folder named "images/"; correct if necessary
image_path = os.path.join("images", file_name)
# Copy the image in one of the folders
# and append a line to metadata.jsonl
if idx < n_train:
dest_path = os.path.join("data/train/", file_name)
shutil.copyfile(image_path, dest_path)
with open("data/train/metadata.jsonl", "a") as f:
f.write(json.dumps(line) + "\n")
elif n_train <= idx < n_train + n_val:
dest_path = os.path.join("data/val/", file_name)
shutil.copyfile(image_path, dest_path)
with open("data/val/metadata.jsonl", "a") as f:
f.write(json.dumps(line) + "\n")
elif n_train + n_val <= idx < n_train + n_val + n_test:
dest_path = os.path.join("data/test/", file_name)
shutil.copyfile(image_path, dest_path)
with open("data/test/metadata.jsonl", "a") as f:
f.write(json.dumps(line) + "\n")
This code will split your dataset into 70% for training set, 20% for validation set, and 10% for testing set. Next, download your spreadsheet as a .csv file and put it on the same place as your notebook. Then we continue to create a code to load and split the csv file.
df = pd.read_csv("labels.csv")
create_sets(df)
Part #03: Create the Model
Before we create the model, let’s install the Donut library first.
pip install donut-python
Then install the others library with specific version.
torch==2.0.1
torchvision==0.15.2
donut-python==1.0.9
pytorch-lightning==1.8.5
transformers==4.11.3
timm==0.5.4
Pillow==9.5
Next, we need to clone the repository of Donut 🍩 coz we need the configuration file.
git clone https://github.com/clovaai/donut.git
Go to the repository and copy file train_cord.yaml to outside the folder and rename it (in my case, I rename it to train_id.yaml).
Open the yaml file and change some of the line like this:
result_path: "/home/rizkynindra/result"
dataset_name_or_paths: ["/home/rizkynindra/data"]
train_batch_sizes: [2]
warmup_steps: 222 # 10% of total steps, equals to num_training_samples_per_epoch / train_batch_sizes * max_epochs / 10
num_training_samples_per_epoch: 148
max_epochs: 30
num_workers: 4
I set the train_batch_sizes lower than the original file coz my machine’s memory is limited.
After all setups done, we move to the fine-tuning part. Open your terminal, going to Donut 🍩 directory and run this script:
python train.py --config train_id.yaml
Wait until the process done and back to your notebook then add this script.
from donut import DonutModel
from PIL import Image
import torch
# Change the path here:
model = DonutModel.from_pretrained("/path/to/result/of/fine-tuned model")
if torch.cuda.is_available():
model.half()
device = torch.device("cuda")
model.to(device)
else:
model.encoder.to(torch.bfloat16)
model.eval()
image = Image.open("/directory/of/testing_image.jpg").convert("RGB") # Change here
with torch.no_grad():
output = model.inference(image=image, prompt="<s_data>")
data = output.get('predictions')
from collections import namedtuple
Item = namedtuple('Item', ['data'])
objects = [Item(data) for data in data]
for obj in objects:
print(f'data: {obj.data}')
json_dict = {
"data": obj.data
}
json_dict
Run the script above and you will get the result like this:
{'data': {'nik': '(unique_ID_number)',
'nama': '(NAME)',
'Tempat Lahir': '(PLACE OF BIRTH)',
'tgl_lahir': 'DD-MM-YYYY',
'Jenis Kelamin': '(GENDER)',
'alamat': '(ADDRESS)',
'RT/RW': '008/001',
'kel_desa': '(SUB ADDRESS DETAIL)',
'kecamatan': '(SUB ADDRESS DETAIL 2)'}}
I don’t show the real value because it’s a privacy.
Well done! We already built a model that can replace the OCR 👓. Based on my experiment, as long as the image quality is good (there is no blur on the image) the model will extract the information with low mistype.
Part #04: Conclusion
So, do you need a conclusion? #haha. What can I say is, Donut 🍩can be an alternative solution for extracting information instead of you tweak the performance of OCR.