Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add Dataset Preparer #1484

Merged
merged 20 commits into from
Nov 2, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add tests for parsers and dumpers
  • Loading branch information
xinke-wang committed Oct 31, 2022
commit 28dc21ef8a2f3d71e430c4a470dd9864d24d50d7
21 changes: 18 additions & 3 deletions mmocr/datasets/preparers/dumpers/dumpers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from typing import List
from typing import Dict, List

import mmengine

Expand All @@ -13,9 +13,16 @@ class JsonDumper:

def __init__(self, task: str) -> None:
self.task = task
self.format = format

def dump(self, data: List, data_root: str, split: str) -> None:
def dump(self, data: Dict, data_root: str, split: str) -> None:
"""Dump data to json file.

Args:
data (Dict): Data to be dumped.
data_root (str): Root directory of data.
split (str): Split of data.
"""

dst_file = osp.join(data_root, f'{self.task}_{split}.json')
mmengine.dump(data, dst_file)

Expand All @@ -27,4 +34,12 @@ def __init__(self, task: str) -> None:
self.task = task

def dump(self, data: List, data_root: str, split: str) -> None:
"""Dump data to txt file.

Args:
data (List): Data to be dumped.
data_root (str): Root directory of data.
split (str): Split of data.
"""

list_to_file(osp.join(data_root, f'openset_{split}.txt'), data)
2 changes: 1 addition & 1 deletion mmocr/datasets/preparers/parsers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from .coco_parser import COCOTextDetAnnParser
from .ic15_parser import ICDAR2015TextDetAnnParser, ICDAR2015TextRecogAnnParser
from .totaltext_parser import TotaltextTextDetAnnParser
from .wildreceipt import WildreceiptKIEAnnParser
from .wildreceipt_parser import WildreceiptKIEAnnParser

