Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ahuber21 committed Nov 14, 2024
1 parent de1767e commit eeb779c
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 15 deletions.
4 changes: 3 additions & 1 deletion onedal/common/policy_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# limitations under the License.
# ==============================================================================

from operator import getitem

from onedal._device_offload import DummySyclQueue


Expand Down Expand Up @@ -42,7 +44,7 @@ def get_queue(*data):
if len(data) < 1:
return
if iface := getattr(data[0], "__sycl_usm_array_interface__", None):
queue = getattr(iface, "syclobj")
queue = iface.get("syclobj")
if not queue:
raise KeyError("No syclobj in provided data")
return queue
Expand Down
30 changes: 21 additions & 9 deletions onedal/common/tests/test_backend_manager.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,48 @@
import pytest

from onedal.common.backend_manager import BackendManager


# Define a simple backend module for testing
class TestBackend:
class DummyBackend:

class Module:
class Submodule:
def method(self, *args, **kwargs):
return "method_result"

def __init__(self):
self.submodule_instance = self.Submodule()

def method(self, *args, **kwargs):
return "method_result"

module = Module()
def __init__(self):
self.module_instance = self.Module()
self.is_dpc = False
self.is_spmd = False


from onedal.common.backend_manager import BackendManager
@property
def module(self):
return self.module_instance


@pytest.fixture
def backend_manager():
backend = TestBackend()
backend = DummyBackend()
return BackendManager(backend)


def test_get_backend_component_with_method(backend_manager):
result = backend_manager.get_backend_component("module", "Submodule", "method")
result = backend_manager.get_backend_component(
"module", "submodule_instance", "method"
)
assert result == "method_result"


def test_get_backend_component_with_method_and_args(backend_manager):
result = backend_manager.get_backend_component(
"module", "Submodule", "method", "arg1", kwarg1="kwarg1"
"module", "submodule_instance", "method", "arg1", kwarg1="kwarg1"
)
assert result == "method_result"

Expand All @@ -41,5 +53,5 @@ def test_get_backend_component_without_submodule(backend_manager):


def test_get_backend_component_without_method(backend_manager):
result = backend_manager.get_backend_component("module", "Submodule")
assert result == backend_manager.backend.module.Submodule
result = backend_manager.get_backend_component("module", "submodule_instance")
assert result == backend_manager.backend.module.submodule_instance
21 changes: 16 additions & 5 deletions onedal/common/tests/test_policy_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
import pytest

from onedal._device_offload import DummySyclQueue
from onedal.common.policy_manager import PolicyManager
from onedal.common.policy_manager import Policy, PolicyManager


# Define a simple backend module for testing
class DummyBackend:
def __init__(self, is_dpc):
self.is_dpc = is_dpc
self.is_spmd = False

def data_parallel_policy(self, obj):
return f"data_parallel_policy({obj})"
Expand Down Expand Up @@ -53,19 +54,25 @@ def test_get_queue_without_sycl_usm_array_interface():

def test_get_policy_with_provided_queue(policy_manager_dpc):
provided_queue = create_autospec(DummySyclQueue)
provided_queue.sycl_device = MagicMock()
provided_queue.sycl_device.get_filter_string.return_value = "filter_string"
policy = policy_manager_dpc.get_policy(provided_queue)
assert policy == "data_parallel_policy(filter_string)"
assert policy.policy == "data_parallel_policy(filter_string)"
assert policy.is_dpc is True
assert policy.is_spmd is False


def test_get_policy_with_data_queue(policy_manager_dpc):
data = [MagicMock()]
data[0].__sycl_usm_array_interface__ = {"syclobj": create_autospec(DummySyclQueue)}
data[0].__sycl_usm_array_interface__["syclobj"].sycl_device = MagicMock()
data[0].__sycl_usm_array_interface__[
"syclobj"
].sycl_device.get_filter_string.return_value = "filter_string"
policy = policy_manager_dpc.get_policy(None, *data)
assert policy == "data_parallel_policy(filter_string)"
assert policy.policy == "data_parallel_policy(filter_string)"
assert policy.is_dpc is True
assert policy.is_spmd is False


def test_get_policy_with_host_backend_and_queue(policy_manager_host):
Expand All @@ -78,9 +85,13 @@ def test_get_policy_with_host_backend_and_queue(policy_manager_host):

def test_get_policy_with_host_backend(policy_manager_host):
policy = policy_manager_host.get_policy(None)
assert policy == "host_policy"
assert policy.policy == "host_policy"
assert policy.is_dpc is False
assert policy.is_spmd is False


def test_get_policy_with_dpc_backend_no_queue(policy_manager_dpc):
policy = policy_manager_dpc.get_policy(None)
assert policy == "host_policy"
assert policy.policy == "host_policy"
assert policy.is_dpc is False
assert policy.is_spmd is False

0 comments on commit eeb779c

Please sign in to comment.