Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Re-seed image if needed #97

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 42 additions & 20 deletions neural_style.lua
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ function nn.SpatialConvolutionMM:accGradParameters()
end

local function main(params)

local lastTotalLoss = 0.0

if params.gpu >= 0 then
require 'cutorch'
require 'cunn'
Expand Down Expand Up @@ -198,27 +201,36 @@ local function main(params)
end
collectgarbage()

-- Initialize the image
if params.seed >= 0 then
torch.manualSeed(params.seed)
end
local img = nil
if params.init == 'random' then
img = torch.randn(content_image:size()):float():mul(0.001)
elseif params.init == 'image' then
img = content_image_caffe:clone():float()
else
error('Invalid init type')
end
if params.gpu >= 0 then
img = img:cuda()
end
local y
local dy

local function initImage(params)
-- Initialize the image
if params.seed >= 0 then
torch.manualSeed(params.seed)
end
local img = nil
if params.init == 'random' then
img = torch.randn(content_image:size()):float():mul(0.001)
elseif params.init == 'image' then
img = content_image_caffe:clone():float()
else
error('Invalid init type')
end
if params.gpu >= 0 then
img = img:cuda()
end

-- Run it through the network once to get the proper size for the gradient
-- All the gradients will come from the extra loss modules, so we just pass
-- zeros into the top of the net on the backward pass.
local y = net:forward(img)
local dy = img.new(#y):zero()
-- Run it through the network once to get the proper size for the gradient
-- All the gradients will come from the extra loss modules, so we just pass
-- zeros into the top of the net on the backward pass.
y = net:forward(img)
dy = img.new(#y):zero()

return img
end

local img = initImage(params)

-- Declaring this here lets us access it in maybe_print
local optim_state = nil
Expand Down Expand Up @@ -283,6 +295,16 @@ local function main(params)
maybe_print(num_calls, loss)
maybe_save(num_calls)

if (lastTotalLoss - loss) == 0.0 then
print('********** 0 total loss from last iteration; re-seeding **********')
-- need to reinitialize image with new noise seed
initImage(params)
lastTotalLoss = 0
collectgarbage()
return 0, grad:view(grad:nElement())
end
lastTotalLoss = loss

collectgarbage()
-- optim.lbfgs expects a vector for gradients
return loss, grad:view(grad:nElement())
Expand Down