Skip to content

Commit

Permalink
Merge pull request #22 from saibalmars/nk_fix
Browse files Browse the repository at this point in the history
Networkit fix: update to support the nk's new neighbor iterator
  • Loading branch information
saibalmars committed Dec 28, 2020
2 parents 462e76a + 5fd3b79 commit db226c4
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 11 deletions.
17 changes: 9 additions & 8 deletions GraphRicciCurvature/OllivierRicci.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,11 @@ def _get_single_node_neighbors_distributions(node, direction="successors"):
"""
if _Gk.isDirected():
if direction == "predecessors":
neighbors = _Gk.inNeighbors(node)
neighbors = list(_Gk.iterInNeighbors(node))
else: # successors
neighbors = _Gk.neighbors(node)
neighbors = list(_Gk.iterNeighbors(node))
else:
neighbors = _Gk.neighbors(node)
neighbors = list(_Gk.iterNeighbors(node))

# Get sum of distributions from x's all neighbors
heap_weight_node_pair = []
Expand All @@ -93,10 +93,11 @@ def _get_single_node_neighbors_distributions(node, direction="successors"):

nbr_edge_weight_sum = sum([x[0] for x in heap_weight_node_pair])

if len(neighbors) == 0:
if not neighbors:
# No neighbor, all mass stay at node
return [1], [node]
elif nbr_edge_weight_sum > EPSILON:

if nbr_edge_weight_sum > EPSILON:
# Sum need to be not too small to prevent divided by zero
distributions = [(1.0 - _alpha) * w / nbr_edge_weight_sum for w, _ in heap_weight_node_pair]
else:
Expand Down Expand Up @@ -283,10 +284,10 @@ def _average_transportation_distance(source, target):

t0 = time.time()
if _Gk.isDirected():
source_nbr = _Gk.inNeighbors(source)
source_nbr = list(_Gk.iterInNeighbors(source))
else:
source_nbr = _Gk.neighbors(source)
target_nbr = _Gk.neighbors(target)
source_nbr = list(_Gk.iterNeighbors(source))
target_nbr = list(_Gk.iterNeighbors(target))

share = (1.0 - _alpha) / (len(source_nbr) * len(target_nbr))
cost_nbr = 0
Expand Down
2 changes: 1 addition & 1 deletion GraphRicciCurvature/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
__version__ = "0.5.2"
__version__ = "0.5.2.1"
__author__ = "Chien-Chun Ni"
__email__ = "saibalmars@gmail.com"
2 changes: 1 addition & 1 deletion example.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
# -----------------------------------------------
print("\n-Construct a directed graph example")
Gd = nx.DiGraph()
Gd.add_edges_from([(1, 2), (2, 3), (3, 4), (2, 4), (4, 2)])
Gd.add_edges_from([(0, 1), (1, 2), (2, 3), (1, 3), (3, 1)])

print("\n===== Compute the Ollivier-Ricci curvature of the given directed graph Gd =====")
orc_directed = OllivierRicci(Gd)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

setuptools.setup(
name="GraphRicciCurvature",
version="0.5.2",
version="0.5.2.1",
author="Chien-Chun Ni",
author_email="saibalmars@gmail.com",
description="Compute discrete Ricci curvatures and Ricci flow on NetworkX graphs.",
Expand Down
34 changes: 34 additions & 0 deletions test/test_OllivierRicci.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,40 @@ def test_compute_ricci_curvature():
npt.assert_array_almost_equal(rc, ans)


def test_compute_ricci_curvature_directed():
Gd = nx.DiGraph()
Gd.add_edges_from([(0, 1), (1, 2), (2, 3), (1, 3), (3, 1)])
orc = OllivierRicci(Gd, method="OTD", alpha=0.5)
Gout = orc.compute_ricci_curvature()
rc = list(nx.get_edge_attributes(Gout, "ricciCurvature").values())
ans = [-0.49999999999999956,
-3.842615114990622e-11,
0.49999999996158007,
0.49999999992677135,
0.7499999999364129]

npt.assert_array_almost_equal(rc, ans)


def test_compute_ricci_curvature_ATD():
G = nx.karate_club_graph()
orc = OllivierRicci(G, alpha=0.5, method="ATD", verbose="INFO")
orc.compute_ricci_curvature()
Gout = orc.compute_ricci_curvature()
rc = list(nx.get_edge_attributes(Gout, "ricciCurvature").values())
ans = [-0.343750, -0.437500, -0.265625, -0.250000, -0.390625, -0.390625, -0.195312, -0.443750, -0.250000,
0.000000, -0.140625, -0.287500, -0.109375, -0.291667, -0.109375, -0.640625, -0.311111, -0.175926,
-0.083333, -0.166667, 0.000000, -0.166667, 0.000000, -0.333333, -0.241667, -0.137500, -0.220000,
-0.125000, -0.160000, -0.400000, -0.200000, -0.479167, 0.020833, 0.041667, -0.100000, -0.041667,
0.055556, -0.062500, -0.041667, 0.000000, 0.000000, -0.075000, -0.275000, -0.300000, -0.176471,
-0.464706, 0.000000, -0.073529, 0.000000, -0.073529, 0.000000, -0.073529, -0.421569, 0.000000,
-0.073529, 0.000000, -0.073529, -0.200000, -0.200000, -0.125000, -0.291667, -0.335294, -0.055556,
-0.208333, -0.194444, -0.194444, 0.062500, -0.176471, -0.375000, -0.166667, -0.245098, -0.197917,
-0.227941, -0.250000, -0.294118, -0.430556, -0.455882, -0.355392]

npt.assert_array_almost_equal(rc, ans)


def test_compute_ricci_flow():
G = nx.karate_club_graph()
orc = OllivierRicci(G, method="OTD", alpha=0.5)
Expand Down

0 comments on commit db226c4

Please sign in to comment.