Skip to content

Commit

Permalink
Type checking of protocols in for-comprehensions (#14124)
Browse files Browse the repository at this point in the history
  • Loading branch information
josevalim authored Dec 29, 2024
1 parent 4787116 commit d3cef1f
Show file tree
Hide file tree
Showing 9 changed files with 214 additions and 56 deletions.
5 changes: 2 additions & 3 deletions lib/elixir/lib/calendar/date.ex
Original file line number Diff line number Diff line change
Expand Up @@ -188,9 +188,8 @@ defmodule Date do
end

def utc_today(calendar) do
calendar
|> DateTime.utc_now()
|> DateTime.to_date()
%{year: year, month: month, day: day} = DateTime.utc_now(calendar)
%Date{year: year, month: month, day: day, calendar: calendar}
end

@doc """
Expand Down
75 changes: 49 additions & 26 deletions lib/elixir/lib/module/types/apply.ex
Original file line number Diff line number Diff line change
Expand Up @@ -811,14 +811,12 @@ defmodule Module.Types.Apply do
end

def format_diagnostic({:badremote, mfac, expr, args_types, domain, clauses, context}) do
traces = collect_traces(expr, context)
{mod, fun, arity, converter} = mfac
meta = elem(expr, 1)

# Protocol errors can be very verbose, so we collapse structs
{banner, hints, opts} =
cond do
meta[:from_interpolation] ->
{banner, hints, traces} =
case Keyword.get(meta, :type_check) do
:interpolation ->
{_, _, [arg]} = expr

{"""
Expand All @@ -827,34 +825,59 @@ defmodule Module.Types.Apply do
#{expr_to_string(arg) |> indent(4)}
it has type:
""", [:interpolation], [collapse_structs: true]}
""", [:interpolation], collect_traces(expr, context)}

Code.ensure_loaded?(mod) and
Keyword.has_key?(mod.module_info(:attributes), :__protocol__) ->
{nil, [{:protocol, mod}], [collapse_structs: true]}
:generator ->
{:<-, _, [_, arg]} = expr

true ->
{nil, [], []}
end
{"""
incompatible value given to for-comprehension:
explanation =
empty_arg_reason(converter.(args_types)) ||
"""
but expected one of:
#{clauses_args_to_quoted_string(clauses, converter, opts)}
"""
#{expr_to_string(expr) |> indent(4)}
mfa_or_fa = if mod, do: Exception.format_mfa(mod, fun, arity), else: "#{fun}/#{arity}"
it has type:
""", [:generator], collect_traces(arg, context)}

banner =
banner ||
"""
incompatible types given to #{mfa_or_fa}:
:into ->
{"""
incompatible value given to :into option in for-comprehension:
#{expr_to_string(expr) |> indent(4)}
into: #{expr_to_string(expr) |> indent(4)}
given types:
"""
it has type:
""", [:into], collect_traces(expr, context)}

_ ->
mfa_or_fa = if mod, do: Exception.format_mfa(mod, fun, arity), else: "#{fun}/#{arity}"

{"""
incompatible types given to #{mfa_or_fa}:
#{expr_to_string(expr) |> indent(4)}
given types:
""", [], collect_traces(expr, context)}
end

explanation =
cond do
reason = empty_arg_reason(converter.(args_types)) ->
reason

Code.ensure_loaded?(mod) and
Keyword.has_key?(mod.module_info(:attributes), :__protocol__) ->
# Protocol errors can be very verbose, so we collapse structs
"""
but expected a type that implements the #{inspect(mod)} protocol, it must be one of:
#{clauses_args_to_quoted_string(clauses, converter, collapse_structs: true)}
"""

true ->
"""
but expected one of:
#{clauses_args_to_quoted_string(clauses, converter, [])}
"""
end

