Skip to content

Commit

Permalink
GPU compilation fixes (#496)
Browse files Browse the repository at this point in the history
* Fix __from_gpu_new__

* Fix GPU tests

* Update GPU debug codegen

* Add will-return attribute for GPU compilation

* Fix isinstance on unresolved types

* Fix union type instantiation and pendingRealizations placement

* Add float16, bfloat16 and float128 IR types

* Add float16, bfloat16 and float128 types

* Mark complex64 as no-python

* Fix float methods

* Add float tests

* Disable some float tests

* Fix bitset in reaching definitions analysis

* Fix static bool unification

---------

Co-authored-by: Ibrahim Numanagić <ibrahimpasa@gmail.com>
  • Loading branch information
arshajii and inumanag authored Nov 18, 2023
1 parent 4eb641e commit 2c74407
Show file tree
Hide file tree
Showing 21 changed files with 1,137 additions and 25 deletions.
4 changes: 2 additions & 2 deletions codon/cir/analyze/dataflow/reaching.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,10 @@ struct BitSet {
return res;
}

void set(unsigned bit) { words.data()[bit / B] |= (1 << (bit % B)); }
void set(unsigned bit) { words.data()[bit / B] |= (1UL << (bit % B)); }

bool get(unsigned bit) const {
return (words.data()[bit / B] & (1 << (bit % B))) != 0;
return (words.data()[bit / B] & (1UL << (bit % B))) != 0;
}

bool equals(const BitSet &other, unsigned size) {
Expand Down
8 changes: 8 additions & 0 deletions codon/cir/llvm/gpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,14 @@ void moduleToPTX(llvm::Module *M, const std::string &filename,
linkLibdevice(M, libdevice);
remapFunctions(M);

// Strip debug info and remove noinline from functions (added in debug mode).
// Also, tell LLVM that all functions will return.
for (auto &F : *M) {
F.removeFnAttr(llvm::Attribute::AttrKind::NoInline);
F.setWillReturn();
}
llvm::StripDebugInfo(*M);

// Run NVPTX passes and general opt pipeline.
{
llvm::LoopAnalysisManager lam;
Expand Down
28 changes: 28 additions & 0 deletions codon/cir/llvm/llvisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2086,6 +2086,18 @@ llvm::Type *LLVMVisitor::getLLVMType(types::Type *t) {
return B->getFloatTy();
}

if (auto *x = cast<types::Float16Type>(t)) {
return B->getHalfTy();
}

if (auto *x = cast<types::BFloat16Type>(t)) {
return B->getBFloatTy();
}

if (auto *x = cast<types::Float128Type>(t)) {
return llvm::Type::getFP128Ty(*context);
}

if (auto *x = cast<types::BoolType>(t)) {
return B->getInt8Ty();
}
Expand Down Expand Up @@ -2203,6 +2215,22 @@ llvm::DIType *LLVMVisitor::getDITypeHelper(
x->getName(), layout.getTypeAllocSizeInBits(type), llvm::dwarf::DW_ATE_float);
}

if (auto *x = cast<types::Float16Type>(t)) {
return db.builder->createBasicType(
x->getName(), layout.getTypeAllocSizeInBits(type), llvm::dwarf::DW_ATE_float);
}

if (auto *x = cast<types::BFloat16Type>(t)) {
return db.builder->createBasicType(
x->getName(), layout.getTypeAllocSizeInBits(type), llvm::dwarf::DW_ATE_float);
}

if (auto *x = cast<types::Float128Type>(t)) {
return db.builder->createBasicType(x->getName(),
layout.getTypeAllocSizeInBits(type),
llvm::dwarf::DW_ATE_HP_float128);
}

if (auto *x = cast<types::BoolType>(t)) {
return db.builder->createBasicType(
x->getName(), layout.getTypeAllocSizeInBits(type), llvm::dwarf::DW_ATE_boolean);
Expand Down
21 changes: 21 additions & 0 deletions codon/cir/module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ const std::string Module::BYTE_NAME = "byte";
const std::string Module::INT_NAME = "int";
const std::string Module::FLOAT_NAME = "float";
const std::string Module::FLOAT32_NAME = "float32";
const std::string Module::FLOAT16_NAME = "float16";
const std::string Module::BFLOAT16_NAME = "bfloat16";
const std::string Module::FLOAT128_NAME = "float128";
const std::string Module::STRING_NAME = "str";

const std::string Module::EQ_MAGIC_NAME = "__eq__";
Expand Down Expand Up @@ -239,6 +242,24 @@ types::Type *Module::getFloat32Type() {
return Nr<types::Float32Type>();
}

types::Type *Module::getFloat16Type() {
if (auto *rVal = getType(FLOAT16_NAME))
return rVal;
return Nr<types::Float16Type>();
}

types::Type *Module::getBFloat16Type() {
if (auto *rVal = getType(BFLOAT16_NAME))
return rVal;
return Nr<types::BFloat16Type>();
}

types::Type *Module::getFloat128Type() {
if (auto *rVal = getType(FLOAT128_NAME))
return rVal;
return Nr<types::Float128Type>();
}

types::Type *Module::getStringType() {
if (auto *rVal = getType(STRING_NAME))
return rVal;
Expand Down
9 changes: 9 additions & 0 deletions codon/cir/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ class Module : public AcceptorExtend<Module, Node> {
static const std::string INT_NAME;
static const std::string FLOAT_NAME;
static const std::string FLOAT32_NAME;
static const std::string FLOAT16_NAME;
static const std::string BFLOAT16_NAME;
static const std::string FLOAT128_NAME;
static const std::string STRING_NAME;

static const std::string EQ_MAGIC_NAME;
Expand Down Expand Up @@ -338,6 +341,12 @@ class Module : public AcceptorExtend<Module, Node> {
types::Type *getFloatType();
/// @return the float32 type
types::Type *getFloat32Type();
/// @return the float16 type
types::Type *getFloat16Type();
/// @return the bfloat16 type
types::Type *getBFloat16Type();
/// @return the float128 type
types::Type *getFloat128Type();
/// @return the string type
types::Type *getStringType();
/// Gets a pointer type.
Expand Down
6 changes: 6 additions & 0 deletions codon/cir/types/types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,12 @@ const char FloatType::NodeId = 0;

const char Float32Type::NodeId = 0;

const char Float16Type::NodeId = 0;

const char BFloat16Type::NodeId = 0;

const char Float128Type::NodeId = 0;

const char BoolType::NodeId = 0;

const char ByteType::NodeId = 0;
Expand Down
27 changes: 27 additions & 0 deletions codon/cir/types/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,33 @@ class Float32Type : public AcceptorExtend<Float32Type, PrimitiveType> {
Float32Type() : AcceptorExtend("float32") {}
};

/// Float16 type (16-bit float)
class Float16Type : public AcceptorExtend<Float16Type, PrimitiveType> {
public:
static const char NodeId;

/// Constructs a float16 type.
Float16Type() : AcceptorExtend("float16") {}
};

/// BFloat16 type (16-bit brain float)
class BFloat16Type : public AcceptorExtend<BFloat16Type, PrimitiveType> {
public:
static const char NodeId;

/// Constructs a bfloat16 type.
BFloat16Type() : AcceptorExtend("bfloat16") {}
};

/// Float128 type (128-bit float)
class Float128Type : public AcceptorExtend<Float128Type, PrimitiveType> {
public:
static const char NodeId;

/// Constructs a float128 type.
Float128Type() : AcceptorExtend("float128") {}
};

/// Bool type (8-bit unsigned integer; either 0 or 1)
class BoolType : public AcceptorExtend<BoolType, PrimitiveType> {
public:
Expand Down
9 changes: 9 additions & 0 deletions codon/cir/util/format.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,15 @@ class FormatVisitor : util::ConstVisitor {
void visit(const types::Float32Type *v) override {
fmt::print(os, FMT_STRING("(float32 '\"{}\")"), v->referenceString());
}
void visit(const types::Float16Type *v) override {
fmt::print(os, FMT_STRING("(float16 '\"{}\")"), v->referenceString());
}
void visit(const types::BFloat16Type *v) override {
fmt::print(os, FMT_STRING("(bfloat16 '\"{}\")"), v->referenceString());
}
void visit(const types::Float128Type *v) override {
fmt::print(os, FMT_STRING("(float128 '\"{}\")"), v->referenceString());
}
void visit(const types::BoolType *v) override {
fmt::print(os, FMT_STRING("(bool '\"{}\")"), v->referenceString());
}
Expand Down
6 changes: 6 additions & 0 deletions codon/cir/util/visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ void Visitor::visit(types::PrimitiveType *x) { defaultVisit(x); }
void Visitor::visit(types::IntType *x) { defaultVisit(x); }
void Visitor::visit(types::FloatType *x) { defaultVisit(x); }
void Visitor::visit(types::Float32Type *x) { defaultVisit(x); }
void Visitor::visit(types::Float16Type *x) { defaultVisit(x); }
void Visitor::visit(types::BFloat16Type *x) { defaultVisit(x); }
void Visitor::visit(types::Float128Type *x) { defaultVisit(x); }
void Visitor::visit(types::BoolType *x) { defaultVisit(x); }
void Visitor::visit(types::ByteType *x) { defaultVisit(x); }
void Visitor::visit(types::VoidType *x) { defaultVisit(x); }
Expand Down Expand Up @@ -114,6 +117,9 @@ void ConstVisitor::visit(const types::PrimitiveType *x) { defaultVisit(x); }
void ConstVisitor::visit(const types::IntType *x) { defaultVisit(x); }
void ConstVisitor::visit(const types::FloatType *x) { defaultVisit(x); }
void ConstVisitor::visit(const types::Float32Type *x) { defaultVisit(x); }
void ConstVisitor::visit(const types::Float16Type *x) { defaultVisit(x); }
void ConstVisitor::visit(const types::BFloat16Type *x) { defaultVisit(x); }
void ConstVisitor::visit(const types::Float128Type *x) { defaultVisit(x); }
void ConstVisitor::visit(const types::BoolType *x) { defaultVisit(x); }
void ConstVisitor::visit(const types::ByteType *x) { defaultVisit(x); }
void ConstVisitor::visit(const types::VoidType *x) { defaultVisit(x); }
Expand Down
9 changes: 9 additions & 0 deletions codon/cir/util/visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ class PrimitiveType;
class IntType;
class FloatType;
class Float32Type;
class Float16Type;
class BFloat16Type;
class Float128Type;
class BoolType;
class ByteType;
class VoidType;
Expand Down Expand Up @@ -152,6 +155,9 @@ class Visitor {
VISIT(types::IntType);
VISIT(types::FloatType);
VISIT(types::Float32Type);
VISIT(types::Float16Type);
VISIT(types::BFloat16Type);
VISIT(types::Float128Type);
VISIT(types::BoolType);
VISIT(types::ByteType);
VISIT(types::VoidType);
Expand Down Expand Up @@ -229,6 +235,9 @@ class ConstVisitor {
CONST_VISIT(types::IntType);
CONST_VISIT(types::FloatType);
CONST_VISIT(types::Float32Type);
CONST_VISIT(types::Float16Type);
CONST_VISIT(types::BFloat16Type);
CONST_VISIT(types::Float128Type);
CONST_VISIT(types::BoolType);
CONST_VISIT(types::ByteType);
CONST_VISIT(types::VoidType);
Expand Down
6 changes: 4 additions & 2 deletions codon/parser/visitors/typecheck/call.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -709,6 +709,7 @@ ExprPtr TypecheckVisitor::transformArray(CallExpr *expr) {
/// `isinstance(obj, ByRef)` is True if `type(obj)` is a reference type
ExprPtr TypecheckVisitor::transformIsInstance(CallExpr *expr) {
expr->setType(unify(expr->type, ctx->getType("bool")));
expr->staticValue.type = StaticValue::INT; // prevent branching until this is resolved
transform(expr->args[0].value);
auto typ = expr->args[0].value->type->getClass();
if (!typ || !typ->canRealize())
Expand Down Expand Up @@ -947,10 +948,11 @@ ExprPtr TypecheckVisitor::transformStaticPrintFn(CallExpr *expr) {
auto &args = expr->args[0].value->getCall()->args;
for (size_t i = 0; i < args.size(); i++) {
realize(args[i].value->type);
fmt::print(stderr, "[static_print] {}: {} := {}{}\n", getSrcInfo(),
fmt::print(stderr, "[static_print] {}: {} := {}{} (iter: {})\n", getSrcInfo(),
FormatVisitor::apply(args[i].value),
args[i].value->type ? args[i].value->type->debugString(1) : "-",
args[i].value->isStatic() ? " [static]" : "");
args[i].value->isStatic() ? " [static]" : "",
ctx->getRealizationBase()->iteration);
}
return nullptr;
}
Expand Down
4 changes: 2 additions & 2 deletions codon/parser/visitors/typecheck/ctx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,13 @@ types::TypePtr TypeContext::instantiate(const SrcInfo &srcInfo,
if (auto l = i.second->getLink()) {
i.second->setSrcInfo(srcInfo);
if (l->defaultType) {
pendingDefaults.insert(i.second);
getRealizationBase()->pendingDefaults.insert(i.second);
}
}
}
if (t->getUnion() && !t->getUnion()->isSealed()) {
t->setSrcInfo(srcInfo);
pendingDefaults.insert(t);
getRealizationBase()->pendingDefaults.insert(t);
}
if (auto r = t->getRecord())
if (r->repeats && r->repeats->canRealize())
Expand Down
2 changes: 1 addition & 1 deletion codon/parser/visitors/typecheck/ctx.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,12 @@ struct TypeContext : public Context<TypecheckItem> {
types::TypePtr returnType = nullptr;
/// Typechecking iteration
int iteration = 0;
std::set<types::TypePtr> pendingDefaults;
};
std::vector<RealizationBase> realizationBases;

/// The current type-checking level (for type instantiation and generalization).
int typecheckLevel;
std::set<types::TypePtr> pendingDefaults;
int changedNodes;

/// The age of the currently parsed statement.
Expand Down
17 changes: 13 additions & 4 deletions codon/parser/visitors/typecheck/infer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,9 @@ StmtPtr TypecheckVisitor::inferTypes(StmtPtr result, bool isToplevel) {
bool anotherRound = false;
// Special case: return type might have default as well (e.g., Union)
if (ctx->getRealizationBase()->returnType)
ctx->pendingDefaults.insert(ctx->getRealizationBase()->returnType);
for (auto &unbound : ctx->pendingDefaults) {
ctx->getRealizationBase()->pendingDefaults.insert(
ctx->getRealizationBase()->returnType);
for (auto &unbound : ctx->getRealizationBase()->pendingDefaults) {
if (auto tu = unbound->getUnion()) {
// Seal all dynamic unions after the iteration is over
if (!tu->isSealed()) {
Expand All @@ -113,7 +114,7 @@ StmtPtr TypecheckVisitor::inferTypes(StmtPtr result, bool isToplevel) {
anotherRound = true;
}
}
ctx->pendingDefaults.clear();
ctx->getRealizationBase()->pendingDefaults.clear();
if (anotherRound)
continue;

Expand Down Expand Up @@ -653,6 +654,12 @@ ir::types::Type *TypecheckVisitor::makeIRType(types::ClassType *t) {
handle = module->getFloatType();
} else if (t->name == "float32") {
handle = module->getFloat32Type();
} else if (t->name == "float16") {
handle = module->getFloat16Type();
} else if (t->name == "bfloat16") {
handle = module->getBFloat16Type();
} else if (t->name == "float128") {
handle = module->getFloat128Type();
} else if (t->name == "str") {
handle = module->getStringType();
} else if (t->name == "Int" || t->name == "UInt") {
Expand Down Expand Up @@ -936,7 +943,9 @@ TypecheckVisitor::generateSpecialAst(types::FuncType *type) {
N<ThrowStmt>(N<CallExpr>(N<IdExpr>("std.internal.types.error.TypeError"),
N<StringExpr>("invalid union call"))));
// suite->stmts.push_back(N<ReturnStmt>(N<NoneExpr>()));
unify(type->getRetType(), ctx->instantiate(ctx->getType("Union")));

auto ret = ctx->instantiate(ctx->getType("Union"));
unify(type->getRetType(), ret);
ast->suite = suite;
} else if (startswith(ast->name, "__internal__.get_union_first:0")) {
// def __internal__.get_union_first(union: Union[T0]):
Expand Down
4 changes: 3 additions & 1 deletion codon/parser/visitors/typecheck/typecheck.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ ExprPtr TypecheckVisitor::transform(ExprPtr &expr) {

auto typ = expr->type;
if (!expr->done) {
bool isIntStatic = expr->staticValue.type == StaticValue::INT;
TypecheckVisitor v(ctx, prependStmts);
v.setSrcInfo(expr->getSrcInfo());
ctx->pushSrcInfo(expr->getSrcInfo());
Expand All @@ -60,7 +61,8 @@ ExprPtr TypecheckVisitor::transform(ExprPtr &expr) {
expr = v.resultExpr;
}
seqassert(expr->type, "type not set for {}", expr);
unify(typ, expr->type);
if (!(isIntStatic && expr->type->is("bool")))
unify(typ, expr->type);
if (expr->done) {
ctx->changedNodes++;
}
Expand Down
21 changes: 21 additions & 0 deletions stdlib/internal/core.codon
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,27 @@ class float32:
MIN_10_EXP = -37
pass

@tuple
@__internal__
@__notuple__
class float16:
MIN_10_EXP = -4
pass

@tuple
@__internal__
@__notuple__
class bfloat16:
MIN_10_EXP = -37
pass

@tuple
@__internal__
@__notuple__
class float128:
MIN_10_EXP = -4931
pass

@tuple
@__internal__
class type:
Expand Down
2 changes: 1 addition & 1 deletion stdlib/internal/internal.codon
Original file line number Diff line number Diff line change
Expand Up @@ -618,7 +618,7 @@ class __magic__:

# @dataclass parameter: gpu=True
def from_gpu_new(other: T, T: type) -> T:
__internal__.class_from_gpu_new(other)
return __internal__.class_from_gpu_new(other)

# @dataclass parameter: repr=True
def repr(slf) -> str:
Expand Down
Loading

0 comments on commit 2c74407

Please sign in to comment.