Skip to content

Commit

Permalink
add i3d model
Browse files Browse the repository at this point in the history
  • Loading branch information
ruiliu-ai committed Nov 20, 2021
1 parent 0686bec commit 3f47e17
Showing 1 changed file with 335 additions and 0 deletions.
335 changes: 335 additions & 0 deletions model/i3d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,335 @@
# This code is from https://github.com/piergiaj/pytorch-i3d/blob/master/pytorch_i3d.py
# I only fix flake8 errors and do some cleaning here
import torch
import torch.nn as nn
import torch.nn.functional as F


class MaxPool3dSamePadding(nn.MaxPool3d):

def compute_pad(self, dim, s):
if s % self.stride[dim] == 0:
return max(self.kernel_size[dim] - self.stride[dim], 0)
else:
return max(self.kernel_size[dim] - (s % self.stride[dim]), 0)

def forward(self, x):
# compute 'same' padding
(batch, channel, t, h, w) = x.size()
pad_t = self.compute_pad(0, t)
pad_h = self.compute_pad(1, h)
pad_w = self.compute_pad(2, w)

pad_t_f = pad_t // 2
pad_t_b = pad_t - pad_t_f
pad_h_f = pad_h // 2
pad_h_b = pad_h - pad_h_f
pad_w_f = pad_w // 2
pad_w_b = pad_w - pad_w_f

pad = (pad_w_f, pad_w_b, pad_h_f, pad_h_b, pad_t_f, pad_t_b)
x = F.pad(x, pad)
return super(MaxPool3dSamePadding, self).forward(x)


class Unit3D(nn.Module):

def __init__(self, in_channels,
output_channels,
kernel_shape=(1, 1, 1),
stride=(1, 1, 1),
padding=0,
activation_fn=F.relu,
use_batch_norm=True,
use_bias=False,
name='unit_3d'):
"""Initializes Unit3D module."""
super(Unit3D, self).__init__()

self._output_channels = output_channels
self._kernel_shape = kernel_shape
self._stride = stride
self._use_batch_norm = use_batch_norm
self._activation_fn = activation_fn
self._use_bias = use_bias
self.name = name
self.padding = padding

self.conv3d = nn.Conv3d(in_channels=in_channels,
out_channels=self._output_channels,
kernel_size=self._kernel_shape,
stride=self._stride,
padding=0, # we always want padding to be 0 here. We will
# dynamically pad based on input size in forward function
bias=self._use_bias)

if self._use_batch_norm:
self.bn = nn.BatchNorm3d(self._output_channels, eps=0.001, momentum=0.01)

def compute_pad(self, dim, s):
if s % self._stride[dim] == 0:
return max(self._kernel_shape[dim] - self._stride[dim], 0)
else:
return max(self._kernel_shape[dim] - (s % self._stride[dim]), 0)

def forward(self, x):
# compute 'same' padding
(batch, channel, t, h, w) = x.size()
pad_t = self.compute_pad(0, t)
pad_h = self.compute_pad(1, h)
pad_w = self.compute_pad(2, w)

pad_t_f = pad_t // 2
pad_t_b = pad_t - pad_t_f
pad_h_f = pad_h // 2
pad_h_b = pad_h - pad_h_f
pad_w_f = pad_w // 2
pad_w_b = pad_w - pad_w_f

pad = (pad_w_f, pad_w_b, pad_h_f, pad_h_b, pad_t_f, pad_t_b)
x = F.pad(x, pad)

x = self.conv3d(x)
if self._use_batch_norm:
x = self.bn(x)
if self._activation_fn is not None:
x = self._activation_fn(x)
return x


class InceptionModule(nn.Module):
def __init__(self, in_channels, out_channels, name):
super(InceptionModule, self).__init__()

