forked from toandaominh1997/EfficientDet.Pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
visualization.py
73 lines (63 loc) · 2.82 KB
/
visualization.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import importlib
from datetime import datetime
class TensorboardWriter():
def __init__(self, log_dir, enabled):
self.writer = None
self.selected_module = ""
if enabled:
log_dir = str(log_dir)
# Retrieve vizualization writer.
succeeded = False
for module in ["torch.utils.tensorboard", "tensorboardX"]:
try:
self.writer = importlib.import_module(module).SummaryWriter(log_dir)
succeeded = True
break
except ImportError:
succeeded = False
self.selected_module = module
if not succeeded:
message = "Warning: visualization (Tensorboard) is configured to use, but currently not installed on " \
"this machine. Please install TensorboardX with 'pip install tensorboardx', upgrade PyTorch to " \
"version >= 1.1 to use 'torch.utils.tensorboard' or turn off the option in the 'config.json' file."
print(message)
self.step = 0
self.mode = ''
self.tb_writer_ftns = {
'add_scalar', 'add_scalars', 'add_image', 'add_images', 'add_audio',
'add_text', 'add_histogram', 'add_pr_curve', 'add_embedding', 'add_graph'
}
self.tag_mode_exceptions = {'add_histogram', 'add_embedding'}
self.timer = datetime.now()
def set_step(self, step, mode='train'):
self.mode = mode
self.step = step
if step == 0:
self.timer = datetime.now()
else:
duration = datetime.now() - self.timer
self.add_scalar('steps_per_sec', 1 / duration.total_seconds())
self.timer = datetime.now()
def __getattr__(self, name):
"""
If visualization is configured to use:
return add_data() methods of tensorboard with additional information (step, tag) added.
Otherwise:
return a blank function handle that does nothing
"""
if name in self.tb_writer_ftns:
add_data = getattr(self.writer, name, None)
def wrapper(tag, data, *args, **kwargs):
if add_data is not None:
# add mode(train/valid) tag
if name not in self.tag_mode_exceptions:
tag = '{}/{}'.format(tag, self.mode)
add_data(tag, data, self.step, *args, **kwargs)
return wrapper
else:
# default action for returning methods defined in this class, set_step() for instance.
try:
attr = object.__getattr__(name)
except AttributeError:
raise AttributeError("type object '{}' has no attribute '{}'".format(self.selected_module, name))
return attr