Skip to content

Commit

Permalink
reorganize sampling code.
Browse files Browse the repository at this point in the history
  • Loading branch information
zhirongw committed Jan 23, 2016
1 parent 123dd85 commit 30b5da9
Show file tree
Hide file tree
Showing 6 changed files with 189 additions and 120 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,5 @@
*.ipynb
models/
samples/
scripts/
logs/
72 changes: 21 additions & 51 deletions gmms.lua
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ end

local gmms = {}

function crit:sample(input, gt_pixels)
function crit:sample(input, temperature, gt_pixels)
local N, G = input:size(1), input:size(2)
local ps = self.pixel_size
local nm = self.num_mixtures
Expand Down Expand Up @@ -215,71 +215,29 @@ function crit:sample(input, gt_pixels)
-- weights coeffs is taken care of at final loss, for computation efficiency and stability
local g_w_input = input:narrow(2,p*nm*ps+1, nm):clone()
local g_w = torch.exp(g_w_input:view(-1, nm))
--print(g_w)
g_w = g_w:cdiv(torch.repeatTensor(torch.sum(g_w,2),1,nm))
local g_ws = torch.exp(g_w_input:div(temperature))
g_ws = g_ws:cdiv(torch.repeatTensor(torch.sum(g_ws,2),1,nm))
--print(g_w)

local pixels = torch.Tensor(N, ps):type(input:type())
local train_pixels = gt_pixels:float()
local losses = 0
local train_losses = 0
local g_clk_x, g_w_x, g_mean_x
if input:type() == 'torch.CudaTensor' then
g_clk_x = g_clk:float()
g_w_x = g_w:float()
g_mean_x = g_mean:float()
else
g_clk_x = g_clk
g_w_x = g_w
g_mean_x = g_mean
end

-- sampling process
local mix_idx
mix_idx = torch.multinomial(g_w, 1)
--ignore, mix_idx = torch.max(g_w, 2)
--mix_idx = mix_idx:resize(N,1)
local mix_idx = torch.multinomial(g_ws, 1)

for b=1,N do
-- print('------------------------------------------')
-- sample from the multinomial
--print(mix_idx)
--local max_prob, mix_idx
--max_prob, mix_idx = torch.max(g_w[b], 1)
--mix_idx = mix_idx[1]
--print(mix_idx)
-- sample from the mvn gaussians
-- print(g_mean[{b, mix_idx, {}}])
-- print(g_clk[{b, mix_idx, {}, {}}])
local p = mvn.rnd(g_mean[{b, mix_idx[{b,1}], {}}], g_clk[{b, mix_idx[{b,1}], {},{}}])
--p = g_mean[{b, mix_idx[{b,1}], {}}]
pixels[b] = p
if ps == 3 then
-- evaluate the loss
local g_rpb_ = torch.Tensor(nm):zero()
local pf = p:float()
--print(pf)
for g=1,nm do -- iterate over mixtures
g_rpb_[g] = mvn.pdf(pf, g_mean_x[{b,g,{}}], g_clk_x[{b,g,{},{}}]) * g_w_x[{b,g}]
end
local pdf = torch.sum(g_rpb_)
losses = losses - torch.log(pdf)
-- VALIDATION
local train_g_rpb_ = torch.Tensor(nm):zero()
local train_pf = train_pixels[b]
--print(val_pf)
for g=1,nm do -- iterate over mixtures
train_g_rpb_[g] = mvn.pdf(train_pf, g_mean_x[{b,g,{}}], g_clk_x[{b,g,{},{}}]) * g_w_x[{b,g}]
end
local train_pdf = torch.sum(train_g_rpb_)
train_losses = train_losses - torch.log(train_pdf)
end
end

