Skip to content

Commit

Permalink
Activate FFI implementation of the QR decomposition.
Browse files Browse the repository at this point in the history
As part of this change, I've added support and tests for shape polymorphism and export on CPU and GPU.

The FFI kernels have been available in jaxlib for over 3 weeks already and they are included with the latest release of jaxlib on PyPI so we don't need to worry about the forward compatibility checks. With this in mind, I also removed the old lowering rules, but kept the backwards compatibility tests for now.

PiperOrigin-RevId: 676825798
  • Loading branch information
dfm authored and Google-ML-Automation committed Sep 26, 2024
1 parent 5cef547 commit 1c2a6e6
Show file tree
Hide file tree
Showing 9 changed files with 367 additions and 444 deletions.
10 changes: 3 additions & 7 deletions jax/_src/export/_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -961,11 +961,6 @@ def _check_lowering(lowering) -> None:
"lapack_sorgqr", "lapack_dorgqr", "lapack_cungqr", "lapack_zungqr",
# svd on CPU
"lapack_sgesdd", "lapack_dgesdd", "lapack_cgesdd", "lapack_zgesdd",
# qr on GPU
"cusolver_geqrf", "cublas_geqrf_batched",
"cusolver_orgqr",
"hipsolver_geqrf", "hipblas_geqrf_batched",
"hipsolver_orgqr",
# qr and svd on TPU
"Qr", "ProductOfElementaryHouseholderReflectors",
# triangular_solve on CPU
Expand All @@ -977,9 +972,10 @@ def _check_lowering(lowering) -> None:
"lapack_sgees", "lapack_dgees", "lapack_cgees", "lapack_zgees",
# lu on GPU
"cu_lu_pivots_to_permutation",
# "cublas_getrf_batched", "cusolver_getrf",
# "hipblas_getrf_batched", "hipsolver_getrf",
"cusolver_getrf_ffi",
# qr on GPU
"cusolver_geqrf_ffi", "cusolver_orgqr_ffi",
# svd on GPU
# lu on TPU
"LuDecomposition",
# ApproxTopK on TPU
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ def run_one_test(self, func: Callable[..., jax.Array],
logging.info("Writing the updated testdata at %s", output_file)
with open(output_file, "w") as f:
f.write(updated_testdata)
print(updated_testdata)

if rtol is None:
rtol = 1.e-7
Expand Down
151 changes: 39 additions & 112 deletions jax/_src/lax/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1678,7 +1678,7 @@ def _geqrf_abstract_eval(operand):
if operand.ndim < 2:
raise ValueError("Argument to QR decomposition must have ndims >= 2")
*batch_dims, m, n = operand.shape
taus = operand.update(shape=(*batch_dims, min(m, n)))
taus = operand.update(shape=(*batch_dims, core.min_dim(m, n)))
return operand, taus

def _geqrf_batching_rule(batched_args, batch_dims):
Expand Down Expand Up @@ -1707,60 +1707,20 @@ def _geqrf_lowering_rule(ctx, operand):
)
return op.results

def _geqrf_cpu_gpu_lowering(geqrf_impl, batched_geqrf_impl, ctx, a, *,
platform: str):
a_aval, taus_aval = ctx.avals_out
*batch_dims, m, n = a_aval.shape
# It should be possible to support fully-dynamic shapes, but since
# the last two dimensions (m, n) are used in more involved ways, we only
# support dynamic dimensions for the batch size for now.
if not is_constant_shape([m, n]):
raise NotImplementedError(
"Shape polymorphism for native serialization for qr on CPU and GPU is "
f"implemented only for the batch dimensions: {a_aval.shape}")
batch = math.prod(batch_dims)

if batch == 0 or m == 0 or n == 0:
return mlir.full_like_aval(ctx, 0, a_aval), mlir.full_like_aval(ctx, 0, taus_aval)

if not is_constant_shape(a_aval.shape):
if platform in ["cuda", "rocm"]:
# TODO(necula): remove the platform kwarg when we implement GPU support.
raise NotImplementedError(
"Shape polymorphism for native serialization for QR is not "
f"implemented, try to upgrade jaxlib; b/261671778; {a_aval.shape}")

