Skip to content

Commit

Permalink
Change the way of setting backends (#87)
Browse files Browse the repository at this point in the history
  • Loading branch information
rogerwwww authored Nov 13, 2023
1 parent 546bfd7 commit 0307736
Show file tree
Hide file tree
Showing 35 changed files with 191 additions and 118 deletions.
2 changes: 1 addition & 1 deletion docs/guide/get_started.rst
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ Step 0: Import packages and set backend

>>> import numpy as np
>>> import pygmtools as pygm
>>> pygm.BACKEND = 'numpy'
>>> pygm.set_backend('numpy')
>>> np.random.seed(1)

Step 1: Generate a batch of isomorphic graphs
Expand Down
22 changes: 11 additions & 11 deletions docs/guide/numerical_backends.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ Once the backend is ready, you may switch to the backend globally by the followi
::

>>> import pygmtools as pygm
>>> pygm.BACKEND = 'pytorch' # replace 'pytorch' by other backend names
>>> pygm.set_backend('pytorch') # replace 'pytorch' by other backend names

PyTorch Backend
------------------------
Expand All @@ -41,7 +41,7 @@ How to enable PyTorch backend:

>>> import pygmtools as pygm
>>> import torch
>>> pygm.BACKEND = 'pytorch'
>>> pygm.set_backend('pytorch')

Example: Matching Isomorphic Graphs with ``pytorch`` backend
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Expand All @@ -55,7 +55,7 @@ Step 0: Import packages and set backend

>>> import torch # pytorch backend
>>> import pygmtools as pygm
>>> pygm.BACKEND = 'pytorch'
>>> pygm.set_backend('pytorch')
>>> torch.manual_seed(1) # fix random seed

Step 1: Generate two isomorphic graphs
Expand Down Expand Up @@ -130,7 +130,7 @@ How to enable Jittor backend:

>>> import pygmtools as pygm
>>> import jittor
>>> pygm.BACKEND = 'jittor'
>>> pygm.set_backend('jittor')


Example: Matching Isomorphic Graphs with ``jittor`` backend
Expand All @@ -145,7 +145,7 @@ Step 0: Import packages and set backend

>>> import jittor as jt # jittor backend
>>> import pygmtools as pygm
>>> pygm.BACKEND = 'jittor'
>>> pygm.set_backend('jittor')
>>> jt.set_seed(1) # fix random seed
>>> jt.flags.use_cuda = jt.has_cuda # detect cuda

Expand Down Expand Up @@ -221,7 +221,7 @@ How to enable Paddle backend:

>>> import pygmtools as pygm
>>> import paddle
>>> pygm.BACKEND = 'paddle'
>>> pygm.set_backend('paddle')

Example: Matching Isomorphic Graphs with ``paddle`` backend
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Expand All @@ -235,7 +235,7 @@ Step 0: Import packages and set backend

>>> import paddle # paddle backend
>>> import pygmtools as pygm
>>> pygm.BACKEND = 'paddle'
>>> pygm.set_backend('paddle')
>>> paddle.seed(1) # fix random seed
>>> paddle.device.set_device('cpu') # set cpu

Expand Down Expand Up @@ -312,7 +312,7 @@ How to enable Tensorflow backend:

>>> import pygmtools as pygm
>>> import tensorflow
>>> pygm.BACKEND = 'tensorflow'
>>> pygm.set_backend('tensorflow')

Example: Matching Isomorphic Graphs with ``tensorflow`` backend
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Expand All @@ -326,7 +326,7 @@ Step 0: Import packages and set backend

>>> import tensorflow as tf # tensorflow backend
>>> import pygmtools as pygm
>>> pygm.BACKEND = 'tensorflow'
>>> pygm.set_backend('tensorflow')
>>> _ = tf.random.set_seed(1) # fix random seed

Step 1: Generate two isomorphic graphs
Expand Down Expand Up @@ -401,7 +401,7 @@ How to enable Mindspore backend:

>>> import pygmtools as pygm
>>> import mindspore
>>> pygm.BACKEND = 'mindspore'
>>> pygm.set_backend('mindspore')

Example: Matching Isomorphic Graphs with ``mindspore`` backend
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Expand All @@ -415,7 +415,7 @@ Step 0: Import packages and set backend

>>> import mindspore as ms # mindspore backend
>>> import pygmtools as pygm
>>> pygm.BACKEND = 'mindspore'
>>> pygm.set_backend('mindspore')
>>> _ = ms.set_seed(1) # fix random seed

Step 1: Generate two isomorphic graphs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
import matplotlib.pyplot as plt # for plotting
from matplotlib.patches import ConnectionPatch # for plotting matching result
import networkx as nx # for plotting graphs
pygm.BACKEND = 'jittor' # set default backend for pygmtools
pygm.set_backend('jittor') # set default backend for pygmtools
_ = jt.set_seed(1) # fix random seed

jt.flags.use_cuda = jt.has_cuda
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
import matplotlib.pyplot as plt # for plotting
from matplotlib.patches import ConnectionPatch # for plotting matching result
import networkx as nx # for plotting graphs
pygm.BACKEND = 'numpy' # set default backend for pygmtools
pygm.set_backend('numpy') # set default backend for pygmtools
np.random.seed(1) # fix random seed

##############################################################################
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
import networkx as nx # for plotting graphs
import warnings
warnings.filterwarnings("ignore")
pygm.BACKEND = 'paddle' # set default backend for pygmtools
pygm.set_backend('paddle') # set default backend for pygmtools

paddle.device.set_device('cpu')
_ = paddle.seed(1) # fix random seed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
import matplotlib.pyplot as plt # for plotting
from matplotlib.patches import ConnectionPatch # for plotting matching result
import networkx as nx # for plotting graphs
pygm.BACKEND = 'pytorch' # set default backend for pygmtools
pygm.set_backend('pytorch') # set default backend for pygmtools
_ = torch.manual_seed(1) # fix random seed

##############################################################################
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
import matplotlib.pyplot as plt # for plotting
from matplotlib.patches import ConnectionPatch # for plotting matching result
import networkx as nx # for plotting graphs
pygm.BACKEND = 'jittor' # set default backend for pygmtools
pygm.set_backend('jittor') # set default backend for pygmtools
_ = jt.set_seed(1) # fix random seed

jt.flags.use_cuda = jt.has_cuda
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
import matplotlib.pyplot as plt # for plotting
from matplotlib.patches import ConnectionPatch # for plotting matching result
import networkx as nx # for plotting graphs
pygm.BACKEND = 'numpy' # set default backend for pygmtools
pygm.set_backend('numpy') # set default backend for pygmtools
np.random.seed(1) # fix random seed

##############################################################################
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
import warnings
warnings.filterwarnings("ignore")

pygm.BACKEND = 'paddle' # set default backend for pygmtools
pygm.set_backend('paddle') # set default backend for pygmtools

paddle.device.set_device('cpu')
_ = paddle.seed(1) # fix random seed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
import matplotlib.pyplot as plt # for plotting
from matplotlib.patches import ConnectionPatch # for plotting matching result
import networkx as nx # for plotting graphs
pygm.BACKEND = 'pytorch' # set default backend for pygmtools
pygm.set_backend('pytorch') # set default backend for pygmtools
_ = torch.manual_seed(1) # fix random seed

##############################################################################
Expand Down
2 changes: 1 addition & 1 deletion examples/3.discovering_subgraphs/plot_subgraphs_jittor.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
import matplotlib.pyplot as plt # for plotting
from matplotlib.patches import ConnectionPatch # for plotting matching result
import networkx as nx # for plotting graphs
pygm.BACKEND = 'jittor' # set default backend for pygmtools
pygm.set_backend('jittor') # set default backend for pygmtools
_ = jt.set_seed(1) # fix random seed

jt.flags.use_cuda = jt.has_cuda
Expand Down
2 changes: 1 addition & 1 deletion examples/3.discovering_subgraphs/plot_subgraphs_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
import matplotlib.pyplot as plt # for plotting
from matplotlib.patches import ConnectionPatch # for plotting matching result
import networkx as nx # for plotting graphs
pygm.BACKEND = 'numpy' # set default backend for pygmtools
pygm.set_backend('numpy') # set default backend for pygmtools
np.random.seed(1) # fix random seed

##############################################################################
Expand Down
2 changes: 1 addition & 1 deletion examples/3.discovering_subgraphs/plot_subgraphs_paddle.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
import warnings
warnings.filterwarnings("ignore")

pygm.BACKEND = 'paddle' # set default backend for pygmtools
pygm.set_backend('paddle') # set default backend for pygmtools

paddle.device.set_device('cpu')
_ = paddle.seed(1) # fix random seed
Expand Down
2 changes: 1 addition & 1 deletion examples/3.discovering_subgraphs/plot_subgraphs_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import matplotlib.pyplot as plt # for plotting
from matplotlib.patches import ConnectionPatch # for plotting matching result
import networkx as nx # for plotting graphs
pygm.BACKEND = 'pytorch' # set default backend for pygmtools
pygm.set_backend('pytorch') # set default backend for pygmtools
_ = torch.manual_seed(1) # fix random seed

##############################################################################
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from PIL import Image
from matplotlib.patches import ConnectionPatch # for plotting matching result

pygm.BACKEND = 'jittor' # set default backend for pygmtools
pygm.set_backend('jittor') # set default backend for pygmtools
jt.flags.use_cuda = True
##############################################################################
# Load the images
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from PIL import Image
from matplotlib.patches import ConnectionPatch # for plotting matching result

pygm.BACKEND = 'numpy' # set default backend for pygmtools
pygm.set_backend('numpy') # set default backend for pygmtools


##############################################################################
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from PIL import Image
from matplotlib.patches import ConnectionPatch # for plotting matching result

pygm.BACKEND = 'paddle' # set default backend for pygmtools
pygm.set_backend('paddle') # set default backend for pygmtools

##############################################################################
# Load the images
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from PIL import Image
from matplotlib.patches import ConnectionPatch # for plotting matching result

pygm.BACKEND = 'pytorch' # set default backend for pygmtools
pygm.set_backend('pytorch') # set default backend for pygmtools


##############################################################################
Expand Down
2 changes: 1 addition & 1 deletion examples/5.model_fusion/plot_model_fusion_jittor.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
import matplotlib.pyplot as plt
import pygmtools as pygm

pygm.BACKEND = 'jittor'
pygm.set_backend('jittor')
jt.flags.use_cuda = jt.has_cuda

##############################################################################
Expand Down
2 changes: 1 addition & 1 deletion examples/5.model_fusion/plot_model_fusion_paddle.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
import warnings
warnings.filterwarnings("ignore")

pygm.BACKEND = 'paddle'
pygm.set_backend('paddle')
device = paddle.device.get_device()
paddle.device.set_device(device)

Expand Down
2 changes: 1 addition & 1 deletion examples/5.model_fusion/plot_model_fusion_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
import matplotlib.pyplot as plt
import pygmtools as pygm

pygm.BACKEND = 'pytorch'
pygm.set_backend('pytorch')
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
import itertools
import numpy as np
from PIL import Image
pygm.BACKEND = 'jittor' # set default backend for pygmtools
pygm.set_backend('jittor') # set default backend for pygmtools

jt.flags.use_cuda = jt.has_cuda

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from sklearn.decomposition import PCA as PCAdimReduc
import itertools
from PIL import Image
pygm.BACKEND = 'numpy' # set numpy as backend for pygmtools
pygm.set_backend('numpy') # set numpy as backend for pygmtools

##############################################################################
# Load the images
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from PIL import Image
import warnings
warnings.filterwarnings("ignore")
pygm.BACKEND = 'paddle' # set default backend for pygmtools
pygm.set_backend('paddle') # set default backend for pygmtools

##############################################################################
# Load the images
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
import itertools
import numpy as np
from PIL import Image
pygm.BACKEND = 'pytorch' # set default backend for pygmtools
pygm.set_backend('pytorch') # set default backend for pygmtools

##############################################################################
# Load the images
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
import itertools
import numpy as np
from PIL import Image
pygm.BACKEND = 'jittor' # set default backend for pygmtools
pygm.set_backend('jittor') # set default backend for pygmtools
jt.flags.use_cuda = jt.has_cuda

##############################################################################
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from PIL import Image
import warnings
warnings.filterwarnings("ignore")
pygm.BACKEND = 'paddle' # set default backend for pygmtools
pygm.set_backend('paddle') # set default backend for pygmtools

device = paddle.device.get_device()
paddle.device.set_device(device) # paddle sets device globally
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
import itertools
import numpy as np
from PIL import Image
pygm.BACKEND = 'pytorch' # set default backend for pygmtools
pygm.set_backend('pytorch') # set default backend for pygmtools

##############################################################################
# Predicting Matching by Graph Matching Neural Networks
Expand Down
11 changes: 11 additions & 0 deletions pygmtools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,21 @@
from .neural_solvers import pca_gm, ipca_gm, cie, ngm, genn_astar
import pygmtools.utils as utils
import importlib.util
set_backend = utils.set_backend

BACKEND = 'numpy'
__version__ = '0.4.2a2'
__author__ = 'ThinkLab at SJTU'

SUPPORTED_BACKENDS = [
'numpy',
'pytorch',
'jittor',
'paddle',
'mindspore',
'tensorflow'
]


def env_report():
"""
Expand Down
Loading

0 comments on commit 0307736

Please sign in to comment.