Skip to content

Commit

Permalink
Create webcam_inference.py
Browse files Browse the repository at this point in the history
  • Loading branch information
hojinchang committed Oct 28, 2022
1 parent 56f71a3 commit 72793b4
Showing 1 changed file with 138 additions and 0 deletions.
138 changes: 138 additions & 0 deletions webcam_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import sys
import os
import time
import argparse

import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.autograd import Variable

from PIL import Image

import cv2
from skimage import io
import numpy as np
import craft_utils
import test
import imgproc
import file_utils
import json
import zipfile
import pandas as pd

from craft import CRAFT

from collections import OrderedDict

# from google.colab.patches import cv2_imshow

def str2bool(v):
return v.lower() in ("yes", "y", "true", "t", "1")

#CRAFT
parser = argparse.ArgumentParser(description='CRAFT Text Detection')
parser.add_argument('--trained_model', default='weights/craft_mlt_25k.pth', type=str, help='pretrained model')
parser.add_argument('--text_threshold', default=0.7, type=float, help='text confidence threshold')
parser.add_argument('--low_text', default=0.4, type=float, help='text low-bound score')
parser.add_argument('--link_threshold', default=0.4, type=float, help='link confidence threshold')
parser.add_argument('--cuda', default=True, type=str2bool, help='Use cuda for inference')
parser.add_argument('--canvas_size', default=1280, type=int, help='image size for inference')
parser.add_argument('--mag_ratio', default=1.5, type=float, help='image magnification ratio')
parser.add_argument('--poly', default=False, action='store_true', help='enable polygon type')
parser.add_argument('--show_time', default=False, action='store_true', help='show processing time')
parser.add_argument('--test_folder', default='/data/', type=str, help='folder path to input images')
parser.add_argument('--refine', default=False, action='store_true', help='enable link refiner')
parser.add_argument('--refiner_model', default='weights/craft_refiner_CTW1500.pth', type=str, help='pretrained refiner model')

args = parser.parse_args()


""" For test images in a folder """
image_list, _, _ = file_utils.get_files(args.test_folder)

image_names = []
image_paths = []

#CUSTOMISE START
start = args.test_folder

for num in range(len(image_list)):
image_names.append(os.path.relpath(image_list[num], start))


result_folder = 'Results'
if not os.path.isdir(result_folder):
os.mkdir(result_folder)

if __name__ == '__main__':

data=pd.DataFrame(columns=['image_name', 'word_bboxes', 'pred_words', 'align_text'])
data['image_name'] = image_names

# load net
net = CRAFT() # initialize

print('Loading weights from checkpoint (' + args.trained_model + ')')
if args.cuda:
net.load_state_dict(test.copyStateDict(torch.load(args.trained_model)))
else:
net.load_state_dict(test.copyStateDict(torch.load(args.trained_model, map_location='cpu')))

if args.cuda:
net = net.cuda()
net = torch.nn.DataParallel(net)
cudnn.benchmark = False

net.eval()

# LinkRefiner
refine_net = None
if args.refine:
from refinenet import RefineNet
refine_net = RefineNet()
print('Loading weights of refiner from checkpoint (' + args.refiner_model + ')')
if args.cuda:
refine_net.load_state_dict(test.copyStateDict(torch.load(args.refiner_model)))
refine_net = refine_net.cuda()
refine_net = torch.nn.DataParallel(refine_net)
else:
refine_net.load_state_dict(test.copyStateDict(torch.load(args.refiner_model, map_location='cpu')))

refine_net.eval()
args.poly = True

t = time.time()


cap = cv2.VideoCapture(0)

if not cap.isOpened():
print("Cannot open camera")
exit()

while True:
# Capture frame-by-frame
ret, frame = cap.read()
# if frame is read correctly ret is True
if not ret:
print("Can't receive frame (stream end?). Exiting ...")
break

# Run through model
bboxes, polys, score_text, det_scores = test.test_net(net, frame, args.text_threshold, args.link_threshold, args.low_text, args.cuda, args.poly, args, refine_net)

for i, box in enumerate(polys):
poly = np.array(box).astype(np.int32).reshape((-1))
strResult = ','.join([str(p) for p in poly]) + '\r\n'

poly = poly.reshape(-1, 2)
cv2.polylines(frame, [poly.reshape((-1, 1, 2))], True, color=(0, 0, 255), thickness=1)


cv2.imshow('webcam_inference.py', frame)
if cv2.waitKey(1) == ord('q'):
break
# When everything done, release the capture
cap.release()
cv2.destroyAllWindows()

0 comments on commit 72793b4

Please sign in to comment.