forked from clovaai/CRAFT-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
pipeline.py
135 lines (101 loc) · 4.5 KB
/
pipeline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import sys
import os
import time
import argparse
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.autograd import Variable
from PIL import Image
import cv2
from skimage import io
import numpy as np
import craft_utils
import test
import imgproc
import file_utils
import json
import zipfile
import pandas as pd
from craft import CRAFT
from collections import OrderedDict
# from google.colab.patches import cv2_imshow
def str2bool(v):
return v.lower() in ("yes", "y", "true", "t", "1")
#CRAFT
parser = argparse.ArgumentParser(description='CRAFT Text Detection')
parser.add_argument('--trained_model', default='weights/craft_mlt_25k.pth', type=str, help='pretrained model')
parser.add_argument('--text_threshold', default=0.7, type=float, help='text confidence threshold')
parser.add_argument('--low_text', default=0.4, type=float, help='text low-bound score')
parser.add_argument('--link_threshold', default=0.4, type=float, help='link confidence threshold')
parser.add_argument('--cuda', default=True, type=str2bool, help='Use cuda for inference')
parser.add_argument('--canvas_size', default=1280, type=int, help='image size for inference')
parser.add_argument('--mag_ratio', default=1.5, type=float, help='image magnification ratio')
parser.add_argument('--poly', default=False, action='store_true', help='enable polygon type')
parser.add_argument('--show_time', default=False, action='store_true', help='show processing time')
parser.add_argument('--test_folder', default='/data/', type=str, help='folder path to input images')
parser.add_argument('--refine', default=False, action='store_true', help='enable link refiner')
parser.add_argument('--refiner_model', default='weights/craft_refiner_CTW1500.pth', type=str, help='pretrained refiner model')
args = parser.parse_args()
""" For test images in a folder """
image_list, _, _ = file_utils.get_files(args.test_folder)
image_names = []
image_paths = []
#CUSTOMISE START
start = args.test_folder
for num in range(len(image_list)):
image_names.append(os.path.relpath(image_list[num], start))
result_folder = 'Results'
if not os.path.isdir(result_folder):
os.mkdir(result_folder)
if __name__ == '__main__':
data=pd.DataFrame(columns=['image_name', 'word_bboxes', 'pred_words', 'align_text'])
data['image_name'] = image_names
# load net
net = CRAFT() # initialize
print('Loading weights from checkpoint (' + args.trained_model + ')')
if args.cuda:
net.load_state_dict(test.copyStateDict(torch.load(args.trained_model)))
else:
net.load_state_dict(test.copyStateDict(torch.load(args.trained_model, map_location='cpu')))
if args.cuda:
net = net.cuda()
net = torch.nn.DataParallel(net)
cudnn.benchmark = False
net.eval()
# LinkRefiner
refine_net = None
if args.refine:
from refinenet import RefineNet
refine_net = RefineNet()
print('Loading weights of refiner from checkpoint (' + args.refiner_model + ')')
if args.cuda:
refine_net.load_state_dict(test.copyStateDict(torch.load(args.refiner_model)))
refine_net = refine_net.cuda()
refine_net = torch.nn.DataParallel(refine_net)
else:
refine_net.load_state_dict(test.copyStateDict(torch.load(args.refiner_model, map_location='cpu')))
refine_net.eval()
args.poly = True
t = time.time()
# load data
for k, image_path in enumerate(image_list):
print("Test image {:d}/{:d}: {:s}".format(k+1, len(image_list), image_path), end='\r')
image = imgproc.loadImage(image_path)
bboxes, polys, score_text, det_scores = test.test_net(net, image, args.text_threshold, args.link_threshold, args.low_text, args.cuda, args.poly, args, refine_net)
bbox_score={}
for box_num in range(len(bboxes)):
key = str (det_scores[box_num])
item = bboxes[box_num]
bbox_score[key]=item
data['word_bboxes'][k]=bbox_score
# save score text
filename, file_ext = os.path.splitext(os.path.basename(image_path))
mask_file = result_folder + "/res_" + filename + '_mask.jpg'
cv2.imwrite(mask_file, score_text)
file_utils.saveResult(image_path, image[:,:,::-1], polys, dirname=result_folder)
output_dir = r'content\Pipeline'
if not os.path.exists(output_dir):
os.makedirs(output_dir)
data.to_csv(os.path.join(output_dir, 'data.csv'), sep = ',', na_rep='Unknown')
print("elapsed time : {}s".format(time.time() - t))