Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve kernel caching #3982

Merged
merged 3 commits into from
Jan 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions firedrake/preconditioners/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from pyop2.codegen.builder import Pack, MatPack, DatPack
from pyop2.codegen.representation import Comparison, Literal
from pyop2.codegen.rep2loopy import register_petsc_function
from pyop2.global_kernel import compile_global_kernel

__all__ = ("PatchPC", "PlaneSmoother", "PatchSNES")

Expand Down Expand Up @@ -222,7 +223,7 @@ def matrix_funptr(form, state):

wrapper_knl_args = tuple(a.global_kernel_arg for a in args)
mod = op2.GlobalKernel(kinfo.kernel, wrapper_knl_args, subset=True)
kernels.append(CompiledKernel(mod.compile(iterset.comm), kinfo))
kernels.append(CompiledKernel(compile_global_kernel(mod, iterset.comm), kinfo))
return cell_kernels, int_facet_kernels


Expand Down Expand Up @@ -316,7 +317,7 @@ def residual_funptr(form, state):

wrapper_knl_args = tuple(a.global_kernel_arg for a in args)
mod = op2.GlobalKernel(kinfo.kernel, wrapper_knl_args, subset=True)
kernels.append(CompiledKernel(mod.compile(iterset.comm), kinfo))
kernels.append(CompiledKernel(compile_global_kernel(mod, iterset.comm), kinfo))
return cell_kernels, int_facet_kernels


Expand Down
7 changes: 2 additions & 5 deletions firedrake/scripts/firedrake_clean.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,7 @@
from firedrake.configuration import setup_cache_dirs
from pyop2.compilation import clear_compiler_disk_cache as pyop2_clear_cache
from firedrake.tsfc_interface import clear_cache as tsfc_clear_cache
try:
import platformdirs as appdirs
except ImportError:
import appdirs
import platformdirs


def main():
Expand All @@ -20,7 +17,7 @@ def main():
print(f"Removing cached PyOP2 code from {os.environ.get('PYOP2_CACHE_DIR', '???')}")
pyop2_clear_cache()

pytools_cache = appdirs.user_cache_dir("pytools", "pytools")
pytools_cache = platformdirs.user_cache_dir("pytools", "pytools")
print(f"Removing cached pytools files from {pytools_cache}")
if os.path.exists(pytools_cache):
shutil.rmtree(pytools_cache, ignore_errors=True)
Expand Down
3 changes: 2 additions & 1 deletion pyop2/caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,8 @@ def wrapper(*args, **kwargs):
value = local_cache.get(key, CACHE_MISS)

if value is CACHE_MISS:
value = func(*args, **kwargs)
with PETSc.Log.Event("pyop2: handle cache miss"):
value = func(*args, **kwargs)
return local_cache.setdefault(key, value)

return wrapper
Expand Down
111 changes: 67 additions & 44 deletions pyop2/global_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,15 @@
import loopy as lp
import numpy as np
import pytools
from loopy.codegen.result import process_preambles
from petsc4py import PETSc

from pyop2 import mpi
from pyop2.caching import parallel_cache, serial_cache
from pyop2.compilation import add_profiling_events, load
from pyop2.configuration import configuration
from pyop2.datatypes import IntType, as_ctypes
from pyop2.codegen.rep2loopy import generate
from pyop2.types import IterationRegion, Constant, READ
from pyop2.utils import cached_property, get_petsc_dir

Expand Down Expand Up @@ -326,8 +329,7 @@ def __call__(self, comm, *args):
:arg comm: Communicator the execution is collective over.
:*args: Arguments to pass to the compiled kernel.
"""
# It is unnecessary to cache this call as it is cached in pyop2/compilation.py
func = self.compile(comm)
func = compile_global_kernel(self, comm)
func(*args)

@property
Expand Down Expand Up @@ -364,48 +366,7 @@ def builder(self):
@cached_property
def code_to_compile(self):
"""Return the C/C++ source code as a string."""
from pyop2.codegen.rep2loopy import generate

with PETSc.Log.Event("GlobalKernel: generate loopy"):
wrapper = generate(self.builder)

with PETSc.Log.Event("GlobalKernel: generate device code"):
code = lp.generate_code_v2(wrapper)

if self.local_kernel.cpp:
from loopy.codegen.result import process_preambles
preamble = "".join(process_preambles(getattr(code, "device_preambles", [])))
device_code = "\n\n".join(str(dp.ast) for dp in code.device_programs)
return preamble + "\nextern \"C\" {\n" + device_code + "\n}\n"
return code.device_code()

@PETSc.Log.EventDecorator()
@mpi.collective
def compile(self, comm):
"""Compile the kernel.

