Skip to content

Commit

Permalink
comment out bug fixed update...
Browse files Browse the repository at this point in the history
  • Loading branch information
AyanKumarBhunia committed Aug 31, 2021
1 parent d8aebd0 commit 709b998
Showing 1 changed file with 39 additions and 39 deletions.
78 changes: 39 additions & 39 deletions Photo_to_Sketch_2D_Attention/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,45 +40,45 @@ def Image2Sketch_Train(self, rgb_image, sketch_vector, length_sketch, step, sket
backbone_feature, rgb_encoded_dist = self.Image_Encoder(rgb_image)
rgb_encoded_dist_z_vector = rgb_encoded_dist.rsample()

# """ Ditribution Matching Loss """
# prior_distribution = torch.distributions.Normal(torch.zeros_like(rgb_encoded_dist.mean),
# torch.ones_like(rgb_encoded_dist.stddev))
#
# kl_cost_rgb = torch.max(torch.distributions.kl_divergence(rgb_encoded_dist, prior_distribution).mean(), torch.tensor(self.hp.kl_tolerance).to(device))
#
# ##############################################################
# ##############################################################
# """ Cross Modal the Decoding """
# ##############################################################
# ##############################################################
#
# photo2sketch_output = self.Sketch_Decoder(backbone_feature, rgb_encoded_dist_z_vector, sketch_vector, length_sketch + 1)
#
# end_token = torch.stack([torch.tensor([0, 0, 0, 0, 1])] * rgb_image.shape[0]).unsqueeze(0).to(device).float()
# batch = torch.cat([sketch_vector, end_token], 0)
# x_target = batch.permute(1, 0, 2) # batch-> Seq_Len, Batch, Feature_dim
#
# sup_p2s_loss = sketch_reconstruction_loss(photo2sketch_output, x_target) #TODO: Photo to Sketch Loss
#
# loss = sup_p2s_loss + curr_kl_weight*kl_cost_rgb
#
# set_learninRate(self.optimizer, curr_learning_rate)
# loss.backward()
# nn.utils.clip_grad_norm_(self.train_params, self.hp.grad_clip)
# self.optimizer.step()
#
# print('Step:{} ** sup_p2s_loss:{} ** kl_cost_rgb:{} ** Total_loss:{}'.format(step, sup_p2s_loss,
# kl_cost_rgb, loss))


# if step%5 == 0:
#
# data = {}
# data['Reconstrcution_Loss'] = sup_p2s_loss
# data['KL_Loss'] = kl_cost_rgb
# data['Total Loss'] = loss
#
# self.visualizer.plot_scalars(data, step)
""" Ditribution Matching Loss """
prior_distribution = torch.distributions.Normal(torch.zeros_like(rgb_encoded_dist.mean),
torch.ones_like(rgb_encoded_dist.stddev))

kl_cost_rgb = torch.max(torch.distributions.kl_divergence(rgb_encoded_dist, prior_distribution).mean(), torch.tensor(self.hp.kl_tolerance).to(device))

##############################################################
##############################################################
""" Cross Modal the Decoding """
##############################################################
##############################################################

photo2sketch_output = self.Sketch_Decoder(backbone_feature, rgb_encoded_dist_z_vector, sketch_vector, length_sketch + 1)

end_token = torch.stack([torch.tensor([0, 0, 0, 0, 1])] * rgb_image.shape[0]).unsqueeze(0).to(device).float()
batch = torch.cat([sketch_vector, end_token], 0)
x_target = batch.permute(1, 0, 2) # batch-> Seq_Len, Batch, Feature_dim

sup_p2s_loss = sketch_reconstruction_loss(photo2sketch_output, x_target) #TODO: Photo to Sketch Loss

loss = sup_p2s_loss + curr_kl_weight*kl_cost_rgb

set_learninRate(self.optimizer, curr_learning_rate)
loss.backward()
nn.utils.clip_grad_norm_(self.train_params, self.hp.grad_clip)
self.optimizer.step()

print('Step:{} ** sup_p2s_loss:{} ** kl_cost_rgb:{} ** Total_loss:{}'.format(step, sup_p2s_loss,
kl_cost_rgb, loss))


if step%5 == 0:

data = {}
data['Reconstrcution_Loss'] = sup_p2s_loss
data['KL_Loss'] = kl_cost_rgb
data['Total Loss'] = loss

self.visualizer.plot_scalars(data, step)


if step%1 == 0:
Expand Down

0 comments on commit 709b998

Please sign in to comment.