Skip to content

Commit

Permalink
Updated columns orientation classifier (#335)
Browse files Browse the repository at this point in the history
* updated txt columns orientation classifier

* deleted "no_lines" parameter

---------

Co-authored-by: Alexander Golodkov <golodkov@ispras.ru>
  • Loading branch information
alexander1999-hub and Alexander Golodkov authored Sep 26, 2023
1 parent 35c1c41 commit 95c38ac
Show file tree
Hide file tree
Showing 9 changed files with 8 additions and 44 deletions.
2 changes: 1 addition & 1 deletion dedoc/download_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
"""
model_hash_dict = dict(
txtlayer_classifier="94e27e184fa2876883d260e0aa58b042e6ab3e35",
scan_orientation_efficient_net_b0="0160965f8a920d12afacf62b8a5a8a3b365b11ef",
scan_orientation_efficient_net_b0="9ea283f3d346ae4fdd82463a9f60b5369a3ffb58",
font_classifier="db4481ad60ab050cbb42079b64f97f9e431feb07",
paragraph_classifier="00bf989876cec171c1cf9859a6b712af6445e864",
line_type_classifiers="2e498d1ec82b72c1a96ba0d25344b71402997013"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,8 @@
from torchvision import transforms
from torchvision.transforms.functional import resize

from dedoc.config import get_config
from dedoc.download_models import download_from_hub
from dedoc.readers.pdf_reader.pdf_image_reader.columns_orientation_classifier.model import ClassificationModelTorch
from dedoc.readers.pdf_reader.pdf_image_reader.table_recognizer.table_utils.img_processing import \
__detect_horizontal_and_vertical_lines as detect_horizontal_and_vertical_lines


class ColumnsOrientationClassifier(object):
Expand All @@ -25,13 +22,12 @@ class ColumnsOrientationClassifier(object):

_nets = {}

def __init__(self, on_gpu: bool, checkpoint_path: Optional[str], delete_lines: bool, *, config: dict) -> None:
def __init__(self, on_gpu: bool, checkpoint_path: Optional[str], *, config: dict) -> None:
self.logger = config.get("logger", logging.getLogger())
self._set_device(on_gpu)
self._set_transform_image()
self.checkpoint_path = path.abspath(checkpoint_path)
self.classes = [1, 2, 0, 90, 180, 270]
self.no_lines = delete_lines

@property
def net(self) -> ClassificationModelTorch:
Expand Down Expand Up @@ -93,24 +89,9 @@ def _set_transform_image(self) -> None:
mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

def get_features_no_lines(self, image: np.array) -> torch.Tensor:
"""
Get features for image without horizontal and vertical lines
"""
image = self.transform(Image.fromarray(np.uint8(image)))
item_np = (image.numpy().transpose(1, 2, 0) * 255).astype(np.uint8)
img = cv2.cvtColor(item_np, cv2.COLOR_BGR2GRAY)
(thresh, img_bin) = cv2.threshold(img, 245, 255, cv2.THRESH_BINARY)
img_bin = 255 - img_bin
img_lines_bin = detect_horizontal_and_vertical_lines(img_bin, get_config(), "orientation")
img_final = 255 - img * img_lines_bin
backtorgb = cv2.cvtColor(img_final, cv2.COLOR_GRAY2RGB).transpose(2, 0, 1) / 255.
image = torch.tensor(backtorgb).unsqueeze(0).float().to(self.device)
return image

def get_features(self, image: np.array) -> torch.Tensor:
"""
Get features for image with horizontal and vertical lines
Get features for the image
"""
image = cv2.cvtColor(np.array(image), cv2.COLOR_BGR2RGB)
pil_image = Image.fromarray(np.uint8(image)).convert("RGB")
Expand All @@ -123,11 +104,7 @@ def predict(self, image: np.ndarray) -> Tuple[int, int]:
"""
self.net.eval()
with torch.no_grad():

if self.no_lines:
tensor_image = self.get_features_no_lines(image)
else:
tensor_image = self.get_features(image)
tensor_image = self.get_features(image)
outputs = self.net(tensor_image)
# first 2 classes mean columns number
# last 4 classes mean orientation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,7 @@ def __init__(self, *, config: dict) -> None:
self.scan_rotator = ScanRotator(config=config)
self.column_orientation_classifier = ColumnsOrientationClassifier(on_gpu=False,
checkpoint_path=get_config()["resources_path"],
config=config,
delete_lines=False)
config=config)
self.binarizer = AdaptiveBinarizer()
self.ocr = OCRLineExtractor(config=config)
self.logger = config.get("logger", logging.getLogger())
Expand Down
12 changes: 0 additions & 12 deletions dedoc/scripts/train/train_acc_orientation_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from time import time
from typing import List

