Skip to content

Commit

Permalink
Enable Purity Checking (#426)
Browse files Browse the repository at this point in the history
This puts in place purity checking which is used in a few different
contexts.  For example, in checking the body of a defpurefun definition.
Likewise, constants cannot refer to defun declarations, etc.
  • Loading branch information
DavePearce authored Dec 10, 2024
1 parent aac6849 commit 6da83cd
Show file tree
Hide file tree
Showing 10 changed files with 91 additions and 22 deletions.
2 changes: 0 additions & 2 deletions pkg/corset/ast.go
Original file line number Diff line number Diff line change
Expand Up @@ -754,8 +754,6 @@ func (p *DefProperty) Lisp() sexp.SExp {
// defined within its enclosing context.
type DefFun struct {
name string
// Specify whether is pure (or not)
pure bool
// Parameters
parameters []*DefParameter
//
Expand Down
5 changes: 5 additions & 0 deletions pkg/corset/binding.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,11 @@ func NewFunctionBinding(pure bool, paramTypes []sc.Type, returnType sc.Type, bod
return FunctionBinding{pure, paramTypes, returnType, body}
}

// IsPure checks whether this is a defpurefun or not
func (p *FunctionBinding) IsPure() bool {
return p.pure
}

// IsFinalised checks whether this binding has been finalised yet or not.
func (p *FunctionBinding) IsFinalised() bool {
return p.returnType != nil
Expand Down
4 changes: 2 additions & 2 deletions pkg/corset/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -628,9 +628,9 @@ func (p *Parser) parseDefFun(pure bool, elements []sexp.SExp) (Declaration, []Sy
paramTypes[i] = p.DataType
}
// Construct binding
binding := NewFunctionBinding(true, paramTypes, ret, body)
binding := NewFunctionBinding(pure, paramTypes, ret, body)
//
return &DefFun{name, pure, params, binding}, nil
return &DefFun{name, params, binding}, nil
}

func (p *Parser) parseFunctionSignature(elements []sexp.SExp) (string, sc.Type, []*DefParameter, []SyntaxError) {
Expand Down
36 changes: 26 additions & 10 deletions pkg/corset/resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ func (r *resolver) declarationDependenciesAreFinalised(scope *ModuleScope,
// Finalise a declaration.
func (r *resolver) finaliseDeclaration(scope *ModuleScope, decl Declaration) []SyntaxError {
if d, ok := decl.(*DefConst); ok {
return r.finaliseDefConstInModule(d)
return r.finaliseDefConstInModule(scope, d)
} else if d, ok := decl.(*DefConstraint); ok {
return r.finaliseDefConstraintInModule(scope, d)
} else if d, ok := decl.(*DefFun); ok {
Expand All @@ -273,10 +273,14 @@ func (r *resolver) finaliseDeclaration(scope *ModuleScope, decl Declaration) []S
// Finalise one or more constant definitions within a given module.
// Specifically, we need to check that the constant values provided are indeed
// constants.
func (r *resolver) finaliseDefConstInModule(decl *DefConst) []SyntaxError {
func (r *resolver) finaliseDefConstInModule(enclosing Scope, decl *DefConst) []SyntaxError {
var errors []SyntaxError
//
for _, c := range decl.constants {
scope := NewLocalScope(enclosing, false, true)
// Resolve constant body
errors = append(errors, r.finaliseExpressionInModule(scope, c.binding.value)...)
// Check it is indeed constant!
if constant := c.binding.value.AsConstant(); constant == nil {
err := r.srcmap.SyntaxError(c, "definition not constant")
errors = append(errors, *err)
Expand All @@ -292,7 +296,7 @@ func (r *resolver) finaliseDefConstInModule(decl *DefConst) []SyntaxError {
func (r *resolver) finaliseDefConstraintInModule(enclosing Scope, decl *DefConstraint) []SyntaxError {
var (
errors []SyntaxError
scope = NewLocalScope(enclosing, false)
scope = NewLocalScope(enclosing, false, false)
)
// Resolve guard
if decl.Guard != nil {
Expand Down Expand Up @@ -385,7 +389,7 @@ func (r *resolver) finaliseDefPermutationInModule(decl *DefPermutation) []Syntax
func (r *resolver) finaliseDefInRangeInModule(enclosing Scope, decl *DefInRange) []SyntaxError {
var (
errors []SyntaxError
scope = NewLocalScope(enclosing, false)
scope = NewLocalScope(enclosing, false, false)
)
// Resolve property body
errors = append(errors, r.finaliseExpressionInModule(scope, decl.Expr)...)
Expand All @@ -401,7 +405,7 @@ func (r *resolver) finaliseDefInRangeInModule(enclosing Scope, decl *DefInRange)
func (r *resolver) finaliseDefFunInModule(enclosing Scope, decl *DefFun) []SyntaxError {
var (
errors []SyntaxError
scope = NewLocalScope(enclosing, false)
scope = NewLocalScope(enclosing, false, decl.IsPure())
)
// Declare parameters in local scope
for _, p := range decl.Parameters() {
Expand All @@ -417,8 +421,8 @@ func (r *resolver) finaliseDefFunInModule(enclosing Scope, decl *DefFun) []Synta
func (r *resolver) finaliseDefLookupInModule(enclosing Scope, decl *DefLookup) []SyntaxError {
var (
errors []SyntaxError
sourceScope = NewLocalScope(enclosing, true)
targetScope = NewLocalScope(enclosing, true)
sourceScope = NewLocalScope(enclosing, true, false)
targetScope = NewLocalScope(enclosing, true, false)
)
// Resolve source expressions
errors = append(errors, r.finaliseExpressionsInModule(sourceScope, decl.Sources)...)
Expand All @@ -432,7 +436,7 @@ func (r *resolver) finaliseDefLookupInModule(enclosing Scope, decl *DefLookup) [
func (r *resolver) finaliseDefPropertyInModule(enclosing Scope, decl *DefProperty) []SyntaxError {
var (
errors []SyntaxError
scope = NewLocalScope(enclosing, false)
scope = NewLocalScope(enclosing, false, false)
)
// Resolve property body
errors = append(errors, r.finaliseExpressionInModule(scope, decl.Assertion)...)
Expand Down Expand Up @@ -466,7 +470,11 @@ func (r *resolver) finaliseExpressionInModule(scope LocalScope, expr Expr) []Syn
} else if v, ok := expr.(*Add); ok {
return r.finaliseExpressionsInModule(scope, v.Args)
} else if v, ok := expr.(*Exp); ok {
return r.finaliseExpressionsInModule(scope, []Expr{v.Arg, v.Pow})
purescope := scope.NestedPureScope()
arg_errs := r.finaliseExpressionInModule(scope, v.Arg)
pow_errs := r.finaliseExpressionInModule(purescope, v.Pow)
// combine errors
return append(arg_errs, pow_errs...)
} else if v, ok := expr.(*IfZero); ok {
return r.finaliseExpressionsInModule(scope, []Expr{v.Condition, v.TrueBranch, v.FalseBranch})
} else if v, ok := expr.(*Invoke); ok {
Expand All @@ -478,7 +486,11 @@ func (r *resolver) finaliseExpressionInModule(scope LocalScope, expr Expr) []Syn
} else if v, ok := expr.(*Normalise); ok {
return r.finaliseExpressionInModule(scope, v.Arg)
} else if v, ok := expr.(*Shift); ok {
return r.finaliseExpressionsInModule(scope, []Expr{v.Arg, v.Shift})
purescope := scope.NestedPureScope()
arg_errs := r.finaliseExpressionInModule(scope, v.Arg)
shf_errs := r.finaliseExpressionInModule(purescope, v.Shift)
// combine errors
return append(arg_errs, shf_errs...)
} else if v, ok := expr.(*Sub); ok {
return r.finaliseExpressionsInModule(scope, v.Args)
} else if v, ok := expr.(*VariableAccess); ok {
Expand All @@ -499,6 +511,8 @@ func (r *resolver) finaliseInvokeInModule(scope LocalScope, expr *Invoke) []Synt
// Lookup the corresponding function definition.
if !scope.Bind(expr) {
return r.srcmap.SyntaxErrors(expr, "unknown function")
} else if scope.IsPure() && !expr.binding.IsPure() {
return r.srcmap.SyntaxErrors(expr, "not permitted in pure context")
}
// Success
return nil
Expand All @@ -522,6 +536,8 @@ func (r *resolver) finaliseVariableInModule(scope LocalScope,
if binding, ok := expr.Binding().(*ColumnBinding); ok {
if !scope.FixContext(binding.Context()) {
return r.srcmap.SyntaxErrors(expr, "conflicting context")
} else if scope.IsPure() {
return r.srcmap.SyntaxErrors(expr, "not permitted in pure context")
}
} else if _, ok := expr.Binding().(*ConstantBinding); !ok {
// Unable to resolve variable
Expand Down
28 changes: 25 additions & 3 deletions pkg/corset/scope.go
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,9 @@ func (p *ModuleScope) Alias(alias string, symbol Symbol) bool {
// which must be evaluated within.
type LocalScope struct {
global bool
// Determines whether or not this scope is "pure" (i.e. whether or not
// columns can be accessed, etc).
pure bool
// Represents the enclosing scope
enclosing Scope
// Context for this scope
Expand All @@ -220,11 +223,11 @@ type LocalScope struct {
// local scope can have local variables declared within it. A local scope can
// also be "global" in the sense that accessing symbols from other modules is
// permitted.
func NewLocalScope(enclosing Scope, global bool) LocalScope {
func NewLocalScope(enclosing Scope, global bool, pure bool) LocalScope {
context := tr.VoidContext[string]()
locals := make(map[string]uint)
//
return LocalScope{global, enclosing, &context, locals}
return LocalScope{global, pure, enclosing, &context, locals}
}

// NestedScope creates a nested scope within this local scope.
Expand All @@ -235,7 +238,19 @@ func (p LocalScope) NestedScope() LocalScope {
nlocals[k] = v
}
// Done
return LocalScope{p.global, p.enclosing, p.context, nlocals}
return LocalScope{p.global, p.pure, p, p.context, nlocals}
}

// NestedPureScope creates a nested scope within this local scope which, in
// addition, is always pure.
func (p LocalScope) NestedPureScope() LocalScope {
nlocals := make(map[string]uint)
// Clone allocated variables
for k, v := range p.locals {
nlocals[k] = v
}
// Done
return LocalScope{p.global, true, p, p.context, nlocals}
}

// IsGlobal determines whether symbols can be accessed in modules other than the
Expand All @@ -244,6 +259,13 @@ func (p LocalScope) IsGlobal() bool {
return p.global
}

// IsPure determines whether or not this scope is pure. That is, whether or not
// expressions in this scope are permitted to access columns (either directly,
// or indirectly via impure invocations).
func (p LocalScope) IsPure() bool {
return p.pure
}

// FixContext fixes the context for this scope. Since every scope requires
// exactly one context, this fails if we fix it to incompatible contexts.
func (p LocalScope) FixContext(context Context) bool {
Expand Down
25 changes: 20 additions & 5 deletions pkg/test/invalid_corset_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,18 @@ func Test_Invalid_Constant_14(t *testing.T) {
CheckInvalid(t, "constant_invalid_14")
}

func Test_Invalid_Constant_15(t *testing.T) {
CheckInvalid(t, "constant_invalid_15")
}

func Test_Invalid_Constant_16(t *testing.T) {
CheckInvalid(t, "constant_invalid_16")
}

func Test_Invalid_Constant_17(t *testing.T) {
CheckInvalid(t, "constant_invalid_17")
}

// ===================================================================
// Alias Tests
// ===================================================================
Expand Down Expand Up @@ -373,11 +385,10 @@ func Test_Invalid_PureFun_03(t *testing.T) {
CheckInvalid(t, "purefun_invalid_03")
}

/*
func Test_Invalid_PureFun_04(t *testing.T) {
CheckInvalid(t, "purefun_invalid_04")
}
*/
func Test_Invalid_PureFun_04(t *testing.T) {
CheckInvalid(t, "purefun_invalid_04")
}

func Test_Invalid_PureFun_05(t *testing.T) {
CheckInvalid(t, "purefun_invalid_05")
}
Expand All @@ -388,6 +399,10 @@ func Test_Invalid_PureFun_05(t *testing.T) {
}
*/

func Test_Invalid_PureFun_07(t *testing.T) {
CheckInvalid(t, "purefun_invalid_07")
}

// ===================================================================
// Test Helpers
// ===================================================================
Expand Down
2 changes: 2 additions & 0 deletions testdata/constant_invalid_15.lisp
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
(defun (ONE) 1)
(defconst X (ONE))
3 changes: 3 additions & 0 deletions testdata/constant_invalid_16.lisp
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
(defcolumns X)
(defun (ONE) 1)
(defconstraint c1 () (* X (^ 2 (ONE))))
3 changes: 3 additions & 0 deletions testdata/constant_invalid_17.lisp
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
(defcolumns X)
(defun (ONE) 1)
(defconstraint c1 () (shift X (ONE)))
5 changes: 5 additions & 0 deletions testdata/purefun_invalid_07.lisp
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
(defcolumns A)
(defun (getA) A)
;; not pure!
(defpurefun (id x) (+ x (getA)))
(defconstraint test () (id 1))

0 comments on commit 6da83cd

Please sign in to comment.