Skip to content

Commit

Permalink
Add torch compile as export method.
Browse files Browse the repository at this point in the history
  • Loading branch information
Tim-Salzmann committed Jul 30, 2024
1 parent 46f3578 commit 901019e
Show file tree
Hide file tree
Showing 15 changed files with 280 additions and 61 deletions.
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ sure that the PyTorch model handles a **two-dimensional** input matrix! Accordin
many optimization problems. However, you can explicitly request the generation of the Hessian by passing
`generate_jac_jac=True`.

L4CasADi v2 can use the new **torch compile** functionality starting from PyTorch 2.4. By passing `scripting=False`. This
will lead to a longer compile time on first L4CasADi function call but will lead to a overall faster
execution. However, currently this functionality is experimental and not fully stable across all models. In the long
term there is a good chance this will become the default over scripting once the functionality is stabilized by the
Torch developers.

## Table of Content
- [Projects using L4CasADi](#projects-using-l4casadi)
Expand Down
2 changes: 1 addition & 1 deletion examples/acados.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def ocp(self):
ocp.cost.W = np.array([[1.]])

# Trivial PyTorch index 0
l4c_y_expr = l4c.L4CasADi(lambda x: x[0], name='y_expr', model_expects_batch_dim=False)
l4c_y_expr = l4c.L4CasADi(lambda x: x[0], name='y_expr')

ocp.model.cost_y_expr = l4c_y_expr(x)
ocp.model.cost_y_expr_e = x[0]
Expand Down
2 changes: 1 addition & 1 deletion examples/cpp_usage/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def forward(self, x):


def generate():
l4casadi_model = l4c.L4CasADi(TorchModel(), model_expects_batch_dim=False, name='sin_l4c')
l4casadi_model = l4c.L4CasADi(TorchModel(), name='sin_l4c')

sym_in = cs.MX.sym('x', 1, 1)

Expand Down
2 changes: 1 addition & 1 deletion examples/fish_turbulent_flow/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def import_l4casadi_model(device):
x = cs.MX.sym("x", 3)
xn = (x - meanX) / stdX

y = l4c.L4CasADi(model, name="turbulent_model", model_expects_batch_dim=True)(xn)
y = l4c.L4CasADi(model, name="turbulent_model", generate_adj1=False, generate_jac_jac=True)(xn.T).T
y = y * stdY + meanY
fU = cs.Function("fU", [x], [y[0]])
fV = cs.Function("fV", [x], [y[1]])
Expand Down
2 changes: 1 addition & 1 deletion examples/matlab/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def forward(self, x):


def generate():
l4casadi_model = l4c.L4CasADi(TorchModel(), model_expects_batch_dim=False, name='sin_l4c')
l4casadi_model = l4c.L4CasADi(TorchModel(), name='sin_l4c')
sym_in = cs.MX.sym('x', 1, 1)
l4casadi_model.build(sym_in)
return
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

CASE = 1

os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'

def polynomial(n, n_eval):
"""Generates a symbolic function for a polynomial of degree n-1"""
Expand Down Expand Up @@ -175,7 +176,7 @@ def main():
strict=False,
)
# -------------------------- Create L4CasADi Module -------------------------- #
l4c_nerf = l4c.L4CasADi(model)
l4c_nerf = l4c.L4CasADi(model, scripting=False)

# ---------------------------------------------------------------------------- #
# NLP warmup #
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import l4casadi as l4c
from density_nerf import DensityNeRF

import os

CASE = 1

Expand Down
4 changes: 2 additions & 2 deletions examples/readme.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ def forward(self, x):


pyTorch_model = MultiLayerPerceptron()
l4c_model = l4c.L4CasADi(pyTorch_model, model_expects_batch_dim=True, device='cpu') # device='cuda' for GPU
l4c_model = l4c.L4CasADi(pyTorch_model, device='cpu') # device='cuda' for GPU

x_sym = cs.MX.sym('x', 2, 1)
x_sym = cs.MX.sym('x', 1, 2)
y_sym = l4c_model(x_sym)
f = cs.Function('y', [x_sym], [y_sym])
df = cs.Function('dy', [x_sym], [cs.jacobian(y_sym, x_sym)])
Expand Down
4 changes: 2 additions & 2 deletions examples/simple_nlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def forward(self, input):


f = PyTorchObjectiveModel() # objective
f = l4c.L4CasADi(f, name='f', model_expects_batch_dim=False)(x)
f = l4c.L4CasADi(f, name='f')(x)


