diff --git a/TypedSyntax/Project.toml b/TypedSyntax/Project.toml index 757e5812..af8a1477 100644 --- a/TypedSyntax/Project.toml +++ b/TypedSyntax/Project.toml @@ -1,7 +1,7 @@ name = "TypedSyntax" uuid = "d265eb64-f81a-44ad-a842-4247ee1503de" authors = ["Tim Holy and contributors"] -version = "1.0.2" +version = "1.0.3" [deps] CodeTracking = "da1fd8a2-8d9e-5ec2-8556-3022fb5608a2" diff --git a/TypedSyntax/src/node.jl b/TypedSyntax/src/node.jl index 9c58ef73..e75d0dfc 100644 --- a/TypedSyntax/src/node.jl +++ b/TypedSyntax/src/node.jl @@ -13,6 +13,9 @@ TypedSyntaxData(sd::SyntaxData, src::CodeInfo, typ=nothing) = TypedSyntaxData(sd const TypedSyntaxNode = TreeNode{TypedSyntaxData} const MaybeTypedSyntaxNode = Union{SyntaxNode,TypedSyntaxNode} +struct NoDefaultValue end +const no_default_value = NoDefaultValue() + # These are TypedSyntaxNode constructor helpers # Call these directly if you want both the TypedSyntaxNode and the `mappings` list, # where `mappings[i]` corresponds to the list of nodes matching `(src::CodeInfo).code[i]`. @@ -69,22 +72,25 @@ function TypedSyntaxNode(rootnode::SyntaxNode, src::CodeInfo, mappings, symtyps) sig = child(sig, 1) end @assert kind(sig) == K"call" - i = 1 + i = j = 1 for arg in Iterators.drop(children(sig), 1) kind(arg) == K"parameters" && break # kw args if kind(arg) == K"..." arg = only(children(arg)) end + defaultval = no_default_value if kind(arg) == K"=" - arg = first(children(arg)) + defaultval = child(arg, 2) + arg = child(arg, 1) end if kind(arg) == K"::" nchildren = length(children(arg)) if nchildren == 1 # unnamed argument + argc = child(arg, 1) found = false while i <= length(src.slotnames) - if src.slotnames[i] == Symbol("#unused#") + if src.slotnames[i] == Symbol("#unused#") || (defaultval != no_default_value && kind(argc) == K"curly" && src.slotnames[i] == Symbol("")) arg.typ = unwrapconst(src.slottypes[i]) i += 1 found = true @@ -92,7 +98,10 @@ function TypedSyntaxNode(rootnode::SyntaxNode, src::CodeInfo, mappings, symtyps) end i += 1 end - @assert found + found && continue + @assert kind(argc) == K"curly" + arg.typ = unwrapconst(src.ssavaluetypes[j]) + j += 1 continue elseif nchildren == 2 arg = child(arg, 1) # extract the name @@ -102,6 +111,11 @@ function TypedSyntaxNode(rootnode::SyntaxNode, src::CodeInfo, mappings, symtyps) end kind(arg) == K"Identifier" || @show sig arg @assert kind(arg) == K"Identifier" + if i > length(src.slotnames) + @assert defaultval != no_default_value + arg.typ = Core.Typeof(unwrapconst(defaultval.val)) + continue + end argname = arg.val while i <= length(src.slotnames) if src.slotnames[i] == argname diff --git a/TypedSyntax/test/runtests.jl b/TypedSyntax/test/runtests.jl index f338463e..c1791d0a 100644 --- a/TypedSyntax/test/runtests.jl +++ b/TypedSyntax/test/runtests.jl @@ -53,6 +53,7 @@ splats(x, y) = vcat(x..., y...) myoftype(ref, val) = typeof(ref)(val) defaultarg(x, y=2) = x + y +hasdefaulttypearg(::Type{T}=Rational{Int}) where T = zero(T) charset1 = 'a':'z' getchar1(idx) = charset1[idx] @@ -243,6 +244,7 @@ end tsn = TypedSyntaxNode(TSN.defaultarg, (Float32,)) sig, body = children(tsn) @test has_name_typ(child(sig, 2), :x, Float32) + @test has_name_typ(child(sig, 3, 1), :y, Int) # there is no argument 2 in tsn.typedsource tsn = TypedSyntaxNode(TSN.defaultarg, (Float32,Int)) sig, body = children(tsn) @@ -250,6 +252,15 @@ end nodearg = child(sig, 3) @test kind(nodearg) == K"=" @test has_name_typ(child(nodearg, 1), :y, Int) + # default position args that are types + tsn = TypedSyntaxNode(TSN.hasdefaulttypearg, (Type{Float32},)) + sig, body = children(tsn) + arg = child(sig, 1, 2, 1) + @test kind(arg) == K"::" && arg.typ === Type{Float32} + tsn = TypedSyntaxNode(TSN.hasdefaulttypearg, ()) + sig, body = children(tsn) + arg = child(sig, 1, 2, 1) + @test kind(arg) == K"::" && arg.typ === Type{Rational{Int}} # macros in function definition tsn = TypedSyntaxNode(TSN.mysin, (Int,))