diff --git a/pygmtools/utils.py b/pygmtools/utils.py index 5ee0c44..01515aa 100644 --- a/pygmtools/utils.py +++ b/pygmtools/utils.py @@ -69,6 +69,10 @@ def set_backend(new_backend: str): >>> import pygmtools as pygm # numpy is the default backend >>> pygm.set_backend('pytorch') # set backend to pytorch, it will through an error if torch is not installed + >>> pygm.set_backend('tf') # throw an error, provide potential matches + ValueError: Unknown backend tf. Did you mean tensorflow? Supported backends: ['numpy', 'pytorch', 'jittor', 'paddle', 'mindspore', 'tensorflow'] + >>> pygm.set_backend('tensorflow') # this is the correct key and will work + """ new_backend = new_backend.lower() if new_backend not in pygmtools.SUPPORTED_BACKENDS: