diff --git a/lib/elixir/lib/module/types/descr.ex b/lib/elixir/lib/module/types/descr.ex index b4bd5dc257f..0133750a248 100644 --- a/lib/elixir/lib/module/types/descr.ex +++ b/lib/elixir/lib/module/types/descr.ex @@ -23,10 +23,10 @@ defmodule Module.Types.Descr do @bit_pid 1 <<< 4 @bit_port 1 <<< 5 @bit_reference 1 <<< 6 - @bit_fun 1 <<< 7 - @bit_top (1 <<< 8) - 1 + @bit_top (1 <<< 7) - 1 @bit_number @bit_integer ||| @bit_float + @fun_top :fun_top @atom_top {:negation, :sets.new(version: 2)} @map_top [{:open, %{}, []}] @non_empty_list_top [{:term, :term, []}] @@ -39,7 +39,8 @@ defmodule Module.Types.Descr do atom: @atom_top, tuple: @tuple_top, map: @map_top, - list: @non_empty_list_top + list: @non_empty_list_top, + fun: @fun_top } @empty_list %{bitmap: @bit_empty_list} @not_non_empty_list Map.delete(@term, :list) @@ -72,7 +73,6 @@ defmodule Module.Types.Descr do def empty_map(), do: %{map: @map_empty} def integer(), do: %{bitmap: @bit_integer} def float(), do: %{bitmap: @bit_float} - def fun(), do: %{bitmap: @bit_fun} def list(type), do: list_descr(type, @empty_list, true) def non_empty_list(type, tail \\ @empty_list), do: list_descr(type, tail, false) def open_map(), do: %{map: @map_top} @@ -87,6 +87,29 @@ defmodule Module.Types.Descr do @boolset :sets.from_list([true, false], version: 2) def boolean(), do: %{atom: {:union, @boolset}} + def fun(), do: %{fun: @fun_top} + + @doc """ + Creates a function type with the given arguments and return type. + + ## Examples + iex> fun([integer()], atom()) # Creates (integer) -> atom + iex> fun([integer(), float()], boolean()) # Creates (integer, float) -> boolean + """ + def fun(args, return) when is_list(args), do: fun_descr(args, return) + + @doc """ + Creates the top function type for the given arity, where all arguments are none() + and return is term(). + + ## Examples + iex> fun(1) # Creates (none) -> term + iex> fun(2) # Creates (none, none) -> term + """ + def fun(arity) when is_integer(arity) and arity >= 0 do + fun(List.duplicate(none(), arity), term()) + end + ## Optional # `not_set()` is a special base type that represents an not_set field in a map. @@ -227,6 +250,7 @@ defmodule Module.Types.Descr do defp union(:map, v1, v2), do: map_union(v1, v2) defp union(:optional, 1, 1), do: 1 defp union(:tuple, v1, v2), do: tuple_union(v1, v2) + defp union(:fun, v1, v2), do: fun_union(v1, v2) @doc """ Computes the intersection of two descrs. @@ -269,6 +293,7 @@ defmodule Module.Types.Descr do defp intersection(:map, v1, v2), do: map_intersection(v1, v2) defp intersection(:optional, 1, 1), do: 1 defp intersection(:tuple, v1, v2), do: tuple_intersection(v1, v2) + defp intersection(:fun, v1, v2), do: fun_intersection(v1, v2) @doc """ Computes the difference between two types. @@ -328,6 +353,7 @@ defmodule Module.Types.Descr do defp difference(:map, v1, v2), do: map_difference(v1, v2) defp difference(:optional, 1, 1), do: 0 defp difference(:tuple, v1, v2), do: tuple_difference(v1, v2) + defp difference(:fun, v1, v2), do: fun_difference(v1, v2) @doc """ Compute the negation of a type. @@ -359,7 +385,8 @@ defmodule Module.Types.Descr do not Map.has_key?(descr, :optional) and (not Map.has_key?(descr, :map) or map_empty?(descr.map)) and (not Map.has_key?(descr, :list) or list_empty?(descr.list)) and - (not Map.has_key?(descr, :tuple) or tuple_empty?(descr.tuple)) + (not Map.has_key?(descr, :tuple) or tuple_empty?(descr.tuple)) and + (not Map.has_key?(descr, :fun) or fun_empty?(descr.fun)) end end @@ -420,6 +447,7 @@ defmodule Module.Types.Descr do defp to_quoted(:map, dnf, opts), do: map_to_quoted(dnf, opts) defp to_quoted(:list, dnf, opts), do: list_to_quoted(dnf, false, opts) defp to_quoted(:tuple, dnf, opts), do: tuple_to_quoted(dnf, opts) + defp to_quoted(:fun, dnf, opts), do: fun_to_quoted(dnf, opts) @doc """ Converts a descr to its quoted string representation. @@ -648,8 +676,7 @@ defmodule Module.Types.Descr do float: @bit_float, pid: @bit_pid, port: @bit_port, - reference: @bit_reference, - fun: @bit_fun + reference: @bit_reference ] for {type, mask} <- pairs, @@ -660,25 +687,35 @@ defmodule Module.Types.Descr do ## Funs @doc """ - Checks there is a function type (and only functions) with said arity. + Checks if a function type with the specified arity exists in the descriptor. + + 1. If there is no dynamic component: + - The static part must be a non-empty function type of the given arity + + 2. If there is a dynamic component: + - Either the static part is a non-empty function type of the given arity, or + - The static part is empty and the dynamic part contains functions of the given arity """ def fun_fetch(:term, _arity), do: :error - def fun_fetch(%{} = descr, _arity) do - {static_or_dynamic, static} = Map.pop(descr, :dynamic, descr) + def fun_fetch(%{} = descr, arity) when is_integer(arity) do + case :maps.take(:dynamic, descr) do + :error -> + if not empty?(descr) and fun_only?(descr, arity), do: :ok, else: :error - if fun_only?(static) do - case static_or_dynamic do - :term -> :ok - %{bitmap: bitmap} when (bitmap &&& @bit_fun) != 0 -> :ok - %{} -> :error - end - else - :error + {dynamic, static} -> + empty_static = empty?(static) + + cond do + not empty_static -> if fun_only?(static, arity), do: :ok, else: :error + empty_static and not empty?(intersection(dynamic, fun(arity))) -> :ok + true -> :error + end end end - defp fun_only?(descr), do: empty?(difference(descr, fun())) + defp fun_only?(descr), do: empty?(Map.delete(descr, :fun)) + defp fun_only?(descr, arity), do: empty?(difference(descr, fun(arity))) ## Atoms @@ -839,6 +876,554 @@ defmodule Module.Types.Descr do |> List.wrap() end + ## Functions + + # Function types are represented using Binary Decision Diagrams (BDDs) for efficient + # handling of unions, intersections, and negations. + + ### Key concepts: + + # * BDD structure: A tree with function nodes and :fun_top/:fun_bottom leaves. Paths to :fun_top + # represent valid function types. Nodes are positive when following a left + # branch (e.g. (int, float) -> bool) and negative otherwise. + + # * Function variance: + # - Contravariance in arguments: If s <: t, then (t → r) <: (s → r) + # - Covariance in returns: If s <: t, then (u → s) <: (u → t) + + # * Representation: + # - fun(): Top function type (leaf 1) + # - Function literals: {tag, [t1, ..., tn], t} where [t1, ..., tn] are argument types and t is return type + # tag is either `:weak` or `:strong` + # TODO: implement `:strong` + # - Normalized form for function applications: {domain, arrows, arity} is produced by `fun_normalize/1` + + # * Examples: + # - fun([integer()], atom()): A function from integer to atom + # - intersection(fun([integer()], atom()), fun([float()], boolean())): A function handling both signatures + + # Note: Function domains are expressed as tuple types. We use separate representations rather than + # unary functions with tuple domains to handle special cases like representing functions of a + # specific arity (e.g., (none,none->term) for arity 2). + + defp fun_new(inputs, output), do: {{inputs, output}, :fun_top, :fun_bottom} + + @doc """ + Creates a function type from a list of inputs and an output where the inputs and/or output may be dynamic. + + For function (t → s) with dynamic components: + - Static part: (upper_bound(t) → lower_bound(s)) + - Dynamic part: dynamic(lower_bound(t) → upper_bound(s)) + + When handling dynamic types: + - `upper_bound(t)` extracts the upper bound (most general type) of a gradual type. + For `dynamic(integer())`, it is `integer()`. + - `lower_bound(t)` extracts the lower bound (most specific type) of a gradual type. + """ + def fun_descr(args, output) when is_list(args) do + dynamic_arguments? = are_arguments_dynamic?(args) + dynamic_output? = match?(%{dynamic: _}, output) + + if dynamic_arguments? or dynamic_output? do + input_static = if dynamic_arguments?, do: materialize_arguments(args, :up), else: args + input_dynamic = if dynamic_arguments?, do: materialize_arguments(args, :down), else: args + + output_static = if dynamic_output?, do: lower_bound(output), else: output + output_dynamic = if dynamic_output?, do: upper_bound(output), else: output + + %{ + fun: fun_new(input_static, output_static), + dynamic: %{fun: fun_new(input_dynamic, output_dynamic)} + } + else + # No dynamic components, use standard function type + %{fun: fun_new(args, output)} + end + end + + # Gets the upper bound of a gradual type. + defp upper_bound(%{dynamic: dynamic}), do: dynamic + defp upper_bound(static), do: static + + # Gets the lower bound of a gradual type. + defp lower_bound(:term), do: :term + defp lower_bound(type), do: Map.delete(type, :dynamic) + + # Tuples represent function domains, using unions to combine parameters. + # Example: for functions (integer,float)->:ok and (float,integer)->:error + # domain isn't {integer|float,integer|float} as that would incorrectly accept {float,float} + # Instead, it is {integer,float} or {float,integer} + def domain_descr(types) when is_list(types), do: tuple(types) + + @doc """ + Calculates the domain of a function type. + + For a function type, the domain is the set of valid input types. + Returns: + - `:badfunction` if the type is not a function type + - A tuple type representing the domain for valid function types + + Handles both static and dynamic function types: + 1. For static functions, returns their exact domain + 2. For dynamic functions, computes domain based on both static and dynamic parts + + Formula is dom(t) = dom(upper_bound(t)) ∪ dynamic(dom(lower_bound(t))). + See Definition 6.15 in https://vlanvin.fr/papers/thesis.pdf. + + ## Examples + iex> fun_domain(fun([integer()], atom())) + domain_repr([integer()]) + + iex> fun_domain(fun([integer(), float()], boolean())) + domain_repr([integer(), float()]) + """ + def fun_domain(:term), do: :badfunction + + def fun_domain(type) do + result = + case :maps.take(:dynamic, type) do + :error -> + # Static function type + with true <- fun_only?(type), {:ok, domain} <- fun_domain_static(type) do + domain + else + _ -> :badfunction + end + + {dynamic, static} when static == @none -> + with {:ok, domain} <- fun_domain_static(dynamic), do: domain + + {dynamic, static} -> + with true <- fun_only?(static), + {:ok, static_domain} <- fun_domain_static(static), + {:ok, dynamic_domain} <- fun_domain_static(dynamic) do + union(dynamic_domain, dynamic(static_domain)) + else + _ -> :badfunction + end + end + + case result do + :badfunction -> :badfunction + result -> if empty?(result), do: :badfunction, else: result + end + end + + # Returns {:ok, domain} if the domain of the static type is well-defined. + # For that, it has to contain a non-empty function type. + # Otherwise, returns :badfunction. + defp fun_domain_static(%{fun: bdd}) do + case fun_normalize(bdd) do + {domain, _, _} -> {:ok, domain} + _ -> {:ok, none()} + end + end + + defp fun_domain_static(:term), do: :badfunction + defp fun_domain_static(%{}), do: {:ok, none()} + defp fun_domain_static(:emptyfunction), do: {:ok, none()} + + @doc """ + Applies a function type to a list of argument types. + + Returns the result type if the application is valid, or `:badarguments` if not. + + Handles both static and dynamic function types: + 1. For static functions: checks exact argument types + 2. For dynamic functions: computes result based on both static and dynamic parts + 3. For mixed static/dynamic: computes all valid combinations + + # Function application formula for dynamic types: + # τ◦τ′ = (lower_bound(τ) ◦ upper_bound(τ′)) ∨ (dynamic(upper_bound(τ) ◦ lower_bound(τ′))) + # + # Where: + # - τ is a dynamic function type + # - τ′ are the arguments + # - ◦ is function application + # + # For more details, see Definition 6.15 in https://vlanvin.fr/papers/thesis.pdf + + ## Examples + iex> fun_apply(fun([integer()], atom()), [integer()]) + atom() + + iex> fun_apply(fun([integer()], atom()), [float()]) + :badarguments + + iex> fun_apply(fun([dynamic()], atom()), [dynamic()]) + atom() + """ + def fun_apply(fun, arguments) do + if empty?(domain_descr(arguments)) do + :badarguments + else + case :maps.take(:dynamic, fun) do + :error -> fun_apply_with_strategy(fun, nil, arguments) + {fun_dynamic, fun_static} -> fun_apply_with_strategy(fun_static, fun_dynamic, arguments) + end + end + end + + defp fun_apply_with_strategy(fun_static, fun_dynamic, arguments) do + args_dynamic? = are_arguments_dynamic?(arguments) + + # For non-dynamic function and arguments, just return the static result + if fun_dynamic == nil and not args_dynamic? do + with {:ok, type} <- fun_apply_static(fun_static, arguments), do: type + else + # For dynamic cases, combine static and dynamic results + {static_args, dynamic_args} = + if args_dynamic?, + do: {materialize_arguments(arguments, :up), materialize_arguments(arguments, :down)}, + else: {arguments, arguments} + + dynamic_fun = fun_dynamic || fun_static + + with {:ok, res1} <- fun_apply_static(fun_static, static_args), + {:ok, res2} <- fun_apply_static(dynamic_fun, dynamic_args) do + union(res1, dynamic(res2)) + else + _ -> :badarguments + end + end + end + + # Materializes arguments using the specified direction (up or down) + defp materialize_arguments(arguments, :up), do: Enum.map(arguments, &upper_bound/1) + defp materialize_arguments(arguments, :down), do: Enum.map(arguments, &lower_bound/1) + + defp are_arguments_dynamic?(arguments), do: Enum.any?(arguments, &match?(%{dynamic: _}, &1)) + + defp fun_apply_static(%{fun: fun_bdd}, arguments) do + type_args = domain_descr(arguments) + + if empty?(type_args) do + # At this stage we do not check that the function can be applied to the arguments (using domain) + with {_domain, arrows, arity} <- fun_normalize(fun_bdd), + true <- arity == length(arguments) do + # Opti: short-circuits when inner loop is none() or outer loop is term() + result = + Enum.reduce_while(arrows, none(), fn intersection_of_arrows, acc -> + Enum.reduce_while(intersection_of_arrows, term(), fn + {_dom, _ret}, acc when acc == @none -> {:halt, acc} + {_dom, ret}, acc -> {:cont, intersection(acc, ret)} + end) + |> case do + :term -> {:halt, :term} + inner -> {:cont, union(inner, acc)} + end + end) + + {:ok, result} + else + false -> :badarity + end + else + with {domain, arrows, arity} <- fun_normalize(fun_bdd), + true <- arity == length(arguments), + true <- subtype?(type_args, domain) do + result = + Enum.reduce(arrows, none(), fn intersection_of_arrows, acc -> + aux_apply(acc, type_args, term(), intersection_of_arrows) + end) + + {:ok, result} + else + _ -> :badarguments + end + end + end + + # Helper function for function application that handles the application of + # function arrows to input types. + + # This function recursively processes a list of function arrows (an intersection), + # applying each arrow to the input type and accumulating the result. + + # ## Parameters + + # - result: The accumulated result type so far + # - input: The input type being applied to the function + # - rets_reached: The intersection of return types reached so far + # - arrow_intersections: The list of function arrows to process + + # For more details, see Definitions 2.20 or 6.11 in https://vlanvin.fr/papers/thesis.pdf + defp aux_apply(result, _input, rets_reached, []) do + if subtype?(rets_reached, result), do: result, else: union(result, rets_reached) + end + + defp aux_apply(result, input, returns_reached, [{dom, ret} | arrow_intersections]) do + # Calculate the part of the input not covered by this arrow's domain + dom_subtract = difference(input, domain_descr(dom)) + + # Refine the return type by intersecting with this arrow's return type + ret_refine = intersection(returns_reached, ret) + + # Phase 1: Domain partitioning + # If the input is not fully covered by the arrow's domain, then the result type should be + # _augmented_ with the outputs obtained by applying the remaining arrows to the non-covered + # parts of the domain. + # + # e.g. (integer()->atom()) and (float()->pid()) when applied to number() should unite + # both atoms and pids in the result. + result = + if empty?(dom_subtract) do + result + else + aux_apply(result, dom_subtract, returns_reached, arrow_intersections) + end + + # 2. Return type refinement + # The result type is also refined (intersected) in the sense that, if several arrows match + # the same part of the input, then the result type is an intersection of the return types of + # those arrows. + + # e.g. (integer()->atom()) and (integer()->pid()) when applied to integer() + # should result in (atom() ∩ pid()), which is none(). + aux_apply(result, input, ret_refine, arrow_intersections) + end + + # Takes all the paths from the root to the leaves finishing with a 1, + # and compile into tuples of positive and negative nodes. Positive nodes are + # those followed by a left path, negative nodes are those followed by a right path. + def fun_get(bdd), do: fun_get([], [], [], bdd) + + def fun_get(acc, pos, neg, bdd) do + case bdd do + :fun_bottom -> acc + :fun_top -> [{pos, neg} | acc] + {fun, left, right} -> fun_get(fun_get(acc, [fun | pos], neg, left), pos, [fun | neg], right) + end + end + + # Transforms a binary decision diagram (BDD) into the canonical form {domain, arrows, arity}: + # + # 1. **domain**: The union of all domains from positive functions in the BDD + # 2. **arrows**: List of lists, where each inner list contains an intersection of function arrows + # 3. **arity**: Function arity (number of parameters) + # + ## Return Values + # + # - `{domain, arrows, arity}` for valid function BDDs + # - `:emptyfunction` if the BDD represents an empty function type + # + # ## Internal Use + # + # This function is used internally by `fun_apply`, `fun_domain`, and others to + # ensure consistent handling of function types in all operations. + defp fun_normalize(bdd) do + {domain, arrows, arity} = + fun_get(bdd) + |> Enum.reduce({term(), [], nil}, fn {pos_funs, neg_funs}, {domain, arrows, arity} -> + # Skip empty function intersections + if fun_empty?(pos_funs, neg_funs) do + {domain, arrows, arity} + else + # Determine arity from first positive function or keep existing + new_arity = arity || pos_funs |> List.first() |> elem(0) |> length() + + # Calculate domain from all positive functions + path_domain = + Enum.reduce(pos_funs, none(), fn {args, _}, acc -> + union(acc, domain_descr(args)) + end) + + {intersection(domain, path_domain), [pos_funs | arrows], new_arity} + end + end) + + if arrows == [], do: :emptyfunction, else: {domain, arrows, arity} + end + + # Checks if a function type is empty. + # + # A function type is empty if: + # 1. It is the empty type (0) + # 2. For each path in the BDD (Binary Decision Diagram) from root to leaf ending in 1, + # the intersection of positive functions and the negation of negative functions is empty. + # + # For example: + # - `fun(1)` is not empty + # - `fun(1) and not fun(1)` is empty + # - `fun(integer() -> atom()) and not fun(none() -> term())` is empty + # - `fun(integer() -> atom()) and not fun(atom() -> integer())` is not empty + defp fun_empty?(bdd) do + case bdd do + :fun_bottom -> true + :fun_top -> false + bdd -> fun_get(bdd) |> Enum.all?(fn {posits, negats} -> fun_empty?(posits, negats) end) + end + end + + # Checks if a function type represented by positive and negative function literals is empty. + + # A function type {positives, negatives} is empty if either: + # 1. The positive functions have different arities (incompatible function types) + # 2. There exists a negative function that negates the whole positive intersection + + ## Examples + # - `{[fun(1)], []}` is not empty + # - `{[fun(1), fun(2)], []}` is empty (different arities) + # - `{[fun(integer() -> atom())], [fun(none() -> term())]}` is empty + # - `{[], _}` (representing the top function type fun()) is never empty + # + # TODO: test performance + defp fun_empty?([], _), do: false + + defp fun_empty?(positives, negatives) do + case fetch_arity_and_domain(positives) do + # If there are functions with different arities in positives, then the function type is empty + {:empty, _} -> + true + + {positive_arity, positive_domain} -> + # Check if any negative function negates the whole positive intersection + # e.g. (integer()->atom()) is negated by + # i) (none()->term()) ii) (none()->atom()) + # ii) (integer()->term()) iv) (integer()->atom()) + Enum.any?(negatives, fn {neg_arguments, neg_return} -> + # Filter positives to only those with matching arity, then check if the negative + # function's domain is a supertype of the positive domain and if the phi function + # determines emptiness. + length(neg_arguments) == positive_arity and + subtype?(domain_descr(neg_arguments), positive_domain) and + phi_starter(neg_arguments, negation(neg_return), positives) + end) + end + end + + # Checks the list of arrows positives and returns {:empty, nil} if there exists two arrows with + # different arities. Otherwise, it returns {arity, domain} with domain the union of all domains of + # the arrows in positives. + defp fetch_arity_and_domain(positives) do + positives + |> Enum.reduce_while({:empty, none()}, fn + {args, _}, {:empty, _} -> + {:cont, {length(args), domain_descr(args)}} + + {args, _}, {arity, dom} when length(args) == arity -> + {:cont, {arity, union(dom, domain_descr(args))}} + + {_args, _}, {_arity, _} -> + {:halt, {:empty, none()}} + end) + end + + # Implements the Φ (phi) function for determining function subtyping relationships. + # + ## Algorithm + # + # For inputs t₁...tₙ, booleans b₁...bₙ, negated return type t, and set of arrow types P: + # + # Φ((b₁,t₁)...(bₙ,tₙ), (b,t), ∅) = (∃j ∈ [1,n]. bⱼ and tⱼ ≤ ∅) ∨ (b and t ≤ ∅) + # + # Φ((b₁,t₁)...(bₙ,tₙ), t, {(t'₁...t'ₙ) → t'} ∪ P) = + # Φ((b₁,t₁)...(bₙ,tₙ), (true,t ∧ t'), P) ∧ + # ∀j ∈ [1,n]. Φ((b₁,t₁)...(true,tⱼ∖t'ⱼ)...(bₙ,tₙ), (b,t), P) + # + # Returns true if the intersection of the positives is a subtype of (t1,...,tn)->(not t). + # + # See [Castagna and Lanvin (2024)](https://arxiv.org/abs/2408.14345), Theorem 4.2. + + defp phi_starter(arguments, return, positives) do + n = length(arguments) + # Arity mismatch: if there is one positive function with a different arity, + # then it cannot be a subtype of the (arguments->type) functions. + if Enum.any?(positives, fn {args, _ret} -> length(args) != n end) do + false + else + arguments = Enum.map(arguments, &{false, &1}) + phi(arguments, {false, return}, positives) + end + end + + defp phi(args, {b, t}, []) do + Enum.any?(args, fn {bool, typ} -> bool and empty?(typ) end) or (b and empty?(t)) + end + + defp phi(args, {b, ret}, [{arguments, return} | rest_positive]) do + phi(args, {true, intersection(ret, return)}, rest_positive) and + Enum.all?(Enum.with_index(arguments), fn {type, index} -> + List.update_at(args, index, fn {_, arg} -> {true, difference(arg, type)} end) + |> phi({b, ret}, rest_positive) + end) + end + + defp fun_union(bdd1, bdd2) do + case {bdd1, bdd2} do + {:fun_top, _} -> :fun_top + {_, :fun_top} -> :fun_top + {:fun_bottom, bdd} -> bdd + {bdd, :fun_bottom} -> bdd + {{fun, l1, r1}, {fun, l2, r2}} -> {fun, fun_union(l1, l2), fun_union(r1, r2)} + # Note: this is a deep merge, that goes down bdd1 to insert bdd2 into it. + # It is the same as going down bdd1 to insert bdd1 into it. + # Possible opti: insert into the bdd with smallest height + {{fun, l, r}, bdd} -> {fun, fun_union(l, bdd), fun_union(r, bdd)} + end + end + + defp fun_intersection(bdd1, bdd2) do + case {bdd1, bdd2} do + # Base cases + {_, :fun_bottom} -> :fun_bottom + {:fun_bottom, _} -> :fun_bottom + {:fun_top, bdd} -> bdd + {bdd, :fun_top} -> bdd + # Optimizations + # If intersecting with a single positive or negative function, we insert + # it at the root instead of merging the trees (this avoids going down the + # whole bdd). + {bdd, {fun, :fun_top, :fun_bottom}} -> {fun, bdd, :fun_bottom} + {bdd, {fun, :fun_bottom, :fun_top}} -> {fun, :fun_bottom, bdd} + {{fun, :fun_top, :fun_bottom}, bdd} -> {fun, bdd, :fun_bottom} + {{fun, :fun_bottom, :fun_top}, bdd} -> {fun, :fun_bottom, bdd} + # General cases + {{fun, l1, r1}, {fun, l2, r2}} -> {fun, fun_intersection(l1, l2), fun_intersection(r1, r2)} + {{fun, l, r}, bdd} -> {fun, fun_intersection(l, bdd), fun_intersection(r, bdd)} + end + end + + defp fun_difference(bdd1, bdd2) do + case {bdd1, bdd2} do + {:fun_bottom, _} -> :fun_bottom + {_, :fun_top} -> :fun_bottom + {bdd, :fun_bottom} -> bdd + {:fun_top, {fun, l, r}} -> {fun, fun_difference(:fun_top, l), fun_difference(:fun_top, r)} + {{fun, l1, r1}, {fun, l2, r2}} -> {fun, fun_difference(l1, l2), fun_difference(r1, r2)} + {{fun, l, r}, bdd} -> {fun, fun_difference(l, bdd), fun_difference(r, bdd)} + end + end + + # Converts a function BDD (Binary Decision Diagram) to its quoted representation. + defp fun_to_quoted(:fun, _opts), do: [{:fun, [], []}] + + defp fun_to_quoted(bdd, opts) do + arrows = bdd |> fun_get() + + for {positives, negatives} <- arrows, + not fun_empty?(positives, negatives) do + fun_intersection_to_quoted(positives, opts) + end + |> case do + [] -> [] + [single] -> [single] + multiple -> [Enum.reduce(multiple, &{:or, [], [&2, &1]})] + end + end + + defp fun_intersection_to_quoted(intersection, opts) do + intersection + |> Enum.map(fn {args, ret} -> + {:->, [], [[to_quoted(tuple_descr(:closed, args), opts)], to_quoted(ret, opts)]} + end) + |> case do + [] -> {:fun, [], []} + [single] -> single + multiple -> Enum.reduce(multiple, &{:and, [], [&2, &1]}) + end + end + ## List # Represents both list and improper list simultaneously using a pair diff --git a/lib/elixir/lib/module/types/expr.ex b/lib/elixir/lib/module/types/expr.ex index cb56929ec39..4459bb94030 100644 --- a/lib/elixir/lib/module/types/expr.ex +++ b/lib/elixir/lib/module/types/expr.ex @@ -329,7 +329,7 @@ defmodule Module.Types.Expr do {patterns, _guards} = extract_head(head) domain = Enum.map(patterns, fn _ -> dynamic() end) {_acc, context} = of_clauses(clauses, domain, @pending, nil, :fn, stack, {none(), context}) - {fun(), context} + {dynamic(fun(length(patterns))), context} end def of_expr({:try, _meta, [[do: body] ++ blocks]}, expected, expr, stack, original) do @@ -451,7 +451,7 @@ defmodule Module.Types.Expr do # TODO: fun.(args) def of_expr({{:., meta, [fun]}, _meta, args} = call, _expected, _expr, stack, context) do - {fun_type, context} = of_expr(fun, fun(), call, stack, context) + {fun_type, context} = of_expr(fun, fun(length(args)), call, stack, context) {_args_types, context} = Enum.map_reduce(args, context, &of_expr(&1, @pending, &1, stack, &2)) diff --git a/lib/elixir/test/elixir/module/types/descr_test.exs b/lib/elixir/test/elixir/module/types/descr_test.exs index 90285845a75..242e86a2526 100644 --- a/lib/elixir/test/elixir/module/types/descr_test.exs +++ b/lib/elixir/test/elixir/module/types/descr_test.exs @@ -13,6 +13,12 @@ defmodule Module.Types.DescrTest do import Module.Types.Descr describe "union" do + test "zoom" do + # 1. dynamic() -> dynamic() applied to dynamic() gives dynamic() + f = fun([dynamic()], dynamic()) + assert fun_apply(f, [dynamic()]) == dynamic() + end + test "bitmap" do assert union(integer(), float()) == union(float(), integer()) end @@ -109,6 +115,14 @@ defmodule Module.Types.DescrTest do |> equal?(list(term())) end + test "fun" do + assert equal?(union(fun(), fun()), fun()) + assert equal?(union(fun(), fun(1)), fun()) + + dynamic_fun = intersection(fun(), dynamic()) + assert equal?(union(dynamic_fun, fun()), fun()) + end + test "optimizations (maps)" do # The tests are checking the actual implementation, not the semantics. # This is why we are using structural comparisons. @@ -513,6 +527,25 @@ defmodule Module.Types.DescrTest do assert difference(list(integer(), atom()), list(integer())) == non_empty_list(integer(), atom()) end + + test "fun" do + for arity <- [0, 1, 2, 3] do + assert empty?(difference(fun(arity), fun(arity))) + end + + assert empty?(difference(fun(), fun())) + assert empty?(difference(fun(3), fun())) + refute empty?(difference(fun(), fun(1))) + refute empty?(difference(fun(2), fun(3))) + assert empty?(intersection(fun(2), fun(3))) + + f1f2 = union(fun(1), fun(2)) + assert f1f2 |> difference(fun(1)) |> difference(fun(2)) |> empty?() + assert fun(1) |> difference(difference(f1f2, fun(2))) |> empty?() + assert f1f2 |> difference(fun(1)) |> equal?(fun(2)) + + assert fun([integer()], term()) |> difference(fun([none()], term())) |> empty?() + end end describe "creation" do @@ -585,6 +618,59 @@ defmodule Module.Types.DescrTest do assert subtype?(list(integer()), list(term())) assert subtype?(list(term()), list(term(), term())) end + + test "fun" do + assert equal?(fun([], term()), fun([], term())) + refute equal?(fun([], integer()), fun([], atom())) + refute subtype?(fun([none()], term()), fun([integer()], integer())) + + # Difference with argument/return type variations + int_to_atom = fun([integer()], atom()) + num_to_atom = fun([number()], atom()) + int_to_bool = fun([integer()], boolean()) + + # number->atom is a subtype of int->atom + assert subtype?(num_to_atom, int_to_atom) + refute subtype?(int_to_atom, num_to_atom) + assert subtype?(int_to_bool, int_to_atom) + refute subtype?(int_to_bool, num_to_atom) + + # Multi-arity + f1 = fun([integer(), atom()], boolean()) + f2 = fun([number(), atom()], boolean()) + + # (int,atom)->boolean is a subtype of (number,atom)->boolean + # since number is a supertype of int + assert subtype?(f2, f1) + # f1 is not a subtype of f2 + refute subtype?(f1, f2) + + # Unary functions / Output covariance + assert subtype?(fun([], float()), fun([], term())) + refute subtype?(fun([], term()), fun([], float())) + + # Contravariance of domain + refute subtype?(fun([integer()], boolean()), fun([number()], boolean())) + assert subtype?(fun([number()], boolean()), fun([integer()], boolean())) + + # Nested function types + higher_order = fun([fun([integer()], atom())], boolean()) + specific = fun([fun([number()], atom())], boolean()) + + assert subtype?(higher_order, specific) + refute subtype?(specific, higher_order) + + ## Multi-arity + f = fun([none(), integer()], atom()) + assert subtype?(f, f) + assert subtype?(f, fun([none(), integer()], term())) + assert subtype?(fun([none(), number()], atom()), f) + assert subtype?(fun([tuple(), number()], atom()), f) + refute subtype?(fun([none(), float()], atom()), f) + refute subtype?(fun([pid(), float()], atom()), f) + # A function with the wrong arity is refused + refute subtype?(fun([none()], atom()), f) + end end describe "compatible" do @@ -655,14 +741,160 @@ defmodule Module.Types.DescrTest do assert closed_map(a: integer(), b: none()) |> empty?() assert intersection(closed_map(b: atom()), open_map(a: integer())) |> empty?() end + + test "fun" do + refute empty?(fun()) + refute empty?(fun(1)) + refute empty?(fun([integer()], atom())) + + assert empty?(intersection(fun(1), fun(2))) + refute empty?(intersection(fun(), fun(1))) + assert empty?(difference(fun(1), union(fun(1), fun(2)))) + end + end + + describe "function operators" do + defmacro assert_domain(f, expected) do + quote do + assert equal?(fun_domain(unquote(f)), domain_descr(unquote(expected))) + end + end + + test "domain operator" do + # For function domain: + # 1. The domain of an intersection of functions is the union of the domains of the functions + # 2. The domain of a union of functions is the intersection of the domains of the functions + # 3. If a type is not a function or its domain is empty, return :badfunction + + # For gradual domain of a function type t: + # It is dom(t) = dom(up(t)) ∪ dynamic(dom(down(t))) + # where dom is the static domain, up is the upcast, and down is the downcast. + + ## Basic domain tests + assert fun_domain(term()) == :badfunction + assert fun_domain(none()) == :badfunction + assert fun_domain(intersection(fun(1), fun(2))) == :badfunction + assert union(atom(), intersection(fun(1), fun(2))) |> fun_domain() == :badfunction + assert fun_domain(fun([none()], term())) == :badfunction + assert fun_domain(difference(fun([pid()], pid()), fun([pid()], term()))) == :badfunction + + assert_domain(fun([], term()), []) + assert_domain(fun([term()], atom()), [term()]) + assert_domain(fun([integer(), atom()], boolean()), [integer(), atom()]) + # See 1. for intersection of functions + assert_domain(intersection(fun([float()], term()), fun([integer()], term())), [number()]) + # See 2. for union of functions + assert_domain(union(fun([number()], term()), fun([float()], term())), [float()]) + + ## Gradual domain tests + assert fun_domain(dynamic()) == :badfunction + assert fun_domain(intersection(dynamic(), fun([none()], term()))) == :badfunction + assert_domain(fun([dynamic()], dynamic()), [dynamic()]) + assert_domain(fun([dynamic(), dynamic()], dynamic()), [dynamic(), dynamic()]) + assert_domain(intersection(fun([integer()], atom()), dynamic()), [integer()]) + assert_domain(intersection(fun([integer()], term()), fun([float()], term())), [number()]) + + assert_domain( + intersection(fun([dynamic(integer())], float()), fun([float()], term())), + [union(dynamic(integer()), float())] + ) + + assert_domain( + intersection(fun([dynamic(integer())], term()), fun([integer()], term())), + [integer()] + ) + + # Domain of an intersection is union of domains + f = intersection(fun([atom(), pid()], term()), fun([pid(), atom()], term())) + dom = fun_domain(f) + refute dom |> equal?(domain_descr([union(atom(), pid()), union(pid(), atom())])) + assert dom |> equal?(union(domain_descr([atom(), pid()]), domain_descr([pid(), atom()]))) + + assert_domain( + intersection(fun([none(), integer()], term()), fun([float(), float()], term())), + [float(), float()] + ) + + # Intersection of domains int and float is empty + assert union(fun([integer()], atom()), fun([float()], boolean())) |> fun_domain() == + :badfunction + end + + test "function application" do + # This should not be empty + assert not empty?(intersection(negation(fun(2)), negation(fun(3)))) + + # Basic function application scenarios + assert fun_apply(fun([integer()], atom()), [integer()]) == atom() + assert fun_apply(fun([integer()], atom()), [float()]) == :badarguments + assert fun_apply(fun([integer()], atom()), [term()]) == :badarguments + assert fun_apply(fun([integer()], none()), [integer()]) == none() + assert fun_apply(fun([integer()], term()), [integer()]) == term() + + # Arity mismatches + assert fun_apply(fun([dynamic()], integer()), [dynamic(), dynamic()]) == :badarguments + assert fun_apply(fun([integer(), atom()], boolean()), [integer()]) == :badarguments + + # Dynamic type handling + assert fun_apply(fun([dynamic()], term()), [dynamic()]) == term() + assert fun_apply(fun([dynamic()], integer()), [dynamic()]) |> equal?(integer()) + assert fun_apply(fun([dynamic(), atom()], float()), [dynamic(), atom()]) |> equal?(float()) + assert fun_apply(fun([integer()], dynamic()), [integer()]) == dynamic() + + # Function intersection tests - basic + fun1 = intersection(fun([integer()], atom()), fun([number()], term())) + assert fun_apply(fun1, [integer()]) == atom() + assert fun_apply(fun1, [float()]) == term() + + # Function intersection with same domain, different codomains + assert fun([integer()], term()) + |> intersection(fun([integer()], atom())) + |> fun_apply([integer()]) == atom() + + # Function intersection with singleton atoms + fun3 = intersection(fun([atom([:ok])], atom([:success])), fun([atom([:ok])], atom([:done]))) + assert fun_apply(fun3, [atom([:ok])]) == none() + + # (dynamic(integer()) -> atom() + # cannot apply it to integer() bc integer() is not a subtype of dynamic() /\ integer() + # dynamic(atom()) + + # $ dynamic(map()) -> map() + # def f(x) when is_map(x) do + # x.foo + # end + + fun9 = fun([intersection(dynamic(), integer())], atom()) + assert fun_apply(fun9, [dynamic(integer())]) |> equal?(atom()) + assert fun_apply(fun9, [dynamic()]) == :badarguments + # TODO: discuss this case + assert fun_apply(fun9, [integer()]) == :badarguments + + # Dynamic with function type combinations + fun12 = + intersection( + fun([union(integer(), atom())], dynamic()), + fun([union(integer(), pid())], atom()) + ) + + assert fun_apply(fun12, [integer()]) == dynamic(atom()) + assert fun_apply(fun12, [atom()]) == dynamic() + assert fun_apply(fun12, [pid()]) |> equal?(atom()) + end end describe "projections" do test "fun_fetch" do + assert fun_fetch(none(), 1) == :error assert fun_fetch(term(), 1) == :error assert fun_fetch(union(term(), dynamic(fun())), 1) == :error - assert fun_fetch(fun(), 1) == :ok + assert fun_fetch(union(atom(), dynamic(fun())), 1) == :error + assert fun_fetch(intersection(fun([], term()), fun([], atom())), 0) == :ok + assert fun_fetch(fun([], term()), 0) == :ok + assert fun_fetch(union(fun([], term()), fun([pid()], term())), 0) == :error + assert fun_fetch(dynamic(fun()), 1) == :ok assert fun_fetch(dynamic(), 1) == :ok + assert fun_fetch(dynamic(fun(2)), 1) == :error end test "truthness" do @@ -1018,8 +1250,8 @@ defmodule Module.Types.DescrTest do assert equal?(value_type, intersection(atom(), negation(atom([:foo, :bar])))) - assert closed_map(a: union(atom(), pid()), b: integer(), c: tuple()) - |> difference(open_map(a: atom(), b: integer())) + assert closed_map(a: union(atom([:ok]), pid()), b: integer(), c: tuple()) + |> difference(open_map(a: atom([:ok]), b: integer())) |> difference(open_map(a: atom(), c: tuple())) |> map_fetch(:a) == {false, pid()} diff --git a/lib/elixir/test/elixir/module/types/expr_test.exs b/lib/elixir/test/elixir/module/types/expr_test.exs index 61f46a7969e..ccb021a0712 100644 --- a/lib/elixir/test/elixir/module/types/expr_test.exs +++ b/lib/elixir/test/elixir/module/types/expr_test.exs @@ -25,8 +25,8 @@ defmodule Module.Types.ExprTest do assert typecheck!("foo") == binary() assert typecheck!([]) == empty_list() assert typecheck!(%{}) == closed_map([]) - assert typecheck!(& &1) == fun() - assert typecheck!(fn -> :ok end) == fun() + assert typecheck!(& &1) == dynamic(fun(1)) + assert typecheck!(fn -> :ok end) == dynamic(fun(0)) end test "generated" do @@ -136,7 +136,7 @@ defmodule Module.Types.ExprTest do x.(1, 2) x ) - ) == dynamic(fun()) + ) == dynamic(fun(2)) end test "incompatible" do