Skip to content

Commit

Permalink
Enable variant constructors to be real consts
Browse files Browse the repository at this point in the history
  • Loading branch information
tjammer committed Oct 21, 2023
1 parent 4999c38 commit 8d71de2
Show file tree
Hide file tree
Showing 8 changed files with 228 additions and 210 deletions.
1 change: 1 addition & 0 deletions lib/codegen/abi.ml
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ module Make (T : Lltypes_intf.S) : Abi_intf.S = struct
match value.typ with
| Trecord (_, _, fields) -> (
match fields.(0).ftyp with Tbool -> false | _ -> true)
| Tvariant _ -> true
| _ -> failwith "Internal Error: Not a record to unbox"
in
let value = Llvm.const_extractvalue value.value [| 0 |] in
Expand Down
53 changes: 51 additions & 2 deletions lib/codegen/codegen.ml
Original file line number Diff line number Diff line change
Expand Up @@ -188,8 +188,11 @@ end = struct
| Mfield (expr, index) -> gen_field param expr index |> fin
| Mset (expr, value, moved) -> gen_set param expr value moved
| Mseq (expr, cont) -> gen_chain param expr cont
| Mctor (ctor, allocref, id) ->
gen_ctor param ctor typed_expr.typ allocref id
| Mctor (ctor, allocref, id) -> (
match typed_expr.const with
| Cnot -> gen_ctor param ctor typed_expr.typ allocref id
| Const ->
gen_ctor_const param ctor typed_expr.typ allocref typed_expr.return)
| Mvar_index expr -> gen_var_index param expr |> fin
| Mvar_data (expr, mid) -> gen_var_data param expr mid typed_expr.typ |> fin
| Mfmt (fmts, allocref, id) ->
Expand Down Expand Up @@ -968,6 +971,52 @@ end = struct
List.iter (fun id -> Strtbl.replace free_tbl id v) ms;
v

and gen_ctor_const param (variant, tag, expr) typ allocref return =
let lltyp = get_struct typ in
let elems = Llvm.struct_element_types lltyp in

let tag = Llvm.const_int i32_t tag in
let value =
match expr with
| Some expr ->
(* Get largest ctor to figure out the size of the variant and pad
accordingly *)
let largestsize =
match typ with
| Tvariant (_, _, ctors) ->
variant_get_largest ctors |> Option.get |> sizeof_typ
| _ -> failwith "unreachable"
in
let data = gen_constexpr param expr in
(* Change to the type of the greatest payload, or construct a type
with needed padding *)
let oursize = sizeof_typ data.typ in
if largestsize > oursize then
let padding =
let padtype = Llvm.array_type u8_t (largestsize - oursize) in
Llvm.undef padtype
in
let value = Llvm.(const_struct context [| data.value; padding |]) in
Llvm.const_named_struct lltyp [| tag; value |]
else
let data = Llvm.const_bitcast data.value elems.(1) in
Llvm.const_named_struct lltyp [| tag; data |]
| None ->
(* We might need a payload type *)
if Array.length elems > 1 then
let null = Llvm.undef elems.(1) in
Llvm.const_named_struct lltyp [| tag; null |]
else Llvm.const_named_struct lltyp [| tag |]
in
let value, kind =
if return then (
let variant = get_prealloc !allocref param lltyp variant in
ignore (Llvm.build_store value variant builder);
(variant, Const_ptr))
else (value, Const)
in
{ value; typ; lltyp; kind }

and gen_var_index param expr =
let var = gen_expr param expr in
var_index var
Expand Down
20 changes: 14 additions & 6 deletions lib/typing/patternmatching.ml
Original file line number Diff line number Diff line change
Expand Up @@ -560,13 +560,21 @@ module Make (C : Core) (R : Recs) = struct
| Some typ, Some expr ->
let texpr = convert env expr in
unify (loc, "In constructor " ^ snd name ^ ":") typ texpr.typ env;
let expr = Ctor (Path.get_hd typename, ctor.index, Some texpr) in