:arg comm: The communicator the compilation is collective over.
:returns: A ctypes function pointer for the compiled function.
"""
extension = "cpp" if self.local_kernel.cpp else "c"
cppargs = (
tuple("-I%s/include" % d for d in get_petsc_dir())
+ tuple("-I%s" % d for d in self.local_kernel.include_dirs)
+ ("-I%s" % os.path.abspath(os.path.dirname(__file__)),)
)
ldargs = (
tuple("-L%s/lib" % d for d in get_petsc_dir())
+ tuple("-Wl,-rpath,%s/lib" % d for d in get_petsc_dir())
+ ("-lpetsc", "-lm")
+ tuple(self.local_kernel.ldargs)
)

dll = load(self.code_to_compile, extension, cppargs=cppargs, ldargs=ldargs, comm=comm)
add_profiling_events(dll, self.local_kernel.events)
fn = getattr(dll, self.name)
fn.argtypes = self.argtypes
fn.restype = ctypes.c_int
return fn
return _generate_code_from_global_kernel(self)

@cached_property
def argtypes(self):
Expand All @@ -427,3 +388,65 @@ def num_flops(self, iterset):
elif region not in {IterationRegion.TOP, IterationRegion.BOTTOM}:
size = layers - 1
return size * self.local_kernel.num_flops

@cached_property
def _cppargs(self):
cppargs = [f"-I{d}/include" for d in get_petsc_dir()]
cppargs.extend(f"-I{d}" for d in self.local_kernel.include_dirs)
cppargs.append(f"-I{os.path.abspath(os.path.dirname(__file__))}")
return tuple(cppargs)

@cached_property
def _ldargs(self):
ldargs = [f"-L{d}/lib" for d in get_petsc_dir()]
ldargs.extend(f"-Wl,-rpath,{d}/lib" for d in get_petsc_dir())
ldargs.extend(["-lpetsc", "-lm"])
ldargs.extend(self.local_kernel.ldargs)
return tuple(ldargs)


@serial_cache(hashkey=lambda knl: knl.cache_key)
def _generate_code_from_global_kernel(kernel):
with PETSc.Log.Event("GlobalKernel: generate loopy"):
wrapper = generate(kernel.builder)

with PETSc.Log.Event("GlobalKernel: generate device code"):
code = lp.generate_code_v2(wrapper)

if kernel.local_kernel.cpp:
preamble = "".join(process_preambles(getattr(code, "device_preambles", [])))
device_code = "\n\n".join(str(dp.ast) for dp in code.device_programs)
return preamble + "\nextern \"C\" {\n" + device_code + "\n}\n"

return code.device_code()


@parallel_cache(hashkey=lambda knl, _: knl.cache_key)
@mpi.collective
def compile_global_kernel(kernel, comm):
"""Compile the kernel.

Parameters
----------
kernel :
The global kernel to generate code for.
comm :
The communicator the compilation is collective over.

Returns
-------
A ctypes function pointer for the compiled function.

