diff --git a/runners/mlcube_singularity/mlcube_singularity/singularity_client.py b/runners/mlcube_singularity/mlcube_singularity/singularity_client.py index 4f141d4..ea09395 100644 --- a/runners/mlcube_singularity/mlcube_singularity/singularity_client.py +++ b/runners/mlcube_singularity/mlcube_singularity/singularity_client.py @@ -104,16 +104,18 @@ def from_env(cls) -> "Client": def supports_fakeroot(self) -> bool: singularity_35 = ( self.version.runtime == Runtime.SINGULARITY - and self.version >= semver.VersionInfo(major=3, minor=5) + and self.version.version >= semver.VersionInfo(major=3, minor=5) ) apptainer = self.version.runtime == Runtime.APPTAINER return singularity_35 or apptainer - def __init__(self, singularity: t.Union[str, t.List]) -> None: + def __init__( + self, singularity: t.Union[str, t.List], version: t.Optional[Version] = None + ) -> None: if isinstance(singularity, str): singularity = singularity.split(" ") self.singularity: t.List[str] = [c.strip() for c in singularity if c.strip()] - self.version: t.Optional[Version] = None + self.version: t.Optional[Version] = version self.init() logger.debug( "Client.__init__ executable=%s, version=%s", self.singularity, self.version diff --git a/runners/mlcube_singularity/mlcube_singularity/tests/test_singularity_client.py b/runners/mlcube_singularity/mlcube_singularity/tests/test_singularity_client.py new file mode 100644 index 0000000..d7c076b --- /dev/null +++ b/runners/mlcube_singularity/mlcube_singularity/tests/test_singularity_client.py @@ -0,0 +1,40 @@ +from unittest import TestCase + +import semver +from mlcube_singularity.singularity_client import Client, Runtime, Version + + +class TestSingularityRunner(TestCase): + def test___init__(self) -> None: + client = Client( + "sudo singularity", Version(Runtime.APPTAINER, semver.VersionInfo(3, 7, 5)) + ) + self.assertListEqual(["sudo", "singularity"], client.singularity) + self.assertEqual(Runtime.APPTAINER, client.version.runtime) + self.assertEqual(3, client.version.version.major) + self.assertEqual(7, client.version.version.minor) + self.assertEqual(5, client.version.version.patch) + + def test_supports_fakeroot(self) -> None: + client = Client( + "sudo singularity", Version(Runtime.APPTAINER, semver.VersionInfo(3, 7, 5)) + ) + self.assertTrue(client.supports_fakeroot()) + + client = Client( + "sudo singularity", + Version(Runtime.SINGULARITY, semver.VersionInfo(3, 7, 5)), + ) + self.assertTrue(client.supports_fakeroot()) + + client = Client( + "sudo singularity", + Version(Runtime.SINGULARITY, semver.VersionInfo(3, 5, 0)), + ) + self.assertTrue(client.supports_fakeroot()) + + client = Client( + "sudo singularity", + Version(Runtime.SINGULARITY, semver.VersionInfo(3, 4, 9)), + ) + self.assertFalse(client.supports_fakeroot())