Skip to content

Commit

Permalink
Merge pull request #157 from GNEHUY/dev
Browse files Browse the repository at this point in the history
Update OCR module
  • Loading branch information
KenelmQLH authored Mar 9, 2024
2 parents 504d147 + d235439 commit 855e250
Show file tree
Hide file tree
Showing 7 changed files with 246 additions and 13 deletions.
2 changes: 2 additions & 0 deletions AUTHORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,6 @@

[Shangzi Xue](https://github.com/ShangziXue)

[Heng Yu](https://github.com/GNEHUY)

The stared contributors are the corresponding authors.
156 changes: 156 additions & 0 deletions EduNLP/SIF/parser/ocr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
# coding: utf-8
# 2024/3/5 @ yuheng
import json
import requests
from EduNLP.utils import image2base64


class FormulaRecognitionError(Exception):
"""Exception raised when formula recognition fails."""
def __init__(self, message="Formula recognition failed"):
self.message = message
super().__init__(self.message)


def ocr_formula_figure(image_PIL_or_base64, is_base64=False):
"""
Recognizes mathematical formulas in an image and returns their LaTeX representation.
Parameters
----------
image_PIL_or_base64 : PngImageFile or str
The PngImageFile if is_base64 is False, or the base64 encoded string of the image if is_base64 is True.
is_base64 : bool, optional
Indicates whether the image_PIL_or_base64 parameter is an PngImageFile or a base64 encoded string.
Returns
-------
latex : str
The LaTeX representation of the mathematical formula recognized in the image.
Raises an exception if the image is not recognized as containing a mathematical formula.
Raises
------
FormulaRecognitionError
If the HTTP request does not return a 200 status code,
if there is an error processing the response,
if the image is not recognized as a mathematical formula.
Examples
--------
>>> import os
>>> from PIL import Image
>>> from EduNLP.utils import abs_current_dir, path_append
>>> img_dir = os.path.abspath(path_append(abs_current_dir(__file__), "..", "..", "..", "asset", "_static"))
>>> image_PIL = Image.open(path_append(img_dir, "item_ocr_formula.png", to_str=True))
>>> print(ocr_formula_figure(image_PIL))
f(x)=\\left (\\frac {1}{3}\\right )^{x}-\\sqrt {x}}
>>> import os
>>> from PIL import Image
>>> from EduNLP.utils import abs_current_dir, path_append, image2base64
>>> img_dir = os.path.abspath(path_append(abs_current_dir(__file__), "..", "..", "..", "asset", "_static"))
>>> image_PIL = Image.open(path_append(img_dir, "item_ocr_formula.png", to_str=True))
>>> image_base64 = image2base64(image_PIL)
>>> print(ocr_formula_figure(image_base64, is_base64=True))
f(x)=\\left (\\frac {1}{3}\\right )^{x}-\\sqrt {x}}
Notes
-----
This function relies on an external service "https://formula-recognition-service-47-production.env.iai.bdaa.pro/v1",
and the `requests` library to make HTTP requests. Make sure the required libraries are installed before use.
"""
url = "https://formula-recognition-service-47-production.env.iai.bdaa.pro/v1"

if is_base64:
image = image_PIL_or_base64
else:
image = image2base64(image_PIL_or_base64)

data = [{
'qid': 0,
'image': image
}]

resp = requests.post(url, data=json.dumps(data))

if resp.status_code != 200:
raise FormulaRecognitionError(f"HTTP error {resp.status_code}: {resp.text}")

try:
res = json.loads(resp.content)
except Exception as e:
raise FormulaRecognitionError(f"Error processing response: {e}")

res = json.loads(resp.content)
data = res['data']
if data['success'] == 1 and data['is_formula'] == 1 and data['detect_formula'] == 1:
latex = data['latex']
else:
latex = None
raise FormulaRecognitionError("Image is not recognized as a formula")

return latex


def ocr(src, is_base64=False, figure_instances: dict = None):
"""
Recognizes mathematical formulas within figures from a given source,
which can be either a base64 string or an identifier for a figure within a provided dictionary.
Parameters
----------
src : str
The source from which the figure is to be recognized.
It can be a base64 encoded string of the image if is_base64 is True,
or an identifier for the figure if is_base64 is False.
is_base64 : bool, optional
Indicates whether the src parameter is a base64 encoded string or an identifier, by default False.
figure_instances : dict, optional
A dictionary mapping figure identifiers to their corresponding PngImageFile, by default None.
This is only required and used if is_base64 is False.
Returns
-------
forumla_figure_latex : str or None
The LaTeX representation of the mathematical formula recognized within the figure.
Returns None if no formula is recognized or
if the figure_instances dictionary does not contain the specified figure identifier when is_base64 is False.
Examples
--------
>>> import os
>>> from PIL import Image
>>> from EduNLP.utils import abs_current_dir, path_append
>>> img_dir = os.path.abspath(path_append(abs_current_dir(__file__), "..", "..", "..", "asset", "_static"))
>>> image_PIL = Image.open(path_append(img_dir, "item_ocr_formula.png", to_str=True))
>>> figure_instances = {"1": image_PIL}
>>> src_id = r"$\\FormFigureID{1}$"
>>> print(ocr(src_id[1:-1], figure_instances=figure_instances))
f(x)=\\left (\\frac {1}{3}\\right )^{x}-\\sqrt {x}}
>>> import os
>>> from PIL import Image
>>> from EduNLP.utils import abs_current_dir, path_append, image2base64
>>> img_dir = os.path.abspath(path_append(abs_current_dir(__file__), "..", "..", "..", "asset", "_static"))
>>> image_PIL = Image.open(path_append(img_dir, "item_ocr_formula.png", to_str=True))
>>> image_base64 = image2base64(image_PIL)
>>> src_base64 = r"$\\FormFigureBase64{%s}$" % (image_base64)
>>> print(ocr(src_base64[1:-1], is_base64=True, figure_instances=True))
f(x)=\\left (\\frac {1}{3}\\right )^{x}-\\sqrt {x}}
Notes
-----
This function relies on `ocr_formula_figure` for the actual OCR (Optical Character Recognition) process.
Ensure that `ocr_formula_figure` is correctly implemented and can handle base64 encoded strings and PngImageFile.
"""
forumla_figure_latex = None
if is_base64:
figure = src[len(r"\FormFigureBase64") + 1: -1]
if figure_instances is not None:
forumla_figure_latex = ocr_formula_figure(figure, is_base64)
else:
figure = src[len(r"\FormFigureID") + 1: -1]
if figure_instances is not None:
figure = figure_instances[figure]
forumla_figure_latex = ocr_formula_figure(figure, is_base64)

return forumla_figure_latex
17 changes: 12 additions & 5 deletions EduNLP/SIF/segment/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import re
from contextlib import contextmanager
from ..constants import Symbol, TEXT_SYMBOL, FORMULA_SYMBOL, FIGURE_SYMBOL, QUES_MARK_SYMBOL, TAG_SYMBOL, SEP_SYMBOL
from ..parser.ocr import ocr


class TextSegment(str):
Expand Down Expand Up @@ -93,7 +94,7 @@ class SegmentList(object):
>>> SegmentList(test_item)
['如图所示,则三角形', 'ABC', '的面积是', '\\\\SIFBlank', '。', \\FigureID{1}]
"""
def __init__(self, item, figures: dict = None):
def __init__(self, item, figures: dict = None, convert_image_to_latex=False):
self._segments = []
self._text_segments = []
self._formula_segments = []
Expand All @@ -112,9 +113,15 @@ def __init__(self, item, figures: dict = None):
if not re.match(r"\$.+?\$", segment):
self.append(TextSegment(segment))
elif re.match(r"\$\\FormFigureID\{.+?}\$", segment):
self.append(FigureFormulaSegment(segment[1:-1], is_base64=False, figure_instances=figures))
if convert_image_to_latex:
self.append(LatexFormulaSegment(ocr(segment[1:-1], is_base64=False, figure_instances=figures)))
else:
self.append(FigureFormulaSegment(segment[1:-1], is_base64=False, figure_instances=figures))
elif re.match(r"\$\\FormFigureBase64\{.+?}\$", segment):
self.append(FigureFormulaSegment(segment[1:-1], is_base64=True, figure_instances=figures))
if convert_image_to_latex:
self.append(LatexFormulaSegment(ocr(segment[1:-1], is_base64=True, figure_instances=figures)))
else:
self.append(FigureFormulaSegment(segment[1:-1], is_base64=True, figure_instances=figures))
elif re.match(r"\$\\FigureID\{.+?}\$", segment):
self.append(FigureSegment(segment[1:-1], is_base64=False, figure_instances=figures))
elif re.match(r"\$\\FigureBase64\{.+?}\$", segment):
Expand Down Expand Up @@ -271,7 +278,7 @@ def describe(self):
}


def seg(item, figures=None, symbol=None):
def seg(item, figures=None, symbol=None, convert_image_to_latex=False):
r"""
It is a interface for SegmentList. And show it in an appropriate way.
Expand Down Expand Up @@ -346,7 +353,7 @@ def seg(item, figures=None, symbol=None):
>>> s2.text_segments
['已知', ',则以下说法中正确的是']
"""
segments = SegmentList(item, figures)
segments = SegmentList(item, figures, convert_image_to_latex)
if symbol is not None:
segments.symbolize(symbol)
return segments
4 changes: 2 additions & 2 deletions EduNLP/SIF/sif.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def to_sif(item, check_formula=True, parser: Parser = None):


def sif4sci(item: str, figures: (dict, bool) = None, mode: int = 2, symbol: str = None, tokenization=True,
tokenization_params=None, errors="raise"):
tokenization_params=None, convert_image_to_latex=False, errors="raise"):
r"""
Default to use linear Tokenizer, change the tokenizer by specifying tokenization_params
Expand Down Expand Up @@ -260,7 +260,7 @@ def sif4sci(item: str, figures: (dict, bool) = None, mode: int = 2, symbol: str
"Unknown mode %s, use only 0 or 1 or 2." % mode
)

ret = seg(item, figures, symbol)
ret = seg(item, figures, symbol, convert_image_to_latex)

if tokenization is True:
ret = tokenize(ret, **(tokenization_params if tokenization_params is not None else {}))
Expand Down
18 changes: 12 additions & 6 deletions EduNLP/Tokenizer/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ def __call__(self, items: Iterable, key=lambda x: x, **kwargs):
for item in items:
yield self._tokenize(item, key=key, **kwargs)

def _tokenize(self, item: Union[str, dict], key=lambda x: x, symbol: str = None, **kwargs):
def _tokenize(self, item: Union[str, dict], key=lambda x: x, symbol: str = None,
convert_image_to_latex=False, **kwargs):
"""Tokenize one item, return token list
Parameters
Expand All @@ -67,7 +68,8 @@ def _tokenize(self, item: Union[str, dict], key=lambda x: x, symbol: str = None,
determine how to get the text of item, by default lambdax: x
"""
symbol = self.symbol if symbol is None else symbol
return tokenize(seg(key(item), symbol=symbol, figures=self.figures),
return tokenize(seg(key(item), symbol=symbol, figures=self.figures,
convert_image_to_latex=convert_image_to_latex),
**self.tokenization_params, **kwargs).tokens


Expand Down Expand Up @@ -191,9 +193,11 @@ def __call__(self, items: Iterable, key=lambda x: x, **kwargs):
for item in items:
yield self._tokenize(item, key=key, **kwargs)

def _tokenize(self, item: Union[str, dict], key=lambda x: x, symbol: str = None, **kwargs):
def _tokenize(self, item: Union[str, dict], key=lambda x: x, symbol: str = None,
convert_image_to_latex=False, **kwargs):
symbol = self.symbol if symbol is None else symbol
return tokenize(seg(key(item), symbol=symbol), **self.tokenization_params, **kwargs).tokens
return tokenize(seg(key(item), symbol=symbol, convert_image_to_latex=convert_image_to_latex),
**self.tokenization_params, **kwargs).tokens


class AstFormulaTokenizer(Tokenizer):
Expand Down Expand Up @@ -235,11 +239,13 @@ def __call__(self, items: Iterable, key=lambda x: x, **kwargs):
for item in items:
yield self._tokenize(item, key=key, **kwargs)

def _tokenize(self, item: Union[str, dict], key=lambda x: x, symbol: str = None, **kwargs):
def _tokenize(self, item: Union[str, dict], key=lambda x: x, symbol: str = None,
convert_image_to_latex=False, **kwargs):
mode = kwargs.pop("mode", 0)
symbol = self.symbol if symbol is None else symbol
ret = sif4sci(key(item), figures=self.figures, mode=mode, symbol=symbol,
tokenization_params=self.tokenization_params, errors="ignore", **kwargs)
tokenization_params=self.tokenization_params, convert_image_to_latex=convert_image_to_latex,
errors="ignore", **kwargs)
ret = [] if ret is None else ret.tokens
return ret

Expand Down
Binary file added asset/_static/item_ocr_formula.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
62 changes: 62 additions & 0 deletions tests/test_sif/test_ocr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# 2024/3/5 @ yuheng

import pytest
import json

from EduNLP.SIF.segment import seg
from EduNLP.SIF.parser.ocr import ocr_formula_figure, FormulaRecognitionError
from unittest.mock import patch


def test_ocr(figure0, figure1, figure0_base64, figure1_base64):
seg(
r"如图所示,则$\FormFigureID{0}$的面积是$\SIFBlank$。$\FigureID{1}$",
figures={
"0": figure0,
"1": figure1
},
convert_image_to_latex=True
)
s = seg(
r"如图所示,则$\FormFigureBase64{%s}$的面积是$\SIFBlank$。$\FigureBase64{%s}$" % (figure0_base64, figure1_base64),
figures=True,
convert_image_to_latex=True
)
with pytest.raises(TypeError):
s.append("123")
seg_test_text = seg(
r"如图所示,有三组$\textf{机器人,bu}$在踢$\textf{足球,b}$",
figures=True
)
assert seg_test_text.text_segments == ['如图所示,有三组机器人在踢足球']


def test_ocr_formula_figure_exceptions(figure0_base64):
"""Simulate a non-200 status code"""
with patch('EduNLP.SIF.parser.ocr.requests.post') as mock_post:
mock_post.return_value.status_code = 404
with pytest.raises(FormulaRecognitionError) as exc_info:
ocr_formula_figure(figure0_base64, is_base64=True)
assert "HTTP error 404" in str(exc_info.value)

"""Simulate an invalid JSON response"""
with patch('EduNLP.SIF.parser.ocr.requests.post') as mock_post:
mock_post.return_value.status_code = 200
mock_post.return_value.content = b"invalid_json_response"
with pytest.raises(FormulaRecognitionError) as exc_info:
ocr_formula_figure(figure0_base64, is_base64=True)
assert "Error processing response" in str(exc_info.value)

"""Simulate image not recognized as a formula"""
with patch('EduNLP.SIF.parser.ocr.requests.post') as mock_post:
mock_post.return_value.status_code = 200
mock_post.return_value.content = json.dumps({
"data": {
'success': 1,
'is_formula': 0,
'detect_formula': 0
}
}).encode('utf-8')
with pytest.raises(FormulaRecognitionError) as exc_info:
ocr_formula_figure(figure0_base64, is_base64=True)
assert "Image is not recognized as a formula" in str(exc_info.value)

0 comments on commit 855e250

Please sign in to comment.