Skip to content

Commit

Permalink
[ML-6097] Update Horovod Runner to include return value and multi-gpu…
Browse files Browse the repository at this point in the history
… info. (#183)

* [ML-6012] add return value in oss, update document and add test (#138)

In this PR, we:

change the horovodrunner.run api to return values in horovod/runner_base.py
update the documentation to reflect that change
add a test to runner_base_test.py to test the api change

* Update HorovodRunner docs to include multi-gpu support.

* Revert "Update HorovodRunner docs to include multi-gpu support."

This reverts commit 0172bbe.

* update np doc (#136)
  • Loading branch information
MrBago authored and mengxr committed Jan 25, 2019
1 parent 7aff76e commit dfa3c07
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 2 deletions.
8 changes: 6 additions & 2 deletions python/sparkdl/horovod/runner_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ def __init__(self, np):
:param np: number of parallel processes to use for the Horovod job.
This argument only takes effect on Databricks Runtime 5.0 ML and above.
It is ignored in the open-source version.
On Databricks, each process will take an available task slot,
which maps to a GPU on a GPU cluster or a CPU core on a CPU cluster.
Accepted values are:
- If -1, this will spawn a subprocess on the driver node to run the Horovod job locally.
Expand Down Expand Up @@ -79,12 +81,14 @@ def run(self, main, **kwargs):
Avoid referencing large objects in the function, which might result large pickled data,
making the job slow to start.
:param kwargs: keyword arguments passed to the main function at invocation time.
:return: None
:return: return value of the main function.
With `np>=0`, this returns the value from the rank 0 process. Note that the returned
value should be serializable using cloudpickle.
"""
logger = logging.getLogger("HorovodRunner")
logger.warning(
"You are running the open-source version of HorovodRunner. "
"It only does basic checks and invokes the main function, "
"which is for local development only. "
"Please use Databricks Runtime ML 5.0+ to distribute the job.")
main(**kwargs)
return main(**kwargs)
7 changes: 7 additions & 0 deletions python/tests/horovod/runner_base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,10 @@ def append(value):

hr.run(append, value=1)
self.assertEquals(data[0], 1)

def test_return_value(self):
"""Test that the return value is returned to the user."""
hr = HorovodRunner(np=-1)
return_value = hr.run(lambda: 42)
self.assertEquals(return_value, 42)

0 comments on commit dfa3c07

Please sign in to comment.