diff --git a/Photo_to_Sketch_2D_Attention/model.py b/Photo_to_Sketch_2D_Attention/model.py index 6f0b8bd..e22b2fc 100644 --- a/Photo_to_Sketch_2D_Attention/model.py +++ b/Photo_to_Sketch_2D_Attention/model.py @@ -40,45 +40,45 @@ def Image2Sketch_Train(self, rgb_image, sketch_vector, length_sketch, step, sket backbone_feature, rgb_encoded_dist = self.Image_Encoder(rgb_image) rgb_encoded_dist_z_vector = rgb_encoded_dist.rsample() - # """ Ditribution Matching Loss """ - # prior_distribution = torch.distributions.Normal(torch.zeros_like(rgb_encoded_dist.mean), - # torch.ones_like(rgb_encoded_dist.stddev)) - # - # kl_cost_rgb = torch.max(torch.distributions.kl_divergence(rgb_encoded_dist, prior_distribution).mean(), torch.tensor(self.hp.kl_tolerance).to(device)) - # - # ############################################################## - # ############################################################## - # """ Cross Modal the Decoding """ - # ############################################################## - # ############################################################## - # - # photo2sketch_output = self.Sketch_Decoder(backbone_feature, rgb_encoded_dist_z_vector, sketch_vector, length_sketch + 1) - # - # end_token = torch.stack([torch.tensor([0, 0, 0, 0, 1])] * rgb_image.shape[0]).unsqueeze(0).to(device).float() - # batch = torch.cat([sketch_vector, end_token], 0) - # x_target = batch.permute(1, 0, 2) # batch-> Seq_Len, Batch, Feature_dim - # - # sup_p2s_loss = sketch_reconstruction_loss(photo2sketch_output, x_target) #TODO: Photo to Sketch Loss - # - # loss = sup_p2s_loss + curr_kl_weight*kl_cost_rgb - # - # set_learninRate(self.optimizer, curr_learning_rate) - # loss.backward() - # nn.utils.clip_grad_norm_(self.train_params, self.hp.grad_clip) - # self.optimizer.step() - # - # print('Step:{} ** sup_p2s_loss:{} ** kl_cost_rgb:{} ** Total_loss:{}'.format(step, sup_p2s_loss, - # kl_cost_rgb, loss)) - - - # if step%5 == 0: - # - # data = {} - # data['Reconstrcution_Loss'] = sup_p2s_loss - # data['KL_Loss'] = kl_cost_rgb - # data['Total Loss'] = loss - # - # self.visualizer.plot_scalars(data, step) + """ Ditribution Matching Loss """ + prior_distribution = torch.distributions.Normal(torch.zeros_like(rgb_encoded_dist.mean), + torch.ones_like(rgb_encoded_dist.stddev)) + + kl_cost_rgb = torch.max(torch.distributions.kl_divergence(rgb_encoded_dist, prior_distribution).mean(), torch.tensor(self.hp.kl_tolerance).to(device)) + + ############################################################## + ############################################################## + """ Cross Modal the Decoding """ + ############################################################## + ############################################################## + + photo2sketch_output = self.Sketch_Decoder(backbone_feature, rgb_encoded_dist_z_vector, sketch_vector, length_sketch + 1) + + end_token = torch.stack([torch.tensor([0, 0, 0, 0, 1])] * rgb_image.shape[0]).unsqueeze(0).to(device).float() + batch = torch.cat([sketch_vector, end_token], 0) + x_target = batch.permute(1, 0, 2) # batch-> Seq_Len, Batch, Feature_dim + + sup_p2s_loss = sketch_reconstruction_loss(photo2sketch_output, x_target) #TODO: Photo to Sketch Loss + + loss = sup_p2s_loss + curr_kl_weight*kl_cost_rgb + + set_learninRate(self.optimizer, curr_learning_rate) + loss.backward() + nn.utils.clip_grad_norm_(self.train_params, self.hp.grad_clip) + self.optimizer.step() + + print('Step:{} ** sup_p2s_loss:{} ** kl_cost_rgb:{} ** Total_loss:{}'.format(step, sup_p2s_loss, + kl_cost_rgb, loss)) + + + if step%5 == 0: + + data = {} + data['Reconstrcution_Loss'] = sup_p2s_loss + data['KL_Loss'] = kl_cost_rgb + data['Total Loss'] = loss + + self.visualizer.plot_scalars(data, step) if step%1 == 0: