Skip to content

Commit

Permalink
fix affinity matrix for unbalanced edges
Browse files Browse the repository at this point in the history
  • Loading branch information
rogerwwww committed Oct 12, 2022
1 parent 36f891d commit 1d489a1
Show file tree
Hide file tree
Showing 4 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion pygmtools/jittor_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,7 @@ def _aff_mat_from_node_edge_aff(node_aff: Var, edge_aff: Var, connectivity1: Var
if ne1 is None:
ne1 = [edge_aff.shape[1]] * batch_size
if ne2 is None:
ne2 = [edge_aff.shape[1]] * batch_size
ne2 = [edge_aff.shape[2]] * batch_size
else:
# device = node_aff.device
dtype = node_aff.dtype
Expand Down
2 changes: 1 addition & 1 deletion pygmtools/numpy_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ def _aff_mat_from_node_edge_aff(node_aff: np.ndarray, edge_aff: np.ndarray, conn
if ne1 is None:
ne1 = [edge_aff.shape[1]] * batch_size
if ne2 is None:
ne2 = [edge_aff.shape[1]] * batch_size
ne2 = [edge_aff.shape[2]] * batch_size
else:
dtype = node_aff.dtype
batch_size = node_aff.shape[0]
Expand Down
2 changes: 1 addition & 1 deletion pygmtools/paddle_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ def _aff_mat_from_node_edge_aff(node_aff: paddle.Tensor, edge_aff: paddle.Tensor
if ne1 is None:
ne1 = [edge_aff.shape[1]] * batch_size
if ne2 is None:
ne2 = [edge_aff.shape[1]] * batch_size
ne2 = [edge_aff.shape[2]] * batch_size
else:
device = node_aff.place
dtype = node_aff.dtype
Expand Down
2 changes: 1 addition & 1 deletion pygmtools/pytorch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1275,7 +1275,7 @@ def _aff_mat_from_node_edge_aff(node_aff: Tensor, edge_aff: Tensor, connectivity
if ne1 is None:
ne1 = [edge_aff.shape[1]] * batch_size
if ne2 is None:
ne2 = [edge_aff.shape[1]] * batch_size
ne2 = [edge_aff.shape[2]] * batch_size
else:
device = node_aff.device
dtype = node_aff.dtype
Expand Down

0 comments on commit 1d489a1

Please sign in to comment.