class PyTorchConstraintModel(torch.nn.Module):
Expand All @@ -23,7 +23,7 @@ def forward(self, input):


g = PyTorchConstraintModel() # constraint
g = l4c.L4CasADi(g, name='g', model_expects_batch_dim=False)(x)
g = l4c.L4CasADi(g, name='g')(x)

nlp = {'x': x, 'f': f, 'g': g}

Expand Down
46 changes: 39 additions & 7 deletions l4casadi/l4casadi.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def __init__(self,
generate_adj1: bool = True,
generate_jac_adj1: bool = True,
generate_jac_jac: bool = False,
scripting: bool = True,
mutable: bool = False):
"""
:param model: PyTorch model.
Expand All @@ -65,11 +66,21 @@ def __init__(self,
:param generate_adj1: If True, the Adjoint of the model is tried to be generated.
:param generate_jac_adj1: If True, the Jacobain of the Adjoint of the model is tried to be generated.
:param generate_jac_jac: If True, the Hessian of the model is tried to be generated.
:param scripting: If True, the model is traced using TorchScript. If False, the model is compiled.
:param mutable: If True, enables updating the model online via the update method.
"""
if platform.system() == "Windows":
warnings.warn("L4CasADi is currently not supported for Windows.")

if not scripting:
warnings.warn("L4CasADi with Torch AOT compilation is experimental at this point and might not work as "
"expected.")
if torch.__version__ < torch.torch_version.TorchVersion('2.4.0'):
raise RuntimeError("For PyTorch versions < 2.4.0 L4CasADi only supports jit scripting. Please pass "
"scripting=True.")
import torch._inductor.config as config
config.freezing = True

