Skip to content

Commit

Permalink
Update networks.py
Browse files Browse the repository at this point in the history
  • Loading branch information
znxlwm authored Aug 27, 2019
1 parent 1ed59f2 commit f7a4b59
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,9 +162,9 @@ def __init__(self, num_features, eps=1e-5):
self.rho.data.fill_(0.9)

def forward(self, input, gamma, beta):
in_mean, in_var = torch.mean(torch.mean(input, dim=2, keepdim=True), dim=3, keepdim=True), torch.var(torch.var(input, dim=2, keepdim=True), dim=3, keepdim=True)
in_mean, in_var = torch.mean(input, dim=[2, 3], keepdim=True), torch.var(input, dim=[2, 3], keepdim=True)
out_in = (input - in_mean) / torch.sqrt(in_var + self.eps)
ln_mean, ln_var = torch.mean(torch.mean(torch.mean(input, dim=1, keepdim=True), dim=2, keepdim=True), dim=3, keepdim=True), torch.var(torch.var(torch.var(input, dim=1, keepdim=True), dim=2, keepdim=True), dim=3, keepdim=True)
ln_mean, ln_var = torch.mean(input, dim=[1, 2, 3], keepdim=True), torch.var(input, dim=[1, 2, 3], keepdim=True)
out_ln = (input - ln_mean) / torch.sqrt(ln_var + self.eps)
out = self.rho.expand(input.shape[0], -1, -1, -1) * out_in + (1-self.rho.expand(input.shape[0], -1, -1, -1)) * out_ln
out = out * gamma.unsqueeze(2).unsqueeze(3) + beta.unsqueeze(2).unsqueeze(3)
Expand All @@ -184,9 +184,9 @@ def __init__(self, num_features, eps=1e-5):
self.beta.data.fill_(0.0)

def forward(self, input):
in_mean, in_var = torch.mean(torch.mean(input, dim=2, keepdim=True), dim=3, keepdim=True), torch.var(torch.var(input, dim=2, keepdim=True), dim=3, keepdim=True)
in_mean, in_var = torch.mean(input, dim=[2, 3], keepdim=True), torch.var(input, dim=[2, 3], keepdim=True)
out_in = (input - in_mean) / torch.sqrt(in_var + self.eps)
ln_mean, ln_var = torch.mean(torch.mean(torch.mean(input, dim=1, keepdim=True), dim=2, keepdim=True), dim=3, keepdim=True), torch.var(torch.var(torch.var(input, dim=1, keepdim=True), dim=2, keepdim=True), dim=3, keepdim=True)
ln_mean, ln_var = torch.mean(input, dim=[1, 2, 3], keepdim=True), torch.var(input, dim=[1, 2, 3], keepdim=True)
out_ln = (input - ln_mean) / torch.sqrt(ln_var + self.eps)
out = self.rho.expand(input.shape[0], -1, -1, -1) * out_in + (1-self.rho.expand(input.shape[0], -1, -1, -1)) * out_ln
out = out * self.gamma.expand(input.shape[0], -1, -1, -1) + self.beta.expand(input.shape[0], -1, -1, -1)
Expand Down

0 comments on commit f7a4b59

Please sign in to comment.