diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 75de4b2a..f2bd6a44 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -33,13 +33,13 @@ jobs: pytorch-version: "1.11.*" # Latest supported versions - - name: Linux (CUDA 11.2, Python 3.10, PyTorch 1.12) + - name: Linux (CUDA 11.8, Python 3.10, PyTorch 2.0) os: ubuntu-22.04 - cuda-version: "11.2.2" + cuda-version: "11.8.0" gcc-version: "10.3.*" - nvcc-version: "11.2" + nvcc-version: "11.8" python-version: "3.10" - pytorch-version: "1.12.*" + pytorch-version: "2.0.*" - name: MacOS (Python 3.9, PyTorch 1.9) os: macos-11 diff --git a/python/openmmtorch.i b/python/openmmtorch.i index 94880522..05988e3d 100644 --- a/python/openmmtorch.i +++ b/python/openmmtorch.i @@ -14,6 +14,7 @@ #include "openmm/RPMDIntegrator.h" #include "openmm/RPMDMonteCarloBarostat.h" #include +#include %} /* @@ -28,23 +29,27 @@ } } -%typemap(in) const torch::jit::Module&(torch::jit::Module module) { +%typemap(in) const torch::jit::Module&(torch::jit::Module mod) { py::object o = py::reinterpret_borrow($input); - module = torch::jit::as_module(o).value(); - $1 = &module; + py::object pybuffer = py::module::import("io").attr("BytesIO")(); + py::module::import("torch.jit").attr("save")(o, pybuffer); + std::string s = py::cast(pybuffer.attr("getvalue")()); + std::stringstream buffer(s); + mod = torch::jit::load(buffer); + $1 = &mod; } %typemap(out) const torch::jit::Module& { - auto fileName = std::tmpnam(nullptr); - $1->save(fileName); - $result = py::module::import("torch.jit").attr("load")(fileName).release().ptr(); - //This typemap assumes that torch does not require the file to exist after construction - std::remove(fileName); + std::stringstream buffer; + $1->save(buffer); + auto pybuffer = py::module::import("io").attr("BytesIO")(py::bytes(buffer.str())); + $result = py::module::import("torch.jit").attr("load")(pybuffer).release().ptr(); } %typecheck(SWIG_TYPECHECK_POINTER) const torch::jit::Module& { py::object o = py::reinterpret_borrow($input); - $1 = torch::jit::as_module(o).has_value() ? 1 : 0; + py::handle ScriptModule = py::module::import("torch.jit").attr("ScriptModule"); + $1 = py::isinstance(o, ScriptModule); } namespace std {