forked from clovaai/CRAFT-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
56f71a3
commit 72793b4
Showing
1 changed file
with
138 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
import sys | ||
import os | ||
import time | ||
import argparse | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.backends.cudnn as cudnn | ||
from torch.autograd import Variable | ||
|
||
from PIL import Image | ||
|
||
import cv2 | ||
from skimage import io | ||
import numpy as np | ||
import craft_utils | ||
import test | ||
import imgproc | ||
import file_utils | ||
import json | ||
import zipfile | ||
import pandas as pd | ||
|
||
from craft import CRAFT | ||
|
||
from collections import OrderedDict | ||
|
||
# from google.colab.patches import cv2_imshow | ||
|
||
def str2bool(v): | ||
return v.lower() in ("yes", "y", "true", "t", "1") | ||
|
||
#CRAFT | ||
parser = argparse.ArgumentParser(description='CRAFT Text Detection') | ||
parser.add_argument('--trained_model', default='weights/craft_mlt_25k.pth', type=str, help='pretrained model') | ||
parser.add_argument('--text_threshold', default=0.7, type=float, help='text confidence threshold') | ||
parser.add_argument('--low_text', default=0.4, type=float, help='text low-bound score') | ||
parser.add_argument('--link_threshold', default=0.4, type=float, help='link confidence threshold') | ||
parser.add_argument('--cuda', default=True, type=str2bool, help='Use cuda for inference') | ||
parser.add_argument('--canvas_size', default=1280, type=int, help='image size for inference') | ||
parser.add_argument('--mag_ratio', default=1.5, type=float, help='image magnification ratio') | ||
parser.add_argument('--poly', default=False, action='store_true', help='enable polygon type') | ||
parser.add_argument('--show_time', default=False, action='store_true', help='show processing time') | ||
parser.add_argument('--test_folder', default='/data/', type=str, help='folder path to input images') | ||
parser.add_argument('--refine', default=False, action='store_true', help='enable link refiner') | ||
parser.add_argument('--refiner_model', default='weights/craft_refiner_CTW1500.pth', type=str, help='pretrained refiner model') | ||
|
||
args = parser.parse_args() | ||
|
||
|
||
""" For test images in a folder """ | ||
image_list, _, _ = file_utils.get_files(args.test_folder) | ||
|
||
image_names = [] | ||
image_paths = [] | ||
|
||
#CUSTOMISE START | ||
start = args.test_folder | ||
|
||
for num in range(len(image_list)): | ||
image_names.append(os.path.relpath(image_list[num], start)) | ||
|
||
|
||
result_folder = 'Results' | ||
if not os.path.isdir(result_folder): | ||
os.mkdir(result_folder) | ||
|
||
if __name__ == '__main__': | ||
|
||
data=pd.DataFrame(columns=['image_name', 'word_bboxes', 'pred_words', 'align_text']) | ||
data['image_name'] = image_names | ||
|
||
# load net | ||
net = CRAFT() # initialize | ||
|
||
print('Loading weights from checkpoint (' + args.trained_model + ')') | ||
if args.cuda: | ||
net.load_state_dict(test.copyStateDict(torch.load(args.trained_model))) | ||
else: | ||
net.load_state_dict(test.copyStateDict(torch.load(args.trained_model, map_location='cpu'))) | ||
|
||
if args.cuda: | ||
net = net.cuda() | ||
net = torch.nn.DataParallel(net) | ||
cudnn.benchmark = False | ||
|
||
net.eval() | ||
|
||
# LinkRefiner | ||
refine_net = None | ||
if args.refine: | ||
from refinenet import RefineNet | ||
refine_net = RefineNet() | ||
print('Loading weights of refiner from checkpoint (' + args.refiner_model + ')') | ||
if args.cuda: | ||
refine_net.load_state_dict(test.copyStateDict(torch.load(args.refiner_model))) | ||
refine_net = refine_net.cuda() | ||
refine_net = torch.nn.DataParallel(refine_net) | ||
else: | ||
refine_net.load_state_dict(test.copyStateDict(torch.load(args.refiner_model, map_location='cpu'))) | ||
|
||
refine_net.eval() | ||
args.poly = True | ||
|
||
t = time.time() | ||
|
||
|
||
cap = cv2.VideoCapture(0) | ||
|
||
if not cap.isOpened(): | ||
print("Cannot open camera") | ||
exit() | ||
|
||
while True: | ||
# Capture frame-by-frame | ||
ret, frame = cap.read() | ||
# if frame is read correctly ret is True | ||
if not ret: | ||
print("Can't receive frame (stream end?). Exiting ...") | ||
break | ||
|
||
# Run through model | ||
bboxes, polys, score_text, det_scores = test.test_net(net, frame, args.text_threshold, args.link_threshold, args.low_text, args.cuda, args.poly, args, refine_net) | ||
|
||
for i, box in enumerate(polys): | ||
poly = np.array(box).astype(np.int32).reshape((-1)) | ||
strResult = ','.join([str(p) for p in poly]) + '\r\n' | ||
|
||
poly = poly.reshape(-1, 2) | ||
cv2.polylines(frame, [poly.reshape((-1, 1, 2))], True, color=(0, 0, 255), thickness=1) | ||
|
||
|
||
cv2.imshow('webcam_inference.py', frame) | ||
if cv2.waitKey(1) == ord('q'): | ||
break | ||
# When everything done, release the capture | ||
cap.release() | ||
cv2.destroyAllWindows() |