if (batched_geqrf_impl is not None and batch > 1 and m // batch <= 128 and
n // batch <= 128):
a_out, taus = batched_geqrf_impl(a_aval.dtype, a)
def _geqrf_cpu_gpu_lowering(ctx, a, *, target_name_prefix: str):
operand_aval, = ctx.avals_in
batch_dims = operand_aval.shape[:-2]
nb = len(batch_dims)
layout = (nb, nb + 1) + tuple(range(nb - 1, -1, -1))
result_layouts = [layout, tuple(range(nb, -1, -1))]
if target_name_prefix == "cpu":
target_name = lapack.build_lapack_fn_target("geqrf_ffi", operand_aval.dtype)
else:
if platform in ["cuda", "rocm"]:
a_out, taus, info_geqrf = geqrf_impl(a_aval.dtype, a)
else:
a_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, a_aval.shape)
ctx_args = (
(ctx,) if platform == "cpu" else ()
)
a_out, taus, *maybe_info_geqrf = geqrf_impl(
*ctx_args, a_aval.dtype, a, a_shape_vals=a_shape_vals
)
if not ctx.is_forward_compat():
# Skip the info parameter verification for the FFI kernel.
return a_out, taus
# TODO(b/344892332): This parameter will no longer be needed after
# the forward compatibility period
info_geqrf = maybe_info_geqrf[0]
zeros = mlir.full_like_aval(ctx, 0, ShapedArray(batch_dims, np.dtype(np.int32)))
ok = mlir.compare_hlo(info_geqrf, zeros, "EQ", "SIGNED")
select_ok_a_aval = ShapedArray(batch_dims + [1, 1], np.dtype(np.bool_))
ok_a = mlir.broadcast_in_dim(ctx, ok, select_ok_a_aval,
broadcast_dimensions=range(len(batch_dims)))
a_out = _broadcasting_select_hlo(ctx, ok_a, select_ok_a_aval, a_out, a_aval, _nan_like_hlo(ctx, a_aval), a_aval)
select_ok_taus_aval = ShapedArray(batch_dims + [1], np.dtype(np.bool_))
ok_taus = mlir.broadcast_in_dim(ctx, ok, select_ok_taus_aval,
broadcast_dimensions=range(len(batch_dims)))
taus = _broadcasting_select_hlo(ctx, ok_taus, select_ok_taus_aval, taus, taus_aval, _nan_like_hlo(ctx, taus_aval), taus_aval)
return a_out, taus
target_name = f"{target_name_prefix}solver_geqrf_ffi"
rule = ffi.ffi_lowering(target_name, operand_layouts=[layout],
result_layouts=result_layouts,
operand_output_aliases={0: 0})
return rule(ctx, a)

geqrf_p = Primitive('geqrf')
geqrf_p.multiple_results = True
Expand All @@ -1770,20 +1730,15 @@ def _geqrf_cpu_gpu_lowering(geqrf_impl, batched_geqrf_impl, ctx, a, *,
mlir.register_lowering(geqrf_p, _geqrf_lowering_rule)

mlir.register_lowering(
geqrf_p, partial(_geqrf_cpu_gpu_lowering, lapack.geqrf_hlo, None,
platform='cpu'),
geqrf_p, partial(_geqrf_cpu_gpu_lowering, target_name_prefix='cpu'),
platform='cpu')
mlir.register_lowering(
geqrf_p,
partial(_geqrf_cpu_gpu_lowering, gpu_solver.cuda_geqrf,
gpu_solver.cuda_geqrf_batched,
platform='cuda'),
partial(_geqrf_cpu_gpu_lowering, target_name_prefix='cu'),
platform='cuda')
mlir.register_lowering(
geqrf_p,
partial(_geqrf_cpu_gpu_lowering, gpu_solver.rocm_geqrf,
gpu_solver.rocm_geqrf_batched,
platform='rocm'),
partial(_geqrf_cpu_gpu_lowering, target_name_prefix='hip'),
platform='rocm')


Expand Down Expand Up @@ -1813,7 +1768,7 @@ def _householder_product_abstract_eval(a, taus):
raise ValueError("Argument to Householder product must have ndims >= 2")
*batch_dims, m, n = a.shape
*taus_batch_dims, k = taus.shape
if a.dtype != taus.dtype or batch_dims != taus_batch_dims or k > min(m, n):
if a.dtype != taus.dtype or batch_dims != taus_batch_dims or k > core.min_dim(m, n):
raise ValueError(f"Type mismatch for Householder product: {a=} {taus=}")
if m < n:
raise ValueError("Householder product inputs must have at least as many "
Expand Down Expand Up @@ -1841,48 +1796,23 @@ def _householder_product_lowering_rule(ctx, a, taus):
result_shapes=result_shapes)
return [op.result]

def _householder_product_cpu_gpu_lowering(orgqr_impl, ctx, a, taus, *,
platform: str):
a_aval, taus_aval = ctx.avals_in
*batch_dims, m, n = a_aval.shape
if not is_constant_shape([m, n]):
raise NotImplementedError(
"Shape polymorphism for native serialization for householder_product on "
f"CPU and GPU is implemented only for the batch dimensions: {a_aval.shape}")

if m == 0 or n == 0:
return [mlir.full_like_aval(ctx, 0, a_aval)]

if platform in ["rocm", "cuda"]:
# TODO(necula): remove the platform kwarg when we implement GPU support.
if not is_constant_shape(a_aval.shape):
raise NotImplementedError(
"Shape polymorphism for native serialization for householder_product "
f"on GPU is not implemented; b/261671778; {a_aval.shape}")
a, info_orgqr = orgqr_impl(a_aval.dtype, a, taus)
def _householder_product_cpu_gpu_lowering(ctx, a, taus, *,
target_name_prefix: str):
a_aval, _ = ctx.avals_in
batch_dims = a_aval.shape[:-2]
nb = len(batch_dims)
layout = (nb, nb + 1) + tuple(range(nb - 1, -1, -1))
tau_layout = tuple(range(nb, -1, -1))
if target_name_prefix == "cpu":
dtype = a_aval.dtype
prefix = "un" if dtypes.issubdtype(dtype, np.complexfloating) else "or"
target_name = lapack.build_lapack_fn_target(f"{prefix}gqr_ffi", dtype)
else:
ctx_args = (
(ctx,) if platform == "cpu" else ()
)
a_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, a_aval.shape)
tau_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, taus_aval.shape)
a, *maybe_info_orgqr = orgqr_impl(*ctx_args, a_aval.dtype, a, taus,
a_shape_vals=a_shape_vals,
tau_shape_vals=tau_shape_vals)
if not ctx.is_forward_compat():
# Skip the info parameter verification for the FFI kernel.
return [a]
# TODO(b/344892332): This parameter will no longer be needed after
# the forward compatibility period
info_orgqr = maybe_info_orgqr[0]
zeros = mlir.full_like_aval(ctx, 0, ShapedArray(batch_dims, np.dtype(np.int32)))
ok = mlir.compare_hlo(info_orgqr, zeros, "EQ", "SIGNED")
select_a_aval = ShapedArray(batch_dims + [1, 1], np.dtype(np.bool_))
ok = mlir.broadcast_in_dim(ctx, ok, select_a_aval,
broadcast_dimensions=range(len(batch_dims)))
a = _broadcasting_select_hlo(ctx, ok, select_a_aval, a, a_aval, _nan_like_hlo(ctx, a_aval), a_aval)
return [a]