import numpy as np
import torch
from torch import nn
from torch import optim
Expand Down Expand Up @@ -75,11 +74,6 @@ def calc_accuracy_by_classes(testloader: DataLoader,
images, orientation, columns = data['image'], data['orientation'], data['columns']
time_begin = time()

if classifier.no_lines:
for i in range(len(images)):
image_w_o_lines = classifier.get_features_no_lines((images[i].numpy().transpose(1, 2, 0) * 255).astype(np.uint8))
images[i] = torch.squeeze(image_w_o_lines, 0)

outputs = classifier.net(images.float().to(classifier.device))
time_predict += time() - time_begin
cnt_predict += len(images)
Expand Down Expand Up @@ -135,11 +129,6 @@ def train_model(trainloader: DataLoader,
# get the inputs; data is a list of [inputs, labels]
inputs, orientation, columns = data['image'], data['orientation'], data['columns']

if classifier.no_lines:
for j in range(len(inputs)):
image_w_o_lines = classifier.get_features_no_lines((inputs[j].numpy().transpose(1, 2, 0) * 255).astype(np.uint8))
inputs[j] = torch.squeeze(image_w_o_lines, 0)

# zero the parameter gradients
optimizer.zero_grad()

Expand Down Expand Up @@ -193,7 +182,6 @@ def train_step(data_executor: DataLoaderImageOrient, classifier: ColumnsOrientat
data_executor = DataLoaderImageOrient()
net = ColumnsOrientationClassifier(on_gpu=True,
checkpoint_path=checkpoint_path if not args.train else None,
delete_lines=False,
config=config)

if args.train:
Expand Down
Binary file added tests/data/scanned/orient_5.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/data/scanned/orient_6.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/data/scanned/orient_7.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/data/scanned/orient_8.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
6 changes: 3 additions & 3 deletions tests/unit_tests/test_format_pdf_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
class TestPDFReader(unittest.TestCase):
checkpoint_path = get_test_config()["resources_path"]
config = get_test_config()
orientation_classifier = ColumnsOrientationClassifier(on_gpu=False, checkpoint_path=checkpoint_path, delete_lines=True, config=config)
orientation_classifier = ColumnsOrientationClassifier(on_gpu=False, checkpoint_path=checkpoint_path, config=config)

def _split_lines_on_pages(self, lines: List[LineWithMeta]) -> List[List[str]]:
pages = set(map(lambda x: x.metadata.page_id, lines))
Expand All @@ -41,8 +41,8 @@ def test_scan_rotator(self) -> None:

def test_scan_orientation(self) -> None:
scan_rotator = ScanRotator(config=get_test_config())
imgs_path = [f"../data/scanned/orient_{i}.png"for i in range(1, 5)]
angles = [90.0, 90.0, 270.0, 270.0]
imgs_path = [f"../data/scanned/orient_{i}.png"for i in range(1, 9)]
angles = [90.0, 90.0, 270.0, 270.0, 180.0, 270.0, 180.0, 270.0]
max_delta = 10.0
for i in range(len(imgs_path)):
path = os.path.join(os.path.dirname(__file__), imgs_path[i])
Expand Down

0 comments on commit 95c38ac

Please sign in to comment.