{ typ = variant; expr; attr = no_attr; loc }
let expr = Ctor (Path.get_hd typename, ctor.index, Some texpr)
and const =
(* There's a special case for string literals.
They will get copied here which makes them not const.
NOTE copy in convert_tuple *)
match texpr.expr with
| Const (String _) -> false
| _ -> texpr.attr.const
in
let attr = { no_attr with const } in
{ typ = variant; expr; attr; loc }
| None, None ->
let expr = Ctor (Path.get_hd typename, ctor.index, None) in
(* NOTE: Const handling for ctors is disabled, see #23 *)
{ typ = variant; expr; attr = no_attr; loc }
let expr = Ctor (Path.get_hd typename, ctor.index, None)
and attr = { no_attr with const = true } in
{ typ = variant; expr; attr; loc }
| _ -> mismatch_err (fst name) (snd name) ctor.ctyp arg

(* We want to be able to reference the exprs in the pattern match without
Expand Down
19 changes: 11 additions & 8 deletions test/misc.t/run.t
Original file line number Diff line number Diff line change
Expand Up @@ -1598,12 +1598,9 @@ Piping for ctors and field accessors
entry:
%0 = tail call i64 @__fun_schmu0(i64 1)
tail call void @Printi(i64 %0)
%option = alloca %option.t_int, align 8
%tag1 = bitcast %option.t_int* %option to i32*
store i32 0, i32* %tag1, align 4
%data = getelementptr inbounds %option.t_int, %option.t_int* %option, i32 0, i32 1
store i64 1, i64* %data, align 8
%1 = call i64 @__fun_schmu1(%option.t_int* %option)
%boxconst = alloca %option.t_int, align 8
store %option.t_int { i32 0, i64 1 }, %option.t_int* %boxconst, align 8
%1 = call i64 @__fun_schmu1(%option.t_int* %boxconst)
call void @Printi(i64 %1)
call void @Printi(i64 1)
ret i64 0
Expand Down Expand Up @@ -2257,13 +2254,14 @@ Global lets with expressions
%option.t_array_int = type { i32, i64* }
%r_array_int = type { i64* }

@schmu_a = internal constant %option.t_array_int { i32 1, i64* undef }
@schmu_b = global i64* null, align 8
@schmu_c = global i64 0, align 8

define void @schmu_ret-none(%option.t_array_int* noalias %0) {
entry:
%tag1 = bitcast %option.t_array_int* %0 to i32*
store i32 1, i32* %tag1, align 4
%1 = bitcast %option.t_array_int* %0 to i8*
tail call void @llvm.memcpy.p0i8.p0i8.i64(i8* %1, i8* bitcast (%option.t_array_int* @schmu_a to i8*), i64 16, i1 false)
ret void
}

Expand All @@ -2290,6 +2288,9 @@ Global lets with expressions
ret i64 %4
}

; Function Attrs: argmemonly nofree nounwind willreturn
declare void @llvm.memcpy.p0i8.p0i8.i64(i8* noalias nocapture writeonly %0, i8* noalias nocapture readonly %1, i64 %2, i1 immarg %3) #0

declare i8* @malloc(i64 %0)

define i64 @main(i64 %arg) {
Expand Down Expand Up @@ -2371,6 +2372,8 @@ Global lets with expressions
}

declare void @free(i8* %0)

attributes #0 = { argmemonly nofree nounwind willreturn }

Mutual recursive function
$ schmu mutual_rec.smu && ./mutual_rec
Expand Down
55 changes: 14 additions & 41 deletions test/modules.t/run.t
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@ Simplest module with 1 type and 1 nonpolymorphic function
source_filename = "context"
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"

%nonpoly_func.either = type { i32 }

@0 = private unnamed_addr constant { i64, i64, [4 x i8] } { i64 3, i64 3, [4 x i8] c"%i\0A\00" }

declare i64 @nonpoly_func_add_ints(i64 %0, i64 %1)
Expand All @@ -50,9 +48,6 @@ Simplest module with 1 type and 1 nonpolymorphic function

define i64 @main(i64 %arg) {
entry:
%either = alloca %nonpoly_func.either, align 8
%tag2 = bitcast %nonpoly_func.either* %either to i32*
store i32 0, i32* %tag2, align 4
%0 = tail call i64 @schmu_doo(i32 0)
tail call void @printf(i8* getelementptr (i8, i8* bitcast ({ i64, i64, [4 x i8] }* @0 to i8*), i64 16), i64 %0)
ret i64 0
Expand All @@ -65,8 +60,6 @@ Simplest module with 1 type and 1 nonpolymorphic function
source_filename = "context"
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"

%nonpoly_func.either = type { i32 }

@0 = private unnamed_addr constant { i64, i64, [4 x i8] } { i64 3, i64 3, [4 x i8] c"%i\0A\00" }

declare i64 @nonpoly_func_add_ints(i64 %0, i64 %1)
Expand Down Expand Up @@ -107,14 +100,8 @@ Simplest module with 1 type and 1 nonpolymorphic function

define i64 @main(i64 %arg) {
entry:
%either = alloca %nonpoly_func.either, align 8
%tag6 = bitcast %nonpoly_func.either* %either to i32*
store i32 0, i32* %tag6, align 4
%0 = tail call i64 @schmu_doo(i32 0)
tail call void @printf(i8* getelementptr (i8, i8* bitcast ({ i64, i64, [4 x i8] }* @0 to i8*), i64 16), i64 %0)
%either2 = alloca %nonpoly_func.either, align 8
%tag37 = bitcast %nonpoly_func.either* %either2 to i32*
store i32 0, i32* %tag37, align 4
%1 = tail call i64 @schmu_do2(i32 0)
tail call void @printf(i8* getelementptr (i8, i8* bitcast ({ i64, i64, [4 x i8] }* @0 to i8*), i64 16), i64 %1)
ret i64 0
Expand Down Expand Up @@ -230,7 +217,7 @@ Simplest module with 1 type and 1 nonpolymorphic function
%option.t_float = type { i32, double }
%option.t_int = type { i32, i64 }

@schmu_none = global %option.t_float zeroinitializer, align 16
@schmu_none = constant %option.t_float { i32 1, double undef }
@0 = private unnamed_addr constant { i64, i64, [4 x i8] } { i64 3, i64 3, [4 x i8] c"%i\0A\00" }

declare void @printf(i8* %0, i64 %1)
Expand Down Expand Up @@ -269,21 +256,14 @@ Simplest module with 1 type and 1 nonpolymorphic function

define i64 @main(i64 %arg) {
entry:
%t = alloca %option.t_int, align 8
%tag4 = bitcast %option.t_int* %t to i32*
store i32 0, i32* %tag4, align 4
%data = getelementptr inbounds %option.t_int, %option.t_int* %t, i32 0, i32 1
store i64 3, i64* %data, align 8
%0 = call i64 @__option.tg.i_poly_func_classify_option.ti.i(%option.t_int* %t)
%boxconst = alloca %option.t_int, align 8
store %option.t_int { i32 0, i64 3 }, %option.t_int* %boxconst, align 8
%0 = call i64 @__option.tg.i_poly_func_classify_option.ti.i(%option.t_int* %boxconst)
call void @printf(i8* getelementptr (i8, i8* bitcast ({ i64, i64, [4 x i8] }* @0 to i8*), i64 16), i64 %0)
%t1 = alloca %option.t_float, align 8
%tag25 = bitcast %option.t_float* %t1 to i32*
store i32 0, i32* %tag25, align 4
%data3 = getelementptr inbounds %option.t_float, %option.t_float* %t1, i32 0, i32 1
store double 3.000000e+00, double* %data3, align 8
%1 = call i64 @__option.tg.i_poly_func_classify_option.tf.i(%option.t_float* %t1)
%boxconst1 = alloca %option.t_float, align 8
store %option.t_float { i32 0, double 3.000000e+00 }, %option.t_float* %boxconst1, align 8
%1 = call i64 @__option.tg.i_poly_func_classify_option.tf.i(%option.t_float* %boxconst1)
call void @printf(i8* getelementptr (i8, i8* bitcast ({ i64, i64, [4 x i8] }* @0 to i8*), i64 16), i64 %1)
store i32 1, i32* getelementptr inbounds (%option.t_float, %option.t_float* @schmu_none, i32 0, i32 0), align 4
%2 = call i64 @__option.tg.i_poly_func_classify_option.tf.i(%option.t_float* @schmu_none)
call void @printf(i8* getelementptr (i8, i8* bitcast ({ i64, i64, [4 x i8] }* @0 to i8*), i64 16), i64 %2)
ret i64 0
Expand All @@ -301,7 +281,7 @@ Simplest module with 1 type and 1 nonpolymorphic function
%option.t_float = type { i32, double }
%option.t_int = type { i32, i64 }

@schmu_none = global %option.t_float zeroinitializer, align 16
@schmu_none = constant %option.t_float { i32 1, double undef }
@0 = private unnamed_addr constant { i64, i64, [4 x i8] } { i64 3, i64 3, [4 x i8] c"%i\0A\00" }

declare void @printf(i8* %0, i64 %1)
Expand Down Expand Up @@ -340,21 +320,14 @@ Simplest module with 1 type and 1 nonpolymorphic function

define i64 @main(i64 %arg) {
entry:
%t = alloca %option.t_int, align 8
%tag4 = bitcast %option.t_int* %t to i32*
store i32 0, i32* %tag4, align 4
%data = getelementptr inbounds %option.t_int, %option.t_int* %t, i32 0, i32 1
store i64 3, i64* %data, align 8
%0 = call i64 @__option.tg.i_poly_func_classify_option.ti.i(%option.t_int* %t)
%boxconst = alloca %option.t_int, align 8
store %option.t_int { i32 0, i64 3 }, %option.t_int* %boxconst, align 8
%0 = call i64 @__option.tg.i_poly_func_classify_option.ti.i(%option.t_int* %boxconst)
call void @printf(i8* getelementptr (i8, i8* bitcast ({ i64, i64, [4 x i8] }* @0 to i8*), i64 16), i64 %0)
%t1 = alloca %option.t_float, align 8
%tag25 = bitcast %option.t_float* %t1 to i32*
store i32 0, i32* %tag25, align 4
%data3 = getelementptr inbounds %option.t_float, %option.t_float* %t1, i32 0, i32 1
store double 3.000000e+00, double* %data3, align 8
%1 = call i64 @__option.tg.i_poly_func_classify_option.tf.i(%option.t_float* %t1)
%boxconst1 = alloca %option.t_float, align 8
store %option.t_float { i32 0, double 3.000000e+00 }, %option.t_float* %boxconst1, align 8
%1 = call i64 @__option.tg.i_poly_func_classify_option.tf.i(%option.t_float* %boxconst1)
call void @printf(i8* getelementptr (i8, i8* bitcast ({ i64, i64, [4 x i8] }* @0 to i8*), i64 16), i64 %1)
store i32 1, i32* getelementptr inbounds (%option.t_float, %option.t_float* @schmu_none, i32 0, i32 0), align 4
%2 = call i64 @__option.tg.i_poly_func_classify_option.tf.i(%option.t_float* @schmu_none)
call void @printf(i8* getelementptr (i8, i8* bitcast ({ i64, i64, [4 x i8] }* @0 to i8*), i64 16), i64 %2)
ret i64 0
Expand Down
50 changes: 19 additions & 31 deletions test/std.t/run.t
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,8 @@ Test hashtbl
define linkonce_odr void @__hashtbl.make.tgac.option.tg_hashtbl_make_string_find_hashtbl.make.tfac.option.tf(%option.t_float* noalias %0, i64 %1, i64 %2, i8* %key) {
entry:
%box = alloca { i64, i64 }, align 8
%fst17 = bitcast { i64, i64 }* %box to i64*
store i64 %1, i64* %fst17, align 8
%fst15 = bitcast { i64, i64 }* %box to i64*
store i64 %1, i64* %fst15, align 8
%snd = getelementptr inbounds { i64, i64 }, { i64, i64 }* %box, i32 0, i32 1
store i64 %2, i64* %snd, align 8
%3 = tail call i64 @__hashtbl.make.tgacb.i_hashtbl_make_string_probe-linear_hashtbl.make.tfacb.i(i64 %1, i64 %2, i8* %key, i1 false)
Expand All @@ -164,40 +164,38 @@ Test hashtbl
%7 = add i64 16, %6
%8 = getelementptr i8, i8* %5, i64 %7
%data = bitcast i8* %8 to %hashtbl.make.slot_float*
%tag18 = bitcast %hashtbl.make.slot_float* %data to i32*
%index = load i32, i32* %tag18, align 4
%tag16 = bitcast %hashtbl.make.slot_float* %data to i32*
%index = load i32, i32* %tag16, align 4
%eq = icmp eq i32 %index, 2
br i1 %eq, label %then, label %else

then: ; preds = %entry
%tag719 = bitcast %option.t_float* %0 to i32*
store i32 0, i32* %tag719, align 4
%tag717 = bitcast %option.t_float* %0 to i32*
store i32 0, i32* %tag717, align 4
%data8 = getelementptr inbounds %option.t_float, %option.t_float* %0, i32 0, i32 1
%sunkaddr = inttoptr i64 %1 to double*
%9 = bitcast double* %sunkaddr to i8*
%sunkaddr20 = getelementptr i8, i8* %9, i64 %6
%sunkaddr21 = getelementptr i8, i8* %sunkaddr20, i64 32
%10 = bitcast i8* %sunkaddr21 to double*
%sunkaddr18 = getelementptr i8, i8* %9, i64 %6
%sunkaddr19 = getelementptr i8, i8* %sunkaddr18, i64 32
%10 = bitcast i8* %sunkaddr19 to double*
%11 = load double, double* %10, align 8
store double %11, double* %data8, align 8
store double %11, double* %data8, align 8
br label %ifcont16
br label %ifcont14

else: ; preds = %entry
%eq11 = icmp eq i32 %index, 0
br i1 %eq11, label %then12, label %else14
br i1 %eq11, label %then12, label %else13

then12: ; preds = %else
%tag1322 = bitcast %option.t_float* %0 to i32*
store i32 1, i32* %tag1322, align 4
br label %ifcont16
store %option.t_float { i32 1, double undef }, %option.t_float* %0, align 8
br label %ifcont14

else14: ; preds = %else
%tag1523 = bitcast %option.t_float* %0 to i32*
store i32 1, i32* %tag1523, align 4
br label %ifcont16
else13: ; preds = %else
store %option.t_float { i32 1, double undef }, %option.t_float* %0, align 8
br label %ifcont14

ifcont16: ; preds = %then12, %else14, %then
ifcont14: ; preds = %then12, %else13, %then
ret void
}

Expand Down Expand Up @@ -399,13 +397,8 @@ Test hashtbl
%2 = bitcast %hashtbl.make.slot_float* %1 to i8*
%3 = getelementptr i8, i8* %2, i64 16
%data = bitcast i8* %3 to %hashtbl.make.slot_float*
%slot = alloca %hashtbl.make.slot_float, align 8
%tag2 = bitcast %hashtbl.make.slot_float* %slot to i32*
store i32 0, i32* %tag2, align 4
%4 = getelementptr inbounds %hashtbl.make.slot_float, %hashtbl.make.slot_float* %data, i64 %i
%5 = bitcast %hashtbl.make.slot_float* %4 to i8*
%6 = bitcast %hashtbl.make.slot_float* %slot to i8*
call void @llvm.memcpy.p0i8.p0i8.i64(i8* %5, i8* %6, i64 24, i1 false)
store %hashtbl.make.slot_float { i32 0, %hashtbl.make.item_float undef }, %hashtbl.make.slot_float* %4, align 8
ret void
}

Expand All @@ -418,13 +411,8 @@ Test hashtbl
%2 = bitcast %hashtbl.make.slot_float* %1 to i8*
%3 = getelementptr i8, i8* %2, i64 16
%data = bitcast i8* %3 to %hashtbl.make.slot_float*
%slot = alloca %hashtbl.make.slot_float, align 8
%tag2 = bitcast %hashtbl.make.slot_float* %slot to i32*
store i32 0, i32* %tag2, align 4
%4 = getelementptr inbounds %hashtbl.make.slot_float, %hashtbl.make.slot_float* %data, i64 %i
%5 = bitcast %hashtbl.make.slot_float* %4 to i8*
%6 = bitcast %hashtbl.make.slot_float* %slot to i8*
call void @llvm.memcpy.p0i8.p0i8.i64(i8* %5, i8* %6, i64 24, i1 false)
store %hashtbl.make.slot_float { i32 0, %hashtbl.make.item_float undef }, %hashtbl.make.slot_float* %4, align 8
ret void
}

Expand Down
16 changes: 16 additions & 0 deletions test/variants.t/const_ctor_issue.smu
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
(type thing {:i int :b {int int int int int}})

(type var ((#float float) (#thing thing)))

(def var (#float 10.0))

(match var
((#float _) (print "float"))
((#thing _) (print "thing")))

(defn dynamic (var)
(match var
((#float _) (print "float"))
((#thing _) (print "thing"))) )

(dynamic var)
Loading

0 comments on commit 8d71de2

Please sign in to comment.