Spaces:
Running
Running
# -*- encoding: utf-8 -*- | |
# @Author: SWHL | |
# @Contact: liekkaskono@163.com | |
from pathlib import Path | |
from typing import List, Union | |
import gradio as gr | |
from omegaconf import OmegaConf | |
import rapidocr | |
from rapidocr import ( | |
EngineType, | |
LangCls, | |
LangDet, | |
LangRec, | |
ModelType, | |
OCRVersion, | |
RapidOCR, | |
) | |
def get_ocr_result( | |
img_input, | |
text_score, | |
box_thresh, | |
unclip_ratio, | |
max_side_len, | |
det_engine, | |
lang_det, | |
det_model_type, | |
det_ocr_version, | |
cls_engine, | |
lang_cls, | |
cls_model_type, | |
cls_ocr_version, | |
rec_engine, | |
lang_rec, | |
rec_model_type, | |
rec_ocr_version, | |
is_word, | |
use_module, | |
): | |
return_word_box = "Yes" in is_word | |
use_det = "use_det" in use_module | |
use_cls = "use_cls" in use_module | |
use_rec = "use_rec" in use_module | |
ocr_engine = RapidOCR( | |
params={ | |
"Global.max_side_len": max_side_len, | |
"Det.engine_type": EngineType(det_engine), | |
"Det.lang_type": LangDet(lang_det), | |
"Det.model_type": ModelType(det_model_type), | |
"Det.ocr_version": OCRVersion(det_ocr_version), | |
"Cls.engine_type": EngineType(cls_engine), | |
"Cls.lang_type": LangCls(lang_cls), | |
"Cls.model_type": ModelType(cls_model_type), | |
"Cls.ocr_version": OCRVersion(cls_ocr_version), | |
"Rec.engine_type": EngineType(rec_engine), | |
"Rec.lang_type": LangRec(lang_rec), | |
"Rec.model_type": ModelType(rec_model_type), | |
"Rec.ocr_version": OCRVersion(rec_ocr_version), | |
} | |
) | |
ocr_result = ocr_engine( | |
img_input, | |
use_det=use_det, | |
use_cls=use_cls, | |
use_rec=use_rec, | |
text_score=text_score, | |
box_thresh=box_thresh, | |
unclip_ratio=unclip_ratio, | |
return_word_box=return_word_box, | |
) | |
vis_img = ocr_result.vis() | |
if return_word_box: | |
txts, scores, _ = list(zip(*ocr_result.word_results)) | |
ocr_txts = [[i, txt, score] for i, (txt, score) in enumerate(zip(txts, scores))] | |
return vis_img, ocr_txts, ocr_result.elapse | |
if use_rec: | |
ocr_txts = [ | |
[i, txt, score] | |
for i, (txt, score) in enumerate(zip(ocr_result.txts, ocr_result.scores)) | |
] | |
else: | |
ocr_txts = [] | |
return vis_img, ocr_txts, ocr_result.elapse | |
def create_examples() -> List[List[Union[str, float]]]: | |
examples = [ | |
[ | |
"images/multi.jpg", | |
0.5, | |
0.5, | |
1.6, | |
2000, | |
EngineType.ONNXRUNTIME, | |
LangDet.CH, | |
ModelType.MOBILE, | |
OCRVersion.PPOCRV5, | |
EngineType.ONNXRUNTIME, | |
LangCls.CH, | |
ModelType.MOBILE, | |
OCRVersion.PPOCRV4, | |
EngineType.ONNXRUNTIME, | |
LangRec.CH, | |
ModelType.MOBILE, | |
OCRVersion.PPOCRV5, | |
"No", | |
["use_det", "use_cls", "use_rec"], | |
], | |
[ | |
"images/ch_en_num.jpg", | |
0.5, | |
0.5, | |
1.6, | |
2000, | |
EngineType.ONNXRUNTIME, | |
LangDet.CH, | |
ModelType.MOBILE, | |
OCRVersion.PPOCRV4, | |
EngineType.ONNXRUNTIME, | |
LangCls.CH, | |
ModelType.MOBILE, | |
OCRVersion.PPOCRV4, | |
EngineType.ONNXRUNTIME, | |
LangRec.CH, | |
ModelType.MOBILE, | |
OCRVersion.PPOCRV4, | |
"No", | |
["use_det", "use_cls", "use_rec"], | |
], | |
[ | |
"images/hand_writen.jpeg", | |
0.5, | |
0.5, | |
1.6, | |
2000, | |
EngineType.ONNXRUNTIME, | |
LangDet.CH, | |
ModelType.MOBILE, | |
OCRVersion.PPOCRV5, | |
EngineType.ONNXRUNTIME, | |
LangCls.CH, | |
ModelType.MOBILE, | |
OCRVersion.PPOCRV4, | |
EngineType.ONNXRUNTIME, | |
LangRec.CH, | |
ModelType.MOBILE, | |
OCRVersion.PPOCRV5, | |
"No", | |
["use_det", "use_cls", "use_rec"], | |
], | |
[ | |
"images/japan.jpg", | |
0.5, | |
0.5, | |
1.6, | |
2000, | |
EngineType.ONNXRUNTIME, | |
LangDet.MULTI, | |
ModelType.MOBILE, | |
OCRVersion.PPOCRV4, | |
EngineType.ONNXRUNTIME, | |
LangCls.CH, | |
ModelType.MOBILE, | |
OCRVersion.PPOCRV4, | |
EngineType.ONNXRUNTIME, | |
LangRec.JAPAN, | |
ModelType.MOBILE, | |
OCRVersion.PPOCRV4, | |
"No", | |
["use_det", "use_cls", "use_rec"], | |
], | |
[ | |
"images/korean.jpg", | |
0.5, | |
0.5, | |
1.6, | |
2000, | |
EngineType.ONNXRUNTIME, | |
LangDet.MULTI, | |
ModelType.MOBILE, | |
OCRVersion.PPOCRV4, | |
EngineType.ONNXRUNTIME, | |
LangCls.CH, | |
ModelType.MOBILE, | |
OCRVersion.PPOCRV4, | |
EngineType.ONNXRUNTIME, | |
LangRec.KOREAN, | |
ModelType.MOBILE, | |
OCRVersion.PPOCRV4, | |
"No", | |
["use_det", "use_cls", "use_rec"], | |
], | |
] | |
return examples | |
def export_yaml( | |
img_input, | |
text_score, | |
box_thresh, | |
unclip_ratio, | |
max_side_len, | |
det_engine, | |
lang_det, | |
det_model_type, | |
det_ocr_version, | |
cls_engine, | |
lang_cls, | |
cls_model_type, | |
cls_ocr_version, | |
rec_engine, | |
lang_rec, | |
rec_model_type, | |
rec_ocr_version, | |
is_word, | |
use_module, | |
): | |
default_yaml_path = Path(rapidocr.__file__).parent / "config.yaml" | |
cfg = OmegaConf.load(default_yaml_path) | |
params = { | |
"Global": { | |
"max_side_len": max_side_len, | |
"use_det": "use_det" in use_module, | |
"use_cls": "use_cls" in use_module, | |
"use_rec": "use_rec" in use_module, | |
"return_word_box": "Yes" in is_word, | |
"text_score": text_score, | |
"box_thresh": box_thresh, | |
}, | |
"Det": { | |
"engine_type": det_engine, | |
"lang_type": lang_det, | |
"model_type": det_model_type, | |
"ocr_version": det_ocr_version, | |
"box_thresh": box_thresh, | |
"unclip_ratio": unclip_ratio, | |
}, | |
"Cls": { | |
"engine_type": cls_engine, | |
"lang_type": lang_cls, | |
"model_type": cls_model_type, | |
"ocr_version": cls_ocr_version, | |
}, | |
"Rec": { | |
"engine_type": rec_engine, | |
"lang_type": lang_rec, | |
"model_type": rec_model_type, | |
"ocr_version": rec_ocr_version, | |
}, | |
} | |
cfg = OmegaConf.merge(cfg, params) | |
save_path = Path(__file__).resolve().parent / "config.yaml" | |
OmegaConf.save(cfg, save_path) | |
return save_path | |
custom_css = """ | |
body {font-family: body {font-family: 'Helvetica Neue', Helvetica;} | |
.gr-button {background-color: #4CAF50; color: white; border: none; padding: 10px 20px; border-radius: 5px;} | |
.gr-button:hover {background-color: #45a049;} | |
.gr-textbox {margin-bottom: 15px;} | |
.example-button {background-color: #1E90FF; color: white; border: none; padding: 8px 15px; border-radius: 5px; margin: 5px;} | |
.example-button:hover {background-color: #FF4500;} | |
.tall-radio .gr-radio-item {padding: 15px 0; min-height: 50px; display: flex; align-items: center;} | |
.tall-radio label {font-size: 16px;} | |
.output-image, .input-image, .image-preview {height: 300px !important} | |
""" | |
with gr.Blocks( | |
title="Rapid⚡OCR Demo", css="custom_css", theme=gr.themes.Soft() | |
) as demo: | |
gr.HTML( | |
""" | |
<h1 style='text-align: center;font-size:40px'>Rapid⚡OCRv3</h1> | |
<div style="display: flex; justify-content: center; gap: 10px;"> | |
<a href=""><img src="https://img.shields.io/badge/Python->=3.6-aff.svg"></a> | |
<a href="https://rapidai.github.io/RapidOCRDocs"><img src="https://img.shields.io/badge/Docs-link-aff.svg"></a> | |
<a href=""><img src="https://img.shields.io/badge/OS-Linux%2C%20Win%2C%20Mac-pink.svg"></a> | |
<a href="https://pepy.tech/project/rapidocr"><img src="https://static.pepy.tech/personalized-badge/rapidocr?period=total&units=abbreviation&left_color=grey&right_color=blue&left_text=Downloads%20rapidocr"></a> | |
<a href="https://pypi.org/project/rapidocr/"><img alt="PyPI" src="https://img.shields.io/pypi/v/rapidocr"></a> | |
<a href="https://github.com/RapidAI/RapidOCR"><img src="https://img.shields.io/github/stars/RapidAI/RapidOCR?color=ccf"></a> | |
</div> | |
""" | |
) | |
img_input = gr.Image(label="Upload or Select Image", sources="upload") | |
with gr.Accordion("Parameter Setting", open=False): | |
with gr.Row(): | |
text_score = gr.Slider( | |
label="text_score", | |
minimum=0, | |
maximum=1.0, | |
value=0.5, | |
step=0.1, | |
info="文本识别结果是正确的置信度,值越大,显示出的识别结果更准确。存在漏检时,调低该值。取值范围:[0, 1.0],默认值为0.5", | |
) | |
box_thresh = gr.Slider( | |
label="box_thresh", | |
minimum=0, | |
maximum=1.0, | |
value=0.5, | |
step=0.1, | |
info="检测到的框是文本的概率,值越大,框中是文本的概率就越大。存在漏检时,调低该值。取值范围:[0, 1.0],默认值为0.5", | |
) | |
unclip_ratio = gr.Slider( | |
label="unclip_ratio", | |
minimum=1.5, | |
maximum=2.0, | |
value=1.6, | |
step=0.1, | |
info="控制文本检测框的大小,值越大,检测框整体越大。在出现框截断文字的情况,调大该值。取值范围:[1.5, 2.0],默认值为1.6", | |
) | |
max_side_len = gr.Number( | |
value=2000, | |
label="max_side_len", | |
info="如果输入图像的最大边大于`max_side_len`,则会按宽高比,将最大边缩放到`max_side_len`。默认为2000px", | |
interactive=True, | |
minimum=20, | |
) | |
with gr.Row(): | |
with gr.Row(): | |
gr.Markdown("Det") | |
det_engine = gr.Dropdown( | |
choices=[v.value for v in EngineType], | |
label="EngineType", | |
value=EngineType.ONNXRUNTIME.value, | |
interactive=True, | |
scale=0, | |
) | |
lang_det = gr.Dropdown( | |
choices=[v.value for v in LangDet], | |
label="LangDet", | |
value=LangDet.CH.value, | |
interactive=True, | |
scale=1, | |
) | |
det_model_type = gr.Dropdown( | |
choices=[v.value for v in ModelType], | |
label="ModelType", | |
value=ModelType.MOBILE.value, | |
interactive=True, | |
scale=1, | |
) | |
det_ocr_version = gr.Dropdown( | |
choices=[v.value for v in OCRVersion], | |
label="OCR Version", | |
value=OCRVersion.PPOCRV4.value, | |
interactive=True, | |
scale=1, | |
) | |
with gr.Row(): | |
gr.Markdown("Cls") | |
cls_engine = gr.Dropdown( | |
choices=[v.value for v in EngineType], | |
label="EngineType", | |
value=EngineType.ONNXRUNTIME.value, | |
interactive=True, | |
) | |
lang_cls = gr.Dropdown( | |
choices=[v.value for v in LangCls], | |
label="LangCls", | |
value=LangCls.CH.value, | |
interactive=True, | |
) | |
cls_model_type = gr.Dropdown( | |
choices=[ModelType.MOBILE.value], | |
label="ModelType", | |
value=ModelType.MOBILE.value, | |
interactive=True, | |
) | |
cls_ocr_version = gr.Dropdown( | |
choices=[OCRVersion.PPOCRV4.value], | |
label="OCR Version", | |
value=OCRVersion.PPOCRV4.value, | |
interactive=True, | |
) | |
with gr.Row(): | |
gr.Markdown("Rec") | |
rec_engine = gr.Dropdown( | |
choices=[v.value for v in EngineType], | |
label="EngineType", | |
value=EngineType.ONNXRUNTIME.value, | |
interactive=True, | |
) | |
lang_rec = gr.Dropdown( | |
choices=[v.value for v in LangRec], | |
label="LangRec", | |
value=LangRec.CH.value, | |
interactive=True, | |
) | |
rec_model_type = gr.Dropdown( | |
choices=[v.value for v in ModelType], | |
label="ModelType", | |
value=ModelType.MOBILE.value, | |
interactive=True, | |
) | |
rec_ocr_version = gr.Dropdown( | |
choices=[v.value for v in OCRVersion], | |
label="OCR Version", | |
value=OCRVersion.PPOCRV4.value, | |
interactive=True, | |
) | |
with gr.Row(): | |
use_module = gr.CheckboxGroup( | |
["use_det", "use_cls", "use_rec"], | |
label="Use module (使用哪些模块)", | |
value=["use_det", "use_cls", "use_rec"], | |
interactive=True, | |
) | |
is_word = gr.Radio( | |
["Yes", "No"], label="Return word box (返回单字符)", value="No" | |
) | |
with gr.Row(): | |
run_btn = gr.Button("Run") | |
btn_export_cfg = gr.Button("Export Config YAML") | |
download_btn_hidden = gr.DownloadButton( | |
visible=False, elem_id="download_btn_hidden" | |
) | |
img_output = gr.Image(label="Output Image") | |
elapse = gr.Textbox(label="Elapse(s)") | |
ocr_results = gr.Dataframe( | |
label="OCR Txts", | |
headers=["Index", "Txt", "Score"], | |
datatype=["number", "str", "number"], | |
show_copy_button=True, | |
) | |
ocr_inputs = [ | |
img_input, | |
text_score, | |
box_thresh, | |
unclip_ratio, | |
max_side_len, | |
det_engine, | |
lang_det, | |
det_model_type, | |
det_ocr_version, | |
cls_engine, | |
lang_cls, | |
cls_model_type, | |
cls_ocr_version, | |
rec_engine, | |
lang_rec, | |
rec_model_type, | |
rec_ocr_version, | |
is_word, | |
use_module, | |
] | |
run_btn.click( | |
get_ocr_result, inputs=ocr_inputs, outputs=[img_output, ocr_results, elapse] | |
) | |
btn_export_cfg.click( | |
fn=export_yaml, inputs=ocr_inputs, outputs=[download_btn_hidden] | |
).then( | |
fn=None, | |
inputs=None, | |
outputs=None, | |
js="() => document.querySelector('#download_btn_hidden').click()", | |
) | |
examples = gr.Examples( | |
examples=create_examples(), | |
examples_per_page=5, | |
inputs=ocr_inputs, | |
fn=get_ocr_result, | |
outputs=[img_output, ocr_results, elapse], | |
cache_examples=False, | |
) | |
if __name__ == "__main__": | |
demo.launch(debug=True) | |