Skip to content

Commit

Permalink
types: change Set() to take variadic arguments to match EntityUIDSet()
Browse files Browse the repository at this point in the history
This also cuts down quite a bit on the annoying []types.Value{} wrapper that we have to put around literal arguments.

Signed-off-by: Patrick Jakubowski <patrick.jakubowski@strongdm.com>
  • Loading branch information
patjakdev committed Nov 7, 2024
1 parent 5cefbcb commit 129b06f
Show file tree
Hide file tree
Showing 16 changed files with 109 additions and 112 deletions.
4 changes: 2 additions & 2 deletions internal/ast/ast_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -249,9 +249,9 @@ func TestASTByTable(t *testing.T) {
},
{
"valueSet",
ast.Permit().When(ast.Value(types.NewSet([]types.Value{types.Long(42), types.Long(43)}))),
ast.Permit().When(ast.Value(types.NewSet(types.Long(42), types.Long(43)))),
ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{},
Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeValue{Value: types.NewSet([]types.Value{types.Long(42), types.Long(43)})}}},
Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeValue{Value: types.NewSet(types.Long(42), types.Long(43))}}},
},
},
{
Expand Down
2 changes: 1 addition & 1 deletion internal/eval/compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ func scopeToNode(varNode ast.NodeTypeVariable, in ast.IsScopeNode) ast.Node {
for i, e := range t.Entities {
vals[i] = e
}
return ast.NewNode(varNode).In(ast.Value(types.NewSet(vals)))
return ast.NewNode(varNode).In(ast.Value(types.NewSet(vals...)))
case ast.ScopeTypeIs:
return ast.NewNode(varNode).Is(t.Type)

Expand Down
2 changes: 1 addition & 1 deletion internal/eval/compile_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ func TestScopeToNode(t *testing.T) {
"inSet",
ast.NewActionNode(),
ast.ScopeTypeInSet{Entities: []types.EntityUID{types.NewEntityUID("T", "42")}},
ast.Action().In(ast.Value(types.NewSet([]types.Value{types.NewEntityUID("T", "42")}))),
ast.Action().In(ast.Value(types.NewSet(types.NewEntityUID("T", "42")))),
},
{
"is",
Expand Down
8 changes: 4 additions & 4 deletions internal/eval/convert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ func TestToEval(t *testing.T) {
{
"set",
ast.Set(ast.Long(42)),
types.NewSet([]types.Value{types.Long(42)}),
types.NewSet(types.Long(42)),
testutil.OK,
},
{
Expand Down Expand Up @@ -182,19 +182,19 @@ func TestToEval(t *testing.T) {
},
{
"contains",
ast.Value(types.NewSet([]types.Value{types.Long(42)})).Contains(ast.Long(42)),
ast.Value(types.NewSet(types.Long(42))).Contains(ast.Long(42)),
types.True,
testutil.OK,
},
{
"containsAll",
ast.Value(types.NewSet([]types.Value{types.Long(42), types.Long(43), types.Long(44)})).ContainsAll(ast.Value(types.NewSet([]types.Value{types.Long(42), types.Long(43)}))),
ast.Value(types.NewSet(types.Long(42), types.Long(43), types.Long(44))).ContainsAll(ast.Value(types.NewSet(types.Long(42), types.Long(43)))),
types.True,
testutil.OK,
},
{
"containsAny",
ast.Value(types.NewSet([]types.Value{types.Long(42), types.Long(43), types.Long(44)})).ContainsAny(ast.Value(types.NewSet([]types.Value{types.Long(1), types.Long(42)}))),
ast.Value(types.NewSet(types.Long(42), types.Long(43), types.Long(44))).ContainsAny(ast.Value(types.NewSet(types.Long(1), types.Long(42)))),
types.True,
testutil.OK,
},
Expand Down
2 changes: 1 addition & 1 deletion internal/eval/evalers.go
Original file line number Diff line number Diff line change
Expand Up @@ -721,7 +721,7 @@ func (n *setLiteralEval) Eval(env Env) (types.Value, error) {
}
vals[i] = v
}
return types.NewSet(vals), nil
return types.NewSet(vals...), nil
}

// containsEval
Expand Down
50 changes: 25 additions & 25 deletions internal/eval/evalers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1290,20 +1290,20 @@ func TestSetLiteralNode(t *testing.T) {
{"nested",
[]Evaler{
newLiteralEval(types.True),
newLiteralEval(types.NewSet([]types.Value{
newLiteralEval(types.NewSet(
types.False,
types.Long(1),
})),
)),
newLiteralEval(types.Long(10)),
},
types.NewSet([]types.Value{
types.NewSet(
types.True,
types.NewSet([]types.Value{
types.NewSet(
types.False,
types.Long(1),
}),
),
types.Long(10),
}),
),
nil},
}
for _, tt := range tests {
Expand Down Expand Up @@ -1343,8 +1343,8 @@ func TestContainsNode(t *testing.T) {
}
{
empty := types.Set{}
trueAndOne := types.NewSet([]types.Value{types.True, types.Long(1)})
nested := types.NewSet([]types.Value{trueAndOne, types.False, types.Long(2)})
trueAndOne := types.NewSet(types.True, types.Long(1))
nested := types.NewSet(trueAndOne, types.False, types.Long(2))

tests := []struct {
name string
Expand Down Expand Up @@ -1398,9 +1398,9 @@ func TestContainsAllNode(t *testing.T) {
}
{
empty := types.Set{}
trueOnly := types.NewSet([]types.Value{types.True})
trueAndOne := types.NewSet([]types.Value{types.True, types.Long(1)})
nested := types.NewSet([]types.Value{trueAndOne, types.False, types.Long(2)})
trueOnly := types.NewSet(types.True)
trueAndOne := types.NewSet(types.True, types.Long(1))
nested := types.NewSet(trueAndOne, types.False, types.Long(2))

tests := []struct {
name string
Expand Down Expand Up @@ -1452,10 +1452,10 @@ func TestContainsAnyNode(t *testing.T) {
}
{
empty := types.Set{}
trueOnly := types.NewSet([]types.Value{types.True})
trueAndOne := types.NewSet([]types.Value{types.True, types.Long(1)})
trueAndTwo := types.NewSet([]types.Value{types.True, types.Long(2)})
nested := types.NewSet([]types.Value{trueAndOne, types.False, types.Long(2)})
trueOnly := types.NewSet(types.True)
trueAndOne := types.NewSet(types.True, types.Long(1))
trueAndTwo := types.NewSet(types.True, types.Long(2))
nested := types.NewSet(trueAndOne, types.False, types.Long(2))

tests := []struct {
name string
Expand Down Expand Up @@ -1495,7 +1495,7 @@ func TestContainsAnyNode(t *testing.T) {
set2[i] = types.Long(setSize + i)
}

n := newContainsAnyEval(newLiteralEval(types.NewSet(set1)), newLiteralEval(types.NewSet(set2)))
n := newContainsAnyEval(newLiteralEval(types.NewSet(set1...)), newLiteralEval(types.NewSet(set2...)))

// This call would take several minutes if the evaluation of ContainsAny was quadratic
val, err := n.Eval(Env{})
Expand Down Expand Up @@ -1954,9 +1954,9 @@ func TestInNode(t *testing.T) {
{
"RhsTypeError2",
newLiteralEval(types.NewEntityUID("human", "joe")),
newLiteralEval(types.NewSet([]types.Value{
newLiteralEval(types.NewSet(
types.String("foo"),
})),
)),
map[string][]string{},
zeroValue(),
ErrType,
Expand All @@ -1972,9 +1972,9 @@ func TestInNode(t *testing.T) {
{
"Reflexive2",
newLiteralEval(types.NewEntityUID("human", "joe")),
newLiteralEval(types.NewSet([]types.Value{
newLiteralEval(types.NewSet(
types.NewEntityUID("human", "joe"),
})),
)),
map[string][]string{},
types.True,
nil,
Expand Down Expand Up @@ -2078,9 +2078,9 @@ func TestIsInNode(t *testing.T) {
"RhsTypeError2",
newLiteralEval(types.NewEntityUID("human", "joe")),
"human",
newLiteralEval(types.NewSet([]types.Value{
newLiteralEval(types.NewSet(
types.String("foo"),
})),
)),
map[string][]string{},
zeroValue(),
ErrType,
Expand All @@ -2098,9 +2098,9 @@ func TestIsInNode(t *testing.T) {
"Reflexive2",
newLiteralEval(types.NewEntityUID("human", "joe")),
"human",
newLiteralEval(types.NewSet([]types.Value{
newLiteralEval(types.NewSet(
types.NewEntityUID("human", "joe"),
})),
)),
map[string][]string{},
types.True,
nil,
Expand Down Expand Up @@ -2307,7 +2307,7 @@ func TestCedarString(t *testing.T) {
{"number", types.Long(42), `42`, `42`},
{"bool", types.True, `true`, `true`},
{"record", types.NewRecord(types.RecordMap{"a": types.Long(42), "b": types.Long(43)}), `{"a":42, "b":43}`, `{"a":42, "b":43}`},
{"set", types.NewSet([]types.Value{types.Long(42), types.Long(43)}), `[42, 43]`, `[42, 43]`},
{"set", types.NewSet(types.Long(42), types.Long(43)), `[42, 43]`, `[42, 43]`},
{"singleIP", types.IPAddr(netip.MustParsePrefix("192.168.0.42/32")), `192.168.0.42`, `ip("192.168.0.42")`},
{"ipPrefix", types.IPAddr(netip.MustParsePrefix("192.168.0.42/24")), `192.168.0.42/24`, `ip("192.168.0.42/24")`},
{"decimal", testutil.Must(types.NewDecimal(12345678, -4)), `1234.5678`, `decimal("1234.5678")`},
Expand Down
6 changes: 3 additions & 3 deletions internal/eval/fold_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,15 @@ func TestFoldNode(t *testing.T) {
},
{"set-bake",
ast.Set(ast.True()),
ast.Value(types.NewSet([]types.Value{types.True})),
ast.Value(types.NewSet(types.True)),
},
{"record-fold",
ast.Record(ast.Pairs{{Key: "key", Value: ast.Long(6).Multiply(ast.Long(7))}}),
ast.Value(types.NewRecord(types.RecordMap{"key": types.Long(42)})),
},
{"set-fold",
ast.Set(ast.Long(6).Multiply(ast.Long(7))),
ast.Value(types.NewSet([]types.Value{types.Long(42)})),
ast.Value(types.NewSet(types.Long(42))),
},
{"record-blocked",
ast.Record(ast.Pairs{{Key: "key", Value: ast.Long(6).Multiply(ast.Context())}}),
Expand Down Expand Up @@ -205,7 +205,7 @@ func TestFoldPolicy(t *testing.T) {
{
"valueSetNodes",
ast.Permit().When(ast.Set(ast.Long(42), ast.Long(43))),
ast.Permit().When(ast.Value(types.NewSet([]types.Value{types.Long(42), types.Long(43)}))),
ast.Permit().When(ast.Value(types.NewSet(types.Long(42), types.Long(43)))),
},
{
"valueRecordElements",
Expand Down
8 changes: 4 additions & 4 deletions internal/eval/partial_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -718,7 +718,7 @@ func TestPartialBasic(t *testing.T) {
{
"valueSetNodesFold",
ast.Set(ast.Long(42), ast.Long(43)),
ast.Value(types.NewSet([]types.Value{types.Long(42), types.Long(43)})),
ast.Value(types.NewSet(types.Long(42), types.Long(43))),
testutil.OK,
},
{
Expand Down Expand Up @@ -1102,7 +1102,7 @@ func TestPartialBasic(t *testing.T) {
{
"opContainsKeep",
ast.Set(ast.Long(42)).Contains(ast.Context()),
ast.Value(types.NewSet([]types.Value{types.Long(42)})).Contains(ast.Context()),
ast.Value(types.NewSet(types.Long(42))).Contains(ast.Context()),
testutil.OK,
},
{
Expand All @@ -1120,7 +1120,7 @@ func TestPartialBasic(t *testing.T) {
{
"opContainsAllKeep",
ast.Set(ast.Long(42)).ContainsAll(ast.Context()),
ast.Value(types.NewSet([]types.Value{types.Long(42)})).ContainsAll(ast.Context()),
ast.Value(types.NewSet(types.Long(42))).ContainsAll(ast.Context()),
testutil.OK,
},
{
Expand All @@ -1138,7 +1138,7 @@ func TestPartialBasic(t *testing.T) {
{
"opContainsAnyKeep",
ast.Set(ast.Long(42)).ContainsAny(ast.Context()),
ast.Value(types.NewSet([]types.Value{types.Long(42)})).ContainsAny(ast.Context()),
ast.Value(types.NewSet(types.Long(42))).ContainsAny(ast.Context()),
testutil.OK,
},
{
Expand Down
2 changes: 1 addition & 1 deletion internal/eval/util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ func TestUtil(t *testing.T) {
t.Parallel()
t.Run("roundTrip", func(t *testing.T) {
t.Parallel()
v := types.NewSet([]types.Value{types.Boolean(true), types.Long(1)})
v := types.NewSet(types.Boolean(true), types.Long(1))
slice, err := ValueToSet(v)
testutil.OK(t, err)
v2 := slice
Expand Down
6 changes: 3 additions & 3 deletions types.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,9 @@ func NewRecord(r RecordMap) Record {
return types.NewRecord(r)
}

// NewSet returns an immutable Set given a Go slice of Values. Duplicates are removed and order is not preserved.
func NewSet(s []types.Value) Set {
return types.NewSet(s)
// NewSet returns an immutable Set given a variadic set of Values. Duplicates are removed and order is not preserved.
func NewSet(s ...types.Value) Set {
return types.NewSet(s...)
}

// NewDecimal returns a Decimal value of i * 10^exponent.
Expand Down
8 changes: 4 additions & 4 deletions types/json_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ func TestJSON_Value(t *testing.T) {
{"badDecimal", `{ "__extn": { "fn": "decimal", "arg": "bad" } }`, zeroValue(), ErrDecimal},
{"badDatetime", `{ "__extn": { "fn": "datetime", "arg": "bad" } }`, zeroValue(), ErrDatetime},
{"badDuration", `{ "__extn": { "fn": "duration", "arg": "bad" } }`, zeroValue(), ErrDuration},
{"set", `[42]`, NewSet([]Value{Long(42)}), nil},
{"set", `[42]`, NewSet(Long(42)), nil},
{"record", `{"a":"b"}`, NewRecord(RecordMap{"a": String("b")}), nil},
{"bool", `false`, Boolean(false), nil},
}
Expand Down Expand Up @@ -556,11 +556,11 @@ func TestJSONMarshal(t *testing.T) {
},
{
"set",
NewSet([]Value{
NewSet(
String("av"),
String("cv"),
String("bv"),
}),
),
`["cv","bv","av"]`,
},
{
Expand Down Expand Up @@ -608,7 +608,7 @@ func TestJSONSet(t *testing.T) {
})
t.Run("MarshalErr", func(t *testing.T) {
t.Parallel()
s := NewSet([]Value{&jsonErr{}})
s := NewSet(&jsonErr{})
_, err := json.Marshal(s)
testutil.Error(t, err)
})
Expand Down
6 changes: 3 additions & 3 deletions types/set.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ type Set struct {
hashVal uint64
}

// NewSet returns an immutable Set given a Go slice of Values. Duplicates are removed and order is not preserved.
func NewSet(v []Value) Set {
// NewSet returns an immutable Set given a variadic set of Values. Duplicates are removed and order is not preserved.
func NewSet(v ...Value) Set {
var set map[uint64]Value
if v != nil {
set = make(map[uint64]Value, len(v))
Expand Down Expand Up @@ -125,7 +125,7 @@ func (v *Set) UnmarshalJSON(b []byte) error {
vals[i] = vv.Value
}

*v = NewSet(vals)
*v = NewSet(vals...)
return nil
}

Expand Down
Loading

0 comments on commit 129b06f

Please sign in to comment.