Skip to content

Commit

Permalink
DataLoader for a single image, training on single texture
Browse files Browse the repository at this point in the history
  • Loading branch information
zhirongw committed Jan 13, 2016
1 parent ff44fc5 commit 158a185
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 58 deletions.
58 changes: 35 additions & 23 deletions misc/DataLoaderRaw.lua
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,24 @@ function DataLoaderRaw:__init(opt)
self.N = #self.files
print('DataLoaderRaw found ' .. self.N .. ' images')

-- how about working on the first texture? D1.png
self.iterator = 1
self.images = {}
local img = image.load(self.files[self.iterator], 3, 'float')
img = image.scale(img, opt.img_size, opt.img_size)
-- print(img[{{},1,1}])
self.images[self.iterator] = img
self.nChannels = img:size(1)
self.nHeight = img:size(2)
self.nWidth = img:size(3)
end

function DataLoaderRaw:resetIterator()
self.iterator = 1
--self.iterator = 1
end

function DataLoaderRaw:getChannelSize()
return self.nChannels
end

--[[
Expand All @@ -54,37 +67,36 @@ end
--]]
function DataLoaderRaw:getBatch(opt)
-- may possibly preprocess the image by resizing, cropping
local seq_length = utils.getopt(opt, 'seq_length', 1000)
local patch_size = utils.getopt(opt, 'patch_size', 7)
local seq_length = patch_size * patch_size - 1
local batch_size = utils.getopt(opt, 'batch_size', 5)
-- load an image
local img = image.load(self.files[self.iterator], 3, 'byte')
local img_raw = image.scale(img, 256, 256)
local pixel_size = img_raw:size(3)
local img = self.images[self.iterator]

local pixels = torch.FloatTensor(seq_length, batch_size, pixel_size+1)
local patches = torch.FloatTensor(batch_size, self.nChannels, patch_size, patch_size)


local infos = {}
--local infos = {}
for i=1,batch_size do
local ri = self.iterator
local ri_next = ri + 1 -- increment iterator
if ri_next > max_index then ri_next = 1; wrapped = true end -- wrap back around
self.iterator = ri_next
local h = torch.random(1, self.nHeight-patch_size+1)
local w = torch.random(1, self.nWidth-patch_size+1)

-- load the image
local img = image.load(self.files[ri], 3, 'byte')
img_batch_raw[i] = image.scale(img, 256, 256)
patches[i] = img[{{}, {h, h+patch_size-1}, {w, w+patch_size-1}}]

-- and record associated info as well
local info_struct = {}
info_struct.id = self.ids[ri]
info_struct.file_path = self.files[ri]
table.insert(infos, info_struct)
-- local info_struct = {}
-- info_struct.id = self.ids[ri]
-- info_struct.file_path = self.files[ri]
-- table.insert(infos, info_struct)
end

patches = patches:view(batch_size, self.nChannels, -1)
patches = patches:permute(3, 1, 2)
local data = {}
data.images = img_batch_raw
data.bounds = {it_pos_now = self.iterator, it_max = self.N, wrapped = wrapped}
data.infos = infos
data.pixels = patches[{{1,seq_length},{},{}}]
data.targets = patches[{seq_length+1,{},{}}]
if opt.gpu >= 0 then
data.pixels = data.pixels:cuda()
data.targets = data.targets:cuda()
end
-- data.infos = infos
return data
end
13 changes: 2 additions & 11 deletions pm.lua
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,17 @@ function layer:__init(opt)
self.num_mixtures = utils.getopt(opt, 'num_mixtures')
local dropout = utils.getopt(opt, 'dropout', 0)
-- options for Pixel Model
self.seq_length = utils.getopt(opt, 'seq_length')
self.recurrent_stride = utils.getopt(opt, 'recurrent_stride')
self.seq_length = utils.getopt(opt, 'seq_length')
self.mult_in = utils.getopt(opt, 'mult_in')
if self.pixel_size == 3 then
self.output_size = self.num_mixtures * (3+3+3+1)
else
self.output_size = self.num_mixtures * (1+1+0+1)
end
-- create the core lstm network.
-- note +1 for addition end tokens, true for multiple input to deep layer connections.
-- mult_in for multiple input to deep layer connections.
self.core = LSTM.lstm2d(self.pixel_size, self.output_size, self.rnn_size, self.num_layers, dropout, self.mult_in)
-- decoding the output to gaussian mixture parameters
-- self.gmm = nn.GMMDecoder(self.pixel_size, self.num_mixtures)
self:_createInitState(1) -- will be lazily resized later during forward passes
end