self.b0 = Unit3D(in_channels=in_channels, output_channels=out_channels[0], kernel_shape=[1, 1, 1], padding=0,
name=name + '/Branch_0/Conv3d_0a_1x1')
self.b1a = Unit3D(in_channels=in_channels, output_channels=out_channels[1], kernel_shape=[1, 1, 1], padding=0,
name=name + '/Branch_1/Conv3d_0a_1x1')
self.b1b = Unit3D(in_channels=out_channels[1], output_channels=out_channels[2], kernel_shape=[3, 3, 3],
name=name + '/Branch_1/Conv3d_0b_3x3')
self.b2a = Unit3D(in_channels=in_channels, output_channels=out_channels[3], kernel_shape=[1, 1, 1], padding=0,
name=name + '/Branch_2/Conv3d_0a_1x1')
self.b2b = Unit3D(in_channels=out_channels[3], output_channels=out_channels[4], kernel_shape=[3, 3, 3],
name=name + '/Branch_2/Conv3d_0b_3x3')
self.b3a = MaxPool3dSamePadding(kernel_size=[3, 3, 3],
stride=(1, 1, 1), padding=0)
self.b3b = Unit3D(in_channels=in_channels, output_channels=out_channels[5], kernel_shape=[1, 1, 1], padding=0,
name=name + '/Branch_3/Conv3d_0b_1x1')
self.name = name

def forward(self, x):
b0 = self.b0(x)
b1 = self.b1b(self.b1a(x))
b2 = self.b2b(self.b2a(x))
b3 = self.b3b(self.b3a(x))
return torch.cat([b0, b1, b2, b3], dim=1)


class InceptionI3d(nn.Module):
"""Inception-v1 I3D architecture.
The model is introduced in:
Quo Vadis, Action Recognition? A New Model and the Kinetics Dataset
Joao Carreira, Andrew Zisserman
https://arxiv.org/pdf/1705.07750v1.pdf.
See also the Inception architecture, introduced in:
Going deeper with convolutions
Christian Szegedy, Wei Liu, Yangqing Jia, Pierre Sermanet, Scott Reed,
Dragomir Anguelov, Dumitru Erhan, Vincent Vanhoucke, Andrew Rabinovich.
http://arxiv.org/pdf/1409.4842v1.pdf.
"""

# Endpoints of the model in order. During construction, all the endpoints up
# to a designated `final_endpoint` are returned in a dictionary as the
# second return value.
VALID_ENDPOINTS = (
'Conv3d_1a_7x7',
'MaxPool3d_2a_3x3',
'Conv3d_2b_1x1',
'Conv3d_2c_3x3',
'MaxPool3d_3a_3x3',
'Mixed_3b',
'Mixed_3c',
'MaxPool3d_4a_3x3',
'Mixed_4b',
'Mixed_4c',
'Mixed_4d',
'Mixed_4e',
'Mixed_4f',
'MaxPool3d_5a_2x2',
'Mixed_5b',
'Mixed_5c',
'Logits',
'Predictions',
)

def __init__(self, num_classes=400, spatial_squeeze=True,
final_endpoint='Logits', name='inception_i3d', in_channels=3, dropout_keep_prob=0.5):
"""Initializes I3D model instance.
Args:
num_classes: The number of outputs in the logit layer (default 400, which
matches the Kinetics dataset).
spatial_squeeze: Whether to squeeze the spatial dimensions for the logits
before returning (default True).
final_endpoint: The model contains many possible endpoints.
`final_endpoint` specifies the last endpoint for the model to be built
up to. In addition to the output at `final_endpoint`, all the outputs
at endpoints up to `final_endpoint` will also be returned, in a
dictionary. `final_endpoint` must be one of
InceptionI3d.VALID_ENDPOINTS (default 'Logits').
name: A string (optional). The name of this module.
Raises:
ValueError: if `final_endpoint` is not recognized.
"""

if final_endpoint not in self.VALID_ENDPOINTS:
raise ValueError('Unknown final endpoint %s' % final_endpoint)

super(InceptionI3d, self).__init__()
self._num_classes = num_classes
self._spatial_squeeze = spatial_squeeze
self._final_endpoint = final_endpoint
self.logits = None

if self._final_endpoint not in self.VALID_ENDPOINTS:
raise ValueError('Unknown final endpoint %s' % self._final_endpoint)

self.end_points = {}
end_point = 'Conv3d_1a_7x7'
self.end_points[end_point] = Unit3D(in_channels=in_channels, output_channels=64, kernel_shape=[7, 7, 7],
stride=(2, 2, 2), padding=(3, 3, 3), name=name + end_point)
if self._final_endpoint == end_point:
return

end_point = 'MaxPool3d_2a_3x3'
self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[1, 3, 3], stride=(1, 2, 2),
padding=0)
if self._final_endpoint == end_point:
return

end_point = 'Conv3d_2b_1x1'
self.end_points[end_point] = Unit3D(in_channels=64, output_channels=64, kernel_shape=[1, 1, 1], padding=0,
name=name + end_point)
if self._final_endpoint == end_point:
return

