Skip to content

Commit

Permalink
Support type inference through invocations
Browse files Browse the repository at this point in the history
The corset language supports an interesting notion of type inference in
the case that a function is declared without an explicit return type
being given.  Specifically, it types it polymorphically at the call site
based on the types of the given arguments.
  • Loading branch information
DavePearce committed Dec 12, 2024
1 parent 08c4c16 commit fc77ef4
Show file tree
Hide file tree
Showing 6 changed files with 106 additions and 45 deletions.
16 changes: 14 additions & 2 deletions pkg/corset/ast.go
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,15 @@ func (p *DefConst) Dependencies() util.Iterator[Symbol] {
// Lisp converts this node into its lisp representation. This is primarily used
// for debugging purposes.
func (p *DefConst) Lisp() sexp.SExp {
panic("got here")
def := sexp.EmptyList()
def.Append(sexp.NewSymbol("defconst"))
//
for _, c := range p.constants {
def.Append(sexp.NewSymbol(c.name))
def.Append(c.binding.value.Lisp())
}
// Done
return def
}

// DefConstUnit represents the definition of exactly one constant value. As
Expand Down Expand Up @@ -698,7 +706,11 @@ func (p *DefFun) Dependencies() util.Iterator[Symbol] {
// Lisp converts this node into its lisp representation. This is primarily used
// for debugging purposes.
func (p *DefFun) Lisp() sexp.SExp {
panic("got here")
return sexp.NewList([]sexp.SExp{
sexp.NewSymbol("defun"),
sexp.NewSymbol(p.name),
sexp.NewSymbol("..."), // todo
})
}

// hasParameter checks whether this function has a parameter with the given
Expand Down
35 changes: 32 additions & 3 deletions pkg/corset/binding.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,12 @@ func (p *ColumnBinding) IsFinalised() bool {
return p.multiplier != 0
}

// Finalise this binding by providing the necessary missing information.
func (p *ColumnBinding) Finalise(multiplier uint, datatype Type) {
p.multiplier = multiplier
p.dataType = datatype
}

// Context returns the of this column. That is, the module in which this colunm
// was declared and also the length multiplier of that module it requires.
func (p *ColumnBinding) Context() Context {
Expand Down Expand Up @@ -83,11 +89,24 @@ func (p *ColumnBinding) ColumnId() uint {
type ConstantBinding struct {
// Constant expression which, when evaluated, produces a constant value.
value Expr
// Inferred type of the given expression
datatype Type
}

// NewConstantBinding creates a new constant binding (which is initially not
// finalised).
func NewConstantBinding(value Expr) ConstantBinding {
return ConstantBinding{value, nil}
}

// IsFinalised checks whether this binding has been finalised yet or not.
func (p *ConstantBinding) IsFinalised() bool {
return true
return p.datatype != nil
}

// Finalise this binding by providing the necessary missing information.
func (p *ConstantBinding) Finalise(datatype Type) {
p.datatype = datatype
}

// Context returns the of this constant, noting that constants (by definition)
Expand All @@ -104,6 +123,8 @@ func (p *ConstantBinding) Context() Context {
type ParameterBinding struct {
// Identifies the variable or column index (as appropriate).
index uint
// Type to use for this parameter.
datatype Type
}

// ============================================================================
Expand All @@ -124,13 +145,16 @@ type FunctionBinding struct {
paramTypes []Type
// Type of return (optional)
returnType Type
// Inferred type of the body. This is used to compare against the declared
// type (if there is one) to check for any descrepencies.
bodyType Type
// body of the function in question.
body Expr
}

// NewFunctionBinding constructs a new function binding.
func NewFunctionBinding(pure bool, paramTypes []Type, returnType Type, body Expr) FunctionBinding {
return FunctionBinding{pure, paramTypes, returnType, body}
return FunctionBinding{pure, paramTypes, returnType, nil, body}
}

// IsPure checks whether this is a defpurefun or not
Expand All @@ -140,14 +164,19 @@ func (p *FunctionBinding) IsPure() bool {

// IsFinalised checks whether this binding has been finalised yet or not.
func (p *FunctionBinding) IsFinalised() bool {
return true
return p.bodyType != nil
}

// Arity returns the number of parameters that this function accepts.
func (p *FunctionBinding) Arity() uint {
return uint(len(p.paramTypes))
}

// Finalise this binding by providing the necessary missing information.
func (p *FunctionBinding) Finalise(bodyType Type) {
p.bodyType = bodyType
}

// Apply a given set of arguments to this function binding.
func (p *FunctionBinding) Apply(args []Expr) Expr {
return p.body.Substitute(args)
Expand Down
2 changes: 1 addition & 1 deletion pkg/corset/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,7 @@ func (p *Parser) parseDefConstUnit(name string, value sexp.SExp) (*DefConstUnit,
return nil, []SyntaxError{*err}
}
// Looks good
def := &DefConstUnit{name, ConstantBinding{expr}}
def := &DefConstUnit{name, NewConstantBinding(expr)}
// Map to source node
p.mapSourceNode(value, def)
// Done
Expand Down
75 changes: 42 additions & 33 deletions pkg/corset/resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -220,9 +220,9 @@ func (r *resolver) finaliseDeclarationsInModule(scope *ModuleScope, decls []Decl
return nil
}

// Check that a given set of source columns have been finalised. This is
// important, since we cannot finalise a declaration until all of its
// dependencies have themselves been finalised.
// Check that a given set of symbols have been finalised. This is important,
// since we cannot finalise a declaration until all of its dependencies have
// themselves been finalised.
func (r *resolver) declarationDependenciesAreFinalised(scope *ModuleScope,
symbols util.Iterator[Symbol]) (bool, []SyntaxError) {
var (
Expand Down Expand Up @@ -357,9 +357,8 @@ func (r *resolver) finaliseDefInterleavedInModule(decl *DefInterleaved) []Syntax
length_multiplier *= uint(len(decl.Sources))
// Lookup existing declaration
binding := decl.Target.Binding().(*ColumnBinding)
// Update with completed information
binding.multiplier = length_multiplier
binding.dataType = datatype
// Finalise column binding
binding.Finalise(length_multiplier, datatype)
}
// Done
return errors
Expand Down Expand Up @@ -420,10 +419,14 @@ func (r *resolver) finaliseDefFunInModule(enclosing Scope, decl *DefFun) []Synta
)
// Declare parameters in local scope
for _, p := range decl.Parameters() {
scope.DeclareLocal(p.Name)
scope.DeclareLocal(p.Name, p.DataType)
}
// Resolve property body
_, errors := r.finaliseExpressionInModule(scope, decl.Body())
datatype, errors := r.finaliseExpressionInModule(scope, decl.Body())
// Finalise declaration
if len(errors) == 0 {
decl.binding.Finalise(datatype)
}
// Done
return errors
}
Expand Down Expand Up @@ -566,9 +569,13 @@ func (r *resolver) finaliseInvokeInModule(scope LocalScope, expr *Invoke) (Type,
// no need, it was provided
return expr.binding.returnType, nil
}
// TODO: this is potentially expensive
// TODO: this is potentially expensive, and it would likely be good if we
// could avoid it. Realistically, this is just about determining the right
// type information. Potentially, we could adjust the local scope to
// provide the required type information. Or we could have a separate pass
// which just determines the type.
body := expr.binding.Apply(expr.Args())
//
// Dig out the type
return r.finaliseExpressionInModule(scope, body)
}

Expand All @@ -583,30 +590,32 @@ func (r *resolver) finaliseVariableInModule(scope LocalScope,
} else if expr.IsQualified() && !scope.HasModule(expr.Module()) {
return nil, r.srcmap.SyntaxErrors(expr, fmt.Sprintf("unknown module %s", expr.Module()))
}
// Symbol should be resolved at this point, but we still need to check the
// context.
if expr.IsResolved() {
// Update context
if binding, ok := expr.Binding().(*ColumnBinding); ok {
if !scope.FixContext(binding.Context()) {
return nil, r.srcmap.SyntaxErrors(expr, "conflicting context")
} else if scope.IsPure() {
return nil, r.srcmap.SyntaxErrors(expr, "not permitted in pure context")
}
// Use column's datatype
return binding.dataType, nil
} else if binding, ok := expr.Binding().(*ConstantBinding); ok {
// Is this safe?
constant := binding.value.AsConstant()
//
return NewUintType(uint(constant.BitLen())), nil
}
// Symbol should be resolved at this point, but we'd better sanity check this.
if !expr.IsResolved() && !scope.Bind(expr) {
// Unable to resolve variable
return nil, r.srcmap.SyntaxErrors(expr, "unresolved symbol")
}
//
if binding, ok := expr.Binding().(*ColumnBinding); ok {
// For column bindings, we still need to sanity check the context is
// compatible.
if !scope.FixContext(binding.Context()) {
return nil, r.srcmap.SyntaxErrors(expr, "conflicting context")
} else if scope.IsPure() {
return nil, r.srcmap.SyntaxErrors(expr, "not permitted in pure context")
}
// Use column's datatype
return binding.dataType, nil
} else if binding, ok := expr.Binding().(*ConstantBinding); ok {
// Constant
return binding.datatype, nil
} else if binding, ok := expr.Binding().(*ParameterBinding); ok {
// Parameter
return binding.datatype, nil
} else if _, ok := expr.Binding().(*FunctionBinding); ok {
// Function doesn't makes sense here.
return nil, r.srcmap.SyntaxErrors(expr, "refers to a function")
} else if scope.Bind(expr) {
// Must be a local variable or parameter access, so we're all good.
return NewFieldType(), nil
}
// Unable to resolve variable
return nil, r.srcmap.SyntaxErrors(expr, "unresolved symbol")
// Should be unreachable.
return nil, r.srcmap.SyntaxErrors(expr, "unknown symbol kind")
}
21 changes: 16 additions & 5 deletions pkg/corset/scope.go
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,8 @@ type LocalScope struct {
context *Context
// Maps inputs parameters to the declaration index.
locals map[string]uint
// Actual parameter bindings
bindings []*ParameterBinding
}

// NewLocalScope constructs a new local scope within a given enclosing scope. A
Expand All @@ -229,31 +231,38 @@ type LocalScope struct {
func NewLocalScope(enclosing Scope, global bool, pure bool) LocalScope {
context := tr.VoidContext[string]()
locals := make(map[string]uint)
bindings := make([]*ParameterBinding, 0)
//
return LocalScope{global, pure, enclosing, &context, locals}
return LocalScope{global, pure, enclosing, &context, locals, bindings}
}

// NestedScope creates a nested scope within this local scope.
func (p LocalScope) NestedScope() LocalScope {
nlocals := make(map[string]uint)
nbindings := make([]*ParameterBinding, len(p.bindings))
// Clone allocated variables
for k, v := range p.locals {
nlocals[k] = v
}
// Copy over bindings.
copy(nbindings, p.bindings)
// Done
return LocalScope{p.global, p.pure, p, p.context, nlocals}
return LocalScope{p.global, p.pure, p, p.context, nlocals, nbindings}
}

// NestedPureScope creates a nested scope within this local scope which, in
// addition, is always pure.
func (p LocalScope) NestedPureScope() LocalScope {
nlocals := make(map[string]uint)
nbindings := make([]*ParameterBinding, len(p.bindings))
// Clone allocated variables
for k, v := range p.locals {
nlocals[k] = v
}
// Copy over bindings.
copy(nbindings, p.bindings)
// Done
return LocalScope{p.global, true, p, p.context, nlocals}
return LocalScope{p.global, true, p, p.context, nlocals, nbindings}
}

// IsGlobal determines whether symbols can be accessed in modules other than the
Expand Down Expand Up @@ -289,16 +298,18 @@ func (p LocalScope) Bind(symbol Symbol) bool {
// Check whether this is a local variable access.
if id, ok := p.locals[symbol.Name()]; ok && !symbol.IsFunction() && !symbol.IsQualified() {
// Yes, this is a local variable access.
return symbol.Resolve(&ParameterBinding{id})
return symbol.Resolve(p.bindings[id])
}
// No, this is not a local variable access.
return p.enclosing.Bind(symbol)
}

// DeclareLocal registers a new local variable (e.g. a parameter).
func (p LocalScope) DeclareLocal(name string) uint {
func (p *LocalScope) DeclareLocal(name string, datatype Type) uint {
index := uint(len(p.locals))
binding := ParameterBinding{index, datatype}
p.locals[name] = index
p.bindings = append(p.bindings, &binding)
// Return variable index
return index
}
2 changes: 1 addition & 1 deletion pkg/corset/stdlib.lisp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
(if
(eq! lhs rhs)
;; True branch
0
(vanishes! 0)
;; False branch
then))

Expand Down

0 comments on commit fc77ef4

Please sign in to comment.