Expand All @@ -59,10 +57,8 @@ function layer:createClones()
-- construct the net clones
print('constructing clones inside the PixelModel')
self.clones = {self.core}
self.gmms = {self.gmm}
for t=2,self.seq_length do
self.clones[t] = self.core:clone('weight', 'bias', 'gradWeight', 'gradBias')
-- self.gmms[t] = self.gmm:clone('weight', 'bias', 'gradWeight', 'gradBias')
end
end

Expand All @@ -73,17 +69,12 @@ end
function layer:parameters()
-- we only have two internal modules, return their params
local p1,g1 = self.core:parameters()
-- local p2,g2 = self.gmm:parameters()

local params = {}
for k,v in pairs(p1) do table.insert(params, v) end
-- assert(p2 == nil, 'GMM decoder should have no params')
-- for k,v in pairs(p2) do table.insert(params, v) end

local grad_params = {}
for k,v in pairs(g1) do table.insert(grad_params, v) end
-- assert(g2 == nil, 'GMM decoder should have no params')
-- for k,v in pairs(g2) do table.insert(grad_params, v) end

-- todo: invalidate self.clones if params were requested?
-- what if someone outside of us decided to call getParameters() or something?
Expand Down
41 changes: 17 additions & 24 deletions train.lua
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@ require 'torch'
require 'nn'
require 'nngraph'
-- local imports
require 'pm'
local utils = require 'misc.utils'
local net_utils = require 'misc.net_utils'
require 'misc.optim_updates'
require 'misc.DataLoaderRaw'
require 'misc.DataLoader'
--require 'misc.DataLoader'

-------------------------------------------------------------------------------
-- Input arguments and options
Expand All @@ -19,12 +20,13 @@ cmd:text('Options')

-- Data input settings
cmd:option('-folder_path','','path to the preprocessed textures')
cmd:option('-image_size',256,'resize the input image to')
--cmd:option('-input_h5','coco/data.h5','path to the h5file containing the preprocessed dataset')
--cmd:option('-input_json','coco/data.json','path to the json file containing additional info and vocab')
cmd:option('-start_from', '', 'path to a model checkpoint to initialize model weights from. Empty = don\'t')