-- evaluate the loss
local losses
local train_losses
if ps == 1 then
-- for synthesis pixels
local g_mean_diff = torch.repeatTensor(pixels:view(N, 1, ps),1,nm,1):add(-1, g_mean)
local g_rpb = mvn.bnormpdf(g_mean_diff, g_clk)
g_rpb = g_rpb:cmul(g_w)
local g_rpb = mvn.bnormpdf(g_mean_diff, g_clk):cmul(g_w)
local pdf = torch.sum(g_rpb, 2)
losses = - torch.sum(torch.log(pdf))
-- for training pixels
Expand All @@ -288,6 +246,18 @@ function crit:sample(input, gt_pixels)
g_rpb = g_rpb:cmul(g_w)
pdf = torch.sum(g_rpb, 2)
train_losses = - torch.sum(torch.log(pdf))
else
local g_clk_inv = mvn.btmi(g_clk)
-- for synthesis pixels
local g_mean_diff = torch.repeatTensor(pixels:view(N, 1, ps),1,nm,1):add(-1, g_mean)
local g_rpb = mvn.b3normpdf(g_mean_diff, g_clk_inv):cmul(g_w)
local pdf = torch.sum(g_rpb, 2)
losses = - torch.sum(torch.log(pdf))
-- for training pixels
g_mean_diff = torch.repeatTensor(gt_pixels:view(N, 1, ps),1,nm,1):add(-1, g_mean)
g_rpb = mvn.b3normpdf(g_mean_diff, g_clk_inv):cmul(g_w)
pdf = torch.sum(g_rpb, 2)
train_losses = - torch.sum(torch.log(pdf))
end

losses = losses / N
Expand Down
8 changes: 5 additions & 3 deletions misc/DataLoaderRaw.lua
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,11 @@ function DataLoaderRaw:__init(opt)
if opt.color > 0 then self.nChannels = 3 else self.nChannels = 1 end

local img = image.load(self.files[self.iterator], self.nChannels, 'float')
img = image.scale(img, opt.img_size, opt.img_size)
img = img:resize(self.nChannels, opt.img_size, opt.img_size)
-- print(img[{{},1,1}])
--if self.nChannels == 1 then img = img:resize(1, img:size(1), img:size(2)) end
if img:size(2) > opt.img_size or img:size(3) > opt.img_size then
img = image.scale(img, opt.img_size)
end

self.images[self.iterator] = img
self.nHeight = img:size(2)
self.nWidth = img:size(3)
Expand Down
2 changes: 1 addition & 1 deletion misc/mvn.lua
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ function mvn.btmi(cholesky)
inv[{{}, 2, 1}]:cmul(inv[{{}, 1, 1}])
inv[{{}, 3, 1}]:cmul(inv[{{}, 1, 1}])
inv[{{}, 3, 2}]:cmul(inv[{{}, 2, 2}])
inv[{{}, 3, 1}]:csub(torch.cmul(inv[{{}, 2, 1}], inv[{{}, 3, 2}]))
inv[{{}, 3, 1}]:add(-1, torch.cmul(inv[{{}, 2, 1}], inv[{{}, 3, 2}]))
inv[{{}, 2, 1}]:cmul(inv[{{}, 2, 2}]):mul(-1)
inv[{{}, 3, 2}]:cmul(inv[{{}, 3, 3}]):mul(-1)
inv[{{}, 3, 1}]:cmul(inv[{{}, 3, 3}]):mul(-1)
Expand Down
218 changes: 157 additions & 61 deletions sample.lua
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ assert(string.len(opt.model) > 0, 'must provide a model')
local checkpoint = torch.load(opt.model)
local batch_size = opt.batch_size
if opt.batch_size == 0 then batch_size = checkpoint.opt.batch_size end
local temperature = opt.temperature

-- change it to evaluation mode
local protos = checkpoint.protos
Expand All @@ -73,6 +74,7 @@ local pm = protos.pm
local crit = nn.PixelModelCriterion(pm.pixel_size, pm.num_mixtures)
pm.core:evaluate()
print('The loaded model is trained on patch size with: ', patch_size)
print('Number of neighbors used: ', pm.num_neighbors)

-- prepare the empty states
local init_state = {}
Expand All @@ -91,81 +93,175 @@ local img = image.load('imgs/D1.png', pm.pixel_size, 'float')
img = image.scale(img, 256, 256):resize(1, pm.pixel_size, 256, 256)
img = torch.repeatTensor(img, batch_size, 1, 1, 1)
img = img:cuda()
-------------------------------------------------

local loss_sum = 0
local train_loss_sum = 0
-- random seed the zero-th pixel
-- local pixel = torch.rand(batch_size, pm.pixel_size):cuda()
local pixel
local gmms
-- loop through each timestep
for h=1,pm.recurrent_stride do
for w=1,pm.recurrent_stride do
local pixel_left, pixel_up
if w == 1 then
if border == 0 then
pixel_left = torch.zeros(batch_size, pm.pixel_size):cuda()
local function sample2n()
local loss_sum = 0
local train_loss_sum = 0