target_name = f"{target_name_prefix}solver_orgqr_ffi"
rule = ffi.ffi_lowering(target_name, operand_layouts=[layout, tau_layout],
result_layouts=[layout],
operand_output_aliases={0: 0})
return rule(ctx, a, taus)

householder_product_p = Primitive('householder_product')
householder_product_p.def_impl(partial(dispatch.apply_primitive, householder_product_p))
Expand All @@ -1892,18 +1822,15 @@ def _householder_product_cpu_gpu_lowering(orgqr_impl, ctx, a, taus, *,

mlir.register_lowering(
householder_product_p,
partial(_householder_product_cpu_gpu_lowering, lapack.orgqr_hlo,
platform='cpu'),
partial(_householder_product_cpu_gpu_lowering, target_name_prefix='cpu'),
platform='cpu')
mlir.register_lowering(
householder_product_p,
partial(_householder_product_cpu_gpu_lowering, gpu_solver.cuda_orgqr,
platform='cuda'),
partial(_householder_product_cpu_gpu_lowering, target_name_prefix='cu'),
platform='cuda')
mlir.register_lowering(
householder_product_p,
partial(_householder_product_cpu_gpu_lowering, gpu_solver.rocm_orgqr,
platform='rocm'),
partial(_householder_product_cpu_gpu_lowering, target_name_prefix='hip'),
platform='rocm')


Expand All @@ -1916,7 +1843,7 @@ def _qr_abstract_eval(operand, *, full_matrices):
if operand.ndim < 2:
raise ValueError("Argument to QR decomposition must have ndims >= 2")
*batch_dims, m, n = operand.shape
k = m if full_matrices else min(m, n)
k = m if full_matrices else core.min_dim(m, n)
q = operand.update(shape=(*batch_dims, m, k))
r = operand.update(shape=(*batch_dims, k, n))
else:
Expand Down Expand Up @@ -1953,7 +1880,7 @@ def _qr_batching_rule(batched_args, batch_dims, *, full_matrices):
def _qr_lowering(a, *, full_matrices):
*batch_dims, m, n = a.shape
if m == 0 or n == 0:
k = m if full_matrices else min(m, n)
k = m if full_matrices else core.min_dim(m, n)
q = lax.broadcast_in_dim(lax_internal._eye(a.dtype, (m, k)),
(*batch_dims, m, k),
(len(batch_dims), len(batch_dims) + 1))
Expand Down
124 changes: 0 additions & 124 deletions jaxlib/gpu_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,86 +151,6 @@ def _getrf_hlo(platform, gpu_blas, gpu_solver, dtype, a):
rocm_getrf = partial(_getrf_hlo, "hip", _hipblas, _hipsolver)


def _geqrf_hlo(platform, gpu_solver, dtype, a):
"""QR decomposition."""
a_type = ir.RankedTensorType(a.type)
dims = a_type.shape
assert len(dims) >= 2
m, n = dims[-2:]
batch_dims = tuple(dims[:-2])
num_bd = len(batch_dims)
batch = math.prod(batch_dims)

lwork, opaque = gpu_solver.build_geqrf_descriptor(
np.dtype(dtype), batch, m, n)

layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
i32_type = ir.IntegerType.get_signless(32)
out = custom_call(
f"{platform}solver_geqrf",
result_types=[
a.type,
ir.RankedTensorType.get(batch_dims + (min(m, n),), a_type.element_type),
ir.RankedTensorType.get(batch_dims, i32_type),
ir.RankedTensorType.get([lwork], a_type.element_type),
],
operands=[a],
backend_config=opaque,
operand_layouts=[layout],
result_layouts=[
layout,
tuple(range(num_bd, -1, -1)),
tuple(range(num_bd - 1, -1, -1)),
[0],
],
operand_output_aliases={0: 0}).results
return out[:3]

cuda_geqrf = partial(_geqrf_hlo, "cu", _cusolver)
rocm_geqrf = partial(_geqrf_hlo, "hip", _hipsolver)

def _geqrf_batched_hlo(platform, gpu_blas, dtype, a):
"""Batched QR decomposition."""
a_type = ir.RankedTensorType(a.type)
dims = a_type.shape
assert len(dims) >= 2
m, n = dims[-2:]
batch_dims = tuple(dims[:-2])
num_bd = len(batch_dims)
batch = math.prod(batch_dims)

if not gpu_blas:
raise GpuLibNotLinkedError()

lwork, opaque = gpu_blas.build_geqrf_batched_descriptor(
np.dtype(dtype), batch, m, n)

layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
out = custom_call(
f"{platform}blas_geqrf_batched",
result_types=[
a.type,
ir.RankedTensorType.get(batch_dims + (min(m, n),), a_type.element_type),
ir.RankedTensorType.get([lwork], ir.IntegerType.get_signless(8)),
ir.RankedTensorType.get([lwork], ir.IntegerType.get_signless(8)),
],
operands=[a],
backend_config=opaque,
operand_layouts=[layout],
result_layouts=[
layout,
tuple(range(num_bd, -1, -1)),
[0],
[0],
],
operand_output_aliases={0: 0}
).results
return out[:2]

cuda_geqrf_batched = partial(_geqrf_batched_hlo, "cu", _cublas)
rocm_geqrf_batched = partial(_geqrf_batched_hlo, "hip", _hipblas)


def _csrlsvqr_hlo(platform, gpu_solver, dtype, data,
indices, indptr, b, tol, reorder):
"""Sparse solver via QR decomposition. CUDA only."""
Expand All @@ -256,50 +176,6 @@ def _csrlsvqr_hlo(platform, gpu_solver, dtype, data,
cuda_csrlsvqr = partial(_csrlsvqr_hlo, "cu", _cusolver)


def _orgqr_hlo(platform, gpu_solver, dtype, a, tau):
"""Product of elementary Householder reflections."""
a_type = ir.RankedTensorType(a.type)
dims = a_type.shape
assert len(dims) >= 2
m, n = dims[-2:]
batch_dims = tuple(dims[:-2])
num_bd = len(batch_dims)
batch = math.prod(batch_dims)

tau_dims = ir.RankedTensorType(tau.type).shape
assert tau_dims[:-1] == dims[:-2]
k = tau_dims[-1]

lwork, opaque = gpu_solver.build_orgqr_descriptor(
np.dtype(dtype), batch, m, n, k)

layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
i32_type = ir.IntegerType.get_signless(32)
out = custom_call(
f"{platform}solver_orgqr",
result_types=[
a.type,
ir.RankedTensorType.get(batch_dims, i32_type),
ir.RankedTensorType.get([lwork], a_type.element_type),
],
operands=[a, tau],
backend_config=opaque,
operand_layouts=[
layout,
tuple(range(num_bd, -1, -1)),
],
result_layouts=[
layout,
tuple(range(num_bd - 1, -1, -1)),
[0],
],
operand_output_aliases={0: 0}).results
return out[:2]

cuda_orgqr = partial(_orgqr_hlo, "cu", _cusolver)
rocm_orgqr = partial(_orgqr_hlo, "hip", _hipsolver)


def _syevd_hlo(platform, gpu_solver, have_jacobi_solver, dtype, a, *,
a_shape_vals: tuple[DimensionSize, ...], lower=False):
"""Symmetric (Hermitian) eigendecomposition."""
Expand Down
Loading

0 comments on commit 1c2a6e6

Please sign in to comment.