-
Notifications
You must be signed in to change notification settings - Fork 17
/
fuseformer.py
402 lines (345 loc) · 15.5 KB
/
fuseformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
''' Fuseformer for Video Inpainting
'''
import numpy as np
import time
import math
from functools import reduce
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from core.spectral_norm import spectral_norm as _spectral_norm
class BaseNetwork(nn.Module):
def __init__(self):
super(BaseNetwork, self).__init__()
def print_network(self):
if isinstance(self, list):
self = self[0]
num_params = 0
for param in self.parameters():
num_params += param.numel()
print('Network [%s] was created. Total number of parameters: %.1f million. '
'To see the architecture, do print(network).' % (type(self).__name__, num_params / 1000000))
def init_weights(self, init_type='normal', gain=0.02):
'''
initialize network's weights
init_type: normal | xavier | kaiming | orthogonal
https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/9451e70673400885567d08a9e97ade2524c700d0/models/networks.py#L39
'''
def init_func(m):
classname = m.__class__.__name__
if classname.find('InstanceNorm2d') != -1:
if hasattr(m, 'weight') and m.weight is not None:
nn.init.constant_(m.weight.data, 1.0)
if hasattr(m, 'bias') and m.bias is not None:
nn.init.constant_(m.bias.data, 0.0)
elif hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
if init_type == 'normal':
nn.init.normal_(m.weight.data, 0.0, gain)
elif init_type == 'xavier':
nn.init.xavier_normal_(m.weight.data, gain=gain)
elif init_type == 'xavier_uniform':
nn.init.xavier_uniform_(m.weight.data, gain=1.0)
elif init_type == 'kaiming':
nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
elif init_type == 'orthogonal':
nn.init.orthogonal_(m.weight.data, gain=gain)
elif init_type == 'none': # uses pytorch's default init method
m.reset_parameters()
else:
raise NotImplementedError(
'initialization method [%s] is not implemented' % init_type)
if hasattr(m, 'bias') and m.bias is not None:
nn.init.constant_(m.bias.data, 0.0)
self.apply(init_func)
# propagate to children
for m in self.children():
if hasattr(m, 'init_weights'):
m.init_weights(init_type, gain)
class Encoder(nn.Module):
def __init__(self):
super(Encoder, self).__init__()
self.group = [1, 2, 4, 8, 1]
self.layers = nn.ModuleList([
nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(128, 256 , kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(256, 384, kernel_size=3, stride=1, padding=1, groups=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(640, 512, kernel_size=3, stride=1, padding=1, groups=2),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(768, 384, kernel_size=3, stride=1, padding=1, groups=4),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(640, 256, kernel_size=3, stride=1, padding=1, groups=8),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(512, 128, kernel_size=3, stride=1, padding=1, groups=1),
nn.LeakyReLU(0.2, inplace=True)
])
def forward(self, x):
bt, c, h, w = x.size()
h, w = h//4, w//4
out = x
for i, layer in enumerate(self.layers):
if i == 8:
x0 = out
if i > 8 and i % 2 == 0:
g = self.group[(i - 8) // 2]
x = x0.view(bt, g, -1, h, w)
o = out.view(bt, g, -1, h, w)
out = torch.cat([x, o], 2).view(bt, -1, h, w)
out = layer(out)
return out
class InpaintGenerator(BaseNetwork):
def __init__(self, init_weights=True):
super(InpaintGenerator, self).__init__()
channel = 256
hidden = 512
stack_num = 8
num_head = 4
kernel_size = (7, 7)
padding = (3, 3)
stride = (3, 3)
output_size = (60, 108)
blocks = []
dropout = 0.
t2t_params = {'kernel_size': kernel_size, 'stride': stride, 'padding': padding, 'output_size': output_size}
n_vecs = 1
for i, d in enumerate(kernel_size):
n_vecs *= int((output_size[i] + 2 * padding[i] - (d - 1) - 1) / stride[i] + 1)
for _ in range(stack_num):
blocks.append(TransformerBlock(hidden=hidden, num_head=num_head, dropout=dropout, n_vecs=n_vecs,
t2t_params=t2t_params))
self.transformer = nn.Sequential(*blocks)
self.ss = SoftSplit(channel // 2, hidden, kernel_size, stride, padding, dropout=dropout)
self.add_pos_emb = AddPosEmb(n_vecs, hidden)
self.sc = SoftComp(channel // 2, hidden, output_size, kernel_size, stride, padding)
self.encoder = Encoder()
# decoder: decode frames from features
self.decoder = nn.Sequential(
deconv(channel // 2, 128, kernel_size=3, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(0.2, inplace=True),
deconv(64, 64, kernel_size=3, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1)
)
if init_weights:
self.init_weights()
def forward(self, masked_frames):
# extracting features
b, t, c, h, w = masked_frames.size()
time0 = time.time()
enc_feat = self.encoder(masked_frames.view(b * t, c, h, w))
_, c, h, w = enc_feat.size()
trans_feat = self.ss(enc_feat, b)
trans_feat = self.add_pos_emb(trans_feat)
trans_feat = self.transformer(trans_feat)
trans_feat = self.sc(trans_feat, t)
enc_feat = enc_feat + trans_feat
output = self.decoder(enc_feat)
output = torch.tanh(output)
return output
class deconv(nn.Module):
def __init__(self, input_channel, output_channel, kernel_size=3, padding=0):
super().__init__()
self.conv = nn.Conv2d(input_channel, output_channel,
kernel_size=kernel_size, stride=1, padding=padding)
def forward(self, x):
x = F.interpolate(x, scale_factor=2, mode='bilinear',
align_corners=True)
return self.conv(x)
# #############################################################################
# ############################# Transformer ##################################
# #############################################################################
class Attention(nn.Module):
"""
Compute 'Scaled Dot Product Attention
"""
def __init__(self, p=0.1):
super(Attention, self).__init__()
self.dropout = nn.Dropout(p=p)
def forward(self, query, key, value, m=None):
scores = torch.matmul(query, key.transpose(-2, -1)
) / math.sqrt(query.size(-1))
if m is not None:
scores.masked_fill_(m, -1e9)
p_attn = F.softmax(scores, dim=-1)
p_attn = self.dropout(p_attn)
p_val = torch.matmul(p_attn, value)
return p_val, p_attn
class AddPosEmb(nn.Module):
def __init__(self, n, c):
super(AddPosEmb, self).__init__()
self.pos_emb = nn.Parameter(torch.zeros(1, 1, n, c).float().normal_(mean=0, std=0.02), requires_grad=True)
self.num_vecs = n
def forward(self, x):
b, n, c = x.size()
x = x.view(b, -1, self.num_vecs, c)
x = x + self.pos_emb
x = x.view(b, n, c)
return x
class SoftSplit(nn.Module):
def __init__(self, channel, hidden, kernel_size, stride, padding, dropout=0.1):
super(SoftSplit, self).__init__()
self.kernel_size = kernel_size
self.t2t = nn.Unfold(kernel_size=kernel_size, stride=stride, padding=padding)
c_in = reduce((lambda x, y: x * y), kernel_size) * channel
self.embedding = nn.Linear(c_in, hidden)
self.dropout = nn.Dropout(p=dropout)
def forward(self, x, b):
feat = self.t2t(x)
feat = feat.permute(0, 2, 1)
feat = self.embedding(feat)
feat = feat.view(b, -1, feat.size(2))
feat = self.dropout(feat)
return feat
class SoftComp(nn.Module):
def __init__(self, channel, hidden, output_size, kernel_size, stride, padding):
super(SoftComp, self).__init__()
self.relu = nn.LeakyReLU(0.2, inplace=True)
c_out = reduce((lambda x, y: x * y), kernel_size) * channel
self.embedding = nn.Linear(hidden, c_out)
self.t2t = torch.nn.Fold(output_size=output_size, kernel_size=kernel_size, stride=stride, padding=padding)
h, w = output_size
self.bias = nn.Parameter(torch.zeros((channel, h, w), dtype=torch.float32), requires_grad=True)
def forward(self, x, t):
feat = self.embedding(x)
b, n, c = feat.size()
feat = feat.view(b * t, -1, c).permute(0, 2, 1)
feat = self.t2t(feat) + self.bias[None]
return feat
class MultiHeadedAttention(nn.Module):
"""
Take in model size and number of heads.
"""
def __init__(self, d_model, head, p=0.1):
super().__init__()
self.query_embedding = nn.Linear(d_model, d_model)
self.value_embedding = nn.Linear(d_model, d_model)
self.key_embedding = nn.Linear(d_model, d_model)
self.output_linear = nn.Linear(d_model, d_model)
self.attention = Attention(p=p)
self.head = head
def forward(self, x):
b, n, c = x.size()
c_h = c // self.head
key = self.key_embedding(x)
key = key.view(b, n, self.head, c_h).permute(0, 2, 1, 3)
query = self.query_embedding(x)
query = query.view(b, n, self.head, c_h).permute(0, 2, 1, 3)
value = self.value_embedding(x)
value = value.view(b, n, self.head, c_h).permute(0, 2, 1, 3)
att, _ = self.attention(query, key, value)
att = att.permute(0, 2, 1, 3).contiguous().view(b, n, c)
output = self.output_linear(att)
return output
class FeedForward(nn.Module):
def __init__(self, d_model, p=0.1):
super(FeedForward, self).__init__()
# We set d_ff as a default to 2048
self.conv = nn.Sequential(
nn.Linear(d_model, d_model * 4),
nn.ReLU(inplace=True),
nn.Dropout(p=p),
nn.Linear(d_model * 4, d_model),
nn.Dropout(p=p))
def forward(self, x):
x = self.conv(x)
return x
class FusionFeedForward(nn.Module):
def __init__(self, d_model, p=0.1, n_vecs=None, t2t_params=None):
super(FusionFeedForward, self).__init__()
# We set d_ff as a default to 1960
hd = 1960
self.conv1 = nn.Sequential(
nn.Linear(d_model, hd))
self.conv2 = nn.Sequential(
nn.ReLU(inplace=True),
nn.Dropout(p=p),
nn.Linear(hd, d_model),
nn.Dropout(p=p))
assert t2t_params is not None and n_vecs is not None
tp = t2t_params.copy()
self.fold = nn.Fold(**tp)
del tp['output_size']
self.unfold = nn.Unfold(**tp)
self.n_vecs = n_vecs
def forward(self, x):
x = self.conv1(x)
b, n, c = x.size()
normalizer = x.new_ones(b, n, 49).view(-1, self.n_vecs, 49).permute(0, 2, 1)
x = self.unfold(self.fold(x.view(-1, self.n_vecs, c).permute(0, 2, 1)) / self.fold(normalizer)).permute(0, 2,
1).contiguous().view(
b, n, c)
x = self.conv2(x)
return x
class TransformerBlock(nn.Module):
"""
Transformer = MultiHead_Attention + Feed_Forward with sublayer connection
"""
def __init__(self, hidden=128, num_head=4, dropout=0.1, n_vecs=None, t2t_params=None):
super().__init__()
self.attention = MultiHeadedAttention(d_model=hidden, head=num_head, p=dropout)
self.ffn = FusionFeedForward(hidden, p=dropout, n_vecs=n_vecs, t2t_params=t2t_params)
self.norm1 = nn.LayerNorm(hidden)
self.norm2 = nn.LayerNorm(hidden)
self.dropout = nn.Dropout(p=dropout)
def forward(self, input):
x = self.norm1(input)
x = input + self.dropout(self.attention(x))
y = self.norm2(x)
x = x + self.ffn(y)
return x
# ######################################################################
# ######################################################################
class Discriminator(BaseNetwork):
def __init__(self, in_channels=3, use_sigmoid=False, use_spectral_norm=True, init_weights=True):
super(Discriminator, self).__init__()
self.use_sigmoid = use_sigmoid
nf = 32
self.conv = nn.Sequential(
spectral_norm(
nn.Conv3d(in_channels=in_channels, out_channels=nf * 1, kernel_size=(3, 5, 5), stride=(1, 2, 2),
padding=1, bias=not use_spectral_norm), use_spectral_norm),
# nn.InstanceNorm2d(64, track_running_stats=False),
nn.LeakyReLU(0.2, inplace=True),
spectral_norm(nn.Conv3d(nf * 1, nf * 2, kernel_size=(3, 5, 5), stride=(1, 2, 2),
padding=(1, 2, 2), bias=not use_spectral_norm), use_spectral_norm),
# nn.InstanceNorm2d(128, track_running_stats=False),
nn.LeakyReLU(0.2, inplace=True),
spectral_norm(nn.Conv3d(nf * 2, nf * 4, kernel_size=(3, 5, 5), stride=(1, 2, 2),
padding=(1, 2, 2), bias=not use_spectral_norm), use_spectral_norm),
# nn.InstanceNorm2d(256, track_running_stats=False),
nn.LeakyReLU(0.2, inplace=True),
spectral_norm(nn.Conv3d(nf * 4, nf * 4, kernel_size=(3, 5, 5), stride=(1, 2, 2),
padding=(1, 2, 2), bias=not use_spectral_norm), use_spectral_norm),
# nn.InstanceNorm2d(256, track_running_stats=False),
nn.LeakyReLU(0.2, inplace=True),
spectral_norm(nn.Conv3d(nf * 4, nf * 4, kernel_size=(3, 5, 5), stride=(1, 2, 2),
padding=(1, 2, 2), bias=not use_spectral_norm), use_spectral_norm),
# nn.InstanceNorm2d(256, track_running_stats=False),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv3d(nf * 4, nf * 4, kernel_size=(3, 5, 5),
stride=(1, 2, 2), padding=(1, 2, 2))
)
if init_weights:
self.init_weights()
def forward(self, xs):
# T, C, H, W = xs.shape
xs_t = torch.transpose(xs, 0, 1)
xs_t = xs_t.unsqueeze(0) # B, C, T, H, W
feat = self.conv(xs_t)
if self.use_sigmoid:
feat = torch.sigmoid(feat)
out = torch.transpose(feat, 1, 2) # B, T, C, H, W
return out
def spectral_norm(module, mode=True):
if mode:
return _spectral_norm(module)
return module