diff --git a/pkg/corset/expr.go b/pkg/corset/expr.go index adb02fd..ef414e5 100644 --- a/pkg/corset/expr.go +++ b/pkg/corset/expr.go @@ -190,9 +190,13 @@ func (e *Exp) Dependencies() []Symbol { // IfZero // ============================================================================ -// IfZero returns the (optional) true branch when the condition evaluates to zero, and +// If returns the (optional) true branch when the condition evaluates to zero, and // the (optional false branch otherwise. -type IfZero struct { +type If struct { + // Indicates whether this is an if-zero (kind==1) or an if-notzero + // (kind==2). Any other kind value implies this has not yet been + // determined. + kind uint8 // Elements contained within this list. Condition Expr // True branch (optional). @@ -201,9 +205,31 @@ type IfZero struct { FalseBranch Expr } +// IsIfZero determines whether or not this has been determined as an IfZero +// condition. +func (e *If) IsIfZero() bool { + return e.kind == 1 +} + +// IsIfNotZero determines whether or not this has been determined as an +// IfNotZero condition. +func (e *If) IsIfNotZero() bool { + return e.kind == 2 +} + +// FixSemantics fixes the semantics for this condition to be either "if-zero" or +// "if-notzero". +func (e *If) FixSemantics(ifzero bool) { + if ifzero { + e.kind = 1 + } else { + e.kind = 2 + } +} + // AsConstant attempts to evaluate this expression as a constant (signed) value. // If this expression is not constant, then nil is returned. -func (e *IfZero) AsConstant() *big.Int { +func (e *If) AsConstant() *big.Int { if condition := e.Condition.AsConstant(); condition != nil { // Determine whether condition holds true (or not). holds := condition.Cmp(big.NewInt(0)) == 0 @@ -220,20 +246,20 @@ func (e *IfZero) AsConstant() *big.Int { // Multiplicity determines the number of values that evaluating this expression // can generate. -func (e *IfZero) Multiplicity() uint { +func (e *If) Multiplicity() uint { return determineMultiplicity([]Expr{e.Condition, e.TrueBranch, e.FalseBranch}) } // Context returns the context for this expression. Observe that the // expression must have been resolved for this to be defined (i.e. it may // panic if it has not been resolved yet). -func (e *IfZero) Context() Context { +func (e *If) Context() Context { return ContextOfExpressions([]Expr{e.Condition, e.TrueBranch, e.FalseBranch}) } // Lisp converts this schema element into a simple S-Expression, for example // so it can be printed. -func (e *IfZero) Lisp() sexp.SExp { +func (e *If) Lisp() sexp.SExp { if e.FalseBranch != nil { return sexp.NewList([]sexp.SExp{ sexp.NewSymbol("if"), @@ -248,15 +274,15 @@ func (e *IfZero) Lisp() sexp.SExp { // Substitute all variables (such as for function parameters) arising in // this expression. -func (e *IfZero) Substitute(args []Expr) Expr { - return &IfZero{e.Condition.Substitute(args), +func (e *If) Substitute(args []Expr) Expr { + return &If{e.kind, e.Condition.Substitute(args), SubstituteOptionalExpression(e.TrueBranch, args), SubstituteOptionalExpression(e.FalseBranch, args), } } // Dependencies needed to signal declaration. -func (e *IfZero) Dependencies() []Symbol { +func (e *If) Dependencies() []Symbol { return DependenciesOfExpressions([]Expr{e.Condition, e.TrueBranch, e.FalseBranch}) } diff --git a/pkg/corset/parser.go b/pkg/corset/parser.go index 6a3ef22..0f1d84e 100644 --- a/pkg/corset/parser.go +++ b/pkg/corset/parser.go @@ -757,9 +757,6 @@ func (p *Parser) parseType(term sexp.SExp) (Type, bool, *SyntaxError) { } // Access string of symbol parts := strings.Split(symbol.Value, "@") - if len(parts) > 2 { - return nil, false, p.translator.SyntaxError(term, "malformed type") - } // Determine whether type should be proven or not. var datatype Type // See what we've got. @@ -864,9 +861,9 @@ func mulParserRule(_ string, args []Expr) (Expr, error) { func ifParserRule(_ string, args []Expr) (Expr, error) { if len(args) == 2 { - return &IfZero{args[0], args[1], nil}, nil + return &If{0, args[0], args[1], nil}, nil } else if len(args) == 3 { - return &IfZero{args[0], args[1], args[2]}, nil + return &If{0, args[0], args[1], args[2]}, nil } return nil, errors.New("incorrect number of arguments") diff --git a/pkg/corset/resolver.go b/pkg/corset/resolver.go index 97c39ee..8211602 100644 --- a/pkg/corset/resolver.go +++ b/pkg/corset/resolver.go @@ -469,8 +469,9 @@ func (r *resolver) finaliseExpressionsInModule(scope LocalScope, args []Expr) ([ // //nolint:staticcheck func (r *resolver) finaliseExpressionInModule(scope LocalScope, expr Expr) (Type, []SyntaxError) { - if _, ok := expr.(*Constant); ok { - return nil, nil + if v, ok := expr.(*Constant); ok { + nbits := v.Val.BitLen() + return NewUintType(uint(nbits)), nil } else if v, ok := expr.(*Add); ok { types, errs := r.finaliseExpressionsInModule(scope, v.Args) return JoinAll(types), errs @@ -480,9 +481,8 @@ func (r *resolver) finaliseExpressionInModule(scope LocalScope, expr Expr) (Type _, pow_errs := r.finaliseExpressionInModule(purescope, v.Pow) // combine errors return arg_types, append(arg_errs, pow_errs...) - } else if v, ok := expr.(*IfZero); ok { - types, errs := r.finaliseExpressionsInModule(scope, []Expr{v.Condition, v.TrueBranch, v.FalseBranch}) - return JoinAll(types), errs + } else if v, ok := expr.(*If); ok { + return r.finaliseIfInModule(scope, v) } else if v, ok := expr.(*Invoke); ok { return r.finaliseInvokeInModule(scope, v) } else if v, ok := expr.(*List); ok { @@ -509,6 +509,30 @@ func (r *resolver) finaliseExpressionInModule(scope LocalScope, expr Expr) (Type } } +// Resolve an if condition contained within some expression which, in turn, is +// contained within some module. An important step occurrs here where, based on +// the semantics of the condition, this is inferred as an "if-zero" or an +// "if-notzero". +func (r *resolver) finaliseIfInModule(scope LocalScope, expr *If) (Type, []SyntaxError) { + types, errs := r.finaliseExpressionsInModule(scope, []Expr{expr.Condition, expr.TrueBranch, expr.FalseBranch}) + // Sanity check + if len(errs) != 0 { + return nil, errs + } + // Check & Resolve Condition + if types[0].HasLoobeanSemantics() { + // if-zero + expr.FixSemantics(true) + } else if types[0].HasBooleanSemantics() { + // if-notzero + expr.FixSemantics(false) + } else { + return nil, r.srcmap.SyntaxErrors(expr.Condition, "invalid condition (neither loobean nor boolean)") + } + // Join result types + return JoinAll(types[1:]), errs +} + // Resolve a specific invocation contained within some expression which, in // turn, is contained within some module. Note, qualified accesses are only // permitted in a global context. @@ -524,7 +548,7 @@ func (r *resolver) finaliseInvokeInModule(scope LocalScope, expr *Invoke) (Type, return nil, r.srcmap.SyntaxErrors(expr, "not permitted in pure context") } // Success - return nil, nil + return NewFieldType(), nil } // Resolve a specific variable access contained within some expression which, in @@ -548,15 +572,19 @@ func (r *resolver) finaliseVariableInModule(scope LocalScope, } else if scope.IsPure() { return nil, r.srcmap.SyntaxErrors(expr, "not permitted in pure context") } - } else if _, ok := expr.Binding().(*ConstantBinding); !ok { - // Unable to resolve variable - return nil, r.srcmap.SyntaxErrors(expr, "refers to a function") + // Use column's datatype + return binding.dataType, nil + } else if binding, ok := expr.Binding().(*ConstantBinding); ok { + // Is this safe? + constant := binding.value.AsConstant() + // + return NewUintType(uint(constant.BitLen())), nil } - // Done - return nil, nil + // Unable to resolve variable + return nil, r.srcmap.SyntaxErrors(expr, "refers to a function") } else if scope.Bind(expr) { // Must be a local variable or parameter access, so we're all good. - return nil, nil + return NewFieldType(), nil } // Unable to resolve variable return nil, r.srcmap.SyntaxErrors(expr, "unresolved symbol") diff --git a/pkg/corset/translator.go b/pkg/corset/translator.go index 4e4e881..35d4066 100644 --- a/pkg/corset/translator.go +++ b/pkg/corset/translator.go @@ -363,9 +363,16 @@ func (t *translator) translateExpressionInModule(expr Expr, module string, shift return &hir.Add{Args: args}, errs } else if e, ok := expr.(*Exp); ok { return t.translateExpInModule(e, module, shift) - } else if v, ok := expr.(*IfZero); ok { + } else if v, ok := expr.(*If); ok { args, errs := t.translateExpressionsInModule([]Expr{v.Condition, v.TrueBranch, v.FalseBranch}, module) - return &hir.IfZero{Condition: args[0], TrueBranch: args[1], FalseBranch: args[2]}, errs + if v.IsIfZero() { + return &hir.IfZero{Condition: args[0], TrueBranch: args[1], FalseBranch: args[2]}, errs + } else if v.IsIfNotZero() { + // In this case, switch the ordering. + return &hir.IfZero{Condition: args[0], TrueBranch: args[2], FalseBranch: args[1]}, errs + } + // Should be unreachable + return nil, t.srcmap.SyntaxErrors(expr, "unresolved conditional") } else if e, ok := expr.(*Invoke); ok { return t.translateInvokeInModule(e, module, shift) } else if v, ok := expr.(*List); ok { diff --git a/pkg/corset/type.go b/pkg/corset/type.go index 1b769a5..d5d4eea 100644 --- a/pkg/corset/type.go +++ b/pkg/corset/type.go @@ -50,11 +50,11 @@ func JoinAll(types []Type) Type { var datatype Type // for _, t := range types { - if t == nil { - return nil + if datatype == nil { + datatype = t + } else if t != nil { + datatype = Join(datatype, t) } - // - datatype = Join(datatype, t) } // return datatype