Skip to content

Commit

Permalink
Use K"function" for short form function AST (#466)
Browse files Browse the repository at this point in the history
A pain point when writing macros is detecting all the types of things which might be lowered to functions. This is partly due to the existence of short form function definitions which in Julia's classic AST parse with `:(=)` rather than a `:function` head - to detect the meaning of `=`, one needs to traverse recursively into the left hand side of the expression.

This change modifies the parsing of short form functions to use the `K"function"` kind. A new syntax flag `SHORT_FORM_FUNCTION_FLAG` is set to enable AST consumers to detect short vs long form functions.
  • Loading branch information
c42f authored Jul 30, 2024
1 parent da801cc commit 25f8eb2
Show file tree
Hide file tree
Showing 8 changed files with 132 additions and 56 deletions.
1 change: 1 addition & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ JuliaSyntax.COLON_QUOTE
JuliaSyntax.TOPLEVEL_SEMICOLONS_FLAG
JuliaSyntax.MUTABLE_FLAG
JuliaSyntax.BARE_MODULE_FLAG
JuliaSyntax.SHORT_FORM_FUNCTION_FLAG
```

## Syntax trees
Expand Down
39 changes: 16 additions & 23 deletions src/expr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,6 @@ macro isexpr(ex, head, nargs)
length($(esc(ex)).args) == $(esc(nargs)))
end

function is_eventually_call(ex)
return ex isa Expr && (ex.head === :call ||
(ex.head === :where || ex.head === :(::)) && is_eventually_call(ex.args[1]))
end

function _reorder_parameters!(args::Vector{Any}, params_pos)
p = 0
for i = length(args):-1:1
Expand Down Expand Up @@ -233,16 +228,6 @@ function _internal_node_to_Expr(source, srcrange, head, childranges, childheads,

if k == K"?"
headsym = :if
elseif k == K"=" && !is_decorated(head)
a2 = args[2]
if is_eventually_call(args[1])
if @isexpr(a2, :block)
pushfirst!(a2.args, loc)
else
# Add block for short form function locations
args[2] = Expr(:block, loc, a2)
end
end
elseif k == K"macrocall"
do_lambda = _extract_do_lambda!(args)
_reorder_parameters!(args, 2)
Expand Down Expand Up @@ -399,14 +384,22 @@ function _internal_node_to_Expr(source, srcrange, head, childranges, childheads,
end
elseif k == K"function"
if length(args) > 1
a1 = args[1]
if @isexpr(a1, :tuple)
# Convert to weird Expr forms for long-form anonymous functions.
#
# (function (tuple (... xs)) body) ==> (function (... xs) body)
if length(a1.args) == 1 && (a11 = a1.args[1]; @isexpr(a11, :...))
# function (xs...) \n body end
args[1] = a11
if has_flags(head, SHORT_FORM_FUNCTION_FLAG)
a2 = args[2]
if !@isexpr(a2, :block)
args[2] = Expr(:block, a2)
end
headsym = :(=)
else
a1 = args[1]
if @isexpr(a1, :tuple)
# Convert to weird Expr forms for long-form anonymous functions.
#
# (function (tuple (... xs)) body) ==> (function (... xs) body)
if length(a1.args) == 1 && (a11 = a1.args[1]; @isexpr(a11, :...))
# function (xs...) \n body end
args[1] = a11
end
end
end
pushfirst!((args[2]::Expr).args, loc)
Expand Down
64 changes: 48 additions & 16 deletions src/parse_stream.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ Set for K"toplevel" which is delimited by parentheses
"""
const TOPLEVEL_SEMICOLONS_FLAG = RawFlags(1<<5)

"""
Set for K"function" in short form definitions such as `f() = 1`
"""
const SHORT_FORM_FUNCTION_FLAG = RawFlags(1<<5)

"""
Set for K"struct" when mutable
"""
Expand Down Expand Up @@ -143,6 +148,8 @@ function untokenize(head::SyntaxHead; unique=true, include_flag_suff=true)
has_flags(head, COLON_QUOTE) && (str = str*"-:")
elseif kind(head) == K"toplevel"
has_flags(head, TOPLEVEL_SEMICOLONS_FLAG) && (str = str*"-;")
elseif kind(head) == K"function"
has_flags(head, SHORT_FORM_FUNCTION_FLAG) && (str = str*"-=")
elseif kind(head) == K"struct"
has_flags(head, MUTABLE_FLAG) && (str = str*"-mut")
elseif kind(head) == K"module"
Expand Down Expand Up @@ -646,17 +653,17 @@ function peek_behind(stream::ParseStream, pos::ParseStreamPosition)
end

function first_child_position(stream::ParseStream, pos::ParseStreamPosition)
ranges = stream.ranges
@assert pos.range_index > 0
parent = ranges[pos.range_index]
# Find the first nontrivia range which is a child of this range but not a
# child of the child
c = 0
@assert pos.range_index > 0
parent = stream.ranges[pos.range_index]
for i = pos.range_index-1:-1:1
if stream.ranges[i].first_token < parent.first_token
if ranges[i].first_token < parent.first_token
break
end
if (c == 0 || stream.ranges[i].first_token < stream.ranges[c].first_token) &&
!is_trivia(stream.ranges[i])
if (c == 0 || ranges[i].first_token < ranges[c].first_token) && !is_trivia(ranges[i])
c = i
end
end
Expand All @@ -670,19 +677,44 @@ function first_child_position(stream::ParseStream, pos::ParseStreamPosition)
end
end

if c != 0
if t != 0
if stream.ranges[c].first_token > t
# Need a child index strictly before `t`. `c=0` works.
return ParseStreamPosition(t, 0)
else
return ParseStreamPosition(stream.ranges[c].last_token, c)
end
else
return ParseStreamPosition(stream.ranges[c].last_token, c)
if c == 0 || (t != 0 && ranges[c].first_token > t)
# Return leaf node at `t`
return ParseStreamPosition(t, 0)
else
# Return interior node at `c`
return ParseStreamPosition(ranges[c].last_token, c)
end
end

function last_child_position(stream::ParseStream, pos::ParseStreamPosition)
ranges = stream.ranges
@assert pos.range_index > 0
parent = ranges[pos.range_index]
# Find the last nontrivia range which is a child of this range
c = 0
if pos.range_index > 1
i = pos.range_index-1
if ranges[i].first_token >= parent.first_token
# Valid child of current range
c = i
end
end

# Find last nontrivia token
t = 0
for i = parent.last_token:-1:parent.first_token
if !is_trivia(stream.tokens[i])
t = i
break
end
end

if c == 0 || (t != 0 && ranges[c].last_token < t)
# Return leaf node at `t`
return ParseStreamPosition(t, 0)
else
return ParseStreamPosition(t, c)
# Return interior node at `c`
return ParseStreamPosition(ranges[c].last_token, c)
end
end

Expand Down
22 changes: 19 additions & 3 deletions src/parser.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,9 @@ function first_child_position(ps::ParseState, pos::ParseStreamPosition)
first_child_position(ps.stream, pos)
end

function last_child_position(ps::ParseState, pos::ParseStreamPosition)
last_child_position(ps.stream, pos)
end
#-------------------------------------------------------------------------------
# Parser Utils

Expand Down Expand Up @@ -325,6 +328,12 @@ function was_eventually_call(ps::ParseState)
return true
elseif b.kind == K"where" || b.kind == K"parens" ||
(b.kind == K"::" && has_flags(b.flags, INFIX_FLAG))
if b.kind == K"::"
p_last = last_child_position(ps, p)
if p == p_last
return false
end
end
p = first_child_position(ps, p)
else
return false
Expand Down Expand Up @@ -618,12 +627,19 @@ function parse_assignment_with_initial_ex(ps::ParseState, mark, down::T) where {
parse_assignment(ps, down)
emit(ps, mark, is_dotted(t) ? K"dotcall" : K"call", INFIX_FLAG)
else
# a += b ==> (+= a b)
# a .= b ==> (.= a b)
# f() = 1 ==> (function-= (call f) 1)
# f() .= 1 ==> (.= (call f) 1)
# a += b ==> (+= a b)
# a .= b ==> (.= a b)
is_short_form_func = k == K"=" && !is_dotted(t) && was_eventually_call(ps)
bump(ps, TRIVIA_FLAG)
bump_trivia(ps)
# Syntax Edition TODO: We'd like to call `down` here when
# is_short_form_func is true, to prevent `f() = 1 = 2` from parsing.
parse_assignment(ps, down)
emit(ps, mark, k, flags(t))
emit(ps, mark,
is_short_form_func ? K"function" : k,
is_short_form_func ? SHORT_FORM_FUNCTION_FLAG : flags(t))
end
end

Expand Down
8 changes: 7 additions & 1 deletion test/parse_packages.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,13 @@ base_path = let
p
end
@testset "Parse Base at $base_path" begin
test_parse_all_in_path(base_path)
test_parse_all_in_path(base_path) do f
if endswith(f, "gmp.jl")
# Loose comparison due to `f(::g(w) = z) = a` syntax
return exprs_roughly_equal
end
return exprs_equal_no_linenum
end
end

base_tests_path = joinpath(Sys.BINDIR, Base.DATAROOTDIR, "julia", "test")
Expand Down
23 changes: 18 additions & 5 deletions test/parse_stream.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ using JuliaSyntax: ParseStream,
peek, peek_token,
bump, bump_trivia, bump_invisible,
emit, emit_diagnostic, TRIVIA_FLAG, INFIX_FLAG,
ParseStreamPosition, first_child_position
ParseStreamPosition, first_child_position, last_child_position

# Here we manually issue parse events in the order the Julia parser would issue
# them
Expand Down Expand Up @@ -110,27 +110,40 @@ end
st = parse_sexpr("((a b) c)")
child1_pos = first_child_position(st, position(st))
@test child1_pos == ParseStreamPosition(7, 1)
child2_pos = first_child_position(st, child1_pos)
@test child2_pos == ParseStreamPosition(4, 0)
@test first_child_position(st, child1_pos) == ParseStreamPosition(4, 0)
@test last_child_position(st, position(st)) == ParseStreamPosition(9, 0)
@test last_child_position(st, child1_pos) == ParseStreamPosition(6, 0)

st = parse_sexpr("( (a b) c)")
child1_pos = first_child_position(st, position(st))
@test child1_pos == ParseStreamPosition(8, 1)
child2_pos = first_child_position(st, child1_pos)
@test child2_pos == ParseStreamPosition(5, 0)
@test first_child_position(st, child1_pos) == ParseStreamPosition(5, 0)
@test last_child_position(st, position(st)) == ParseStreamPosition(10, 0)
@test last_child_position(st, child1_pos) == ParseStreamPosition(7, 0)

st = parse_sexpr("(a (b c))")
@test first_child_position(st, position(st)) == ParseStreamPosition(3, 0)
child2_pos = last_child_position(st, position(st))
@test child2_pos == ParseStreamPosition(9, 1)
@test first_child_position(st, child2_pos) == ParseStreamPosition(6, 0)
@test last_child_position(st, child2_pos) == ParseStreamPosition(8, 0)

st = parse_sexpr("( a (b c))")
@test first_child_position(st, position(st)) == ParseStreamPosition(4, 0)
child2_pos = last_child_position(st, position(st))
@test child2_pos == ParseStreamPosition(10, 1)
@test first_child_position(st, child2_pos) == ParseStreamPosition(7, 0)
@test last_child_position(st, child2_pos) == ParseStreamPosition(9, 0)

st = parse_sexpr("a (b c)")
@test first_child_position(st, position(st)) == ParseStreamPosition(5, 0)
@test last_child_position(st, position(st)) == ParseStreamPosition(7, 0)

st = parse_sexpr("(a) (b c)")
@test first_child_position(st, position(st)) == ParseStreamPosition(7, 0)
@test last_child_position(st, position(st)) == ParseStreamPosition(9, 0)

st = parse_sexpr("(() ())")
@test first_child_position(st, position(st)) == ParseStreamPosition(4, 1)
@test last_child_position(st, position(st)) == ParseStreamPosition(7, 2)
end
13 changes: 10 additions & 3 deletions test/parser.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,13 @@ tests = [
"a .~ b" => "(dotcall-i a ~ b)"
"[a ~ b c]" => "(hcat (call-i a ~ b) c)"
"[a~b]" => "(vect (call-i a ~ b))"
"f(x) .= 1" => "(.= (call f x) 1)"
"::g() = 1" => "(= (::-pre (call g)) 1)"
"f(x) = 1" => "(function-= (call f x) 1)"
"f(x)::T = 1" => "(function-= (::-i (call f x) T) 1)"
"f(x) where S where U = 1" => "(function-= (where (where (call f x) S) U) 1)"
"(f(x)::T) where S = 1" => "(function-= (where (parens (::-i (call f x) T)) S) 1)"
"f(x) = 1 = 2" => "(function-= (call f x) (= 1 2))" # Should be a warning!
],
JuliaSyntax.parse_pair => [
"a => b" => "(call-i a => b)"
Expand Down Expand Up @@ -449,7 +456,7 @@ tests = [
],
JuliaSyntax.parse_resword => [
# In normal_context
"begin f() where T = x end" => "(block (= (where (call f) T) x))"
"begin f() where T = x end" => "(block (function-= (where (call f) T) x))"
# block
"begin end" => "(block)"
"begin a ; b end" => "(block a b)"
Expand Down Expand Up @@ -955,14 +962,14 @@ tests = [
"if true \n public A, B \n end" => PARSE_ERROR
"public export=true foo, bar" => PARSE_ERROR # but these may be
"public experimental=true foo, bar" => PARSE_ERROR # supported soon ;)
"public(x::String) = false" => "(= (call public (::-i x String)) false)"
"public(x::String) = false" => "(function-= (call public (::-i x String)) false)"
"module M; export @a; end" => "(module M (block (export @a)))"
"module M; public @a; end" => "(module M (block (public @a)))"
"module M; export ⤈; end" => "(module M (block (export ⤈)))"
"module M; public ⤈; end" => "(module M (block (public ⤈)))"
"public = 4" => "(= public 4)"
"public[7] = 5" => "(= (ref public 7) 5)"
"public() = 6" => "(= (call public) 6)"
"public() = 6" => "(function-= (call public) 6)"
]),
JuliaSyntax.parse_docstring => [
""" "notdoc" ] """ => "(string \"notdoc\")"
Expand Down
18 changes: 13 additions & 5 deletions test/test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,11 @@ function exprs_equal_no_linenum(fl_ex, ex)
remove_all_linenums!(deepcopy(ex)) == remove_all_linenums!(deepcopy(fl_ex))
end

function is_eventually_call(ex)
return ex isa Expr && (ex.head === :call ||
(ex.head === :where || ex.head === :(::)) && is_eventually_call(ex.args[1]))
end

# Compare Expr from reference parser expression to JuliaSyntax parser, ignoring
# differences due to bugs in the reference parser.
function exprs_roughly_equal(fl_ex, ex)
Expand Down Expand Up @@ -149,7 +154,7 @@ function exprs_roughly_equal(fl_ex, ex)
fl_args[1] = Expr(:tuple, Expr(:parameters, kwargs...), posargs...)
elseif h == :for
iterspec = args[1]
if JuliaSyntax.is_eventually_call(iterspec.args[1]) &&
if is_eventually_call(iterspec.args[1]) &&
Meta.isexpr(iterspec.args[2], :block)
blk = iterspec.args[2]
if length(blk.args) == 2 && blk.args[1] isa LineNumberNode
Expand All @@ -158,6 +163,11 @@ function exprs_roughly_equal(fl_ex, ex)
iterspec.args[2] = blk.args[2]
end
end
elseif (h == :(=) || h == :kw) && Meta.isexpr(fl_args[1], :(::), 1) &&
Meta.isexpr(fl_args[2], :block, 2) && fl_args[2].args[1] isa LineNumberNode
# The flisp parser adds an extra block around `w` in the following case
# f(::g(z) = w) = 1
fl_args[2] = fl_args[2].args[2]
end
if length(fl_args) != length(args)
return false
Expand All @@ -169,9 +179,7 @@ function exprs_roughly_equal(fl_ex, ex)
fl_args[1] = Expr(:macrocall, map(kw_to_eq, args[1].args)...)
end
for i = 1:length(args)
flarg = fl_args[i]
arg = args[i]
if !exprs_roughly_equal(flarg, arg)
if !exprs_roughly_equal(fl_args[i], args[i])
return false
end
end
Expand Down Expand Up @@ -307,7 +315,7 @@ between flisp and JuliaSyntax parsers and return the source text of those
subtrees.
"""
function reduce_tree(text::AbstractString; kws...)
tree = parseall(SyntaxNode, text)
tree = parseall(SyntaxNode, text, ignore_warnings=true)
sourcetext.(reduce_tree(tree; kws...))
end

Expand Down

0 comments on commit 25f8eb2

Please sign in to comment.