Skip to content

Commit

Permalink
update 0914 comboBox
Browse files Browse the repository at this point in the history
  • Loading branch information
JackieZhai committed Sep 14, 2021
1 parent ef3f952 commit 1deb8b8
Showing 1 changed file with 49 additions and 81 deletions.
130 changes: 49 additions & 81 deletions pyqt5/window_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,111 +32,79 @@ def __init__(self, mainWin, progressSignal, parent=None):
super().__init__(parent=parent)
self.mainWin = mainWin
self.progressSignal = progressSignal

def _pysot_init(self, config, snapshot):
# Import depending packages
from pysot.core.config import cfg
from pysot.models.model_builder import ModelBuilder
from pysot.tracker.tracker_builder import build_tracker
from pysot.tracker.multiple_tracker import MultiTracker
cfg.merge_from_file(config)
cfg.CUDA = torch.cuda.is_available()
device = torch.device('cuda' if cfg.CUDA else 'cpu')
self.progressSignal.emit(20)
model = ModelBuilder()
self.progressSignal.emit(40)
model.load_state_dict(torch.load(snapshot, map_location=lambda storage, loc: storage.cpu()))
self.progressSignal.emit(60)
model.eval().to(device)
self.progressSignal.emit(80)
self.mainWin.tracker = MultiTracker(build_tracker(model), cfg)

def _pytracking_init(self, tracker_name, parameter_name):
# Import depending packages
from pytracking.evaluation.multi_object_wrapper import MultiObjectWrapper
tracker_module_abspath = path.abspath(path.join('./pytracking', 'tracker', tracker_name))
if path.isdir(tracker_module_abspath):
tracker_module = importlib.import_module('pytracking.tracker.{}'.format(tracker_name))
tracker_class = tracker_module.get_tracker_class()
else:
tracker_class = None
param_module = importlib.import_module('pytracking.parameter.{}.{}'.format(tracker_name, parameter_name))
self.progressSignal.emit(20)
params = param_module.parameters()
self.progressSignal.emit(40)
params.tracker_name = tracker_name
self.progressSignal.emit(60)
params.param_name = parameter_name
self.progressSignal.emit(80)
multiobj_mode = getattr(params, 'multiobj_mode', getattr(tracker_class, 'multiobj_mode', 'default'))
if multiobj_mode == 'default':
self.mainWin.tracker = tracker_class(params)
if hasattr(self.mainWin.tracker, 'initialize_features'):
self.mainWin.tracker.initialize_features()
elif multiobj_mode == 'parallel':
self.mainWin.tracker = MultiObjectWrapper(tracker_class, params, fast_load=True)
else:
raise ValueError('Unknown multi object mode {}'.format(multiobj_mode))

