diff --git a/src/core.jl b/src/core.jl index d7deab7..8f3dd24 100644 --- a/src/core.jl +++ b/src/core.jl @@ -93,6 +93,50 @@ end const FetchJLError = JLError_FetchMsgStr +@noinline function _load_call_args(funcProxy::JV, argProxies::TyList{JV}, kwargProxies::TyList{TyTuple{JSym,JV}}) + func = JV_LOAD(funcProxy)::Any + args = Any[JV_LOAD(unsafe_load(argProxies.data, i)) for i in 1:(argProxies.len)] + kwargs = Dict{Symbol,Any}() + for i in 1:(kwargProxies.len) + kv = unsafe_load(kwargProxies.data, i) + k = JSym_LOAD(kv.fst) + v = JV_LOAD(kv.snd) + kwargs[k] = v + end + return func, args, kwargs +end + +@noinline function _barrier_call_f(f, args...; kwargs...) + @nospecialize f, args, kwargs + f(args...; kwargs...) +end + +@noinline function _barrier_call_f_dot(f, args...; kwargs...) + @nospecialize f, args, kwargs + f.(args...; kwargs...) +end + +@noinline function _barrier_call_f_a0(f) + @nospecialize f + f() +end + +@noinline function _barrier_call_f_a1(f, arg) + @nospecialize f, arg + f(arg) +end + +@noinline function _barrier_call_f_a2(f, arg1, arg2) + @nospecialize f, arg1, arg2 + f(arg1, arg2) +end + +@noinline function _barrier_call_f_a3(f, arg1, args2, arg3) + @nospecialize f, arg1, arg2, arg3 + f(arg1, arg2, arg3) +end + + function JLCallImpl( out::Ptr{JV}, funcProxy::JV, @@ -101,32 +145,24 @@ function JLCallImpl( dotcall::Bool=false, ) try - func = JV_LOAD(funcProxy) - args = Any[JV_LOAD(unsafe_load(argProxies.data, i)) for i in 1:(argProxies.len)] - kwargs = Dict{Symbol,Any}() - for i in 1:(kwargProxies.len) - kv = unsafe_load(kwargProxies.data, i) - k = JSym_LOAD(kv.fst) - v = JV_LOAD(kv.snd) - kwargs[k] = v - end + func, args, kwargs = _load_call_args(funcProxy, argProxies, kwargProxies) ret = if dotcall - func.(args...; kwargs...) + _barrier_call_f_dot(func, args...; kwargs...) else if length(kwargs) == 0 if length(args) == 0 - func() + _barrier_call_f_a0(Base.inferencebarrier(func)) elseif length(args) == 1 - func(args[1]) + _barrier_call_f_a1(Base.inferencebarrier(func), args[1]) elseif length(args) == 2 - func(args[1], args[2]) + _barrier_call_f_a2(Base.inferencebarrier(func), args[1], args[2]) elseif length(args) == 3 - func(args[1], args[2], args[3]) + _barrier_call_f_a3(Base.inferencebarrier(func), args[1], args[2], args[3]) else - func(args...) + _barrier_call_f(Base.inferencebarrier(func), args...) end else - func(args...; kwargs...) + _barrier_call_f(Base.inferencebarrier(func), args...; kwargs...) end end retProxy = JV_ALLOC(ret) @@ -223,7 +259,7 @@ function JLGetProperty(out::Ptr{JV}, self::JV, property::JSym)::ErrorCode try self′ = JV_LOAD(self) property′ = JSym_LOAD(property) - propVal = getproperty(self′, property′) + propVal = getproperty(Base.inferencebarrier(self′), property′) unsafe_store!(out, JV_ALLOC(propVal)) return OK catch e @@ -237,7 +273,7 @@ function JLSetProperty(self::JV, property::JSym, value::JV)::ErrorCode self′ = JV_LOAD(self) property′ = JSym_LOAD(property) value′ = JV_LOAD(value) - setproperty!(self′, property′, value′) + setproperty!(Base.inferencebarrier(self′), property′, Base.inferencebarrier(value′)) return OK catch e @produce_error!(e) @@ -249,7 +285,7 @@ function JLHasProperty(out::Ptr{Bool}, self::JV, property::JSym)::ErrorCode try self′ = JV_LOAD(self) property′ = JSym_LOAD(property) - test = hasproperty(self′, property′) + test = hasproperty(Base.inferencebarrier(self′), property′) unsafe_store!(out, test) return OK catch e @@ -262,7 +298,7 @@ function JLGetIndex(out::Ptr{JV}, self::JV, indices::TyList{JV})::ErrorCode try self′ = JV_LOAD(self) indices′ = [JV_LOAD(unsafe_load(indices.data, i)) for i in 1:(indices.len)] - indexVal = getindex(self′, indices′...) + indexVal = getindex(Base.inferencebarrier(self′), indices′...) unsafe_store!(out, JV_ALLOC(indexVal)) return OK catch e @@ -274,7 +310,7 @@ end function JLGetIndexI(out::Ptr{JV}, self::JV, index::Int64)::ErrorCode try self′ = JV_LOAD(self) - indexVal = getindex(self′, index) + indexVal = getindex(Base.inferencebarrier(self′), index) unsafe_store!(out, JV_ALLOC(indexVal)) return OK catch e @@ -288,7 +324,7 @@ function JLSetIndex(self::JV, indices::TyList{JV}, value::JV)::ErrorCode self′ = JV_LOAD(self) value′ = JV_LOAD(value) indices′ = [JV_LOAD(unsafe_load(indices.data, i)) for i in 1:(indices.len)] - setindex!(self′, value′, indices′...) + setindex!(Base.inferencebarrier(self′), Base.inferencebarrier(value′), indices′...) return OK catch e @produce_error!(e) @@ -300,7 +336,7 @@ function JLSetIndexI(self::JV, index::Int64, value::JV)::ErrorCode try self′ = JV_LOAD(self) value′ = JV_LOAD(value) - setindex!(self′, value′, index) + setindex!(Base.inferencebarrier(self′), Base.inferencebarrier(value′), index) return OK catch e @produce_error!(e) @@ -446,7 +482,7 @@ end function JLGetArrayPointer(dataOut::Ptr{Ptr{UInt8}}, lenOut::Ptr{Int64}, array::JV)::ErrorCode try a = JV_LOAD(array) - (len, p) = _array_pointer_barrier(a) + (len, p) = _array_pointer_barrier(Base.inferencebarrier(a)) unsafe_store!(dataOut, p) unsafe_store!(lenOut, len) return OK