Skip to content

Commit

Permalink
make sure final activation of siren layer defaults to just a linear l…
Browse files Browse the repository at this point in the history
…ayer if final_activation is not supplied
  • Loading branch information
lucidrains committed Jan 17, 2021
1 parent f9ffdc4 commit 3c602e7
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 3 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'siren-pytorch',
packages = find_packages(),
version = '0.0.5',
version = '0.0.6',
license='MIT',
description = 'Implicit Neural Representations with Periodic Activation Functions',
author = 'Phil Wang',
Expand Down
11 changes: 9 additions & 2 deletions siren_pytorch/siren_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@
from torch import nn
import torch.nn.functional as F

# helpers

def exists(val):
return val is not None

# sin activation

class Sine(nn.Module):
Expand Down Expand Up @@ -34,7 +39,7 @@ def init_(self, weight, bias, c, w0):
w_std = (1 / dim) if self.is_first else (math.sqrt(c / dim) / w0)
weight.uniform_(-w_std, w_std)

if bias is not None:
if exists(bias):
bias.uniform_(-w_std, w_std)

def forward(self, x):
Expand All @@ -45,7 +50,7 @@ def forward(self, x):
# siren network

class SirenNet(nn.Module):
def __init__(self, dim_in, dim_hidden, dim_out, num_layers, w0 = 1., w0_initial = 30., use_bias = True, final_activation = None):
def __init__(self, dim_in, dim_hidden, dim_out, num_layers, w0 = 30., w0_initial = 30., use_bias = True, final_activation = None):
super().__init__()
layers = []
for ind in range(num_layers):
Expand All @@ -62,6 +67,8 @@ def __init__(self, dim_in, dim_hidden, dim_out, num_layers, w0 = 1., w0_initial
))

self.net = nn.Sequential(*layers)

final_activation = nn.Identity() if not exists(final_activation) else final_activation
self.last_layer = Siren(dim_in = dim_hidden, dim_out = dim_out, w0 = w0, use_bias = use_bias, activation = final_activation)

def forward(self, x):
Expand Down

0 comments on commit 3c602e7

Please sign in to comment.