-- Model settings
cmd:option('-rnn_size',512,'size of the rnn in number of hidden nodes in each layer')
cmd:option('-rnn_size',12,'size of the rnn in number of hidden nodes in each layer')
cmd:option('-num_layers',3,'number of layers in stacked RNN/LSTMs')
cmd:option('-num_mixtures',20,'number of gaussian mixtures to encode the output pixel')
cmd:option('-patch_size',7,'size of the neighbor patch that a pixel is conditioned on')
Expand All @@ -34,7 +36,7 @@ cmd:option('-max_iters', -1, 'max number of iterations to run for (-1 = run fore
cmd:option('-batch_size',16,'what is the batch size in number of images per batch? (there will be x seq_per_img sentences)')
cmd:option('-grad_clip',0.1,'clip gradients at this value (note should be lower than usual 5 because we normalize grads by both batch and seq_length)')
cmd:option('-drop_prob_pm', 0.5, 'strength of dropout in the Pixel RNN')
cmd:option('-mult_in', true, 'An extension of the LSTM architecture')
cmd:option('-mult_in', false, 'An extension of the LSTM architecture')
-- Optimization: for the Pixel Model
cmd:option('-optim','adam','what update to use? rmsprop|sgd|sgdmom|adagrad|adam')
cmd:option('-learning_rate',4e-4,'learning rate')
Expand All @@ -45,10 +47,8 @@ cmd:option('-optim_beta',0.999,'beta used for adam')
cmd:option('-optim_epsilon',1e-8,'epsilon that goes into denominator for smoothing')

-- Evaluation/Checkpointing
cmd:option('-val_images_use', 3200, 'how many images to use when periodically evaluating the validation loss? (-1 = all)')
cmd:option('-save_checkpoint_every', 2500, 'how often to save a model checkpoint?')
cmd:option('-checkpoint_path', '', 'folder to save checkpoints into (empty = this folder)')
cmd:option('-language_eval', 0, 'Evaluate language as well (1 = yes, 0 = no)? BLEU/CIDEr/METEOR/ROUGE_L? requires coco-caption code from Github.')
cmd:option('-losses_log_every', 25, 'How often do we snapshot losses, for inclusion in the progress dump? (0 = disable)')

-- misc
Expand Down Expand Up @@ -77,7 +77,7 @@ end
-------------------------------------------------------------------------------
-- Create the Data Loader instance
-------------------------------------------------------------------------------
local loader = DataLoaderRaw{folder_path = opt.data_folder}
local loader = DataLoaderRaw{folder_path = opt.folder_path, img_size = opt.image_size}

-------------------------------------------------------------------------------
-- Initialize the networks
Expand All @@ -99,12 +99,11 @@ else
pmOpt.pixel_size = loader:getChannelSize()
pmOpt.rnn_size = opt.rnn_size
pmOpt.num_mixtures = opt.num_mixtures
pmOpt.recurrent_stride = loader:getImgSize() + 1
pmOpt.num_layers = opt.num_layers
pmOpt.dropout = opt.drop_prob_pm
pmOpt.seq_length = loader:getSeqLength()
pmOpt.batch_size = opt.batch_size
pmOpt.seq_length = opt.patch_size * (loader:getImgSize() + 1)
pmOpt.recurrent_stride = opt.patch_size
pmOpt.seq_length = opt.patch_size * opt.patch_size - 1
pmOpt.mult_in = opt.mult_in
protos.pm = nn.PixelModel(pmOpt)
-- criterion for the pixel model
Expand Down Expand Up @@ -153,14 +152,11 @@ local function eval_split(split, evalopt)
while true do

-- fetch a batch of data
local data = loader:getBatch{batch_size = opt.batch_size, split = split, seq_per_img = opt.seq_per_img}
data.images = net_utils.prepro(data.images, false, opt.gpuid >= 0) -- preprocess in place, and don't augment
n = n + data.images:size(1)
local data = loader:getBatch{batch_size = opt.batch_size, patch_size = opt.patch_size}

-- forward the model to get loss
local feats = protos.cnn:forward(data.images)
local logprobs = protos.pm:forward{feats, data.labels}
local loss = protos.crit:forward(logprobs, data.labels)
local gmms = protos.pm:forward(data.pixels)
local loss = protos.crit:forward(gmms, data.targets)
loss_sum = loss_sum + loss
loss_evals = loss_evals + 1

Expand Down Expand Up @@ -193,7 +189,7 @@ end
-------------------------------------------------------------------------------
-- Loss function
-------------------------------------------------------------------------------
local iter = 0
local iter = 1
local function lossFun()
protos.pm:training()
grad_params:zero()
Expand All @@ -202,23 +198,20 @@ local function lossFun()
-- Forward pass
-----------------------------------------------------------------------------
-- get batch of data
local data = loader:getBatch{batch_size = opt.batch_size, split = 'train', seq_per_img = opt.seq_per_img}
data.images = net_utils.prepro(data.images, true, opt.gpuid >= 0) -- preprocess in place, do data augmentation
-- data.images: Nx3x224x224
-- data.seq: LxM where L is sequence length upper bound, and M = N*seq_per_img
local data = loader:getBatch{batch_size = opt.batch_size, patch_size = opt.patch_size, gpu = opt.gpuid, split = 'train'}

-- forward the pixel model
local output = protos.pm:forward(data.pixels)
local gmms = protos.pm:forward(data.pixels)
-- forward the pixel model criterion
local loss = protos.crit:forward(output, data.pixels)
local loss = protos.crit:forward(gmms, data.targets)

-----------------------------------------------------------------------------
-- Backward pass
-----------------------------------------------------------------------------
-- backprop criterion
local doutput = protos.crit:backward(output, data.pixels)
local dgmms = protos.crit:backward(gmms, data.targets)
-- backprop pixel model
local dpixels = protos.pm:backward(data.pixels, doutput)
local dpixels = protos.pm:backward(data.pixels, dgmms)

-- clip gradients
-- print(string.format('claming %f%% of gradients', 100*torch.mean(torch.gt(torch.abs(grad_params), opt.grad_clip))))
Expand Down

0 comments on commit 158a185

Please sign in to comment.