self.model = model
self.naive = False
if isinstance(self.model, NaiveL4CasADiModule):
Expand All @@ -94,6 +105,8 @@ def __init__(self,
self._generate_jac_adj1 = generate_jac_adj1
self._generate_jac_jac = generate_jac_jac

self._scripting = scripting

self._mutable = mutable

self._input_shape: Tuple[int, int] = (-1, -1)
Expand Down Expand Up @@ -284,6 +297,7 @@ def _generate_cpp_function_template(self, has_jac: bool, has_adj1: bool, has_jac
'has_adj1': 'true' if has_adj1 else 'false',
'has_jac_adj1': 'true' if has_jac_adj1 else 'false',
'has_jac_jac': 'true' if has_jac_jac else 'false',
'scripting': 'true' if self._scripting else 'false',
'model_is_mutable': 'true' if self._mutable else 'false',
'batched': 'true' if self.batched else 'false',
'jac_ccs_len': len(jac_ccs) if self.batched else 0,
Expand Down Expand Up @@ -372,15 +386,15 @@ def export_torch_traces(self) -> Tuple[bool, bool, bool, bool]:

out_folder = self.build_dir

self._jit_compile_and_save(make_fx(functionalize(self.model, remove='mutations_and_views'))(d_inp),
self.model_compile( make_fx(functionalize(self.model, remove='mutations_and_views'))(d_inp),
(out_folder / f'{self.name}.pt').as_posix(),
(d_inp,))

exported_jac = False
if self._generate_jac:
jac_model = self._trace_jac_model(d_inp)

exported_jac = self._jit_compile_and_save(
exported_jac = self.model_compile(
jac_model,
(out_folder / f'jac_{self.name}.pt').as_posix(),
(d_inp,)
Expand All @@ -389,7 +403,7 @@ def export_torch_traces(self) -> Tuple[bool, bool, bool, bool]:
exported_adj1 = False
if self._generate_adj1:
adj1_model = self._trace_adj1_model()
exported_adj1 = self._jit_compile_and_save(
exported_adj1 = self.model_compile(
adj1_model,
(out_folder / f'adj1_{self.name}.pt').as_posix(),
(d_inp, d_out)
Expand All @@ -398,7 +412,7 @@ def export_torch_traces(self) -> Tuple[bool, bool, bool, bool]:
exported_jac_adj1 = False
if self._generate_jac_adj1:
jac_adj1_model = self._trace_jac_adj1_model()
exported_jac_adj1 = self._jit_compile_and_save(
exported_jac_adj1 = self.model_compile(
jac_adj1_model,
(out_folder / f'jac_adj1_{self.name}.pt').as_posix(),
(d_inp, d_out)
Expand All @@ -413,17 +427,35 @@ def export_torch_traces(self) -> Tuple[bool, bool, bool, bool]:
pass

if hess_model is not None:
exported_hess = self._jit_compile_and_save(
exported_hess = self.model_compile(
hess_model,
(out_folder / f'jac_jac_{self.name}.pt').as_posix(),
(d_inp,)
)

return exported_jac, exported_adj1, exported_jac_adj1, exported_hess

def model_compile(self, model, file_path: str, dummy_inp: Tuple[torch.Tensor, ...]):
if self._scripting:
return self._jit_compile_and_save(model, file_path, dummy_inp)
else:
return self._aot_compile_and_save(model, file_path, dummy_inp)

@staticmethod
def _aot_compile_and_save(model, file_path: str, dummy_inp: Tuple[torch.Tensor, ...]):
try:
with torch.no_grad():
torch._export.aot_compile(
model,
dummy_inp,
options={"aot_inductor.output_path": file_path[:-2] + 'so'},
)
return True
except: # noqa
return False

@staticmethod
def _jit_compile_and_save(model, file_path: str, dummy_inp: torch.Tensor):
# TODO: Could switch to torch export https://pytorch.org/docs/stable/export.html
def _jit_compile_and_save(model, file_path: str, dummy_inp: Tuple[torch.Tensor, ...]):
try:
# Try scripting
ts_compile(model).save(file_path)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#include <l4casadi.hpp>

L4CasADi l4casadi("{{ model_path }}", "{{ name }}", {{ rows_in }}, {{ cols_in }}, {{ rows_out }}, {{ cols_out }}, "{{ device }}", {{ has_jac }}, {{ has_adj1 }}, {{ has_jac_adj1 }}, {{ has_jac_jac }}, {{ model_is_mutable }});
L4CasADi l4casadi("{{ model_path }}", "{{ name }}", {{ rows_in }}, {{ cols_in }}, {{ rows_out }}, {{ cols_out }}, "{{ device }}", {{ has_jac }}, {{ has_adj1 }}, {{ has_jac_adj1 }}, {{ has_jac_jac }}, {{ scripting }}, {{ model_is_mutable }});

#ifdef __cplusplus
extern "C" {
Expand Down
24 changes: 17 additions & 7 deletions libl4casadi/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,13 +1,7 @@
cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
project(L4CasADi)

# Load CUDA if it is installed
find_package(CUDAToolkit)
find_package(CUDA)

if (USE_CUDA)
add_definitions(-DUSE_CUDA)
endif ()
set(CMAKE_COMPILE_WARNING_AS_ERROR ON)

if (WIN32)
set (CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS TRUE)
Expand All @@ -18,6 +12,22 @@ endif ()
set(CMAKE_PREFIX_PATH ${CMAKE_TORCH_PATH})

find_package(Torch REQUIRED)

# Load CUDA if it is installed
find_package(CUDAToolkit)
find_package(CUDA)

add_definitions(-DTORCH_VERSION_MAJOR=${Torch_VERSION_MAJOR})
add_definitions(-DTORCH_VERSION_MINOR=${Torch_VERSION_MINOR})
add_definitions(-DTORCH_VERSION_PATCH=${Torch_VERSION_PATCH})

if (Torch_VERSION_MAJOR GREATER_EQUAL 1 AND Torch_VERSION_MINOR GREATER_EQUAL 4)
#add_definitions(-DENABLE_TORCH_COMPILE)
endif ()
if (USE_CUDA)
add_definitions(-DUSE_CUDA)
endif ()

set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")

add_library(l4casadi SHARED src/l4casadi.cpp include/l4casadi.hpp)
Expand Down
4 changes: 3 additions & 1 deletion libl4casadi/include/l4casadi.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class L4CasADi
int cols_out;
public:
L4CasADi(std::string, std::string, int, int, int, int, std::string = "cpu", bool = false, bool = false, bool = false, bool = false,
bool = false);
bool = false,bool = false);
~L4CasADi();
void forward(const double*, double*);
void jac(const double*, double*);
Expand All @@ -26,6 +26,8 @@ class L4CasADi

// PImpl Idiom
class L4CasADiImpl;
class L4CasADiScriptedImpl;
class L4CasADiCompiledImpl;
std::unique_ptr<L4CasADiImpl> pImpl;

};
Expand Down
Loading

0 comments on commit 901019e

Please sign in to comment.