local pixel
local gmms
-- loop through each timestep
for h=1,pm.recurrent_stride do
for w=1,pm.recurrent_stride do
local pixel_left, pixel_up
if w == 1 then
if border == 0 then
pixel_left = torch.zeros(batch_size, pm.pixel_size):cuda()
else
pixel_left = torch.rand(batch_size, pm.pixel_size):cuda()
end
else
pixel_left = torch.rand(batch_size, pm.pixel_size):cuda()
pixel_left = images[{{}, {}, h, w-1}]
end
else
pixel_left = images[{{}, {}, h, w-1}]
end
if h == 1 then
if border == 0 then
pixel_up = torch.zeros(batch_size, pm.pixel_size):cuda()
if h == 1 then
if border == 0 then
pixel_up = torch.zeros(batch_size, pm.pixel_size):cuda()
else
pixel_up = torch.rand(batch_size, pm.pixel_size):cuda()
end
else
pixel_up = torch.rand(batch_size, pm.pixel_size):cuda()
pixel_up = images[{{}, {}, h-1,w}]
end
else
pixel_up = images[{{}, {}, h-1,w}]
end

-- inputs to LSTM, {input, states[t, t-1], states[t-1, t] }
-- Need to fix this for the new model
local inputs = {torch.cat(pixel_left, pixel_up, 2), unpack(states[w-1])}
local prev_w = w
if states[w] == nil then prev_w = 0 end
-- insert the states[t-1,t]
for i,v in ipairs(states[prev_w]) do table.insert(inputs, v) end
-- forward the network outputs, {next_c, next_h, next_c, next_h ..., output_vec}
local lsts = pm.core:forward(inputs)