end_point = 'Conv3d_2c_3x3'
self.end_points[end_point] = Unit3D(in_channels=64, output_channels=192, kernel_shape=[3, 3, 3], padding=1,
name=name + end_point)
if self._final_endpoint == end_point:
return

end_point = 'MaxPool3d_3a_3x3'
self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[1, 3, 3], stride=(1, 2, 2),
padding=0)
if self._final_endpoint == end_point:
return

end_point = 'Mixed_3b'
self.end_points[end_point] = InceptionModule(192, [64, 96, 128, 16, 32, 32], name + end_point)
if self._final_endpoint == end_point:
return

end_point = 'Mixed_3c'
self.end_points[end_point] = InceptionModule(256, [128, 128, 192, 32, 96, 64], name + end_point)
if self._final_endpoint == end_point:
return

end_point = 'MaxPool3d_4a_3x3'
self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[3, 3, 3], stride=(2, 2, 2),
padding=0)
if self._final_endpoint == end_point:
return

end_point = 'Mixed_4b'
self.end_points[end_point] = InceptionModule(128 + 192 + 96 + 64, [192, 96, 208, 16, 48, 64], name + end_point)
if self._final_endpoint == end_point:
return

end_point = 'Mixed_4c'
self.end_points[end_point] = InceptionModule(192 + 208 + 48 + 64, [160, 112, 224, 24, 64, 64], name + end_point)
if self._final_endpoint == end_point:
return

end_point = 'Mixed_4d'
self.end_points[end_point] = InceptionModule(160 + 224 + 64 + 64, [128, 128, 256, 24, 64, 64], name + end_point)
if self._final_endpoint == end_point:
return

end_point = 'Mixed_4e'
self.end_points[end_point] = InceptionModule(128 + 256 + 64 + 64, [112, 144, 288, 32, 64, 64], name + end_point)
if self._final_endpoint == end_point:
return

end_point = 'Mixed_4f'
self.end_points[end_point] = InceptionModule(
112 + 288 + 64 + 64, [256, 160, 320, 32, 128, 128], name + end_point)
if self._final_endpoint == end_point:
return

end_point = 'MaxPool3d_5a_2x2'
self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[2, 2, 2], stride=(2, 2, 2),
padding=0)
if self._final_endpoint == end_point:
return

end_point = 'Mixed_5b'
self.end_points[end_point] = InceptionModule(
256 + 320 + 128 + 128, [256, 160, 320, 32, 128, 128], name + end_point)
if self._final_endpoint == end_point:
return

end_point = 'Mixed_5c'
self.end_points[end_point] = InceptionModule(
256 + 320 + 128 + 128, [384, 192, 384, 48, 128, 128], name + end_point)
if self._final_endpoint == end_point:
return

end_point = 'Logits'
self.avg_pool = nn.AvgPool3d(kernel_size=[2, 7, 7],
stride=(1, 1, 1))
self.dropout = nn.Dropout(dropout_keep_prob)
self.logits = Unit3D(in_channels=384 + 384 + 128 + 128, output_channels=self._num_classes,
kernel_shape=[1, 1, 1],
padding=0,
activation_fn=None,
use_batch_norm=False,
use_bias=True,
name='logits')

self.build()

def replace_logits(self, num_classes):
self._num_classes = num_classes
self.logits = Unit3D(in_channels=384 + 384 + 128 + 128, output_channels=self._num_classes,
kernel_shape=[1, 1, 1],
padding=0,
activation_fn=None,
use_batch_norm=False,
use_bias=True,
name='logits')

def build(self):
for k in self.end_points.keys():
self.add_module(k, self.end_points[k])

def forward(self, x):
for end_point in self.VALID_ENDPOINTS:
if end_point in self.end_points:
x = self._modules[end_point](x) # use _modules to work with dataparallel

x = self.logits(self.dropout(self.avg_pool(x)))
if self._spatial_squeeze:
logits = x.squeeze(3).squeeze(3)
# logits is batch X time X classes, which is what we want to work with
return logits

def extract_features(self, x, target_endpoint='Logits'):
for end_point in self.VALID_ENDPOINTS:
if end_point in self.end_points:
x = self._modules[end_point](x)
if end_point == target_endpoint:
break
if target_endpoint == 'Logits':
return x.mean(4).mean(3).mean(2)
else:
return x

0 comments on commit 3f47e17

Please sign in to comment.