diff --git a/model/BIMPM.py b/model/BIMPM.py index 046cf87..86b46b5 100644 --- a/model/BIMPM.py +++ b/model/BIMPM.py @@ -314,11 +314,22 @@ def div_with_small_value(n, d, eps=1e-8): # 4. Max-Attentive-Matching # (batch, seq_len1, hidden_size) - att_max_h_fw, _ = att_h_fw.max(dim=2) - att_max_h_bw, _ = att_h_bw.max(dim=2) + _, tmp_idx = att_fw.max(dim=2) + att_max_h_fw_idx = torch.stack([tmp_idx] * self.args.hidden_size, dim=2) + _, tmp_idx = att_bw.max(dim=2) + att_max_h_bw_idx = torch.stack([tmp_idx] * self.args.hidden_size, dim=2) + + att_max_h_fw = torch.gather(con_h_fw, dim=1, index=att_max_h_fw_idx) + att_max_h_bw = torch.gather(con_h_bw, dim=1, index=att_max_h_bw_idx) + # (batch, seq_len2, hidden_size) - att_max_p_fw, _ = att_p_fw.max(dim=1) - att_max_p_bw, _ = att_p_bw.max(dim=1) + _, tmp_idx = att_fw.max(dim=1) + att_max_p_fw_idx = torch.stack([tmp_idx] * self.args.hidden_size, dim=2) + _, tmp_idx = att_bw.max(dim=1) + att_max_p_bw_idx = torch.stack([tmp_idx] * self.args.hidden_size, dim=2) + + att_max_p_fw = torch.gather(con_p_fw, dim=1, index=att_max_p_fw_idx) + att_max_p_bw = torch.gather(con_p_bw, dim=1, index=att_max_p_bw_idx) # (batch, seq_len, l) mv_p_att_max_fw = mp_matching_func(con_p_fw, att_max_h_fw, self.mp_w7)