Skip to content

Commit

Permalink
Upstream polynomial.ntt and polynomial.intt (llvm#90992)
Browse files Browse the repository at this point in the history
These two ops represent a number-theoretic transform of a polynomial to
a tensor of evaluations of the polynomial at a list of powers of
primitive roots of the polynomial.

To support this, a new optional attribute is added to the ring attribute
to specify the primitive root of unity used for the NTT. A verifier for
the op is added to ensure the chosen root is a primitive nth root of
unity.

---------

Co-authored-by: Jeremy Kun <j2kun@users.noreply.github.com>
Co-authored-by: Oleksandr "Alex" Zinenko <ftynse@gmail.com>
  • Loading branch information
3 people authored May 5, 2024
1 parent 716eab7 commit 624c9fc
Show file tree
Hide file tree
Showing 5 changed files with 251 additions and 5 deletions.
56 changes: 53 additions & 3 deletions mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def Polynomial_PolynomialAttr : Polynomial_Attr<"Polynomial", "polynomial"> {
#poly = #polynomial.polynomial<x**1024 + 1>
```
}];
let parameters = (ins "Polynomial":$polynomial);
let parameters = (ins "::mlir::polynomial::Polynomial":$polynomial);
let hasCustomAssemblyFormat = 1;
}

Expand Down Expand Up @@ -122,10 +122,19 @@ def Polynomial_RingAttr : Polynomial_Attr<"Ring", "ring"> {

let parameters = (ins
"Type": $coefficientType,
OptionalParameter<"IntegerAttr">: $coefficientModulus,
OptionalParameter<"PolynomialAttr">: $polynomialModulus
OptionalParameter<"::mlir::IntegerAttr">: $coefficientModulus,
OptionalParameter<"::mlir::polynomial::PolynomialAttr">: $polynomialModulus,
OptionalParameter<"::mlir::IntegerAttr">: $primitiveRoot
);

let builders = [
AttrBuilder<
(ins "::mlir::Type":$coefficientTy,
"::mlir::IntegerAttr":$coefficientModulusAttr,
"::mlir::polynomial::PolynomialAttr":$polynomialModulusAttr), [{
return $_get($_ctxt, coefficientTy, coefficientModulusAttr, polynomialModulusAttr, nullptr);
}]>
];
let hasCustomAssemblyFormat = 1;
}

Expand Down Expand Up @@ -416,4 +425,45 @@ def Polynomial_ConstantOp : Polynomial_Op<"constant", [Pure]> {
let assemblyFormat = "$input attr-dict `:` type($output)";
}

def Polynomial_NTTOp : Polynomial_Op<"ntt", [Pure]> {
let summary = "Computes point-value tensor representation of a polynomial.";
let description = [{
`polynomial.ntt` computes the forward integer Number Theoretic Transform
(NTT) on the input polynomial. It returns a tensor containing a point-value
representation of the input polynomial. The output tensor has shape equal
to the degree of the ring's `polynomialModulus`. The polynomial's RingAttr
is embedded as the encoding attribute of the output tensor.

Given an input polynomial `F(x)` over a ring whose `polynomialModulus` has
degree `n`, and a primitive `n`-th root of unity `omega_n`, the output is
the list of $n$ evaluations

`f[k] = F(omega[n]^k) ; k = {0, ..., n-1}`

The choice of primitive root is determined by subsequent lowerings.
}];
let arguments = (ins Polynomial_PolynomialType:$input);
let results = (outs RankedTensorOf<[AnyInteger]>:$output);
let assemblyFormat = "$input attr-dict `:` qualified(type($input)) `->` type($output)";
let hasVerifier = 1;
}

def Polynomial_INTTOp : Polynomial_Op<"intt", [Pure]> {
let summary = "Computes the reverse integer Number Theoretic Transform (NTT).";
let description = [{
`polynomial.intt` computes the reverse integer Number Theoretic Transform
(INTT) on the input tensor. This is the inverse operation of the
`polynomial.ntt` operation.

The input tensor is interpreted as a point-value representation of the
output polynomial at powers of a primitive `n`-th root of unity (see
`polynomial.ntt`). The ring of the polynomial is taken from the required
encoding attribute of the tensor.
}];
let arguments = (ins RankedTensorOf<[AnyInteger]>:$input);
let results = (outs Polynomial_PolynomialType:$output);
let assemblyFormat = "$input attr-dict `:` qualified(type($input)) `->` type($output)";
let hasVerifier = 1;
}

#endif // POLYNOMIAL_OPS
18 changes: 17 additions & 1 deletion mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -202,11 +202,27 @@ Attribute RingAttr::parse(AsmParser &parser, Type type) {
polyAttr = attr;
}

Polynomial poly = polyAttr.getPolynomial();
APInt root(coefficientModulusAttr.getValue().getBitWidth(), 0);
IntegerAttr rootAttr = nullptr;
if (succeeded(parser.parseOptionalComma())) {
if (failed(parser.parseKeyword("primitiveRoot")) ||
failed(parser.parseEqual()))
return {};

ParseResult result = parser.parseInteger(root);
if (failed(result)) {
parser.emitError(parser.getCurrentLocation(), "invalid primitiveRoot");
return {};
}
rootAttr = IntegerAttr::get(coefficientModulusAttr.getType(), root);
}

if (failed(parser.parseGreater()))
return {};

return RingAttr::get(parser.getContext(), ty, coefficientModulusAttr,
polyAttr);
polyAttr, rootAttr);
}

} // namespace polynomial
Expand Down
79 changes: 79 additions & 0 deletions mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,82 @@ LogicalResult MulScalarOp::verify() {

return success();
}

/// Test if a value is a primitive nth root of unity modulo cmod.
bool isPrimitiveNthRootOfUnity(const APInt &root, const unsigned n,
const APInt &cmod) {
// Root bitwidth may be 1 less then cmod.
APInt r = APInt(root).zext(cmod.getBitWidth());
assert(r.ule(cmod) && "root must be less than cmod");

APInt a = r;
for (size_t k = 1; k < n; k++) {
if (a.isOne())
return false;
a = (a * r).urem(cmod);
}
return a.isOne();
}

/// Verify that the types involved in an NTT or INTT operation are
/// compatible.
static LogicalResult verifyNTTOp(Operation *op, RingAttr ring,
RankedTensorType tensorType) {
Attribute encoding = tensorType.getEncoding();
if (!encoding) {
return op->emitOpError()
<< "expects a ring encoding to be provided to the tensor";
}
auto encodedRing = dyn_cast<RingAttr>(encoding);
if (!encodedRing) {
return op->emitOpError()
<< "the provided tensor encoding is not a ring attribute";
}

if (encodedRing != ring) {
return op->emitOpError()
<< "encoded ring type " << encodedRing
<< " is not equivalent to the polynomial ring " << ring;
}

unsigned polyDegree = ring.getPolynomialModulus().getPolynomial().getDegree();
ArrayRef<int64_t> tensorShape = tensorType.getShape();
bool compatible = tensorShape.size() == 1 && tensorShape[0] == polyDegree;
if (!compatible) {
InFlightDiagnostic diag = op->emitOpError()
<< "tensor type " << tensorType
<< " does not match output type " << ring;
diag.attachNote() << "the tensor must have shape [d] where d "
"is exactly the degree of the polynomialModulus of "
"the polynomial type's ring attribute";
return diag;
}

if (!ring.getPrimitiveRoot()) {
return op->emitOpError()
<< "ring type " << ring << " does not provide a primitive root "
<< "of unity, which is required to express an NTT";
}

if (!isPrimitiveNthRootOfUnity(ring.getPrimitiveRoot().getValue(), polyDegree,
ring.getCoefficientModulus().getValue())) {
return op->emitOpError()
<< "ring type " << ring << " has a primitiveRoot attribute '"
<< ring.getPrimitiveRoot()
<< "' that is not a primitive root of the coefficient ring";
}

return success();
}

LogicalResult NTTOp::verify() {
auto ring = getInput().getType().getRing();
auto tensorType = getOutput().getType();
return verifyNTTOp(this->getOperation(), ring, tensorType);
}

LogicalResult INTTOp::verify() {
auto tensorType = getInput().getType();
auto ring = getOutput().getType().getRing();
return verifyNTTOp(this->getOperation(), ring, tensorType);
}
16 changes: 15 additions & 1 deletion mlir/test/Dialect/Polynomial/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,13 @@
#one_plus_x_squared = #polynomial.polynomial<1 + x**2>

#ideal = #polynomial.polynomial<-1 + x**1024>
#ring = #polynomial.ring<coefficientType=i32, coefficientModulus=18, polynomialModulus=#ideal>
#ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ideal, primitiveRoot=193>
!poly_ty = !polynomial.polynomial<#ring>

#ntt_poly = #polynomial.polynomial<-1 + x**8>
#ntt_ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ntt_poly, primitiveRoot=31>
!ntt_poly_ty = !polynomial.polynomial<#ntt_ring>

module {
func.func @test_multiply() -> !polynomial.polynomial<#ring1> {
%c0 = arith.constant 0 : index
Expand Down Expand Up @@ -79,4 +83,14 @@ module {
%1 = polynomial.constant <1 + x**2> : !polynomial.polynomial<#ring1>
return
}

func.func @test_ntt(%0 : !ntt_poly_ty) {
%1 = polynomial.ntt %0 : !ntt_poly_ty -> tensor<8xi32, #ntt_ring>
return
}

func.func @test_intt(%0 : tensor<8xi32, #ntt_ring>) {
%1 = polynomial.intt %0 : tensor<8xi32, #ntt_ring> -> !ntt_poly_ty
return
}
}
87 changes: 87 additions & 0 deletions mlir/test/Dialect/Polynomial/ops_errors.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,90 @@ func.func @test_mul_scalar_wrong_type(%arg0: !ty) -> !ty {
%poly = polynomial.mul_scalar %arg0, %scalar : !ty, i32
return %poly : !ty
}

// -----

#my_poly = #polynomial.polynomial<-1 + x**1024>
#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256, polynomialModulus=#my_poly, primitiveRoot=31>
!poly_ty = !polynomial.polynomial<#ring>

// CHECK-NOT: @test_invalid_ntt
// CHECK-NOT: polynomial.ntt
func.func @test_invalid_ntt(%0 : !poly_ty) {
// expected-error@below {{expects a ring encoding to be provided to the tensor}}
%1 = polynomial.ntt %0 : !poly_ty -> tensor<1024xi32>
return
}

// -----

#my_poly = #polynomial.polynomial<-1 + x**1024>
#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256, polynomialModulus=#my_poly, primitiveRoot=31>
!poly_ty = !polynomial.polynomial<#ring>

// CHECK-NOT: @test_invalid_ntt
// CHECK-NOT: polynomial.ntt
func.func @test_invalid_ntt(%0 : !poly_ty) {
// expected-error@below {{tensor encoding is not a ring attribute}}
%1 = polynomial.ntt %0 : !poly_ty -> tensor<1024xi32, #my_poly>
return
}

// -----

#my_poly = #polynomial.polynomial<-1 + x**1024>
#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256, polynomialModulus=#my_poly>
#ring1 = #polynomial.ring<coefficientType=i16, coefficientModulus=257, polynomialModulus=#my_poly, primitiveRoot=31>
!poly_ty = !polynomial.polynomial<#ring>

// CHECK-NOT: @test_invalid_intt
// CHECK-NOT: polynomial.intt
func.func @test_invalid_intt(%0 : tensor<1024xi32, #ring1>) {
// expected-error@below {{not equivalent to the polynomial ring}}
%1 = polynomial.intt %0 : tensor<1024xi32, #ring1> -> !poly_ty
return
}

// -----

#my_poly = #polynomial.polynomial<-1 + x**1024>
#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256, polynomialModulus=#my_poly, primitiveRoot=31>
!poly_ty = !polynomial.polynomial<#ring>

// CHECK-NOT: @test_invalid_intt
// CHECK-NOT: polynomial.intt
func.func @test_invalid_intt(%0 : tensor<1025xi32, #ring>) {
// expected-error@below {{does not match output type}}
// expected-note@below {{exactly the degree of the polynomialModulus of the polynomial type's ring attribute}}
%1 = polynomial.intt %0 : tensor<1025xi32, #ring> -> !poly_ty
return
}

// -----

#my_poly = #polynomial.polynomial<-1 + x**1024>
#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256, polynomialModulus=#my_poly>
!poly_ty = !polynomial.polynomial<#ring>

// CHECK-NOT: @test_invalid_ntt
// CHECK-NOT: polynomial.ntt
func.func @test_invalid_ntt(%0 : !poly_ty) {
// expected-error@below {{does not provide a primitive root of unity, which is required to express an NTT}}
%1 = polynomial.ntt %0 : !poly_ty -> tensor<1024xi32, #ring>
return
}

// -----

#my_poly = #polynomial.polynomial<-1 + x**8>
// A valid root is 31
#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256, polynomialModulus=#my_poly, primitiveRoot=32>
!poly_ty = !polynomial.polynomial<#ring>

// CHECK-NOT: @test_invalid_intt
// CHECK-NOT: polynomial.intt
func.func @test_invalid_intt(%0 : tensor<8xi32, #ring>) {
// expected-error@below {{has a primitiveRoot attribute '32 : i16' that is not a primitive root of the coefficient ring}}
%1 = polynomial.intt %0 : tensor<8xi32, #ring> -> !poly_ty
return
}

0 comments on commit 624c9fc

Please sign in to comment.