Skip to content

Commit

Permalink
Refactor when expression handling
Browse files Browse the repository at this point in the history
- Use expr.Patch and a visitor to only replace identifiers that start
  with `$` with a `Env.` member lookup.
- use `expr.AsBool()` to assert the expression produces a boolean
- Clean up envMap parsing to use strings.Cut()
- Expand tests, using testing.T.Run() with test names.
  • Loading branch information
mjpieters committed Jun 21, 2024
1 parent af5d7bb commit 03b51c1
Show file tree
Hide file tree
Showing 2 changed files with 145 additions and 43 deletions.
59 changes: 36 additions & 23 deletions cmdutil/when.go
Original file line number Diff line number Diff line change
@@ -1,49 +1,62 @@
package cmdutil

import (
"fmt"
"os"
"strings"

"github.com/expr-lang/expr"
"github.com/expr-lang/expr/ast"
"github.com/pkg/errors"
)

// AST walker which replaces `$IDENTIFIER` with `Env.IDENTIFIER` member lookup expressions.
type EnvPatcher struct{}

func (ep *EnvPatcher) Visit(node *ast.Node) {
if n, ok := (*node).(*ast.IdentifierNode); ok && n.Value[0] == '$' && n.Value != "$env" {
ast.Patch(
node,
&ast.MemberNode{
Node: &ast.IdentifierNode{Value: "Env"},
Property: &ast.StringNode{Value: n.Value[1:]},
},
)
}
}

// The predefined variables of a when expression
type WhenEnv struct {
Env map[string]string
}

var NewWhenEnv = func() *WhenEnv {
return &WhenEnv{Env: envMap()}
}

func IsAllowedToExecute(when string) (bool, error) {
if when == "" {
return true, nil
}
ropts := []string{}
em := envMap()
for k := range em {
ropts = append(ropts, fmt.Sprintf("$%s", k), fmt.Sprintf("Env.%s", k))
}
r := strings.NewReplacer(ropts...)
when = r.Replace(when)
got, err := expr.Eval(fmt.Sprintf("(%s) == true", when), struct {
Env map[string]string
}{
Env: em,
})

whenEnv := NewWhenEnv()
// when expressions must produce a boolean result
program, err := expr.Compile(when, expr.Patch(&EnvPatcher{}), expr.AsBool(), expr.Env(whenEnv))
if err != nil {
return false, errors.WithStack(err)
}
return got.(bool), nil
if got, err := expr.Run(program, whenEnv); err != nil {
return false, errors.WithStack(err)
} else {
return got.(bool), nil
}
}

func envMap() map[string]string {
m := map[string]string{}
for _, kv := range os.Environ() {
if !strings.Contains(kv, "=") {
continue
}
parts := strings.SplitN(kv, "=", 2)
k := parts[0]
if len(parts) < 2 {
m[k] = ""
continue
if k, v, ok := strings.Cut(kv, "="); ok {
m[k] = v
}
m[k] = parts[1]
}
return m
}
129 changes: 109 additions & 20 deletions cmdutil/when_test.go
Original file line number Diff line number Diff line change
@@ -1,48 +1,137 @@
package cmdutil

import (
"os"
"strings"
"testing"
)

func TestEnvMap(t *testing.T) {
t.Setenv("TEST_ENV_EMPTY", "")
t.Setenv("TEST_ENV_SET", "value")
result := envMap()
if value, ok := result["TEST_ENV_EMPTY"]; !ok {
t.Error("Expected TEST_ENV_EMPTY to be set")
} else if value != "" {
t.Errorf("Expected TEST_ENV_EMPTY to be an empty string, got %v", value)
}
if value, ok := result["TEST_ENV_SET"]; !ok {
t.Error("Expected TEST_ENV_SET to be set")
} else if value != "value" {
t.Errorf("Expected TEST_ENV_SET to be 'value', got %v", value)
}
}

func TestIsAllowedToExecute(t *testing.T) {
tests := []struct {
envset map[string]string
when string
want bool
name string
envset map[string]string
when string
want bool
errorContains any
}{
{
name: "Empty expression",
envset: map[string]string{},
when: "",
want: true,
errorContains: nil,
},
{
name: "Equality test, true",
envset: map[string]string{
"TEST_ENV1": "a",
},
when: "$TEST_ENV1 == 'a'",
want: true,
errorContains: nil,
},
{
name: "Equality test, false",
envset: map[string]string{
"TEST_ENV1": "a",
},
when: "$TEST_ENV1 == 'b'",
want: false,
errorContains: nil,
},
{
name: "Containment in $env",
envset: map[string]string{
"env": "should not replace $env",
"TEST_ENV1": "a",
},
when: "$TEST_ENV1 == 'a'",
want: true,
when: `'TEST_ENV1' not in $env`,
want: true,
errorContains: nil,
},
{
name: "Containment in Env",
envset: map[string]string{
"TEST_ENV1": "a",
},
when: "$TEST_ENV1 == 'b'",
want: false,
when: "'TEST_ENV1' in Env",
want: true,
errorContains: nil,
},
{
name: "Env var name is used in string literal",
envset: map[string]string{
"TEST_ENV1": "foo",
"TEST_ENV2": "$TEST_ENV1",
},
when: `$TEST_ENV2 == '$TEST_ENV1'`,
want: true,
errorContains: nil,
},
{
name: "Env var not set",
envset: map[string]string{},
when: `$TEST_ENV_NONESUCH == ""`,
want: true,
errorContains: nil,
},
{
name: "Invalid expression",
envset: map[string]string{},
when: `($TEST_ENV1 == "Missing parentheses"`,
want: false,
errorContains: "unexpected token EOF",
},
{
name: "Expression produces a non-boolean result",
envset: map[string]string{},
when: `"String literal expression"`,
want: false,
errorContains: "expected bool, but got string",
},
{
name: "Expression references an unknown variable",
envset: map[string]string{
"TEST_ENV1": "a",
},
when: `$TEST_ENV1 == "a"`,
want: true,
when: `$TEST_ENV1 == NoneSuchVariable`,
want: false,
errorContains: "unknown name NoneSuchVariable",
},
}
for _, tt := range tests {
for k, v := range tt.envset {
os.Setenv(k, v)
}
got, err := IsAllowedToExecute(tt.when)
if err != nil {
t.Fatal(err)
}
if got != tt.want {
t.Errorf("got %v\nwant %v", got, tt.want)
}
t.Run(tt.name, func(t *testing.T) {
NewWhenEnv = func() *WhenEnv { return &WhenEnv{Env: tt.envset} }
got, err := IsAllowedToExecute(tt.when)
if err != nil {
if tt.errorContains != nil {
if !strings.Contains(err.Error(), tt.errorContains.(string)) {
t.Errorf("Error %v does not contain %s", err, tt.errorContains)
}
} else {
t.Error(err)
}
} else if tt.errorContains != nil {
t.Errorf("Expected an error containing %v", tt.errorContains)
}
if got != tt.want {
t.Errorf("got %v\nwant %v", got, tt.want)
}
})
}
}

0 comments on commit 03b51c1

Please sign in to comment.