Skip to content

Commit 4bdefb3

Browse files
Merge pull request #1496 from AayushSabharwal/as/better-build-function
fix: fix `iip_config` with RGF
2 parents 27dc858 + 8471af4 commit 4bdefb3

11 files changed

+80
-27
lines changed

src/arrays.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -1005,7 +1005,7 @@ function _array_toexpr(x, st)
10051005
[
10061006
Assignment(outsym, term(zeros, Float64, term(map, length, shape(x)))),
10071007
Assignment(Symbol("%$outsym"), inplace_expr(x, outsym))
1008-
], outsym, true)
1008+
], outsym, false)
10091009

10101010
toexpr(ex, st)
10111011
end

src/build_function.jl

+22-17
Original file line numberDiff line numberDiff line change
@@ -113,10 +113,10 @@ function _build_function(target::JuliaTarget, op, args...;
113113
cse = false,
114114
nanmath = true,
115115
kwargs...)
116-
116+
op = _recursive_unwrap(op)
117117
states.rewrites[:nanmath] = nanmath
118118
dargs = map((x) -> destructure_arg(x[2], !checkbounds, default_arg_name(x[1])), enumerate(collect(args)))
119-
fun = Func(dargs, [], unwrap(op))
119+
fun = Func(dargs, [], op)
120120
if wrap_code !== nothing
121121
fun = wrap_code(fun)
122122
end
@@ -135,7 +135,9 @@ function _build_function(target::JuliaTarget, op, args...;
135135
end
136136
end
137137

138-
const UNIMPLEMENTED_EXPR = :(function (args...); $throw_missing_specialization(length(args)); end)
138+
function get_unimplemented_expr(dargs)
139+
Func(dargs, [], term(throw_missing_specialization, length(dargs)))
140+
end
139141

140142
SymbolicUtils.Code.get_rewrites(x::Arr) = SymbolicUtils.Code.get_rewrites(unwrap(x))
141143

@@ -151,21 +153,21 @@ function _build_function(target::JuliaTarget, op::Union{Arr, ArrayOp, SymbolicUt
151153
wrap_code = (identity, identity),
152154
iip_config = (true, true),
153155
kwargs...)
154-
156+
op = _recursive_unwrap(op)
155157
dargs = map((x) -> destructure_arg(x[2], !checkbounds, default_arg_name(x[1])), enumerate(collect(args)))
156158
states.rewrites[:nanmath] = nanmath
157159
if iip_config[1]
158-
oop_expr = wrap_code[1](Func(dargs, [], unwrap(op)))
160+
oop_expr = wrap_code[1](Func(dargs, [], op))
159161
else
160-
oop_expr = UNIMPLEMENTED_EXPR
162+
oop_expr = get_unimplemented_expr(dargs)
161163
end
162164

163165
outsym = DEFAULT_OUTSYM
164-
body = inplace_expr(unwrap(op), outsym)
166+
body = inplace_expr(op, outsym)
165167
if iip_config[2]
166168
iip_expr = wrap_code[2](Func(vcat(outsym, dargs), [], body))
167169
else
168-
iip_expr = UNIMPLEMENTED_EXPR
170+
iip_expr = get_unimplemented_expr([outsym; dargs])
169171
end
170172

171173
if cse
@@ -196,10 +198,7 @@ function _build_and_inject_function(mod::Module, ex)
196198
elseif ex.head == :(->)
197199
return _build_and_inject_function(mod, Expr(:function, ex.args...))
198200
end
199-
# XXX: Workaround to specify the module as both the cache module AND context module.
200-
# Currently, the @RuntimeGeneratedFunction macro only sets the context module.
201-
module_tag = getproperty(mod, RuntimeGeneratedFunctions._tagname)
202-
RuntimeGeneratedFunctions.RuntimeGeneratedFunction(module_tag, module_tag, ex; opaque_closures=false)
201+
RuntimeGeneratedFunction(mod, mod, ex)
203202
end
204203

205204
toexpr(n::Num, st) = toexpr(value(n), st)
@@ -305,7 +304,10 @@ function _build_function(target::JuliaTarget, rhss::AbstractArray, args...;
305304
iip_config = (true, true),
306305
nanmath = true,
307306
parallel=nothing, cse = false, kwargs...)
308-
307+
if rhss isa SubArray
308+
rhss = copy(rhss)
309+
end
310+
rhss = _recursive_unwrap(rhss)
309311
states.rewrites[:nanmath] = nanmath
310312
# We cannot switch to ShardedForm because it deadlocks with
311313
# RuntimeGeneratedFunctions
@@ -323,7 +325,7 @@ function _build_function(target::JuliaTarget, rhss::AbstractArray, args...;
323325
oop_expr = wrap_code[1](oop_expr)
324326
end
325327
else
326-
oop_expr = UNIMPLEMENTED_EXPR
328+
oop_expr = get_unimplemented_expr(dargs)
327329
end
328330

329331

@@ -340,7 +342,7 @@ function _build_function(target::JuliaTarget, rhss::AbstractArray, args...;
340342
iip_expr = wrap_code[2](iip_expr)
341343
end
342344
else
343-
iip_expr = UNIMPLEMENTED_EXPR
345+
iip_expr = get_unimplemented_expr([DEFAULT_OUTSYM; dargs])
344346
end
345347

346348
if cse
@@ -466,7 +468,7 @@ function _make_sparse_array(arr, similarto)
466468
return term(setparent, nzmap(Returns(true), arr), newarr)
467469
else
468470
newarr = _make_array(arr.nzval, Vector{symtype(eltype(arr))})
469-
return Let([Assignment(:__reference, term(copy, nzmap(Returns(true), arr)))], term(set_nzval, :__reference, newarr), true)
471+
return Let([Assignment(:__reference, term(copy, nzmap(Returns(true), arr)))], term(set_nzval, :__reference, newarr), false)
470472
end
471473
end
472474

@@ -539,10 +541,13 @@ function set_array(s::ShardedForm, closed_args, out, outputidxs, rhss, checkboun
539541
end
540542

541543
function _set_array(out, outputidxs, rhss::AbstractSparseArray, checkbounds, skipzeros)
542-
Let([Assignment(Symbol("%$out"), _set_array(LiteralExpr(:($out.nzval)), nothing, rhss.nzval, checkbounds, skipzeros))], out)
544+
Let([Assignment(Symbol("%$out"), _set_array(LiteralExpr(:($out.nzval)), nothing, rhss.nzval, checkbounds, skipzeros))], out, false)
543545
end
544546

545547
function _set_array(out, outputidxs, rhss::AbstractArray, checkbounds, skipzeros)
548+
if parent(rhss) !== rhss
549+
return _set_array(out, outputidxs, parent(rhss), checkbounds, skipzeros)
550+
end
546551
if outputidxs === nothing
547552
outputidxs = collect(eachindex(rhss))
548553
end

src/variable.jl

+4-1
Original file line numberDiff line numberDiff line change
@@ -536,7 +536,10 @@ function SymbolicIndexingInterface.symbolic_evaluate(ex::Union{Num, Arr, Symboli
536536
end
537537

538538
function _recursive_unwrap(val)
539-
if symbolic_type(val) == NotSymbolic() && val isa AbstractArray
539+
if symbolic_type(val) == NotSymbolic() && val isa Union{AbstractArray, Tuple}
540+
if parent(val) !== val
541+
return Setfield.@set val.parent = _recursive_unwrap(parent(val))
542+
end
540543
return _recursive_unwrap.(val)
541544
else
542545
return unwrap(val)

test/build_function.jl

+29
Original file line numberDiff line numberDiff line change
@@ -316,3 +316,32 @@ end
316316
fn = build_function(f(x), DestructuredArgs([f]; create_bindings = false), x; expression = Val{false})
317317
@test fn([isodd], 3)
318318
end
319+
320+
@testset "iip_config with RGF" begin
321+
@variables a b
322+
oop, iip = build_function([a + b, a - b], a, b; iip_config = (false, false), expression = Val{false})
323+
@test_throws ArgumentError oop(1, 2)
324+
@test_throws ArgumentError iip(ones(2), 1, 2)
325+
326+
@variables a[1:2]
327+
oop, iip = build_function(a .* 2, a; iip_config = (false, false), expression = Val{false})
328+
@test_throws ArgumentError oop(ones(2))
329+
@test_throws ArgumentError iip(ones(2), ones(2))
330+
end
331+
332+
@testset "unwrapping/CSE in array of symbolics codegen" begin
333+
@variables a b
334+
oop, _ = build_function([a^2 + b^2, a^2 + b^2], a, b; expression = Val{true}, cse = true)
335+
336+
function find_create_array(expr)
337+
while expr isa Expr && (!Meta.isexpr(expr, :call) || expr.args[1] != SymbolicUtils.Code.create_array)
338+
expr = expr.args[end]
339+
end
340+
return expr
341+
end
342+
343+
expr = find_create_array(oop)
344+
# CSE works, we just need to test that it's happening and OOP is the easiest way to do it
345+
@test Meta.isexpr(expr, :call) && expr.args[1] == SymbolicUtils.Code.create_array &&
346+
expr.args[end] isa Symbol && expr.args[end-1] isa Symbol
347+
end

test/build_function_tests/intermediate-exprs-inplace.jl

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
:(function (ˍ₋out, u)
22
begin
3-
ˍ₋out_input_1 = let _out = (zeros)(Float64, (map)(length, (Base.OneTo(5), Base.OneTo(5)))), var"%_out" = for var"%jj′" = (zip)(1:5, (Symbolics.reset_to_one)(1:5))
3+
ˍ₋out_input_1 = begin
4+
_out = (zeros)(Float64, (map)(length, (Base.OneTo(5), Base.OneTo(5))))
5+
var"%_out" = for var"%jj′" = (zip)(1:5, (Symbolics.reset_to_one)(1:5))
46
begin
57
j = var"%jj′"[1]
68
j′ = var"%jj′"[2]

test/build_function_tests/intermediate-exprs-outplace.jl

+6-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
:(function (u,)
2-
let _out = (zeros)(Float64, (map)(length, (Base.OneTo(5), Base.OneTo(5)))), var"%_out" = begin
3-
_out_input_1 = let _out = (zeros)(Float64, (map)(length, (Base.OneTo(5), Base.OneTo(5)))), var"%_out" = for var"%jj′" = (zip)(1:5, (Symbolics.reset_to_one)(1:5))
2+
begin
3+
_out = (zeros)(Float64, (map)(length, (Base.OneTo(5), Base.OneTo(5))))
4+
var"%_out" = begin
5+
_out_input_1 = begin
6+
_out = (zeros)(Float64, (map)(length, (Base.OneTo(5), Base.OneTo(5))))
7+
var"%_out" = for var"%jj′" = (zip)(1:5, (Symbolics.reset_to_one)(1:5))
48
begin
59
j = var"%jj′"[1]
610
j′ = var"%jj′"[2]

test/build_function_tests/manual-limits-outplace.jl

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
:(function (u,)
2-
let _out = (zeros)(Float64, (map)(length, (Base.OneTo(5), Base.OneTo(5)))), var"%_out" = for var"%jj′" = (zip)(1:5, (Symbolics.reset_to_one)(1:5))
2+
begin
3+
_out = (zeros)(Float64, (map)(length, (Base.OneTo(5), Base.OneTo(5))))
4+
var"%_out" = for var"%jj′" = (zip)(1:5, (Symbolics.reset_to_one)(1:5))
35
begin
46
j = var"%jj′"[1]
57
j′ = var"%jj′"[2]

test/build_function_tests/stencil-broadcast-outplace.jl

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
:(function (x,)
2-
let _out = (zeros)(Float64, (map)(length, (1:6, 1:6))), var"%_out" = begin
2+
begin
3+
_out = (zeros)(Float64, (map)(length, (1:6, 1:6)))
4+
var"%_out" = begin
35
_out_2_input_1 = (broadcast)(+, x, (adjoint)(x))
46
_out_1 = (view)(_out, 1:6, 1:6)
57
var"%_out_1" = (Symbolics.broadcast_assign!)(_out_1, 0)

test/build_function_tests/stencil-extents-outplace.jl

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
:(function (x,)
2-
let _out = (zeros)(Float64, (map)(length, (1:5, 1:5))), var"%_out" = begin
2+
begin
3+
_out = (zeros)(Float64, (map)(length, (1:5, 1:5)))
4+
var"%_out" = begin
35
_out_1 = (view)(_out, 1:5, 1:5)
46
var"%_out_1" = (Symbolics.broadcast_assign!)(_out_1, 0)
57
_out_2 = (view)(_out, 2:4, 2:4)

test/build_function_tests/stencil-transpose-arrayop-outplace.jl

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
:(function (x,)
2-
let _out = (zeros)(Float64, (map)(length, (1:6, 1:6))), var"%_out" = begin
2+
begin
3+
_out = (zeros)(Float64, (map)(length, (1:6, 1:6)))
4+
var"%_out" = begin
35
_out_1 = (view)(_out, 2:5, 2:5)
46
var"%_out_1" = for var"%jj′" = (zip)(1:4, (Symbolics.reset_to_one)(1:4))
57
begin

test/build_function_tests/transpose-outplace.jl

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
:(function (x,)
2-
let _out = (zeros)(Float64, (map)(length, (Base.OneTo(4), Base.OneTo(4)))), var"%_out" = for var"%jj′" = (zip)(1:4, (Symbolics.reset_to_one)(1:4))
2+
begin
3+
_out = (zeros)(Float64, (map)(length, (Base.OneTo(4), Base.OneTo(4))))
4+
var"%_out" = for var"%jj′" = (zip)(1:4, (Symbolics.reset_to_one)(1:4))
35
begin
46
j = var"%jj′"[1]
57
j′ = var"%jj′"[2]

0 commit comments

Comments
 (0)