def run(self):
self.progressSignal.emit(0)
# Initialize trackers
if self.mainWin.model_name == 'SiamMask_E':
self.progressSignal.emit(10)
# Import depending packages
from pysot.core.config import cfg
from pysot.models.model_builder import ModelBuilder
from pysot.tracker.tracker_builder import build_tracker
from pysot.tracker.multiple_tracker import MultiTracker
model_location = './pysot/experiments/siammaske_r50_l3'
config = model_location + '/config.yaml'
snapshot = model_location + '/model.pth'
cfg.merge_from_file(config)
cfg.CUDA = torch.cuda.is_available()
device = torch.device('cuda' if cfg.CUDA else 'cpu')
self.progressSignal.emit(20)
model = ModelBuilder()
self.progressSignal.emit(40)
model.load_state_dict(torch.load(snapshot, map_location=lambda storage, loc: storage.cpu()))
self.progressSignal.emit(60)
model.eval().to(device)
self.progressSignal.emit(80)
self.mainWin.tracker = MultiTracker(build_tracker(model), cfg)
self._pysot_init(config, snapshot)
self.progressSignal.emit(100)
elif self.mainWin.model_name == 'KYS':
self.progressSignal.emit(10)
# Import depending packages
from pytracking.evaluation.multi_object_wrapper import MultiObjectWrapper
tracker_name = 'kys'
parameter_name = 'default'
tracker_module_abspath = path.abspath(path.join('./pytracking', 'tracker', tracker_name))
if path.isdir(tracker_module_abspath):
tracker_module = importlib.import_module('pytracking.tracker.{}'.format(tracker_name))
tracker_class = tracker_module.get_tracker_class()
else:
tracker_class = None
param_module = importlib.import_module('pytracking.parameter.{}.{}'.format(tracker_name, parameter_name))
params = param_module.parameters()
params.tracker_name = tracker_name
params.param_name = parameter_name
multiobj_mode = getattr(params, 'multiobj_mode', getattr(tracker_class, 'multiobj_mode', 'default'))
if multiobj_mode == 'default':
self.mainWin.tracker = tracker_class(params)
if hasattr(self.mainWin.tracker, 'initialize_features'):
self.mainWin.tracker.initialize_features()
elif multiobj_mode == 'parallel':
self.mainWin.tracker = MultiObjectWrapper(tracker_class, params, fast_load=True)
else:
raise ValueError('Unknown multi object mode {}'.format(multiobj_mode))
self._pytracking_init(tracker_name, parameter_name)
self.progressSignal.emit(100)
elif self.mainWin.model_name == 'LWL':
self.progressSignal.emit(10)
# Import depending packages
from pytracking.evaluation.multi_object_wrapper import MultiObjectWrapper
tracker_name = 'lwl'
parameter_name = 'lwl_boxinit'
tracker_module_abspath = path.abspath(path.join('./pytracking', 'tracker', tracker_name))
if path.isdir(tracker_module_abspath):
tracker_module = importlib.import_module('pytracking.tracker.{}'.format(tracker_name))
tracker_class = tracker_module.get_tracker_class()
else:
tracker_class = None
param_module = importlib.import_module('pytracking.parameter.{}.{}'. \
format(tracker_name, parameter_name))
params = param_module.parameters()
params.tracker_name = tracker_name
params.param_name = parameter_name
multiobj_mode = getattr(params, 'multiobj_mode', getattr(tracker_class, 'multiobj_mode', 'default'))
if multiobj_mode == 'default':
self.mainWin.tracker = tracker_class(params)
if hasattr(self.mainWin.tracker, 'initialize_features'):
self.mainWin.tracker.initialize_features()
elif multiobj_mode == 'parallel':
self.mainWin.tracker = MultiObjectWrapper(tracker_class, params, fast_load=True)
else:
raise ValueError('Unknown multi object mode {}'.format(multiobj_mode))
self._pytracking_init(tracker_name, parameter_name)
self.progressSignal.emit(100)
elif self.mainWin.model_name == 'KeepTrack':
self.progressSignal.emit(10)
# Import depending packages
from pytracking.evaluation.multi_object_wrapper import MultiObjectWrapper
tracker_name = 'keep_track'
parameter_name = 'default_fast'
tracker_module_abspath = path.abspath(path.join('./pytracking', 'tracker', tracker_name))
if path.isdir(tracker_module_abspath):
tracker_module = importlib.import_module('pytracking.tracker.{}'.format(tracker_name))
tracker_class = tracker_module.get_tracker_class()
else:
tracker_class = None
param_module = importlib.import_module('pytracking.parameter.{}.{}'. \
format(tracker_name, parameter_name))
params = param_module.parameters()
params.tracker_name = tracker_name
params.param_name = parameter_name
multiobj_mode = getattr(params, 'multiobj_mode', getattr(tracker_class, 'multiobj_mode', 'default'))
if multiobj_mode == 'default':
self.mainWin.tracker = tracker_class(params)
if hasattr(self.mainWin.tracker, 'initialize_features'):
self.mainWin.tracker.initialize_features()
elif multiobj_mode == 'parallel':
self.mainWin.tracker = MultiObjectWrapper(tracker_class, params, fast_load=True)
else:
raise ValueError('Unknown multi object mode {}'.format(multiobj_mode))
self._pytracking_init(tracker_name, parameter_name)
self.progressSignal.emit(100)
else:
self.progressSignal.emit(-1)
Expand Down

0 comments on commit 1deb8b8

Please sign in to comment.