@@ -113,10 +113,10 @@ function _build_function(target::JuliaTarget, op, args...;
113
113
cse = false ,
114
114
nanmath = true ,
115
115
kwargs... )
116
-
116
+ op = _recursive_unwrap (op)
117
117
states. rewrites[:nanmath ] = nanmath
118
118
dargs = map ((x) -> destructure_arg (x[2 ], ! checkbounds, default_arg_name (x[1 ])), enumerate (collect (args)))
119
- fun = Func (dargs, [], unwrap (op) )
119
+ fun = Func (dargs, [], op )
120
120
if wrap_code != = nothing
121
121
fun = wrap_code (fun)
122
122
end
@@ -135,7 +135,9 @@ function _build_function(target::JuliaTarget, op, args...;
135
135
end
136
136
end
137
137
138
- const UNIMPLEMENTED_EXPR = :(function (args... ); $ throw_missing_specialization (length (args)); end )
138
+ function get_unimplemented_expr (dargs)
139
+ Func (dargs, [], term (throw_missing_specialization, length (dargs)))
140
+ end
139
141
140
142
SymbolicUtils. Code. get_rewrites (x:: Arr ) = SymbolicUtils. Code. get_rewrites (unwrap (x))
141
143
@@ -151,21 +153,21 @@ function _build_function(target::JuliaTarget, op::Union{Arr, ArrayOp, SymbolicUt
151
153
wrap_code = (identity, identity),
152
154
iip_config = (true , true ),
153
155
kwargs... )
154
-
156
+ op = _recursive_unwrap (op)
155
157
dargs = map ((x) -> destructure_arg (x[2 ], ! checkbounds, default_arg_name (x[1 ])), enumerate (collect (args)))
156
158
states. rewrites[:nanmath ] = nanmath
157
159
if iip_config[1 ]
158
- oop_expr = wrap_code[1 ](Func (dargs, [], unwrap (op) ))
160
+ oop_expr = wrap_code[1 ](Func (dargs, [], op ))
159
161
else
160
- oop_expr = UNIMPLEMENTED_EXPR
162
+ oop_expr = get_unimplemented_expr (dargs)
161
163
end
162
164
163
165
outsym = DEFAULT_OUTSYM
164
- body = inplace_expr (unwrap (op) , outsym)
166
+ body = inplace_expr (op , outsym)
165
167
if iip_config[2 ]
166
168
iip_expr = wrap_code[2 ](Func (vcat (outsym, dargs), [], body))
167
169
else
168
- iip_expr = UNIMPLEMENTED_EXPR
170
+ iip_expr = get_unimplemented_expr ([outsym; dargs])
169
171
end
170
172
171
173
if cse
@@ -196,10 +198,7 @@ function _build_and_inject_function(mod::Module, ex)
196
198
elseif ex. head == :(-> )
197
199
return _build_and_inject_function (mod, Expr (:function , ex. args... ))
198
200
end
199
- # XXX : Workaround to specify the module as both the cache module AND context module.
200
- # Currently, the @RuntimeGeneratedFunction macro only sets the context module.
201
- module_tag = getproperty (mod, RuntimeGeneratedFunctions. _tagname)
202
- RuntimeGeneratedFunctions. RuntimeGeneratedFunction (module_tag, module_tag, ex; opaque_closures= false )
201
+ RuntimeGeneratedFunction (mod, mod, ex)
203
202
end
204
203
205
204
toexpr (n:: Num , st) = toexpr (value (n), st)
@@ -305,7 +304,10 @@ function _build_function(target::JuliaTarget, rhss::AbstractArray, args...;
305
304
iip_config = (true , true ),
306
305
nanmath = true ,
307
306
parallel= nothing , cse = false , kwargs... )
308
-
307
+ if rhss isa SubArray
308
+ rhss = copy (rhss)
309
+ end
310
+ rhss = _recursive_unwrap (rhss)
309
311
states. rewrites[:nanmath ] = nanmath
310
312
# We cannot switch to ShardedForm because it deadlocks with
311
313
# RuntimeGeneratedFunctions
@@ -323,7 +325,7 @@ function _build_function(target::JuliaTarget, rhss::AbstractArray, args...;
323
325
oop_expr = wrap_code[1 ](oop_expr)
324
326
end
325
327
else
326
- oop_expr = UNIMPLEMENTED_EXPR
328
+ oop_expr = get_unimplemented_expr (dargs)
327
329
end
328
330
329
331
@@ -340,7 +342,7 @@ function _build_function(target::JuliaTarget, rhss::AbstractArray, args...;
340
342
iip_expr = wrap_code[2 ](iip_expr)
341
343
end
342
344
else
343
- iip_expr = UNIMPLEMENTED_EXPR
345
+ iip_expr = get_unimplemented_expr ([DEFAULT_OUTSYM; dargs])
344
346
end
345
347
346
348
if cse
@@ -466,7 +468,7 @@ function _make_sparse_array(arr, similarto)
466
468
return term (setparent, nzmap (Returns (true ), arr), newarr)
467
469
else
468
470
newarr = _make_array (arr. nzval, Vector{symtype (eltype (arr))})
469
- return Let ([Assignment (:__reference , term (copy, nzmap (Returns (true ), arr)))], term (set_nzval, :__reference , newarr), true )
471
+ return Let ([Assignment (:__reference , term (copy, nzmap (Returns (true ), arr)))], term (set_nzval, :__reference , newarr), false )
470
472
end
471
473
end
472
474
@@ -539,10 +541,13 @@ function set_array(s::ShardedForm, closed_args, out, outputidxs, rhss, checkboun
539
541
end
540
542
541
543
function _set_array (out, outputidxs, rhss:: AbstractSparseArray , checkbounds, skipzeros)
542
- Let ([Assignment (Symbol (" %$out " ), _set_array (LiteralExpr (:($ out. nzval)), nothing , rhss. nzval, checkbounds, skipzeros))], out)
544
+ Let ([Assignment (Symbol (" %$out " ), _set_array (LiteralExpr (:($ out. nzval)), nothing , rhss. nzval, checkbounds, skipzeros))], out, false )
543
545
end
544
546
545
547
function _set_array (out, outputidxs, rhss:: AbstractArray , checkbounds, skipzeros)
548
+ if parent (rhss) != = rhss
549
+ return _set_array (out, outputidxs, parent (rhss), checkbounds, skipzeros)
550
+ end
546
551
if outputidxs === nothing
547
552
outputidxs = collect (eachindex (rhss))
548
553
end
0 commit comments