Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor when expression handling #600

Merged
merged 1 commit into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 36 additions & 23 deletions cmdutil/when.go
Original file line number Diff line number Diff line change
@@ -1,49 +1,62 @@
package cmdutil

import (
"fmt"
"os"
"strings"

"github.com/expr-lang/expr"
"github.com/expr-lang/expr/ast"
"github.com/pkg/errors"
)

// AST walker which replaces `$IDENTIFIER` with `Env.IDENTIFIER` member lookup expressions.
type EnvPatcher struct{}

func (EnvPatcher) Visit(node *ast.Node) {
if n, ok := (*node).(*ast.IdentifierNode); ok && n.Value[0] == '$' && n.Value != "$env" {
ast.Patch(
node,
&ast.MemberNode{
Node: &ast.IdentifierNode{Value: "Env"},
Property: &ast.StringNode{Value: n.Value[1:]},
},
)
}
}

// The predefined variables of a when expression
type WhenEnv struct {
Env map[string]string
}

var NewWhenEnv = func() *WhenEnv {
return &WhenEnv{Env: envMap()}
}

func IsAllowedToExecute(when string) (bool, error) {
if when == "" {
return true, nil
}
ropts := []string{}
em := envMap()
for k := range em {
ropts = append(ropts, fmt.Sprintf("$%s", k), fmt.Sprintf("Env.%s", k))
}
r := strings.NewReplacer(ropts...)
when = r.Replace(when)
got, err := expr.Eval(fmt.Sprintf("(%s) == true", when), struct {
Env map[string]string
}{
Env: em,
})

whenEnv := NewWhenEnv()
// when expressions must produce a boolean result
program, err := expr.Compile(when, expr.Patch(&EnvPatcher{}), expr.AsBool(), expr.Env(whenEnv))
if err != nil {
return false, errors.WithStack(err)
}
return got.(bool), nil
if got, err := expr.Run(program, whenEnv); err != nil {
return false, errors.WithStack(err)
} else {
return got.(bool), nil
}
}

func envMap() map[string]string {
m := map[string]string{}
for _, kv := range os.Environ() {
if !strings.Contains(kv, "=") {
continue
}
parts := strings.SplitN(kv, "=", 2)
k := parts[0]
if len(parts) < 2 {
m[k] = ""
continue
if k, v, ok := strings.Cut(kv, "="); ok {
m[k] = v
}
m[k] = parts[1]
}
return m
}
129 changes: 109 additions & 20 deletions cmdutil/when_test.go
Original file line number Diff line number Diff line change
@@ -1,48 +1,137 @@
package cmdutil

import (
"os"
"strings"
"testing"
)

func TestEnvMap(t *testing.T) {
t.Setenv("TEST_ENV_EMPTY", "")
t.Setenv("TEST_ENV_SET", "value")
result := envMap()
if value, ok := result["TEST_ENV_EMPTY"]; !ok {
t.Error("Expected TEST_ENV_EMPTY to be set")
} else if value != "" {
t.Errorf("Expected TEST_ENV_EMPTY to be an empty string, got %v", value)
}
if value, ok := result["TEST_ENV_SET"]; !ok {
t.Error("Expected TEST_ENV_SET to be set")
} else if value != "value" {
t.Errorf("Expected TEST_ENV_SET to be 'value', got %v", value)
}
}

func TestIsAllowedToExecute(t *testing.T) {
tests := []struct {
envset map[string]string
when string
want bool
name string
envset map[string]string
when string
want bool
errorContains any
}{
{
name: "Empty expression",
envset: map[string]string{},
when: "",
want: true,
errorContains: nil,
},
{
name: "Equality test, true",
envset: map[string]string{
"TEST_ENV1": "a",
},
when: "$TEST_ENV1 == 'a'",
want: true,
errorContains: nil,
},
{
name: "Equality test, false",
envset: map[string]string{
"TEST_ENV1": "a",
},
when: "$TEST_ENV1 == 'b'",
want: false,
errorContains: nil,
},
{
name: "Containment in $env",
envset: map[string]string{
"env": "should not replace $env",
"TEST_ENV1": "a",
},
when: "$TEST_ENV1 == 'a'",
want: true,
when: `'TEST_ENV1' not in $env`,
want: true,
errorContains: nil,
},
{
name: "Containment in Env",
envset: map[string]string{
"TEST_ENV1": "a",
},
when: "$TEST_ENV1 == 'b'",
want: false,
when: "'TEST_ENV1' in Env",
want: true,
errorContains: nil,
},
{
name: "Env var name is used in string literal",
envset: map[string]string{
"TEST_ENV1": "foo",
"TEST_ENV2": "$TEST_ENV1",
},
when: `$TEST_ENV2 == '$TEST_ENV1'`,
want: true,
errorContains: nil,
},
{
name: "Env var not set",
envset: map[string]string{},
when: `$TEST_ENV_NONESUCH == ""`,
want: true,
errorContains: nil,
},
{
name: "Invalid expression",
envset: map[string]string{},
when: `($TEST_ENV1 == "Missing parentheses"`,
want: false,
errorContains: "unexpected token EOF",
},
{
name: "Expression produces a non-boolean result",
envset: map[string]string{},
when: `"String literal expression"`,
want: false,
errorContains: "expected bool, but got string",
},
{
name: "Expression references an unknown variable",
envset: map[string]string{
"TEST_ENV1": "a",
},
when: `$TEST_ENV1 == "a"`,
want: true,
when: `$TEST_ENV1 == NoneSuchVariable`,
want: false,
errorContains: "unknown name NoneSuchVariable",
},
}
for _, tt := range tests {
for k, v := range tt.envset {
os.Setenv(k, v)
}
got, err := IsAllowedToExecute(tt.when)
if err != nil {
t.Fatal(err)
}
if got != tt.want {
t.Errorf("got %v\nwant %v", got, tt.want)
}
t.Run(tt.name, func(t *testing.T) {
NewWhenEnv = func() *WhenEnv { return &WhenEnv{Env: tt.envset} }
got, err := IsAllowedToExecute(tt.when)
if err != nil {
if tt.errorContains != nil {
if !strings.Contains(err.Error(), tt.errorContains.(string)) {
t.Errorf("Error %v does not contain %s", err, tt.errorContains)
}
} else {
t.Error(err)
}
} else if tt.errorContains != nil {
t.Errorf("Expected an error containing %v", tt.errorContains)
}
if got != tt.want {
t.Errorf("got %v\nwant %v", got, tt.want)
}
})
}
}
Loading