From 1d489a16c5cceed6a9a7cf88e42fc33476041275 Mon Sep 17 00:00:00 2001 From: roger <18309862+rogerwwww@users.noreply.github.com> Date: Wed, 12 Oct 2022 11:30:35 +0800 Subject: [PATCH] fix affinity matrix for unbalanced edges --- pygmtools/jittor_backend.py | 2 +- pygmtools/numpy_backend.py | 2 +- pygmtools/paddle_backend.py | 2 +- pygmtools/pytorch_backend.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pygmtools/jittor_backend.py b/pygmtools/jittor_backend.py index 29f97595..9c85a151 100644 --- a/pygmtools/jittor_backend.py +++ b/pygmtools/jittor_backend.py @@ -417,7 +417,7 @@ def _aff_mat_from_node_edge_aff(node_aff: Var, edge_aff: Var, connectivity1: Var if ne1 is None: ne1 = [edge_aff.shape[1]] * batch_size if ne2 is None: - ne2 = [edge_aff.shape[1]] * batch_size + ne2 = [edge_aff.shape[2]] * batch_size else: # device = node_aff.device dtype = node_aff.dtype diff --git a/pygmtools/numpy_backend.py b/pygmtools/numpy_backend.py index 6c95c811..35776711 100644 --- a/pygmtools/numpy_backend.py +++ b/pygmtools/numpy_backend.py @@ -351,7 +351,7 @@ def _aff_mat_from_node_edge_aff(node_aff: np.ndarray, edge_aff: np.ndarray, conn if ne1 is None: ne1 = [edge_aff.shape[1]] * batch_size if ne2 is None: - ne2 = [edge_aff.shape[1]] * batch_size + ne2 = [edge_aff.shape[2]] * batch_size else: dtype = node_aff.dtype batch_size = node_aff.shape[0] diff --git a/pygmtools/paddle_backend.py b/pygmtools/paddle_backend.py index d6f728cc..e1696cd7 100644 --- a/pygmtools/paddle_backend.py +++ b/pygmtools/paddle_backend.py @@ -395,7 +395,7 @@ def _aff_mat_from_node_edge_aff(node_aff: paddle.Tensor, edge_aff: paddle.Tensor if ne1 is None: ne1 = [edge_aff.shape[1]] * batch_size if ne2 is None: - ne2 = [edge_aff.shape[1]] * batch_size + ne2 = [edge_aff.shape[2]] * batch_size else: device = node_aff.place dtype = node_aff.dtype diff --git a/pygmtools/pytorch_backend.py b/pygmtools/pytorch_backend.py index 785feb59..2eff2831 100644 --- a/pygmtools/pytorch_backend.py +++ b/pygmtools/pytorch_backend.py @@ -1275,7 +1275,7 @@ def _aff_mat_from_node_edge_aff(node_aff: Tensor, edge_aff: Tensor, connectivity if ne1 is None: ne1 = [edge_aff.shape[1]] * batch_size if ne2 is None: - ne2 = [edge_aff.shape[1]] * batch_size + ne2 = [edge_aff.shape[2]] * batch_size else: device = node_aff.device dtype = node_aff.dtype