Skip to content

Commit

Permalink
fix crash if content is not provided
Browse files Browse the repository at this point in the history
  • Loading branch information
vadim-v-lebedev committed Mar 25, 2016
1 parent fc749d6 commit 7b77c37
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 7 deletions.
26 changes: 20 additions & 6 deletions fast_neural_doodle.lua
Original file line number Diff line number Diff line change
Expand Up @@ -50,21 +50,36 @@ local function main()
-- Load images
local f_data = hdf5.open(params.masks_hdf5)
local style_img = f_data:read('style_img'):all()
local content_img = f_data:read('content_img'):all()
if cur_resolution ~= 0 then
style_img = image.scale(style_img, cur_resolution, cur_resolution)
content_img = image.scale(content_img, cur_resolution, cur_resolution)
end
style_img = preprocess(style_img):float()
content_img = preprocess(content_img):float()

local has_content = f_data:read('has_content')[0]
if has_content then
local content_img = f_data:read('content_img'):all()
content_img = preprocess(content_img):float()
if cur_resolution ~= 0 then
content_img = image.scale(content_img, cur_resolution, cur_resolution)
end
content_img = preprocess(content_img):float()
else
print('Content image is not provided, content weight will be set to zero')
params.content_weight = 0
end

if params.gpu >= 0 then
if params.backend ~= 'clnn' then
style_img = style_img:cuda()
content_img = content_img:cuda()
if has_content then
content_img = content_img:cuda()
end
else
style_img = style_img:cl()
content_img = content_img:cl()
if has_content then
content_img = content_img:cl()
end
end
end

Expand Down Expand Up @@ -162,7 +177,7 @@ local function main()
net:add(layer)

-- Content
if name == content_layers[next_content_idx] then
if has_content and name == content_layers[next_content_idx] then
print("Setting up content layer", i, ":", layer.name)
local target = net:forward(content_img):clone()
local norm = params.normalize_gradients
Expand Down Expand Up @@ -215,7 +230,6 @@ local function main()
end

local norm = params.normalize_gradients
print('style loss')
local loss_module = nn.StyleLoss(params.style_weight, target_grams, norm, deepcopy(target_masks)):float()
if params.gpu >= 0 then
if params.backend ~= 'clnn' then
Expand Down
4 changes: 3 additions & 1 deletion get_mask_hdf5.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,10 @@
f['style_img'] = img_style.transpose(2, 0, 1).astype(float) / 255.
if args.content_image != None:
f['content_img'] = img_content.transpose(2, 0, 1).astype(float) / 255.
f['has_content'] = np.array([True])
else:
f['content_img'] = None
#f['content_img'] = np.array([0])
f['has_content'] = np.array([False])
f['n_colors'] = np.array([args.n_colors]) # Torch does not want to read just number

f.close()
Expand Down

0 comments on commit 7b77c37

Please sign in to comment.