-- save the state
states[w] = {}
for i=1,pm.num_state do table.insert(states[w], lsts[i]:clone()) end
gmms = lsts[#lsts]


-- sampling
--if w < patch_size or h < patch_size then
if false then
pixel = img[{{}, {}, h, w}]
images[{{},{},h,w}] = pixel
else
-- inputs to LSTM, {input, states[t, t-1], states[t-1, t] }
-- Need to fix this for the new model
local inputs = {torch.cat(pixel_left, pixel_up, 2), unpack(states[w-1])}
local prev_w = w
if states[w] == nil then prev_w = 0 end
-- insert the states[t-1,t]
for i,v in ipairs(states[prev_w]) do table.insert(inputs, v) end
-- forward the network outputs, {next_c, next_h, next_c, next_h ..., output_vec}
local lsts = pm.core:forward(inputs)

-- save the state
states[w] = {}
for i=1,pm.num_state do table.insert(states[w], lsts[i]:clone()) end
gmms = lsts[#lsts]

-- sampling
local train_pixel = img[{{}, {}, h, w}]:clone()
pixel, loss, train_loss = crit:sample(gmms, train_pixel)
pixel, loss, train_loss = crit:sample(gmms, temperature, train_pixel)
--pixel = train_pixel
images[{{},{},h,w}] = pixel
loss_sum = loss_sum + loss
train_loss_sum = train_loss_sum + train_loss
end
collectgarbage()
end

-- output the sampled images
local images_cpu = images:float()
images_cpu = images_cpu[{{}, {}, {patch_size+1, pm.recurrent_stride},{patch_size+1, pm.recurrent_stride}}]
images_cpu = images_cpu:clamp(0,1):mul(255):type('torch.ByteTensor')
for i=1,batch_size do
local filename = path.join('samples', i .. '.png')
image.save(filename, images_cpu[{i,1,{},{}}])
end
collectgarbage()

--loss_sum = loss_sum / (opt.img_size * opt.img_size)
--train_loss_sum = train_loss_sum / (opt.img_size * opt.img_size)
loss_sum = loss_sum / (pm.recurrent_stride * pm.recurrent_stride)
train_loss_sum = train_loss_sum / (pm.recurrent_stride * pm.recurrent_stride)
print('testing loss: ', loss_sum)
print('training loss: ', train_loss_sum)
end

-- output the sampled images
local images_cpu = images:float()
images_cpu = images_cpu[{{}, {}, {patch_size+1, pm.recurrent_stride},{patch_size+1, pm.recurrent_stride}}]
images_cpu = images_cpu:clamp(0,1):mul(255):type('torch.ByteTensor')
for i=1,batch_size do
local filename = path.join('samples', i .. '.png')
image.save(filename, images_cpu[{i,1,{},{}}])
local function sample3n()
local loss_sum = 0
local train_loss_sum = 0

local pixel
local gmms
-- loop through each timestep
for h=1,pm.recurrent_stride do
for w=1,pm.recurrent_stride do
local ww = w -- actual coordinate
if h % 2 == 0 then ww = pm.recurrent_stride + 1 - w end

local pixel_left, pixel_up, pixel_right
local pl, pr, pu
if ww == 1 or h % 2 == 0 then
if border == 0 then
pixel_left = torch.zeros(batch_size, pm.pixel_size):cuda()
else
pixel_left = torch.rand(batch_size, pm.pixel_size):cuda()
end
pl = 0
else
pixel_left = images[{{}, {}, h, ww-1}]
pl = ww - 1
end
if ww == pm.recurrent_stride or h % 2 == 1 then
if border == 0 then
pixel_right = torch.zeros(batch_size, pm.pixel_size):cuda()
else
pixel_right = torch.rand(batch_size, pm.pixel_size):cuda()
end
pr = 0
else
pixel_right = images[{{}, {}, h, ww+1}]
pr = ww + 1
end
if h == 1 then
if border == 0 then
pixel_up = torch.zeros(batch_size, pm.pixel_size):cuda()
else
pixel_up = torch.rand(batch_size, pm.pixel_size):cuda()
end
pu = 0
else
pixel_up = images[{{}, {}, h-1, ww}]
pu = ww
end

-- inputs to LSTM, {input, states[t, t-1], states[t-1, t], states[t, t+1] }
-- Need to fix this for the new model
local inputs = {torch.cat(torch.cat(pixel_left, pixel_up, 2), pixel_right, 2), unpack(states[pl])}
-- insert the states[t-1,t]
for i,v in ipairs(states[pu]) do table.insert(inputs, v) end
-- insert the states[t,t+1]
for i,v in ipairs(states[pr]) do table.insert(inputs, v) end
-- forward the network outputs, {next_c, next_h, next_c, next_h ..., output_vec}
local lsts = pm.core:forward(inputs)

-- save the state
states[ww] = {}
for i=1,pm.num_state do table.insert(states[ww], lsts[i]:clone()) end
gmms = lsts[#lsts]

-- sampling
local train_pixel = img[{{}, {}, h, ww}]:clone()
pixel, loss, train_loss = crit:sample(gmms, temperature, train_pixel)
--pixel = train_pixel
images[{{},{},h,ww}] = pixel
loss_sum = loss_sum + loss
train_loss_sum = train_loss_sum + train_loss
end
collectgarbage()
end

-- output the sampled images
local images_cpu = images:float()
images_cpu = images_cpu[{{}, {}, {patch_size+1, pm.recurrent_stride},{patch_size+1, pm.recurrent_stride}}]
images_cpu = images_cpu:clamp(0,1):mul(255):type('torch.ByteTensor')
for i=1,batch_size do
local filename = path.join('samples', i .. '.png')
image.save(filename, images_cpu[{i,1,{},{}}])
end

--loss_sum = loss_sum / (opt.img_size * opt.img_size)
--train_loss_sum = train_loss_sum / (opt.img_size * opt.img_size)
loss_sum = loss_sum / (pm.recurrent_stride * pm.recurrent_stride)
train_loss_sum = train_loss_sum / (pm.recurrent_stride * pm.recurrent_stride)
print('testing loss: ', loss_sum)
print('training loss: ', train_loss_sum)
end

loss_sum = loss_sum / (opt.img_size * opt.img_size)
train_loss_sum = train_loss_sum / (opt.img_size * opt.img_size)
--loss_sum = loss_sum / (pm.recurrent_stride * pm.recurrent_stride)
--train_loss_sum = train_loss_sum / (pm.recurrent_stride * pm.recurrent_stride)
print('testing loss: ', loss_sum)
print('training loss: ', train_loss_sum)
if pm.num_neighbors == 2 then
sample2n()
elseif pm.num_neighbors == 3 then
sample3n()
else
print('not implemented')
end
Loading

0 comments on commit 30b5da9

Please sign in to comment.