"""
dll = load(
kernel.code_to_compile,
"cpp" if kernel.local_kernel.cpp else "c",
cppargs=kernel._cppargs,
ldargs=kernel._ldargs,
comm=comm,
)
add_profiling_events(dll, kernel.local_kernel.events)
fn = getattr(dll, kernel.name)
fn.argtypes = kernel.argtypes
fn.restype = ctypes.c_int
return fn
39 changes: 23 additions & 16 deletions tests/pyop2/test_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,13 @@ def cache(self):
int_comm.Set_attr(comm_cache_keyval, _cache_collection)
return _cache_collection[default_cache_name]

def code_cache_len_equals(self, expected):
# We need to do this check because different things also get
# put into self.cache
return sum(
1 for key in self.cache if key[1] == "compile_global_kernel"
) == expected

@pytest.fixture
def a(cls, diterset):
return op2.Dat(diterset, list(range(nelems)), numpy.uint32, "a")
Expand All @@ -328,14 +335,14 @@ def test_same_args(self, iterset, iter2ind1, x, a):
a(op2.WRITE),
x(op2.READ, iter2ind1))

assert len(self.cache) == 1
assert self.code_cache_len_equals(1)

op2.par_loop(op2.Kernel(kernel_cpy, "cpy"),
iterset,
a(op2.WRITE),
x(op2.READ, iter2ind1))

assert len(self.cache) == 1
assert self.code_cache_len_equals(1)

def test_diff_kernel(self, iterset, iter2ind1, x, a):
self.cache.clear()
Expand All @@ -348,7 +355,7 @@ def test_diff_kernel(self, iterset, iter2ind1, x, a):
a(op2.WRITE),
x(op2.READ, iter2ind1))

assert len(self.cache) == 1
assert self.code_cache_len_equals(1)

kernel_cpy = "static void cpy(unsigned int* DST, unsigned int* SRC) { *DST = *SRC; }"

Expand All @@ -357,7 +364,7 @@ def test_diff_kernel(self, iterset, iter2ind1, x, a):
a(op2.WRITE),
x(op2.READ, iter2ind1))

assert len(self.cache) == 2
assert self.code_cache_len_equals(2)

def test_invert_arg_similar_shape(self, iterset, iter2ind1, x, y):
self.cache.clear()
Expand All @@ -377,14 +384,14 @@ def test_invert_arg_similar_shape(self, iterset, iter2ind1, x, y):
x(op2.RW, iter2ind1),
y(op2.RW, iter2ind1))

assert len(self.cache) == 1
assert self.code_cache_len_equals(1)

op2.par_loop(op2.Kernel(kernel_swap, "swap"),
iterset,
y(op2.RW, iter2ind1),
x(op2.RW, iter2ind1))

assert len(self.cache) == 1
assert self.code_cache_len_equals(1)

def test_dloop_ignore_scalar(self, iterset, a, b):
self.cache.clear()
Expand All @@ -404,14 +411,14 @@ def test_dloop_ignore_scalar(self, iterset, a, b):
a(op2.RW),
b(op2.RW))

assert len(self.cache) == 1
assert self.code_cache_len_equals(1)

op2.par_loop(op2.Kernel(kernel_swap, "swap"),
iterset,
b(op2.RW),
a(op2.RW))

assert len(self.cache) == 1
assert self.code_cache_len_equals(1)

def test_vector_map(self, iterset, x2, iter2ind2):
self.cache.clear()
Expand All @@ -431,13 +438,13 @@ def test_vector_map(self, iterset, x2, iter2ind2):
iterset,
x2(op2.RW, iter2ind2))

assert len(self.cache) == 1
assert self.code_cache_len_equals(1)

op2.par_loop(op2.Kernel(kernel_swap, "swap"),
iterset,
x2(op2.RW, iter2ind2))

assert len(self.cache) == 1
assert self.code_cache_len_equals(1)

def test_same_iteration_space_works(self, iterset, x2, iter2ind2):
self.cache.clear()
Expand All @@ -447,12 +454,12 @@ def test_same_iteration_space_works(self, iterset, x2, iter2ind2):
op2.par_loop(k, iterset,
x2(op2.INC, iter2ind2))

assert len(self.cache) == 1
assert self.code_cache_len_equals(1)

op2.par_loop(k, iterset,
x2(op2.INC, iter2ind2))

assert len(self.cache) == 1
assert self.code_cache_len_equals(1)

def test_change_dat_dtype_matters(self, iterset, diterset):
d = op2.Dat(diterset, list(range(nelems)), numpy.uint32)
Expand All @@ -463,12 +470,12 @@ def test_change_dat_dtype_matters(self, iterset, diterset):

op2.par_loop(k, iterset, d(op2.WRITE))

assert len(self.cache) == 1
assert self.code_cache_len_equals(1)

d = op2.Dat(diterset, list(range(nelems)), numpy.int32)
op2.par_loop(k, iterset, d(op2.WRITE))

assert len(self.cache) == 2
assert self.code_cache_len_equals(2)

def test_change_global_dtype_matters(self, iterset, diterset):
g = op2.Global(1, 0, dtype=numpy.uint32, comm=COMM_WORLD)
Expand All @@ -479,12 +486,12 @@ def test_change_global_dtype_matters(self, iterset, diterset):

op2.par_loop(k, iterset, g(op2.INC))

assert len(self.cache) == 1
assert self.code_cache_len_equals(1)

g = op2.Global(1, 0, dtype=numpy.float64, comm=COMM_WORLD)
op2.par_loop(k, iterset, g(op2.INC))

assert len(self.cache) == 2
assert self.code_cache_len_equals(2)


class TestSparsityCache:
Expand Down
Loading