diff --git a/codon/cir/analyze/dataflow/reaching.cpp b/codon/cir/analyze/dataflow/reaching.cpp index d5d8e61e..016c906d 100644 --- a/codon/cir/analyze/dataflow/reaching.cpp +++ b/codon/cir/analyze/dataflow/reaching.cpp @@ -71,10 +71,10 @@ struct BitSet { return res; } - void set(unsigned bit) { words.data()[bit / B] |= (1 << (bit % B)); } + void set(unsigned bit) { words.data()[bit / B] |= (1UL << (bit % B)); } bool get(unsigned bit) const { - return (words.data()[bit / B] & (1 << (bit % B))) != 0; + return (words.data()[bit / B] & (1UL << (bit % B))) != 0; } bool equals(const BitSet &other, unsigned size) { diff --git a/codon/cir/llvm/gpu.cpp b/codon/cir/llvm/gpu.cpp index b05ebfb1..3205d059 100644 --- a/codon/cir/llvm/gpu.cpp +++ b/codon/cir/llvm/gpu.cpp @@ -500,6 +500,14 @@ void moduleToPTX(llvm::Module *M, const std::string &filename, linkLibdevice(M, libdevice); remapFunctions(M); + // Strip debug info and remove noinline from functions (added in debug mode). + // Also, tell LLVM that all functions will return. + for (auto &F : *M) { + F.removeFnAttr(llvm::Attribute::AttrKind::NoInline); + F.setWillReturn(); + } + llvm::StripDebugInfo(*M); + // Run NVPTX passes and general opt pipeline. { llvm::LoopAnalysisManager lam; diff --git a/codon/cir/llvm/llvisitor.cpp b/codon/cir/llvm/llvisitor.cpp index a6f3a1bd..a4385be4 100644 --- a/codon/cir/llvm/llvisitor.cpp +++ b/codon/cir/llvm/llvisitor.cpp @@ -2086,6 +2086,18 @@ llvm::Type *LLVMVisitor::getLLVMType(types::Type *t) { return B->getFloatTy(); } + if (auto *x = cast(t)) { + return B->getHalfTy(); + } + + if (auto *x = cast(t)) { + return B->getBFloatTy(); + } + + if (auto *x = cast(t)) { + return llvm::Type::getFP128Ty(*context); + } + if (auto *x = cast(t)) { return B->getInt8Ty(); } @@ -2203,6 +2215,22 @@ llvm::DIType *LLVMVisitor::getDITypeHelper( x->getName(), layout.getTypeAllocSizeInBits(type), llvm::dwarf::DW_ATE_float); } + if (auto *x = cast(t)) { + return db.builder->createBasicType( + x->getName(), layout.getTypeAllocSizeInBits(type), llvm::dwarf::DW_ATE_float); + } + + if (auto *x = cast(t)) { + return db.builder->createBasicType( + x->getName(), layout.getTypeAllocSizeInBits(type), llvm::dwarf::DW_ATE_float); + } + + if (auto *x = cast(t)) { + return db.builder->createBasicType(x->getName(), + layout.getTypeAllocSizeInBits(type), + llvm::dwarf::DW_ATE_HP_float128); + } + if (auto *x = cast(t)) { return db.builder->createBasicType( x->getName(), layout.getTypeAllocSizeInBits(type), llvm::dwarf::DW_ATE_boolean); diff --git a/codon/cir/module.cpp b/codon/cir/module.cpp index 3588a455..e4ee00a0 100644 --- a/codon/cir/module.cpp +++ b/codon/cir/module.cpp @@ -62,6 +62,9 @@ const std::string Module::BYTE_NAME = "byte"; const std::string Module::INT_NAME = "int"; const std::string Module::FLOAT_NAME = "float"; const std::string Module::FLOAT32_NAME = "float32"; +const std::string Module::FLOAT16_NAME = "float16"; +const std::string Module::BFLOAT16_NAME = "bfloat16"; +const std::string Module::FLOAT128_NAME = "float128"; const std::string Module::STRING_NAME = "str"; const std::string Module::EQ_MAGIC_NAME = "__eq__"; @@ -239,6 +242,24 @@ types::Type *Module::getFloat32Type() { return Nr(); } +types::Type *Module::getFloat16Type() { + if (auto *rVal = getType(FLOAT16_NAME)) + return rVal; + return Nr(); +} + +types::Type *Module::getBFloat16Type() { + if (auto *rVal = getType(BFLOAT16_NAME)) + return rVal; + return Nr(); +} + +types::Type *Module::getFloat128Type() { + if (auto *rVal = getType(FLOAT128_NAME)) + return rVal; + return Nr(); +} + types::Type *Module::getStringType() { if (auto *rVal = getType(STRING_NAME)) return rVal; diff --git a/codon/cir/module.h b/codon/cir/module.h index d4ffc866..3f90b52b 100644 --- a/codon/cir/module.h +++ b/codon/cir/module.h @@ -34,6 +34,9 @@ class Module : public AcceptorExtend { static const std::string INT_NAME; static const std::string FLOAT_NAME; static const std::string FLOAT32_NAME; + static const std::string FLOAT16_NAME; + static const std::string BFLOAT16_NAME; + static const std::string FLOAT128_NAME; static const std::string STRING_NAME; static const std::string EQ_MAGIC_NAME; @@ -338,6 +341,12 @@ class Module : public AcceptorExtend { types::Type *getFloatType(); /// @return the float32 type types::Type *getFloat32Type(); + /// @return the float16 type + types::Type *getFloat16Type(); + /// @return the bfloat16 type + types::Type *getBFloat16Type(); + /// @return the float128 type + types::Type *getFloat128Type(); /// @return the string type types::Type *getStringType(); /// Gets a pointer type. diff --git a/codon/cir/types/types.cpp b/codon/cir/types/types.cpp index 1688f4f2..a5353bf2 100644 --- a/codon/cir/types/types.cpp +++ b/codon/cir/types/types.cpp @@ -69,6 +69,12 @@ const char FloatType::NodeId = 0; const char Float32Type::NodeId = 0; +const char Float16Type::NodeId = 0; + +const char BFloat16Type::NodeId = 0; + +const char Float128Type::NodeId = 0; + const char BoolType::NodeId = 0; const char ByteType::NodeId = 0; diff --git a/codon/cir/types/types.h b/codon/cir/types/types.h index 8d55aea5..0845e4cf 100644 --- a/codon/cir/types/types.h +++ b/codon/cir/types/types.h @@ -169,6 +169,33 @@ class Float32Type : public AcceptorExtend { Float32Type() : AcceptorExtend("float32") {} }; +/// Float16 type (16-bit float) +class Float16Type : public AcceptorExtend { +public: + static const char NodeId; + + /// Constructs a float16 type. + Float16Type() : AcceptorExtend("float16") {} +}; + +/// BFloat16 type (16-bit brain float) +class BFloat16Type : public AcceptorExtend { +public: + static const char NodeId; + + /// Constructs a bfloat16 type. + BFloat16Type() : AcceptorExtend("bfloat16") {} +}; + +/// Float128 type (128-bit float) +class Float128Type : public AcceptorExtend { +public: + static const char NodeId; + + /// Constructs a float128 type. + Float128Type() : AcceptorExtend("float128") {} +}; + /// Bool type (8-bit unsigned integer; either 0 or 1) class BoolType : public AcceptorExtend { public: diff --git a/codon/cir/util/format.cpp b/codon/cir/util/format.cpp index 17151da5..c3146443 100644 --- a/codon/cir/util/format.cpp +++ b/codon/cir/util/format.cpp @@ -295,6 +295,15 @@ class FormatVisitor : util::ConstVisitor { void visit(const types::Float32Type *v) override { fmt::print(os, FMT_STRING("(float32 '\"{}\")"), v->referenceString()); } + void visit(const types::Float16Type *v) override { + fmt::print(os, FMT_STRING("(float16 '\"{}\")"), v->referenceString()); + } + void visit(const types::BFloat16Type *v) override { + fmt::print(os, FMT_STRING("(bfloat16 '\"{}\")"), v->referenceString()); + } + void visit(const types::Float128Type *v) override { + fmt::print(os, FMT_STRING("(float128 '\"{}\")"), v->referenceString()); + } void visit(const types::BoolType *v) override { fmt::print(os, FMT_STRING("(bool '\"{}\")"), v->referenceString()); } diff --git a/codon/cir/util/visitor.cpp b/codon/cir/util/visitor.cpp index 2044e26d..5856516e 100644 --- a/codon/cir/util/visitor.cpp +++ b/codon/cir/util/visitor.cpp @@ -54,6 +54,9 @@ void Visitor::visit(types::PrimitiveType *x) { defaultVisit(x); } void Visitor::visit(types::IntType *x) { defaultVisit(x); } void Visitor::visit(types::FloatType *x) { defaultVisit(x); } void Visitor::visit(types::Float32Type *x) { defaultVisit(x); } +void Visitor::visit(types::Float16Type *x) { defaultVisit(x); } +void Visitor::visit(types::BFloat16Type *x) { defaultVisit(x); } +void Visitor::visit(types::Float128Type *x) { defaultVisit(x); } void Visitor::visit(types::BoolType *x) { defaultVisit(x); } void Visitor::visit(types::ByteType *x) { defaultVisit(x); } void Visitor::visit(types::VoidType *x) { defaultVisit(x); } @@ -114,6 +117,9 @@ void ConstVisitor::visit(const types::PrimitiveType *x) { defaultVisit(x); } void ConstVisitor::visit(const types::IntType *x) { defaultVisit(x); } void ConstVisitor::visit(const types::FloatType *x) { defaultVisit(x); } void ConstVisitor::visit(const types::Float32Type *x) { defaultVisit(x); } +void ConstVisitor::visit(const types::Float16Type *x) { defaultVisit(x); } +void ConstVisitor::visit(const types::BFloat16Type *x) { defaultVisit(x); } +void ConstVisitor::visit(const types::Float128Type *x) { defaultVisit(x); } void ConstVisitor::visit(const types::BoolType *x) { defaultVisit(x); } void ConstVisitor::visit(const types::ByteType *x) { defaultVisit(x); } void ConstVisitor::visit(const types::VoidType *x) { defaultVisit(x); } diff --git a/codon/cir/util/visitor.h b/codon/cir/util/visitor.h index 4767f0e2..199440c7 100644 --- a/codon/cir/util/visitor.h +++ b/codon/cir/util/visitor.h @@ -19,6 +19,9 @@ class PrimitiveType; class IntType; class FloatType; class Float32Type; +class Float16Type; +class BFloat16Type; +class Float128Type; class BoolType; class ByteType; class VoidType; @@ -152,6 +155,9 @@ class Visitor { VISIT(types::IntType); VISIT(types::FloatType); VISIT(types::Float32Type); + VISIT(types::Float16Type); + VISIT(types::BFloat16Type); + VISIT(types::Float128Type); VISIT(types::BoolType); VISIT(types::ByteType); VISIT(types::VoidType); @@ -229,6 +235,9 @@ class ConstVisitor { CONST_VISIT(types::IntType); CONST_VISIT(types::FloatType); CONST_VISIT(types::Float32Type); + CONST_VISIT(types::Float16Type); + CONST_VISIT(types::BFloat16Type); + CONST_VISIT(types::Float128Type); CONST_VISIT(types::BoolType); CONST_VISIT(types::ByteType); CONST_VISIT(types::VoidType); diff --git a/codon/parser/visitors/typecheck/call.cpp b/codon/parser/visitors/typecheck/call.cpp index b91f0777..ed43095f 100644 --- a/codon/parser/visitors/typecheck/call.cpp +++ b/codon/parser/visitors/typecheck/call.cpp @@ -709,6 +709,7 @@ ExprPtr TypecheckVisitor::transformArray(CallExpr *expr) { /// `isinstance(obj, ByRef)` is True if `type(obj)` is a reference type ExprPtr TypecheckVisitor::transformIsInstance(CallExpr *expr) { expr->setType(unify(expr->type, ctx->getType("bool"))); + expr->staticValue.type = StaticValue::INT; // prevent branching until this is resolved transform(expr->args[0].value); auto typ = expr->args[0].value->type->getClass(); if (!typ || !typ->canRealize()) @@ -947,10 +948,11 @@ ExprPtr TypecheckVisitor::transformStaticPrintFn(CallExpr *expr) { auto &args = expr->args[0].value->getCall()->args; for (size_t i = 0; i < args.size(); i++) { realize(args[i].value->type); - fmt::print(stderr, "[static_print] {}: {} := {}{}\n", getSrcInfo(), + fmt::print(stderr, "[static_print] {}: {} := {}{} (iter: {})\n", getSrcInfo(), FormatVisitor::apply(args[i].value), args[i].value->type ? args[i].value->type->debugString(1) : "-", - args[i].value->isStatic() ? " [static]" : ""); + args[i].value->isStatic() ? " [static]" : "", + ctx->getRealizationBase()->iteration); } return nullptr; } diff --git a/codon/parser/visitors/typecheck/ctx.cpp b/codon/parser/visitors/typecheck/ctx.cpp index 5df9e861..40d0dd36 100644 --- a/codon/parser/visitors/typecheck/ctx.cpp +++ b/codon/parser/visitors/typecheck/ctx.cpp @@ -100,13 +100,13 @@ types::TypePtr TypeContext::instantiate(const SrcInfo &srcInfo, if (auto l = i.second->getLink()) { i.second->setSrcInfo(srcInfo); if (l->defaultType) { - pendingDefaults.insert(i.second); + getRealizationBase()->pendingDefaults.insert(i.second); } } } if (t->getUnion() && !t->getUnion()->isSealed()) { t->setSrcInfo(srcInfo); - pendingDefaults.insert(t); + getRealizationBase()->pendingDefaults.insert(t); } if (auto r = t->getRecord()) if (r->repeats && r->repeats->canRealize()) diff --git a/codon/parser/visitors/typecheck/ctx.h b/codon/parser/visitors/typecheck/ctx.h index 8eb8dfce..9d8a36e4 100644 --- a/codon/parser/visitors/typecheck/ctx.h +++ b/codon/parser/visitors/typecheck/ctx.h @@ -50,12 +50,12 @@ struct TypeContext : public Context { types::TypePtr returnType = nullptr; /// Typechecking iteration int iteration = 0; + std::set pendingDefaults; }; std::vector realizationBases; /// The current type-checking level (for type instantiation and generalization). int typecheckLevel; - std::set pendingDefaults; int changedNodes; /// The age of the currently parsed statement. diff --git a/codon/parser/visitors/typecheck/infer.cpp b/codon/parser/visitors/typecheck/infer.cpp index f53f66fd..2e91ad9c 100644 --- a/codon/parser/visitors/typecheck/infer.cpp +++ b/codon/parser/visitors/typecheck/infer.cpp @@ -99,8 +99,9 @@ StmtPtr TypecheckVisitor::inferTypes(StmtPtr result, bool isToplevel) { bool anotherRound = false; // Special case: return type might have default as well (e.g., Union) if (ctx->getRealizationBase()->returnType) - ctx->pendingDefaults.insert(ctx->getRealizationBase()->returnType); - for (auto &unbound : ctx->pendingDefaults) { + ctx->getRealizationBase()->pendingDefaults.insert( + ctx->getRealizationBase()->returnType); + for (auto &unbound : ctx->getRealizationBase()->pendingDefaults) { if (auto tu = unbound->getUnion()) { // Seal all dynamic unions after the iteration is over if (!tu->isSealed()) { @@ -113,7 +114,7 @@ StmtPtr TypecheckVisitor::inferTypes(StmtPtr result, bool isToplevel) { anotherRound = true; } } - ctx->pendingDefaults.clear(); + ctx->getRealizationBase()->pendingDefaults.clear(); if (anotherRound) continue; @@ -653,6 +654,12 @@ ir::types::Type *TypecheckVisitor::makeIRType(types::ClassType *t) { handle = module->getFloatType(); } else if (t->name == "float32") { handle = module->getFloat32Type(); + } else if (t->name == "float16") { + handle = module->getFloat16Type(); + } else if (t->name == "bfloat16") { + handle = module->getBFloat16Type(); + } else if (t->name == "float128") { + handle = module->getFloat128Type(); } else if (t->name == "str") { handle = module->getStringType(); } else if (t->name == "Int" || t->name == "UInt") { @@ -936,7 +943,9 @@ TypecheckVisitor::generateSpecialAst(types::FuncType *type) { N(N(N("std.internal.types.error.TypeError"), N("invalid union call")))); // suite->stmts.push_back(N(N())); - unify(type->getRetType(), ctx->instantiate(ctx->getType("Union"))); + + auto ret = ctx->instantiate(ctx->getType("Union")); + unify(type->getRetType(), ret); ast->suite = suite; } else if (startswith(ast->name, "__internal__.get_union_first:0")) { // def __internal__.get_union_first(union: Union[T0]): diff --git a/codon/parser/visitors/typecheck/typecheck.cpp b/codon/parser/visitors/typecheck/typecheck.cpp index 64ab4cd9..93030123 100644 --- a/codon/parser/visitors/typecheck/typecheck.cpp +++ b/codon/parser/visitors/typecheck/typecheck.cpp @@ -49,6 +49,7 @@ ExprPtr TypecheckVisitor::transform(ExprPtr &expr) { auto typ = expr->type; if (!expr->done) { + bool isIntStatic = expr->staticValue.type == StaticValue::INT; TypecheckVisitor v(ctx, prependStmts); v.setSrcInfo(expr->getSrcInfo()); ctx->pushSrcInfo(expr->getSrcInfo()); @@ -60,7 +61,8 @@ ExprPtr TypecheckVisitor::transform(ExprPtr &expr) { expr = v.resultExpr; } seqassert(expr->type, "type not set for {}", expr); - unify(typ, expr->type); + if (!(isIntStatic && expr->type->is("bool"))) + unify(typ, expr->type); if (expr->done) { ctx->changedNodes++; } diff --git a/stdlib/internal/core.codon b/stdlib/internal/core.codon index 940b47a3..b2febe89 100644 --- a/stdlib/internal/core.codon +++ b/stdlib/internal/core.codon @@ -45,6 +45,27 @@ class float32: MIN_10_EXP = -37 pass +@tuple +@__internal__ +@__notuple__ +class float16: + MIN_10_EXP = -4 + pass + +@tuple +@__internal__ +@__notuple__ +class bfloat16: + MIN_10_EXP = -37 + pass + +@tuple +@__internal__ +@__notuple__ +class float128: + MIN_10_EXP = -4931 + pass + @tuple @__internal__ class type: diff --git a/stdlib/internal/internal.codon b/stdlib/internal/internal.codon index f94fd0be..87a5dfc4 100644 --- a/stdlib/internal/internal.codon +++ b/stdlib/internal/internal.codon @@ -618,7 +618,7 @@ class __magic__: # @dataclass parameter: gpu=True def from_gpu_new(other: T, T: type) -> T: - __internal__.class_from_gpu_new(other) + return __internal__.class_from_gpu_new(other) # @dataclass parameter: repr=True def repr(slf) -> str: diff --git a/stdlib/internal/types/complex.codon b/stdlib/internal/types/complex.codon index 9e803190..d7e85934 100644 --- a/stdlib/internal/types/complex.codon +++ b/stdlib/internal/types/complex.codon @@ -1,6 +1,6 @@ # Copyright (C) 2022-2023 Exaloop Inc. -@tuple +@tuple(python=False) class complex64: real: float32 imag: float32 diff --git a/stdlib/internal/types/float.codon b/stdlib/internal/types/float.codon index c1cf3db2..e5b3cfc1 100644 --- a/stdlib/internal/types/float.codon +++ b/stdlib/internal/types/float.codon @@ -92,7 +92,7 @@ class float: mod = self % other div = (self - mod) / other if mod: - if (other < 0) != (mod < 0): + if (other < 0.0) != (mod < 0.0): mod += other div -= 1.0 else: @@ -475,7 +475,7 @@ class float32: %tmp = fmul float %a, %b ret float %tmp - def __floordiv__(self, other: float32) -> float: + def __floordiv__(self, other: float32) -> float32: return self.__truediv__(other).__floor__() @pure @@ -494,19 +494,19 @@ class float32: mod = self % other div = (self - mod) / other if mod: - if (other < 0) != (mod < 0): + if (other < float32(0.0)) != (mod < float32(0.0)): mod += other - div -= 1.0 + div -= float32(1.0) else: - mod = (0.0).copysign(other) + mod = float32(0.0).copysign(other) - floordiv = 0.0 + floordiv = float32(0.0) if div: floordiv = div.__floor__() - if div - floordiv > 0.5: - floordiv += 1.0 + if div - floordiv > float32(0.5): + floordiv += float32(1.0) else: - floordiv = (0.0).copysign(self / other) + floordiv = float32(0.0).copysign(self / other) return (floordiv, mod) @@ -752,10 +752,913 @@ class float32: def __match__(self, obj: float32) -> bool: return self == obj +@extend +class float16: + @pure + @llvm + def __new__(self: float) -> float16: + %0 = fptrunc double %self to half + ret half %0 + + def __new__(what: float16) -> float16: + return what + + def __new__() -> float16: + return float16.__new__(0.0) + + def __repr__(self) -> str: + return self.__float__().__repr__() + + def __format__(self, format_spec: str) -> str: + return self.__float__().__format(format_spec) + + def __copy__(self) -> float16: + return self + + def __deepcopy__(self) -> float16: + return self + + @pure + @llvm + def __int__(self) -> int: + %0 = fptosi half %self to i64 + ret i64 %0 + + @pure + @llvm + def __float__(self) -> float: + %0 = fpext half %self to double + ret double %0 + + @pure + @llvm + def __bool__(self) -> bool: + %0 = fcmp une half %self, 0.000000e+00 + %1 = zext i1 %0 to i8 + ret i8 %1 + + def __pos__(self) -> float16: + return self + + @pure + @llvm + def __neg__(self) -> float16: + %0 = fneg half %self + ret half %0 + + @pure + @commutative + @llvm + def __add__(a: float16, b: float16) -> float16: + %tmp = fadd half %a, %b + ret half %tmp + + @pure + @llvm + def __sub__(a: float16, b: float16) -> float16: + %tmp = fsub half %a, %b + ret half %tmp + + @pure + @commutative + @llvm + def __mul__(a: float16, b: float16) -> float16: + %tmp = fmul half %a, %b + ret half %tmp + + def __floordiv__(self, other: float16) -> float16: + return self.__truediv__(other).__floor__() + + @pure + @llvm + def __truediv__(a: float16, b: float16) -> float16: + %tmp = fdiv half %a, %b + ret half %tmp + + @pure + @llvm + def __mod__(a: float16, b: float16) -> float16: + %tmp = frem half %a, %b + ret half %tmp + + def __divmod__(self, other: float16) -> Tuple[float16, float16]: + mod = self % other + div = (self - mod) / other + if mod: + if (other < float16(0.0)) != (mod < float16(0.0)): + mod += other + div -= float16(1.0) + else: + mod = float16(0.0).copysign(other) + + floordiv = float16(0.0) + if div: + floordiv = div.__floor__() + if div - floordiv > float16(0.5): + floordiv += float16(1.0) + else: + floordiv = float16(0.0).copysign(self / other) + + return (floordiv, mod) + + @pure + @llvm + def __eq__(a: float16, b: float16) -> bool: + %tmp = fcmp oeq half %a, %b + %res = zext i1 %tmp to i8 + ret i8 %res + + @pure + @llvm + def __ne__(a: float16, b: float16) -> bool: + %tmp = fcmp une half %a, %b + %res = zext i1 %tmp to i8 + ret i8 %res + + @pure + @llvm + def __lt__(a: float16, b: float16) -> bool: + %tmp = fcmp olt half %a, %b + %res = zext i1 %tmp to i8 + ret i8 %res + + @pure + @llvm + def __gt__(a: float16, b: float16) -> bool: + %tmp = fcmp ogt half %a, %b + %res = zext i1 %tmp to i8 + ret i8 %res + + @pure + @llvm + def __le__(a: float16, b: float16) -> bool: + %tmp = fcmp ole half %a, %b + %res = zext i1 %tmp to i8 + ret i8 %res + + @pure + @llvm + def __ge__(a: float16, b: float16) -> bool: + %tmp = fcmp oge half %a, %b + %res = zext i1 %tmp to i8 + ret i8 %res + + @pure + @llvm + def sqrt(a: float16) -> float16: + declare half @llvm.sqrt.f16(half %a) + %tmp = call half @llvm.sqrt.f16(half %a) + ret half %tmp + + @pure + @llvm + def sin(a: float16) -> float16: + declare half @llvm.sin.f16(half %a) + %tmp = call half @llvm.sin.f16(half %a) + ret half %tmp + + @pure + @llvm + def cos(a: float16) -> float16: + declare half @llvm.cos.f16(half %a) + %tmp = call half @llvm.cos.f16(half %a) + ret half %tmp + + @pure + @llvm + def exp(a: float16) -> float16: + declare half @llvm.exp.f16(half %a) + %tmp = call half @llvm.exp.f16(half %a) + ret half %tmp + + @pure + @llvm + def exp2(a: float16) -> float16: + declare half @llvm.exp2.f16(half %a) + %tmp = call half @llvm.exp2.f16(half %a) + ret half %tmp + + @pure + @llvm + def log(a: float16) -> float16: + declare half @llvm.log.f16(half %a) + %tmp = call half @llvm.log.f16(half %a) + ret half %tmp + + @pure + @llvm + def log10(a: float16) -> float16: + declare half @llvm.log10.f16(half %a) + %tmp = call half @llvm.log10.f16(half %a) + ret half %tmp + + @pure + @llvm + def log2(a: float16) -> float16: + declare half @llvm.log2.f16(half %a) + %tmp = call half @llvm.log2.f16(half %a) + ret half %tmp + + @pure + @llvm + def __abs__(a: float16) -> float16: + declare half @llvm.fabs.f16(half %a) + %tmp = call half @llvm.fabs.f16(half %a) + ret half %tmp + + @pure + @llvm + def __floor__(a: float16) -> float16: + declare half @llvm.floor.f16(half %a) + %tmp = call half @llvm.floor.f16(half %a) + ret half %tmp + + @pure + @llvm + def __ceil__(a: float16) -> float16: + declare half @llvm.ceil.f16(half %a) + %tmp = call half @llvm.ceil.f16(half %a) + ret half %tmp + + @pure + @llvm + def __trunc__(a: float16) -> float16: + declare half @llvm.trunc.f16(half %a) + %tmp = call half @llvm.trunc.f16(half %a) + ret half %tmp + + @pure + @llvm + def rint(a: float16) -> float16: + declare half @llvm.rint.f16(half %a) + %tmp = call half @llvm.rint.f16(half %a) + ret half %tmp + + @pure + @llvm + def nearbyint(a: float16) -> float16: + declare half @llvm.nearbyint.f16(half %a) + %tmp = call half @llvm.nearbyint.f16(half %a) + ret half %tmp + + @pure + @llvm + def __round__(a: float16) -> float16: + declare half @llvm.round.f16(half %a) + %tmp = call half @llvm.round.f16(half %a) + ret half %tmp + + @pure + @llvm + def __pow__(a: float16, b: float16) -> float16: + declare half @llvm.pow.f16(half %a, half %b) + %tmp = call half @llvm.pow.f16(half %a, half %b) + ret half %tmp + + @pure + @llvm + def min(a: float16, b: float16) -> float16: + declare half @llvm.minnum.f16(half %a, half %b) + %tmp = call half @llvm.minnum.f16(half %a, half %b) + ret half %tmp + + @pure + @llvm + def max(a: float16, b: float16) -> float16: + declare half @llvm.maxnum.f16(half %a, half %b) + %tmp = call half @llvm.maxnum.f16(half %a, half %b) + ret half %tmp + + @pure + @llvm + def copysign(a: float16, b: float16) -> float16: + declare half @llvm.copysign.f16(half %a, half %b) + %tmp = call half @llvm.copysign.f16(half %a, half %b) + ret half %tmp + + @pure + @llvm + def fma(a: float16, b: float16, c: float16) -> float16: + declare half @llvm.fma.f16(half %a, half %b, half %c) + %tmp = call half @llvm.fma.f16(half %a, half %b, half %c) + ret half %tmp + + def __hash__(self) -> int: + return self.__float__().__hash__() + + def __match__(self, obj: float16) -> bool: + return self == obj + +@extend +class bfloat16: + @pure + @llvm + def __new__(self: float) -> bfloat16: + %0 = fptrunc double %self to bfloat + ret bfloat %0 + + def __new__(what: bfloat16) -> bfloat16: + return what + + def __new__() -> bfloat16: + return bfloat16.__new__(0.0) + + def __repr__(self) -> str: + return self.__float__().__repr__() + + def __format__(self, format_spec: str) -> str: + return self.__float__().__format(format_spec) + + def __copy__(self) -> bfloat16: + return self + + def __deepcopy__(self) -> bfloat16: + return self + + @pure + @llvm + def __int__(self) -> int: + %0 = fptosi bfloat %self to i64 + ret i64 %0 + + @pure + @llvm + def __float__(self) -> float: + %0 = fpext bfloat %self to double + ret double %0 + + @pure + @llvm + def __bool__(self) -> bool: + %0 = fcmp une bfloat %self, 0.000000e+00 + %1 = zext i1 %0 to i8 + ret i8 %1 + + def __pos__(self) -> bfloat16: + return self + + @pure + @llvm + def __neg__(self) -> bfloat16: + %0 = fneg bfloat %self + ret bfloat %0 + + @pure + @commutative + @llvm + def __add__(a: bfloat16, b: bfloat16) -> bfloat16: + %tmp = fadd bfloat %a, %b + ret bfloat %tmp + + @pure + @llvm + def __sub__(a: bfloat16, b: bfloat16) -> bfloat16: + %tmp = fsub bfloat %a, %b + ret bfloat %tmp + + @pure + @commutative + @llvm + def __mul__(a: bfloat16, b: bfloat16) -> bfloat16: + %tmp = fmul bfloat %a, %b + ret bfloat %tmp + + def __floordiv__(self, other: bfloat16) -> bfloat16: + return self.__truediv__(other).__floor__() + + @pure + @llvm + def __truediv__(a: bfloat16, b: bfloat16) -> bfloat16: + %tmp = fdiv bfloat %a, %b + ret bfloat %tmp + + @pure + @llvm + def __mod__(a: bfloat16, b: bfloat16) -> bfloat16: + %tmp = frem bfloat %a, %b + ret bfloat %tmp + + def __divmod__(self, other: bfloat16) -> Tuple[bfloat16, bfloat16]: + mod = self % other + div = (self - mod) / other + if mod: + if (other < bfloat16(0.0)) != (mod < bfloat16(0.0)): + mod += other + div -= bfloat16(1.0) + else: + mod = bfloat16(0.0).copysign(other) + + floordiv = bfloat16(0.0) + if div: + floordiv = div.__floor__() + if div - floordiv > bfloat16(0.5): + floordiv += bfloat16(1.0) + else: + floordiv = bfloat16(0.0).copysign(self / other) + + return (floordiv, mod) + + @pure + @llvm + def __eq__(a: bfloat16, b: bfloat16) -> bool: + %tmp = fcmp oeq bfloat %a, %b + %res = zext i1 %tmp to i8 + ret i8 %res + + @pure + @llvm + def __ne__(a: bfloat16, b: bfloat16) -> bool: + %tmp = fcmp une bfloat %a, %b + %res = zext i1 %tmp to i8 + ret i8 %res + + @pure + @llvm + def __lt__(a: bfloat16, b: bfloat16) -> bool: + %tmp = fcmp olt bfloat %a, %b + %res = zext i1 %tmp to i8 + ret i8 %res + + @pure + @llvm + def __gt__(a: bfloat16, b: bfloat16) -> bool: + %tmp = fcmp ogt bfloat %a, %b + %res = zext i1 %tmp to i8 + ret i8 %res + + @pure + @llvm + def __le__(a: bfloat16, b: bfloat16) -> bool: + %tmp = fcmp ole bfloat %a, %b + %res = zext i1 %tmp to i8 + ret i8 %res + + @pure + @llvm + def __ge__(a: bfloat16, b: bfloat16) -> bool: + %tmp = fcmp oge bfloat %a, %b + %res = zext i1 %tmp to i8 + ret i8 %res + + @pure + @llvm + def sqrt(a: bfloat16) -> bfloat16: + declare bfloat @llvm.sqrt.bf16(bfloat %a) + %tmp = call bfloat @llvm.sqrt.bf16(bfloat %a) + ret bfloat %tmp + + @pure + @llvm + def sin(a: bfloat16) -> bfloat16: + declare bfloat @llvm.sin.bf16(bfloat %a) + %tmp = call bfloat @llvm.sin.bf16(bfloat %a) + ret bfloat %tmp + + @pure + @llvm + def cos(a: bfloat16) -> bfloat16: + declare bfloat @llvm.cos.bf16(bfloat %a) + %tmp = call bfloat @llvm.cos.bf16(bfloat %a) + ret bfloat %tmp + + @pure + @llvm + def exp(a: bfloat16) -> bfloat16: + declare bfloat @llvm.exp.bf16(bfloat %a) + %tmp = call bfloat @llvm.exp.bf16(bfloat %a) + ret bfloat %tmp + + @pure + @llvm + def exp2(a: bfloat16) -> bfloat16: + declare bfloat @llvm.exp2.bf16(bfloat %a) + %tmp = call bfloat @llvm.exp2.bf16(bfloat %a) + ret bfloat %tmp + + @pure + @llvm + def log(a: bfloat16) -> bfloat16: + declare bfloat @llvm.log.bf16(bfloat %a) + %tmp = call bfloat @llvm.log.bf16(bfloat %a) + ret bfloat %tmp + + @pure + @llvm + def log10(a: bfloat16) -> bfloat16: + declare bfloat @llvm.log10.bf16(bfloat %a) + %tmp = call bfloat @llvm.log10.bf16(bfloat %a) + ret bfloat %tmp + + @pure + @llvm + def log2(a: bfloat16) -> bfloat16: + declare bfloat @llvm.log2.bf16(bfloat %a) + %tmp = call bfloat @llvm.log2.bf16(bfloat %a) + ret bfloat %tmp + + @pure + @llvm + def __abs__(a: bfloat16) -> bfloat16: + declare bfloat @llvm.fabs.bf16(bfloat %a) + %tmp = call bfloat @llvm.fabs.bf16(bfloat %a) + ret bfloat %tmp + + @pure + @llvm + def __floor__(a: bfloat16) -> bfloat16: + declare bfloat @llvm.floor.bf16(bfloat %a) + %tmp = call bfloat @llvm.floor.bf16(bfloat %a) + ret bfloat %tmp + + @pure + @llvm + def __ceil__(a: bfloat16) -> bfloat16: + declare bfloat @llvm.ceil.bf16(bfloat %a) + %tmp = call bfloat @llvm.ceil.bf16(bfloat %a) + ret bfloat %tmp + + @pure + @llvm + def __trunc__(a: bfloat16) -> bfloat16: + declare bfloat @llvm.trunc.bf16(bfloat %a) + %tmp = call bfloat @llvm.trunc.bf16(bfloat %a) + ret bfloat %tmp + + @pure + @llvm + def rint(a: bfloat16) -> bfloat16: + declare bfloat @llvm.rint.bf16(bfloat %a) + %tmp = call bfloat @llvm.rint.bf16(bfloat %a) + ret bfloat %tmp + + @pure + @llvm + def nearbyint(a: bfloat16) -> bfloat16: + declare bfloat @llvm.nearbyint.bf16(bfloat %a) + %tmp = call bfloat @llvm.nearbyint.bf16(bfloat %a) + ret bfloat %tmp + + @pure + @llvm + def __round__(a: bfloat16) -> bfloat16: + declare bfloat @llvm.round.bf16(bfloat %a) + %tmp = call bfloat @llvm.round.bf16(bfloat %a) + ret bfloat %tmp + + @pure + @llvm + def __pow__(a: bfloat16, b: bfloat16) -> bfloat16: + declare bfloat @llvm.pow.bf16(bfloat %a, bfloat %b) + %tmp = call bfloat @llvm.pow.bf16(bfloat %a, bfloat %b) + ret bfloat %tmp + + @pure + @llvm + def min(a: bfloat16, b: bfloat16) -> bfloat16: + declare bfloat @llvm.minnum.bf16(bfloat %a, bfloat %b) + %tmp = call bfloat @llvm.minnum.bf16(bfloat %a, bfloat %b) + ret bfloat %tmp + + @pure + @llvm + def max(a: bfloat16, b: bfloat16) -> bfloat16: + declare bfloat @llvm.maxnum.bf16(bfloat %a, bfloat %b) + %tmp = call bfloat @llvm.maxnum.bf16(bfloat %a, bfloat %b) + ret bfloat %tmp + + @pure + @llvm + def copysign(a: bfloat16, b: bfloat16) -> bfloat16: + declare bfloat @llvm.copysign.bf16(bfloat %a, bfloat %b) + %tmp = call bfloat @llvm.copysign.bf16(bfloat %a, bfloat %b) + ret bfloat %tmp + + @pure + @llvm + def fma(a: bfloat16, b: bfloat16, c: bfloat16) -> bfloat16: + declare bfloat @llvm.fma.bf16(bfloat %a, bfloat %b, bfloat %c) + %tmp = call bfloat @llvm.fma.bf16(bfloat %a, bfloat %b, bfloat %c) + ret bfloat %tmp + + def __hash__(self) -> int: + return self.__float__().__hash__() + + def __match__(self, obj: bfloat16) -> bool: + return self == obj + +@extend +class float128: + @pure + @llvm + def __new__(self: float) -> float128: + %0 = fpext double %self to fp128 + ret fp128 %0 + + def __new__(what: float128) -> float128: + return what + + def __new__() -> float128: + return float128.__new__(0.0) + + def __repr__(self) -> str: + return self.__float__().__repr__() + + def __format__(self, format_spec: str) -> str: + return self.__float__().__format(format_spec) + + def __copy__(self) -> float128: + return self + + def __deepcopy__(self) -> float128: + return self + + @pure + @llvm + def __int__(self) -> int: + %0 = fptosi fp128 %self to i64 + ret i64 %0 + + @pure + @llvm + def __float__(self) -> float: + %0 = fptrunc fp128 %self to double + ret double %0 + + @pure + @llvm + def __bool__(self) -> bool: + %0 = fcmp une fp128 %self, 0xL00000000000000000000000000000000 + %1 = zext i1 %0 to i8 + ret i8 %1 + + def __pos__(self) -> float128: + return self + + @pure + @llvm + def __neg__(self) -> float128: + %0 = fneg fp128 %self + ret fp128 %0 + + @pure + @commutative + @llvm + def __add__(a: float128, b: float128) -> float128: + %tmp = fadd fp128 %a, %b + ret fp128 %tmp + + @pure + @llvm + def __sub__(a: float128, b: float128) -> float128: + %tmp = fsub fp128 %a, %b + ret fp128 %tmp + + @pure + @commutative + @llvm + def __mul__(a: float128, b: float128) -> float128: + %tmp = fmul fp128 %a, %b + ret fp128 %tmp + + def __floordiv__(self, other: float128) -> float128: + return self.__truediv__(other).__floor__() + + @pure + @llvm + def __truediv__(a: float128, b: float128) -> float128: + %tmp = fdiv fp128 %a, %b + ret fp128 %tmp + + @pure + @llvm + def __mod__(a: float128, b: float128) -> float128: + %tmp = frem fp128 %a, %b + ret fp128 %tmp + + def __divmod__(self, other: float128) -> Tuple[float128, float128]: + mod = self % other + div = (self - mod) / other + if mod: + if (other < float128(0.0)) != (mod < float128(0)): + mod += other + div -= float128(1.0) + else: + mod = float128(0.0).copysign(other) + + floordiv = float128(0.0) + if div: + floordiv = div.__floor__() + if div - floordiv > float128(0.5): + floordiv += float128(1.0) + else: + floordiv = float128(0.0).copysign(self / other) + + return (floordiv, mod) + + @pure + @llvm + def __eq__(a: float128, b: float128) -> bool: + %tmp = fcmp oeq fp128 %a, %b + %res = zext i1 %tmp to i8 + ret i8 %res + + @pure + @llvm + def __ne__(a: float128, b: float128) -> bool: + %tmp = fcmp une fp128 %a, %b + %res = zext i1 %tmp to i8 + ret i8 %res + + @pure + @llvm + def __lt__(a: float128, b: float128) -> bool: + %tmp = fcmp olt fp128 %a, %b + %res = zext i1 %tmp to i8 + ret i8 %res + + @pure + @llvm + def __gt__(a: float128, b: float128) -> bool: + %tmp = fcmp ogt fp128 %a, %b + %res = zext i1 %tmp to i8 + ret i8 %res + + @pure + @llvm + def __le__(a: float128, b: float128) -> bool: + %tmp = fcmp ole fp128 %a, %b + %res = zext i1 %tmp to i8 + ret i8 %res + + @pure + @llvm + def __ge__(a: float128, b: float128) -> bool: + %tmp = fcmp oge fp128 %a, %b + %res = zext i1 %tmp to i8 + ret i8 %res + + @pure + @llvm + def sqrt(a: float128) -> float128: + declare fp128 @llvm.sqrt.f128(fp128 %a) + %tmp = call fp128 @llvm.sqrt.f128(fp128 %a) + ret fp128 %tmp + + @pure + @llvm + def sin(a: float128) -> float128: + declare fp128 @llvm.sin.f128(fp128 %a) + %tmp = call fp128 @llvm.sin.f128(fp128 %a) + ret fp128 %tmp + + @pure + @llvm + def cos(a: float128) -> float128: + declare fp128 @llvm.cos.f128(fp128 %a) + %tmp = call fp128 @llvm.cos.f128(fp128 %a) + ret fp128 %tmp + + @pure + @llvm + def exp(a: float128) -> float128: + declare fp128 @llvm.exp.f128(fp128 %a) + %tmp = call fp128 @llvm.exp.f128(fp128 %a) + ret fp128 %tmp + + @pure + @llvm + def exp2(a: float128) -> float128: + declare fp128 @llvm.exp2.f128(fp128 %a) + %tmp = call fp128 @llvm.exp2.f128(fp128 %a) + ret fp128 %tmp + + @pure + @llvm + def log(a: float128) -> float128: + declare fp128 @llvm.log.f128(fp128 %a) + %tmp = call fp128 @llvm.log.f128(fp128 %a) + ret fp128 %tmp + + @pure + @llvm + def log10(a: float128) -> float128: + declare fp128 @llvm.log10.f128(fp128 %a) + %tmp = call fp128 @llvm.log10.f128(fp128 %a) + ret fp128 %tmp + + @pure + @llvm + def log2(a: float128) -> float128: + declare fp128 @llvm.log2.f128(fp128 %a) + %tmp = call fp128 @llvm.log2.f128(fp128 %a) + ret fp128 %tmp + + @pure + @llvm + def __abs__(a: float128) -> float128: + declare fp128 @llvm.fabs.f128(fp128 %a) + %tmp = call fp128 @llvm.fabs.f128(fp128 %a) + ret fp128 %tmp + + @pure + @llvm + def __floor__(a: float128) -> float128: + declare fp128 @llvm.floor.f128(fp128 %a) + %tmp = call fp128 @llvm.floor.f128(fp128 %a) + ret fp128 %tmp + + @pure + @llvm + def __ceil__(a: float128) -> float128: + declare fp128 @llvm.ceil.f128(fp128 %a) + %tmp = call fp128 @llvm.ceil.f128(fp128 %a) + ret fp128 %tmp + + @pure + @llvm + def __trunc__(a: float128) -> float128: + declare fp128 @llvm.trunc.f128(fp128 %a) + %tmp = call fp128 @llvm.trunc.f128(fp128 %a) + ret fp128 %tmp + + @pure + @llvm + def rint(a: float128) -> float128: + declare fp128 @llvm.rint.f128(fp128 %a) + %tmp = call fp128 @llvm.rint.f128(fp128 %a) + ret fp128 %tmp + + @pure + @llvm + def nearbyint(a: float128) -> float128: + declare fp128 @llvm.nearbyint.f128(fp128 %a) + %tmp = call fp128 @llvm.nearbyint.f128(fp128 %a) + ret fp128 %tmp + + @pure + @llvm + def __round__(a: float128) -> float128: + declare fp128 @llvm.round.f128(fp128 %a) + %tmp = call fp128 @llvm.round.f128(fp128 %a) + ret fp128 %tmp + + @pure + @llvm + def __pow__(a: float128, b: float128) -> float128: + declare fp128 @llvm.pow.f128(fp128 %a, fp128 %b) + %tmp = call fp128 @llvm.pow.f128(fp128 %a, fp128 %b) + ret fp128 %tmp + + @pure + @llvm + def min(a: float128, b: float128) -> float128: + declare fp128 @llvm.minnum.f128(fp128 %a, fp128 %b) + %tmp = call fp128 @llvm.minnum.f128(fp128 %a, fp128 %b) + ret fp128 %tmp + + @pure + @llvm + def max(a: float128, b: float128) -> float128: + declare fp128 @llvm.maxnum.f128(fp128 %a, fp128 %b) + %tmp = call fp128 @llvm.maxnum.f128(fp128 %a, fp128 %b) + ret fp128 %tmp + + @pure + @llvm + def copysign(a: float128, b: float128) -> float128: + declare fp128 @llvm.copysign.f128(fp128 %a, fp128 %b) + %tmp = call fp128 @llvm.copysign.f128(fp128 %a, fp128 %b) + ret fp128 %tmp + + @pure + @llvm + def fma(a: float128, b: float128, c: float128) -> float128: + declare fp128 @llvm.fma.f128(fp128 %a, fp128 %b, fp128 %c) + %tmp = call fp128 @llvm.fma.f128(fp128 %a, fp128 %b, fp128 %c) + ret fp128 %tmp + + def __hash__(self) -> int: + return self.__float__().__hash__() + + def __match__(self, obj: float128) -> bool: + return self == obj + @extend class float: def __suffix_f32__(double) -> float32: return float32.__new__(double) + def __suffix_f16__(double) -> float16: + return float16.__new__(double) + + def __suffix_bf16__(double) -> bfloat16: + return bfloat16.__new__(double) + + def __suffix_f128__(double) -> float128: + return float128.__new__(double) + +f16 = float16 +bf16 = bfloat16 f32 = float32 f64 = float +f128 = float128 diff --git a/test/core/arithmetic.codon b/test/core/arithmetic.codon index dc8c90be..203190b2 100644 --- a/test/core/arithmetic.codon +++ b/test/core/arithmetic.codon @@ -142,3 +142,45 @@ def test_int_pow(): assert f(T2(0)) ** f(T2(0)) == T2(1) assert str(f(T2(31)) ** f(T2(31))) == '17069174130723235958610643029059314756044734431' test_int_pow() + +@test +def test_float(F: type): + x = F(5.5) + assert str(x) == '5.5' + assert F(x) == x + assert F() == F(0.0) + assert x.__copy__() == x + assert x.__deepcopy__() == x + assert int(x) == 5 + assert float(x) == 5.5 + assert bool(x) + assert not bool(F()) + assert +x == x + assert -x == F(-5.5) + assert x + x == F(11.0) + assert x - F(1.0) == F(4.5) + assert x * F(3.0) == F(16.5) + assert x / F(2.0) == F(2.75) + if F is not float128: # LLVM ops give wrong results for fp128 + assert x // F(2.0) == F(2.0) + assert x % F(0.75) == F(0.25) + assert divmod(x, F(0.75)) == (F(7.0), F(0.25)) + assert x == x + assert x != F() + assert x < F(6.5) + assert x > F(4.5) + assert x <= F(6.5) + assert x >= F(4.5) + assert x >= x + assert x <= x + assert abs(x) == x + assert abs(-x) == x + assert x.__match__(x) + assert not x.__match__(F()) + assert hash(x) == hash(5.5) + +test_float(float) +test_float(float32) +#test_float(float16) +#test_float(bfloat16) +#test_float(float128) diff --git a/test/transform/kernels.codon b/test/transform/kernels.codon index df2d93e7..f20a05c5 100644 --- a/test/transform/kernels.codon +++ b/test/transform/kernels.codon @@ -34,9 +34,19 @@ def test_conversions(): def kernel(x, v): v[0] = x + def empty_tuple(x): + if staticlen(x) == 0: + return () + else: + T = type(x[0]) + return (T(),) + empty_tuple(x[1:]) + def check(x): - T = type(x) - v = [T()] + if isinstance(x, Tuple): + e = empty_tuple(x) + else: + e = type(x)() + v = [e] kernel(x, v, grid=1, block=1) return v == [x]