__all__ = [
'ICDAR2015TextDetAnnParser', 'ICDAR2015TextRecogAnnParser',
Expand Down
38 changes: 38 additions & 0 deletions tests/test_datasets/test_preparers/test_dumpers/test_dumpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright (c) OpenMMLab. All rights reserved.
import json
import os.path as osp
import tempfile
import unittest

from mmocr.datasets.preparers.dumpers import (JsonDumper,
WildreceiptOpensetDumper)


class TestDumpers(unittest.TestCase):

def setUp(self) -> None:
self.root = tempfile.TemporaryDirectory()

def test_json_dumpers(self):
task, split = 'textdet', 'train'
fake_data = dict(
metainfo=dict(
dataset_type='TextDetDataset',
task_name='textdet',
category=[dict(id=0, name='text')]))

dumper = JsonDumper(task)
dumper.dump(fake_data, self.root.name, split)
with open(osp.join(self.root.name, f'{task}_{split}.json'), 'r') as f:
data = json.load(f)
self.assertEqual(data, fake_data)

def test_wildreceipt_dumper(self):
task, split = 'kie', 'train'
fake_data = ['test1', 'test2']

dumper = WildreceiptOpensetDumper(task)
dumper.dump(fake_data, self.root.name, split)
with open(osp.join(self.root.name, f'openset_{split}.txt'), 'r') as f:
data = f.read().splitlines()
self.assertEqual(data, fake_data)
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import tempfile
import unittest

from mmocr.datasets.preparers.parsers.ic15_parser import (
ICDAR2015TextDetAnnParser, ICDAR2015TextRecogAnnParser)
from mmocr.utils import list_to_file


class TestIC15Parsers(unittest.TestCase):

def setUp(self) -> None:
self.root = tempfile.TemporaryDirectory()

def _create_dummy_ic15_det(self):
fake_anno = [
'377,117,463,117,465,130,378,130,Genaxis Theatre',
'493,115,519,115,519,131,493,131,[06]',
'374,155,409,155,409,170,374,170,###',
]
ann_file = osp.join(self.root.name, 'ic15_det.txt')
list_to_file(ann_file, fake_anno)
return (osp.join(self.root.name, 'ic15_det.jpg'), ann_file)

def _create_dummy_ic15_recog(self):
fake_anno = [
'word_1.png, "Genaxis Theatre"',
'word_2.png, "[06]"',
'word_3.png, "62-03"',
]
ann_file = osp.join(self.root.name, 'ic15_recog.txt')
list_to_file(ann_file, fake_anno)
return ann_file

def test_textdet_parsers(self):
parser = ICDAR2015TextDetAnnParser()
file = self._create_dummy_ic15_det()
img, instances = parser.parse_file(file, 'train')
self.assertEqual(img, file[0])
self.assertEqual(len(instances), 3)
self.assertIn('poly', instances[0])
self.assertIn('text', instances[0])
self.assertIn('ignore', instances[0])
self.assertEqual(instances[0]['text'], 'Genaxis Theatre')
self.assertEqual(instances[2]['ignore'], True)

def test_textrecog_parsers(self):
parser = ICDAR2015TextRecogAnnParser()
file = self._create_dummy_ic15_recog()
samples = parser.parse_files(file, 'train')
self.assertEqual(len(samples), 3)
img, text = samples[0]
self.assertEqual(img, 'word_1.png')
self.assertEqual(text, 'Genaxis Theatre')
36 changes: 36 additions & 0 deletions tests/test_datasets/test_preparers/test_parsers/test_tt_parsers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import tempfile
import unittest

from mmocr.datasets.preparers.parsers.totaltext_parser import \
TotaltextTextDetAnnParser
from mmocr.utils import list_to_file


class TestTTParsers(unittest.TestCase):

def setUp(self) -> None:
self.root = tempfile.TemporaryDirectory()

def _create_dummy_tt_det(self):
fake_anno = [
"x: [[ 53 120 121 56]], y: [[446 443 456 458]], ornt: [u'h'], transcriptions: [u'PERUNDING']", # noqa: E501
"x: [[123 165 166 125]], y: [[443 440 453 455]], ornt: [u'h'], transcriptions: [u'PENILAI']", # noqa: E501
"x: [[168 179 179 167]], y: [[439 439 452 453]], ornt: [u'#'], transcriptions: [u'#']", # noqa: E501
]
ann_file = osp.join(self.root.name, 'tt_det.txt')
list_to_file(ann_file, fake_anno)
return (osp.join(self.root.name, 'tt_det.jpg'), ann_file)

def test_textdet_parsers(self):
parser = TotaltextTextDetAnnParser(self.root.name)
file = self._create_dummy_tt_det()
img, instances = parser.parse_file(file, 'train')
self.assertEqual(img, file[0])
self.assertEqual(len(instances), 3)
self.assertIn('poly', instances[0])
self.assertIn('text', instances[0])
self.assertIn('ignore', instances[0])
self.assertEqual(instances[0]['text'], 'PERUNDING')
self.assertEqual(instances[2]['ignore'], True)
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright (c) OpenMMLab. All rights reserved.
import json
import os.path as osp
import tempfile
import unittest

from mmocr.datasets.preparers.parsers.wildreceipt_parser import (
WildreceiptKIEAnnParser, WildreceiptTextDetAnnParser)
from mmocr.utils import list_to_file


class TestWildReceiptParsers(unittest.TestCase):
gaotongxiao marked this conversation as resolved.
Show resolved Hide resolved

def setUp(self) -> None:
self.root = tempfile.TemporaryDirectory()
fake_sample = dict(
file_name='test.jpg',
height=100,
width=100,
annotations=[
dict(
box=[
550.0, 190.0, 937.0, 190.0, 937.0, 104.0, 550.0, 104.0
],
text='test',
label=1,
),
dict(
box=[
1048.0, 211.0, 1074.0, 211.0, 1074.0, 196.0, 1048.0,
196.0
],
text='ATOREMGRTOMMILAZZO',
label=0,
)
])
fake_sample = [json.dumps(fake_sample)]
self.anno = osp.join(self.root.name, 'wildreceipt.txt')
list_to_file(self.anno, fake_sample)

def test_textdet_parsers(self):
parser = WildreceiptTextDetAnnParser(self.root.name)
samples = parser.parse_files(self.anno, 'train')
self.assertEqual(len(samples), 1)
self.assertEqual(osp.basename(samples[0][0]), 'test.jpg')
instances = samples[0][1]
self.assertEqual(len(instances), 2)
self.assertIn('poly', instances[0])
self.assertIn('text', instances[0])
self.assertIn('ignore', instances[0])
self.assertEqual(instances[0]['text'], 'test')
self.assertEqual(instances[1]['ignore'], True)

def test_kie_parsers(self):
parser = WildreceiptKIEAnnParser(self.root.name)
samples = parser.parse_files(self.anno, 'train')
self.assertEqual(len(samples), 1)