%{
details: %{typing_traces: traces},
Expand Down
34 changes: 31 additions & 3 deletions lib/elixir/lib/module/types/descr.ex
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,15 @@ defmodule Module.Types.Descr do
@map_empty [{:closed, %{}, []}]

@none %{}
@empty_list %{bitmap: @bit_empty_list}
@not_non_empty_list %{bitmap: @bit_top, atom: @atom_top, tuple: @tuple_top, map: @map_top}
@term %{
bitmap: @bit_top,
atom: @atom_top,
tuple: @tuple_top,
map: @map_top,
list: @non_empty_list_top
}
@empty_list %{bitmap: @bit_empty_list}
@not_non_empty_list Map.delete(@term, :list)

@empty_intersection [0, @none]
@empty_difference [0, []]
Expand Down Expand Up @@ -98,6 +98,7 @@ defmodule Module.Types.Descr do
@not_set %{optional: 1}
@term_or_optional Map.put(@term, :optional, 1)
@term_or_dynamic_optional Map.put(@term, :dynamic, %{optional: 1})
@not_atom_or_optional Map.delete(@term_or_optional, :atom)

def not_set(), do: @not_set
def if_set(:term), do: term_or_optional()
Expand Down Expand Up @@ -1751,10 +1752,32 @@ defmodule Module.Types.Descr do
end

# Two maps are fusible if they differ in at most one element.
defp non_fusible_maps?({_, fields1, []}, {_, fields2, []})
when map_size(fields1) > map_size(fields2) do
not fusible_maps?(Map.to_list(fields2), fields1, 0)
end

defp non_fusible_maps?({_, fields1, []}, {_, fields2, []}) do
Enum.count_until(fields1, fn {key, value} -> Map.fetch!(fields2, key) != value end, 2) > 1
not fusible_maps?(Map.to_list(fields1), fields2, 0)
end

defp fusible_maps?([{:__struct__, value} | rest], fields, count) do
case Map.fetch!(fields, :__struct__) do
^value -> fusible_maps?(rest, fields, count)
_ -> false
end
end

defp fusible_maps?([{key, value} | rest], fields, count) do
case Map.fetch!(fields, key) do
^value -> fusible_maps?(rest, fields, count)
_ when count == 1 -> false
_ when count == 0 -> fusible_maps?(rest, fields, count + 1)
end
end

defp fusible_maps?([], _fields, _count), do: true

defp map_non_negated_fuse_pair({tag, fields1, []}, {_, fields2, []}) do
fields =
symmetrical_merge(fields1, fields2, fn _k, v1, v2 ->
Expand Down Expand Up @@ -1818,6 +1841,11 @@ defmodule Module.Types.Descr do
{:empty_map, [], []}
end

def map_literal_to_quoted({:open, %{__struct__: @not_atom_or_optional} = fields}, _opts)
when map_size(fields) == 1 do
{:non_struct_map, [], []}
end

def map_literal_to_quoted({tag, fields}, opts) do
case tag do
:closed ->
Expand Down
37 changes: 29 additions & 8 deletions lib/elixir/lib/module/types/expr.ex
Original file line number Diff line number Diff line change
Expand Up @@ -338,10 +338,10 @@ defmodule Module.Types.Expr do
end

# TODO: for pat <- expr do expr end
def of_expr({:for, _meta, [_ | _] = args}, stack, context) do
def of_expr({:for, meta, [_ | _] = args}, stack, context) do
{clauses, [[{:do, block} | opts]]} = Enum.split(args, -1)
context = Enum.reduce(clauses, context, &for_clause(&1, stack, &2))
context = Enum.reduce(opts, context, &for_option(&1, stack, &2))
context = Enum.reduce(opts, context, &for_option(&1, meta, stack, &2))

if Keyword.has_key?(opts, :reduce) do
{_, context} = of_clauses(block, [dynamic()], :for_reduce, stack, {none(), context})
Expand Down Expand Up @@ -471,13 +471,17 @@ defmodule Module.Types.Expr do

## Comprehensions

defp for_clause({:<-, _, [left, right]} = expr, stack, context) do
defp for_clause({:<-, meta, [left, right]}, stack, context) do
expr = {:<-, [type_check: :generator] ++ meta, [left, right]}
{pattern, guards} = extract_head([left])
{_, context} = of_expr(right, stack, context)
{type, context} = of_expr(right, stack, context)

{_type, context} =
Pattern.of_match(pattern, guards, dynamic(), expr, :for, stack, context)

{_type, context} =
Apply.remote(Enumerable, :count, [right], [type], expr, stack, context)

context
end

Expand All @@ -500,17 +504,34 @@ defmodule Module.Types.Expr do
context
end

defp for_option({:into, expr}, stack, context) do
{_type, context} = of_expr(expr, stack, context)
defp for_option({:into, expr}, _meta, _stack, context) when is_list(expr) or is_binary(expr) do
context
end

defp for_option({:into, expr}, meta, stack, context) do
{type, context} = of_expr(expr, stack, context)

meta =
case expr do
{_, meta, _} -> meta
_ -> meta
end

wrapped_expr = {:__block__, [type_check: :into] ++ meta, [expr]}

{_type, context} =
Apply.remote(Collectable, :into, [expr], [type], wrapped_expr, stack, context)

context
end

defp for_option({:reduce, expr}, stack, context) do
defp for_option({:reduce, expr}, _meta, stack, context) do
{_type, context} = of_expr(expr, stack, context)
context
end

defp for_option({:uniq, _}, _stack, context) do
defp for_option({:uniq, _}, _meta, _stack, context) do
# This option is verified to be a boolean at compile-time
context
end

Expand Down
16 changes: 12 additions & 4 deletions lib/elixir/lib/module/types/helpers.ex
Original file line number Diff line number Diff line change
Expand Up @@ -85,16 +85,24 @@ defmodule Module.Types.Helpers do
:interpolation ->
"""
#{hint()} string interpolation in Elixir uses the String.Chars protocol to \
#{hint()} string interpolation uses the String.Chars protocol to \
convert a data structure into a string. Either convert the data type into a \
string upfront or implement the protocol accordingly
"""

{:protocol, protocol} ->
:generator ->
"""
#{hint()} #{inspect(protocol)} is a protocol in Elixir. Either make sure you \
give valid data types as arguments or implement the protocol accordingly
#{hint()} for-comprehensions use the Enumerable protocol to traverse \
data structures. Either convert the data type into a list (or another Enumerable) \
or implement the protocol accordingly
"""

:into ->
"""
#{hint()} the :into option in for-comprehensions use the Collectable protocol to \
build its result. Either pass a valid data type or implement the protocol accordingly
"""

:anonymous_rescue ->
Expand Down
4 changes: 2 additions & 2 deletions lib/elixir/lib/module/types/of.ex
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ defmodule Module.Types.Of do
{Function, fun()},
{Integer, integer()},
{List, list(term())},
{Map, open_map(__struct__: not_set())},
{Map, open_map(__struct__: if_set(negation(atom())))},
{Port, port()},
{PID, pid()},
{Reference, reference()},
Expand Down Expand Up @@ -339,7 +339,7 @@ defmodule Module.Types.Of do
{{:., _, [String.Chars, :to_string]} = dot, meta, [arg]},
{:binary, _, nil}
) do
{dot, [from_interpolation: true] ++ meta, [arg]}
{dot, [type_check: :interpolation] ++ meta, [arg]}
end

defp annotate_interpolation(left, _right) do
Expand Down
9 changes: 9 additions & 0 deletions lib/elixir/test/elixir/module/types/descr_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -1494,6 +1494,15 @@ defmodule Module.Types.DescrTest do
assert closed_map(__struct__: atom([Decimal]), coef: term(), exp: term(), sign: integer())
|> to_quoted_string(collapse_structs: true) ==
"%Decimal{sign: integer()}"

# Does not fuse structs
assert union(closed_map(__struct__: atom([Foo])), closed_map(__struct__: atom([Bar])))
|> to_quoted_string() ==
"%{__struct__: Bar} or %{__struct__: Foo}"

# Properly format non_struct_map
assert open_map(__struct__: if_set(negation(atom()))) |> to_quoted_string() ==
"non_struct_map()"
end
end

Expand Down
Loading

0 comments on commit d3cef1f

Please sign in to comment.