forked from PaddlePaddle/PaddleSeg
-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
125 lines (107 loc) · 4.88 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import filelock
import os
import tempfile
import numpy as np
import random
from urllib.parse import urlparse, unquote
import paddle
from paddleseg.utils import logger, seg_env
from paddleseg.utils.download import download_file_and_uncompress
@contextlib.contextmanager
def generate_tempdir(directory: str = None, **kwargs):
'''Generate a temporary directory'''
directory = seg_env.TMP_HOME if not directory else directory
with tempfile.TemporaryDirectory(dir=directory, **kwargs) as _dir:
yield _dir
def load_entire_model(model, pretrained):
if pretrained is not None:
load_pretrained_model(model, pretrained)
else:
logger.warning('Not all pretrained params of {} are loaded, ' \
'training from scratch or a pretrained backbone.'.format(model.__class__.__name__))
def load_pretrained_model(model, pretrained_model):
if pretrained_model is not None:
logger.info('Loading pretrained model from {}'.format(pretrained_model))
# download pretrained model from url
if urlparse(pretrained_model).netloc:
pretrained_model = unquote(pretrained_model)
savename = pretrained_model.split('/')[-1]
if not savename.endswith(('tgz', 'tar.gz', 'tar', 'zip')):
savename = pretrained_model.split('/')[-2]
else:
savename = savename.split('.')[0]
with generate_tempdir() as _dir:
with filelock.FileLock(
os.path.join(seg_env.TMP_HOME, savename)):
pretrained_model = download_file_and_uncompress(
pretrained_model,
savepath=_dir,
extrapath=seg_env.PRETRAINED_MODEL_HOME,
extraname=savename)
pretrained_model = os.path.join(pretrained_model,
'model.pdparams')
if os.path.exists(pretrained_model):
para_state_dict = paddle.load(pretrained_model)
model_state_dict = model.state_dict()
keys = model_state_dict.keys()
num_params_loaded = 0
for k in keys:
if k not in para_state_dict:
logger.warning("{} is not in pretrained model".format(k))
elif list(para_state_dict[k].shape) != list(
model_state_dict[k].shape):
logger.warning(
"[SKIP] Shape of pretrained params {} doesn't match.(Pretrained: {}, Actual: {})"
.format(k, para_state_dict[k].shape,
model_state_dict[k].shape))
else:
model_state_dict[k] = para_state_dict[k]
num_params_loaded += 1
model.set_dict(model_state_dict)
logger.info("There are {}/{} variables loaded into {}.".format(
num_params_loaded, len(model_state_dict),
model.__class__.__name__))
else:
raise ValueError(
'The pretrained model directory is not Found: {}'.format(
pretrained_model))
else:
logger.info(
'No pretrained model to load, {} will be trained from scratch.'.
format(model.__class__.__name__))
def resume(model, optimizer, resume_model):
if resume_model is not None:
logger.info('Resume model from {}'.format(resume_model))
if os.path.exists(resume_model):
resume_model = os.path.normpath(resume_model)
ckpt_path = os.path.join(resume_model, 'model.pdparams')
para_state_dict = paddle.load(ckpt_path)
ckpt_path = os.path.join(resume_model, 'model.pdopt')
opti_state_dict = paddle.load(ckpt_path)
model.set_state_dict(para_state_dict)
optimizer.set_state_dict(opti_state_dict)
iter = resume_model.split('_')[-1]
iter = int(iter)
return iter
else:
raise ValueError(
'Directory of the model needed to resume is not Found: {}'.
format(resume_model))
else:
logger.info('No model needed to resume.')
def worker_init_fn(worker_id):
np.random.seed(random.randint(0, 100000))