Skip to content

Commit

Permalink
Improve kernel caching (#3982)
Browse files Browse the repository at this point in the history
* Improve kernel caching

Each parloop invocation was doing more than necessary leading to quite
poor performance.
  • Loading branch information
connorjward authored Jan 22, 2025
1 parent b97d14a commit e21c4bb
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 68 deletions.
5 changes: 3 additions & 2 deletions firedrake/preconditioners/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from pyop2.codegen.builder import Pack, MatPack, DatPack
from pyop2.codegen.representation import Comparison, Literal
from pyop2.codegen.rep2loopy import register_petsc_function
from pyop2.global_kernel import compile_global_kernel

__all__ = ("PatchPC", "PlaneSmoother", "PatchSNES")

Expand Down Expand Up @@ -222,7 +223,7 @@ def matrix_funptr(form, state):

wrapper_knl_args = tuple(a.global_kernel_arg for a in args)
mod = op2.GlobalKernel(kinfo.kernel, wrapper_knl_args, subset=True)
kernels.append(CompiledKernel(mod.compile(iterset.comm), kinfo))
kernels.append(CompiledKernel(compile_global_kernel(mod, iterset.comm), kinfo))
return cell_kernels, int_facet_kernels


Expand Down Expand Up @@ -316,7 +317,7 @@ def residual_funptr(form, state):

wrapper_knl_args = tuple(a.global_kernel_arg for a in args)
mod = op2.GlobalKernel(kinfo.kernel, wrapper_knl_args, subset=True)
kernels.append(CompiledKernel(mod.compile(iterset.comm), kinfo))
kernels.append(CompiledKernel(compile_global_kernel(mod, iterset.comm), kinfo))
return cell_kernels, int_facet_kernels


Expand Down
7 changes: 2 additions & 5 deletions firedrake/scripts/firedrake_clean.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,7 @@
from firedrake.configuration import setup_cache_dirs
from pyop2.compilation import clear_compiler_disk_cache as pyop2_clear_cache
from firedrake.tsfc_interface import clear_cache as tsfc_clear_cache
try:
import platformdirs as appdirs
except ImportError:
import appdirs
import platformdirs


def main():
Expand All @@ -20,7 +17,7 @@ def main():
print(f"Removing cached PyOP2 code from {os.environ.get('PYOP2_CACHE_DIR', '???')}")
pyop2_clear_cache()

pytools_cache = appdirs.user_cache_dir("pytools", "pytools")
pytools_cache = platformdirs.user_cache_dir("pytools", "pytools")
print(f"Removing cached pytools files from {pytools_cache}")
if os.path.exists(pytools_cache):
shutil.rmtree(pytools_cache, ignore_errors=True)
Expand Down
3 changes: 2 additions & 1 deletion pyop2/caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,8 @@ def wrapper(*args, **kwargs):
value = local_cache.get(key, CACHE_MISS)

if value is CACHE_MISS:
value = func(*args, **kwargs)
with PETSc.Log.Event("pyop2: handle cache miss"):
value = func(*args, **kwargs)
return local_cache.setdefault(key, value)

return wrapper
Expand Down
111 changes: 67 additions & 44 deletions pyop2/global_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,15 @@
import loopy as lp
import numpy as np
import pytools
from loopy.codegen.result import process_preambles
from petsc4py import PETSc

from pyop2 import mpi
from pyop2.caching import parallel_cache, serial_cache
from pyop2.compilation import add_profiling_events, load
from pyop2.configuration import configuration
from pyop2.datatypes import IntType, as_ctypes
from pyop2.codegen.rep2loopy import generate
from pyop2.types import IterationRegion, Constant, READ
from pyop2.utils import cached_property, get_petsc_dir

Expand Down Expand Up @@ -326,8 +329,7 @@ def __call__(self, comm, *args):
:arg comm: Communicator the execution is collective over.
:*args: Arguments to pass to the compiled kernel.
"""
# It is unnecessary to cache this call as it is cached in pyop2/compilation.py
func = self.compile(comm)
func = compile_global_kernel(self, comm)
func(*args)

@property
Expand Down Expand Up @@ -364,48 +366,7 @@ def builder(self):
@cached_property
def code_to_compile(self):
"""Return the C/C++ source code as a string."""
from pyop2.codegen.rep2loopy import generate

with PETSc.Log.Event("GlobalKernel: generate loopy"):
wrapper = generate(self.builder)

with PETSc.Log.Event("GlobalKernel: generate device code"):
code = lp.generate_code_v2(wrapper)

if self.local_kernel.cpp:
from loopy.codegen.result import process_preambles
preamble = "".join(process_preambles(getattr(code, "device_preambles", [])))
device_code = "\n\n".join(str(dp.ast) for dp in code.device_programs)
return preamble + "\nextern \"C\" {\n" + device_code + "\n}\n"
return code.device_code()

@PETSc.Log.EventDecorator()
@mpi.collective
def compile(self, comm):
"""Compile the kernel.
:arg comm: The communicator the compilation is collective over.
:returns: A ctypes function pointer for the compiled function.
"""
extension = "cpp" if self.local_kernel.cpp else "c"
cppargs = (
tuple("-I%s/include" % d for d in get_petsc_dir())
+ tuple("-I%s" % d for d in self.local_kernel.include_dirs)
+ ("-I%s" % os.path.abspath(os.path.dirname(__file__)),)
)
ldargs = (
tuple("-L%s/lib" % d for d in get_petsc_dir())
+ tuple("-Wl,-rpath,%s/lib" % d for d in get_petsc_dir())
+ ("-lpetsc", "-lm")
+ tuple(self.local_kernel.ldargs)
)

dll = load(self.code_to_compile, extension, cppargs=cppargs, ldargs=ldargs, comm=comm)
add_profiling_events(dll, self.local_kernel.events)
fn = getattr(dll, self.name)
fn.argtypes = self.argtypes
fn.restype = ctypes.c_int
return fn
return _generate_code_from_global_kernel(self)

@cached_property
def argtypes(self):
Expand All @@ -427,3 +388,65 @@ def num_flops(self, iterset):
elif region not in {IterationRegion.TOP, IterationRegion.BOTTOM}:
size = layers - 1
return size * self.local_kernel.num_flops

@cached_property
def _cppargs(self):
cppargs = [f"-I{d}/include" for d in get_petsc_dir()]
cppargs.extend(f"-I{d}" for d in self.local_kernel.include_dirs)
cppargs.append(f"-I{os.path.abspath(os.path.dirname(__file__))}")
return tuple(cppargs)

@cached_property
def _ldargs(self):
ldargs = [f"-L{d}/lib" for d in get_petsc_dir()]
ldargs.extend(f"-Wl,-rpath,{d}/lib" for d in get_petsc_dir())
ldargs.extend(["-lpetsc", "-lm"])
ldargs.extend(self.local_kernel.ldargs)
return tuple(ldargs)


@serial_cache(hashkey=lambda knl: knl.cache_key)
def _generate_code_from_global_kernel(kernel):
with PETSc.Log.Event("GlobalKernel: generate loopy"):
wrapper = generate(kernel.builder)

with PETSc.Log.Event("GlobalKernel: generate device code"):
code = lp.generate_code_v2(wrapper)

if kernel.local_kernel.cpp:
preamble = "".join(process_preambles(getattr(code, "device_preambles", [])))
device_code = "\n\n".join(str(dp.ast) for dp in code.device_programs)
return preamble + "\nextern \"C\" {\n" + device_code + "\n}\n"

return code.device_code()


@parallel_cache(hashkey=lambda knl, _: knl.cache_key)
@mpi.collective
def compile_global_kernel(kernel, comm):
"""Compile the kernel.
Parameters
----------
kernel :
The global kernel to generate code for.
comm :
The communicator the compilation is collective over.
Returns
-------
A ctypes function pointer for the compiled function.
"""
dll = load(
kernel.code_to_compile,
"cpp" if kernel.local_kernel.cpp else "c",
cppargs=kernel._cppargs,
ldargs=kernel._ldargs,
comm=comm,
)
add_profiling_events(dll, kernel.local_kernel.events)
fn = getattr(dll, kernel.name)
fn.argtypes = kernel.argtypes
fn.restype = ctypes.c_int
return fn
39 changes: 23 additions & 16 deletions tests/pyop2/test_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,13 @@ def cache(self):
int_comm.Set_attr(comm_cache_keyval, _cache_collection)
return _cache_collection[default_cache_name]

def code_cache_len_equals(self, expected):
# We need to do this check because different things also get
# put into self.cache
return sum(
1 for key in self.cache if key[1] == "compile_global_kernel"
) == expected

@pytest.fixture
def a(cls, diterset):
return op2.Dat(diterset, list(range(nelems)), numpy.uint32, "a")
Expand All @@ -328,14 +335,14 @@ def test_same_args(self, iterset, iter2ind1, x, a):
a(op2.WRITE),
x(op2.READ, iter2ind1))

assert len(self.cache) == 1
assert self.code_cache_len_equals(1)

op2.par_loop(op2.Kernel(kernel_cpy, "cpy"),
iterset,
a(op2.WRITE),
x(op2.READ, iter2ind1))

assert len(self.cache) == 1
assert self.code_cache_len_equals(1)

def test_diff_kernel(self, iterset, iter2ind1, x, a):
self.cache.clear()
Expand All @@ -348,7 +355,7 @@ def test_diff_kernel(self, iterset, iter2ind1, x, a):
a(op2.WRITE),
x(op2.READ, iter2ind1))

assert len(self.cache) == 1
assert self.code_cache_len_equals(1)

kernel_cpy = "static void cpy(unsigned int* DST, unsigned int* SRC) { *DST = *SRC; }"

Expand All @@ -357,7 +364,7 @@ def test_diff_kernel(self, iterset, iter2ind1, x, a):
a(op2.WRITE),
x(op2.READ, iter2ind1))

assert len(self.cache) == 2
assert self.code_cache_len_equals(2)

def test_invert_arg_similar_shape(self, iterset, iter2ind1, x, y):
self.cache.clear()
Expand All @@ -377,14 +384,14 @@ def test_invert_arg_similar_shape(self, iterset, iter2ind1, x, y):
x(op2.RW, iter2ind1),
y(op2.RW, iter2ind1))

assert len(self.cache) == 1
assert self.code_cache_len_equals(1)

op2.par_loop(op2.Kernel(kernel_swap, "swap"),
iterset,
y(op2.RW, iter2ind1),
x(op2.RW, iter2ind1))

assert len(self.cache) == 1
assert self.code_cache_len_equals(1)

def test_dloop_ignore_scalar(self, iterset, a, b):
self.cache.clear()
Expand All @@ -404,14 +411,14 @@ def test_dloop_ignore_scalar(self, iterset, a, b):
a(op2.RW),
b(op2.RW))

assert len(self.cache) == 1
assert self.code_cache_len_equals(1)

op2.par_loop(op2.Kernel(kernel_swap, "swap"),
iterset,
b(op2.RW),
a(op2.RW))

assert len(self.cache) == 1
assert self.code_cache_len_equals(1)

def test_vector_map(self, iterset, x2, iter2ind2):
self.cache.clear()
Expand All @@ -431,13 +438,13 @@ def test_vector_map(self, iterset, x2, iter2ind2):
iterset,
x2(op2.RW, iter2ind2))

assert len(self.cache) == 1
assert self.code_cache_len_equals(1)

op2.par_loop(op2.Kernel(kernel_swap, "swap"),
iterset,
x2(op2.RW, iter2ind2))

assert len(self.cache) == 1
assert self.code_cache_len_equals(1)

def test_same_iteration_space_works(self, iterset, x2, iter2ind2):
self.cache.clear()
Expand All @@ -447,12 +454,12 @@ def test_same_iteration_space_works(self, iterset, x2, iter2ind2):
op2.par_loop(k, iterset,
x2(op2.INC, iter2ind2))

assert len(self.cache) == 1
assert self.code_cache_len_equals(1)

op2.par_loop(k, iterset,
x2(op2.INC, iter2ind2))

assert len(self.cache) == 1
assert self.code_cache_len_equals(1)

def test_change_dat_dtype_matters(self, iterset, diterset):
d = op2.Dat(diterset, list(range(nelems)), numpy.uint32)
Expand All @@ -463,12 +470,12 @@ def test_change_dat_dtype_matters(self, iterset, diterset):

op2.par_loop(k, iterset, d(op2.WRITE))

assert len(self.cache) == 1
assert self.code_cache_len_equals(1)

d = op2.Dat(diterset, list(range(nelems)), numpy.int32)
op2.par_loop(k, iterset, d(op2.WRITE))

assert len(self.cache) == 2
assert self.code_cache_len_equals(2)

def test_change_global_dtype_matters(self, iterset, diterset):
g = op2.Global(1, 0, dtype=numpy.uint32, comm=COMM_WORLD)
Expand All @@ -479,12 +486,12 @@ def test_change_global_dtype_matters(self, iterset, diterset):

op2.par_loop(k, iterset, g(op2.INC))

assert len(self.cache) == 1
assert self.code_cache_len_equals(1)

g = op2.Global(1, 0, dtype=numpy.float64, comm=COMM_WORLD)
op2.par_loop(k, iterset, g(op2.INC))

assert len(self.cache) == 2
assert self.code_cache_len_equals(2)


class TestSparsityCache:
Expand Down

0 comments on commit e21c4bb

Please sign in to comment.