-
Notifications
You must be signed in to change notification settings - Fork 21
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add LSTMs for 3, and 4 neighbors, and various training utilities
- Loading branch information
Showing
6 changed files
with
435 additions
and
226 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,211 @@ | ||
require 'nn' | ||
local utils = require 'misc.utils' | ||
local net_utils = require 'misc.net_utils' | ||
local LSTM = require 'lstm' | ||
local mvn = require 'misc.mvn' | ||
|
||
------------------------------------------------------------------------------- | ||
-- Pixel Model Mixture of Gaussian Density Criterion | ||
------------------------------------------------------------------------------- | ||
|
||
local crit, parent = torch.class('nn.PixelModelCriterion', 'nn.Criterion') | ||
function crit:__init(pixel_size, num_mixtures) | ||
parent.__init(self) | ||
self.pixel_size = pixel_size | ||
self.num_mixtures = num_mixtures | ||
if self.pixel_size == 3 then | ||
self.output_size = self.num_mixtures * (3+3+3+1) | ||
else | ||
self.output_size = self.num_mixtures * (1+1+0+1) | ||
end | ||
if pixel_size == 3 then self.var_mm = nn.MM() end | ||
self.w_softmax = nn.SoftMax() | ||
self.var_exp = nn.Exp() | ||
end | ||
|
||
--[[ | ||
-- this is an optimized version of the gmm loss, though looks ugly though | ||
inputs: | ||
input is a Tensor of size DxNx(G), encodings of the gmms | ||
target is a Tensor of size DxNx(M+1). | ||
where, D is the sequence length, N is the batch size, M is the pixel channels, | ||
Criterion: | ||
Mixture of Gaussian, Log probability. | ||
The way we infer the target | ||
in this criterion is as follows: | ||
- at first time step the output is ignored (loss = 0). It's the image tick | ||
- the label sequence "seq" is shifted by one to produce targets | ||
- at last time step the output is always the special END token (last dimension) | ||
The criterion must be able to accomodate variably-sized sequences by making sure | ||
the gradients are properly set to zeros where appropriate. | ||
--]] | ||
function crit:updateOutput(input, target) | ||
local D_ = input:size(1) | ||
local N_ = input:size(2) | ||
local D,N,Mp1= target:size(1), target:size(2), target:size(3) | ||
local ps = Mp1 -- pixel size | ||
assert(D == D_, 'input Tensor should have the same sequence length as the target') | ||
assert(N == N_, 'input Tensor should have the same batch size as the target') | ||
assert(ps == self.pixel_size, 'input dimensions of pixel do not match') | ||
local nm = self.num_mixtures | ||
|
||
local input_x, target_x | ||
if ps == 3 and input:type() == 'torch.CudaTensor' then | ||
input_x = input:float() | ||
target_x = target:float() | ||
else | ||
input_x = input | ||
target_x = target | ||
end | ||
-- decode the gmms first | ||
-- mean undertake no changes | ||
local g_mean_input = input_x:narrow(3,1,nm*ps):clone() | ||
local g_mean = g_mean_input:view(D, N, nm, ps) | ||
local g_mean_diff = torch.repeatTensor(target_x:view(D, N, 1, ps), 1,1,nm,1):add(-1, g_mean) | ||
-- we use 6 numbers to denote the cholesky depositions | ||
local g_var_input = input_x:narrow(3, nm*ps+1, nm*ps):clone() | ||
g_var_input = g_var_input:view(-1, ps) | ||
local g_var = self.var_exp:forward(g_var_input) | ||
|
||
local g_cov_input | ||
local g_clk | ||
local g_clk_T | ||
local g_sigma | ||
p = 2 | ||
if ps == 3 then | ||
g_cov_input = input_x:narrow(3, p*nm*ps+1, nm*ps):clone() | ||
g_cov_input = g_cov_input:view(-1, 3) | ||
p = p + 1 | ||
g_clk = torch.Tensor(D*N*nm, 3, 3):fill(0):type(g_var_input:type()) | ||
g_clk_T = torch.Tensor(D*N*nm, 3, 3):fill(0):type(g_var_input:type()) | ||
g_clk[{{}, 1, 1}] = g_var[{{}, 1}] | ||
g_clk[{{}, 2, 2}] = g_var[{{}, 2}] | ||
g_clk[{{}, 3, 3}] = g_var[{{}, 3}] | ||
g_clk[{{}, 2, 1}] = g_cov_input[{{}, 1}] | ||
g_clk[{{}, 3, 1}] = g_cov_input[{{}, 2}] | ||
g_clk[{{}, 3, 2}] = g_cov_input[{{}, 3}] | ||
g_clk_T[{{}, 1, 1}] = g_var[{{}, 1}] | ||
g_clk_T[{{}, 2, 2}] = g_var[{{}, 2}] | ||
g_clk_T[{{}, 3, 3}] = g_var[{{}, 3}] | ||
g_clk_T[{{}, 1, 2}] = g_cov_input[{{}, 1}] | ||
g_clk_T[{{}, 1, 3}] = g_cov_input[{{}, 2}] | ||
g_clk_T[{{}, 2, 3}] = g_cov_input[{{}, 3}] | ||
g_sigma = self.var_mm:forward({g_clk, g_clk_T}) | ||
g_clk = g_clk:view(D, N, nm, ps, ps) | ||
g_clk_T = g_clk_T:view(D, N, nm, ps, ps) | ||
else | ||
g_clk = g_var | ||
g_clk = g_clk:view(D, N, nm, ps, ps) | ||
g_clk_T = g_var | ||
g_clk_T = g_clk_T:view(D, N, nm, ps, ps) | ||
g_sigma = torch.cmul(g_clk, g_clk_T) | ||
end | ||
g_sigma = g_sigma:view(D, N, nm, ps, ps) | ||
local g_sigma_inv = torch.Tensor(g_sigma:size()):type(g_sigma:type()) | ||
-- weights coeffs is taken care of at final loss, for computation efficiency and stability | ||
local g_w_input = input_x:narrow(3,p*nm*ps+1, nm):clone() | ||
local g_w = self.w_softmax:forward(g_w_input:view(-1, nm)) | ||
g_w = g_w:view(D, N, nm) | ||
|
||
-- do the loss the gradients | ||
local loss1 = 0 -- loss of pixels, Mixture of Gaussians | ||
local grad_g_mean = torch.Tensor(g_mean:size()):type(g_mean:type()) | ||
local grad_g_sigma = torch.Tensor(g_sigma:size()):type(g_sigma:type()) | ||
local grad_g_w = torch.Tensor(g_w:size()):type(g_w:type()) | ||
|
||
if ps == 1 then | ||
local g_rpb = mvn.bnormpdf(g_mean_diff, g_clk) | ||
g_rpb = g_rpb:cmul(g_w) | ||
local pdf = torch.sum(g_rpb, 3) | ||
loss1 = - torch.sum(torch.log(pdf)) | ||
g_sigma_inv:fill(1):cdiv(g_sigma) | ||
grad_g_w = - torch.cdiv(g_rpb, torch.repeatTensor(pdf,1,1,nm,1)) | ||
else | ||
-- for color pixels, we have to calculate the inverse one at a time. FIX THIS? | ||
for t=1,D do -- iterate over timestep | ||
for b=1,N do -- iterate over batches | ||
-- can we vectorize this? Now constrains by the MVN.PDF | ||
local g_rpb_ = torch.Tensor(nm):zero() | ||
for g=1,nm do -- iterate over mixtures | ||
g_rpb_[g] = mvn.pdf(g_mean_diff[{t,b,g,{}}], g_clk[{t,b,g,{},{}}]) * g_w[{t,b,g}] | ||
g_sigma_inv[{t,b,g,{},{}}] = torch.inverse(g_sigma[{t,b,g,{},{}}]) | ||
end | ||
local pdf = torch.sum(g_rpb_) | ||
loss1 = loss1 - torch.log(pdf) | ||
|
||
-- normalize the responsibilities for backprop | ||
g_rpb_:div(pdf) | ||
|
||
-- gradient of weight is tricky, making it efficient together with softmax | ||
grad_g_w[{t,b, {}}] = - g_rpb_ | ||
end | ||
end | ||
end | ||
|
||
local mean_left = g_mean_diff:view(-1, ps, 1) | ||
local mean_right = g_mean_diff:view(-1, 1, ps) | ||
g_sigma_inv = g_sigma_inv:view(-1, ps, ps) | ||
local g_rpb = torch.repeatTensor(grad_g_w:view(-1,1),1,ps) | ||
-- gradient for mean | ||
grad_g_mean = torch.cmul(g_rpb, torch.bmm(g_sigma_inv, mean_left)) | ||
-- gradient for sigma | ||
local g_temp = torch.bmm(torch.bmm(g_sigma_inv, torch.bmm(mean_left, mean_right)), g_sigma_inv) - g_sigma_inv | ||
g_rpb = torch.repeatTensor(g_rpb:view(-1,ps,1),1,1,ps) | ||
grad_g_sigma = torch.cmul(g_rpb, g_temp):mul(0.5) | ||
|
||
-- back prop encodings | ||
-- mean undertake no changes | ||
grad_g_mean = grad_g_mean:view(D, N, -1) | ||
-- gradient of weight is tricky, making it efficient together with softmax | ||
grad_g_w:add(g_w) | ||
grad_g_w = grad_g_w:view(D, N, -1) | ||
-- gradient of the var, and cov | ||
local grad_g_var | ||
local grad_g_cov | ||
g_clk = g_clk:view(-1, ps, ps) | ||
g_clk_T = g_clk_T:view(-1, ps, ps) | ||
if ps == 3 then | ||
grad_g_sigma = grad_g_sigma:view(-1, 3, 3) | ||
local grad_g_clk = self.var_mm:backward({g_clk, g_clk_T}, grad_g_sigma) | ||
grad_g_clk = grad_g_clk[1]:mul(2) | ||
grad_g_var = torch.Tensor(D*N*nm, ps):type(grad_g_clk:type()) | ||
grad_g_cov = torch.Tensor(D*N*nm, ps):type(grad_g_clk:type()) | ||
grad_g_var[{{}, 1}] = grad_g_clk[{{},1,1}] | ||
grad_g_var[{{}, 2}] = grad_g_clk[{{},2,2}] | ||
grad_g_var[{{}, 3}] = grad_g_clk[{{},3,3}] | ||
grad_g_var = self.var_exp:backward(g_var_input, grad_g_var) | ||
grad_g_var = grad_g_var:view(D, N, -1) | ||
grad_g_cov[{{}, 1}] = grad_g_clk[{{},2,1}] | ||
grad_g_cov[{{}, 2}] = grad_g_clk[{{},3,1}] | ||
grad_g_cov[{{}, 3}] = grad_g_clk[{{},3,2}] | ||
grad_g_cov = grad_g_cov:view(D, N, -1) | ||
else | ||
grad_g_var = torch.cmul(g_var, grad_g_sigma):mul(2) | ||
grad_g_var = self.var_exp:backward(g_var_input, grad_g_var) | ||
grad_g_var = grad_g_var:view(D, N, -1) | ||
end | ||
|
||
grad_g_mean:div(D*N) | ||
grad_g_var:div(D*N) | ||
if ps == 3 then grad_g_cov:div(D*N) end | ||
grad_g_w:div(D*N) | ||
|
||
-- concat to gradInput | ||
if self.pixel_size == 3 then | ||
-- torch does not allow us to concat more than 2 tensors for FloatTensors | ||
self.gradInput = torch.cat(torch.cat(grad_g_mean, grad_g_var), torch.cat(grad_g_cov, grad_g_w)) | ||
else | ||
self.gradInput = torch.cat(torch.cat(grad_g_mean, grad_g_var), grad_g_w) | ||
end | ||
if input:type() == 'torch.CudaTensor' and ps == 3 then | ||
self.gradInput = self.gradInput:cuda() | ||
end | ||
-- return the loss | ||
self.output = (loss1) / (D*N) | ||
return self.output | ||
end | ||
|
||
function crit:updateGradInput(input, target) | ||
-- just return it | ||
return self.gradInput | ||
end |
Oops, something went wrong.