diff --git a/Project.toml b/Project.toml index 3efd5e9..9e9987d 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "OpenPolicyAgent" uuid = "8f257efb-743c-4ebc-8197-d291a1f743b4" authors = ["JuliaHub Inc.", "Tanmay Mohapatra "] -version = "0.3.0" +version = "0.3.1" [deps] Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" diff --git a/src/utils/sql.jl b/src/utils/sql.jl index 43b6c2d..ecff1eb 100644 --- a/src/utils/sql.jl +++ b/src/utils/sql.jl @@ -46,6 +46,14 @@ const SQL_OP_MAP = Dict{String,String}( "lte" => "<=", "equal" => "=", "internal.member_2" => "in", + "bits.and" => "&", + "bits.or" => "|", + "bits.xor" => "#", + "bits.negate" => "~", + "bits.lsh" => "<<", + "bits.rsh" => ">>", + "plus" => "+", + "minus" => "-", ) const VALID_SQL_OPS = Set(keys(SQL_OP_MAP)) @@ -156,6 +164,21 @@ function visit(visitor::SQLVisitor, arr::AST.OPAArray) return nothing end +function visit(visitor::SQLVisitor, call::AST.OPACall) + op_name = walk(visitor, call.operator) + if !(op_name in VALID_SQL_OPS) + error("Invalid SQL operator: $op_name") + end + op = SQL_OP_MAP[op_name] + + op_operands = call.operands + @assert length(op_operands) == 2 + op_lhs = walk(visitor, op_operands[1]) + op_rhs = walk(visitor, op_operands[2]) + + push!(visitor.result_stack, join(["(", op_lhs, op, op_rhs, ")"], " ")) +end + function visit(visitor::SQLVisitor, expr::AST.OPAExpr) @assert AST.is_call(expr) @@ -171,7 +194,7 @@ function visit(visitor::SQLVisitor, expr::AST.OPAExpr) op_rhs = walk(visitor, op_operands[2]) op = SQL_OP_MAP[op_name] - push!(visitor.result_stack, join([op_lhs, op, op_rhs], " ")) + push!(visitor.result_stack, join(["(", op_lhs, op, op_rhs, ")"], " ")) end end # module SQL \ No newline at end of file diff --git a/test/sql_translate.jl b/test/sql_translate.jl index 5aa08a9..b37bdba 100644 --- a/test/sql_translate.jl +++ b/test/sql_translate.jl @@ -157,6 +157,14 @@ const SQL_OP_MAP = Dict{String,String}( "lte" => "<=", "equal" => "=", "internal.member_2" => "in", + "bits.and" => "&", + "bits.or" => "|", + "bits.xor" => "#", + "bits.negate" => "~", + "bits.lsh" => "<<", + "bits.rsh" => ">>", + "plus" => "+", + "minus" => "-", ) const VALID_SQL_OPS = Set(keys(SQL_OP_MAP)) @@ -229,7 +237,20 @@ function to_sql(expr::OPAExpr) op_lhs = to_sql(op_operands[1]) op_rhs = to_sql(op_operands[2]) - return join([op_lhs, op, op_rhs], " ") + return join(["(", op_lhs, op, op_rhs, ")"], " ") +end + +function to_sql(call::OPACall) + op_name = to_sql(call.operator) + if !(op_name in VALID_SQL_OPS) + error("Invalid SQL operator: $op_name") + end + + op = SQL_OP_MAP[op_name] + op_lhs = to_sql(call.operands[1]) + op_rhs = to_sql(call.operands[2]) + + return join(["(", op_lhs, op, op_rhs, ")"], " ") end function to_sql(query::OPAQuery) diff --git a/test/test_data.jl b/test/test_data.jl index f8094f6..b0e88d9 100644 --- a/test/test_data.jl +++ b/test/test_data.jl @@ -38,7 +38,7 @@ const PARTIAL_COMPILE_CASES = [ "disableInlining" => [] ), unknowns = ["data.reports"], - sql = "4 >= public.juliahub_reports.clearance_level", + sql = "( 4 >= public.juliahub_reports.clearance_level )", ), ( policy = """package example @@ -66,7 +66,7 @@ const PARTIAL_COMPILE_CASES = [ "disableInlining" => [] ), unknowns = ["data.reports"], - sql = "public.juliahub_reports.public = true or\n4 >= public.juliahub_reports.clearance_level and 'bob' = public.juliahub_reports.owner", + sql = "( public.juliahub_reports.public = true ) or\n( 4 >= public.juliahub_reports.clearance_level ) and ( 'bob' = public.juliahub_reports.owner )", ), ( # always allowed if the policy is fully satisfied with the given input for any one condition @@ -163,7 +163,75 @@ const PARTIAL_COMPILE_CASES = [ "disableInlining" => [] ), unknowns = ["data.reports"], - sql = "public.juliahub_reports.category in ('public', 'pinned') or\n4 >= public.juliahub_reports.clearance_level", + sql = "( public.juliahub_reports.category in ('public', 'pinned') ) or\n( 4 >= public.juliahub_reports.clearance_level )", + ), + ( + policy = """package example + allow { + bits.and(data.reports[_].clearance_level, input.subject.clearance_level) >= input.subject.clearance_level + }""", + query = "data.example.allow == true", + input = Dict{String,Any}( + "subject" => Dict{String,Any}( + "clearance_level" => 4 + ) + ), + options = Dict{String,Any}( + "disableInlining" => [] + ), + unknowns = ["data.reports"], + sql = "( ( public.juliahub_reports.clearance_level & 4 ) >= 4 )", + ), + ( + policy = """package example + allow { + bits.or(data.reports[_].clearance_level, input.subject.clearance_level) >= input.subject.clearance_level + }""", + query = "data.example.allow == true", + input = Dict{String,Any}( + "subject" => Dict{String,Any}( + "clearance_level" => 4 + ) + ), + options = Dict{String,Any}( + "disableInlining" => [] + ), + unknowns = ["data.reports"], + sql = "( ( public.juliahub_reports.clearance_level | 4 ) >= 4 )", + ), + ( + policy = """package example + allow { + (data.reports[_].clearance_level + input.subject.clearance_level) >= input.subject.clearance_level + }""", + query = "data.example.allow == true", + input = Dict{String,Any}( + "subject" => Dict{String,Any}( + "clearance_level" => 4 + ) + ), + options = Dict{String,Any}( + "disableInlining" => [] + ), + unknowns = ["data.reports"], + sql = "( ( public.juliahub_reports.clearance_level + 4 ) >= 4 )", + ), + ( + policy = """package example + allow { + (data.reports[_].clearance_level - input.subject.clearance_level) >= 0 + }""", + query = "data.example.allow == true", + input = Dict{String,Any}( + "subject" => Dict{String,Any}( + "clearance_level" => 4 + ) + ), + options = Dict{String,Any}( + "disableInlining" => [] + ), + unknowns = ["data.reports"], + sql = "( ( public.juliahub_reports.clearance_level - 4 ) >= 0 )", ), ]