Skip to content

Commit a119204

Browse files
Merge pull request #1353 from karlwessel/master
add flag for activating robust calculation of expand_derivatives
2 parents ab3fcd6 + 3bf685a commit a119204

File tree

2 files changed

+157
-111
lines changed

2 files changed

+157
-111
lines changed

src/diff.jl

+129-111
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,127 @@ function recursive_hasoperator(op, O)
150150
end
151151
end
152152

153+
"""
154+
executediff(D, arg, simplify=false; occurrences=nothing)
155+
156+
Apply the passed Differential D on the passed argument.
157+
158+
This function differs to `expand_derivatives` in that in only expands the
159+
passed differential and not any other Differentials it encounters.
160+
161+
# Arguments
162+
- `D::Differential`: The differential to apply
163+
- `arg::Symbolic`: The symbolic expression to apply the differential on.
164+
- `simplify::Bool=false`: Whether to simplify the resulting expression using
165+
[`SymbolicUtils.simplify`](@ref).
166+
- `occurrences=nothing`: Information about the occurrences of the independent
167+
variable in the argument of the derivative. This is used internally for
168+
optimization purposes.
169+
"""
170+
function executediff(D, arg, simplify=false; occurrences=nothing)
171+
if occurrences == nothing
172+
occurrences = occursin_info(D.x, arg)
173+
end
174+
175+
_isfalse(occurrences) && return 0
176+
occurrences isa Bool && return 1 # means it's a `true`
177+
178+
if !iscall(arg)
179+
return D(arg) # Cannot expand
180+
elseif (op = operation(arg); issym(op))
181+
inner_args = arguments(arg)
182+
if any(isequal(D.x), inner_args)
183+
return D(arg) # base case if any argument is directly equal to the i.v.
184+
else
185+
return sum(inner_args, init=0) do a
186+
return executediff(Differential(a), arg) *
187+
executediff(D, a)
188+
end
189+
end
190+
elseif op === (IfElse.ifelse)
191+
args = arguments(arg)
192+
O = op(args[1],
193+
executediff(D, args[2], simplify; occurrences=arguments(occurrences)[2]),
194+
executediff(D, args[3], simplify; occurrences=arguments(occurrences)[3]))
195+
return O
196+
elseif isa(op, Differential)
197+
# The recursive expand_derivatives was not able to remove
198+
# a nested Differential. We can attempt to differentiate the
199+
# inner expression wrt to the outer iv. And leave the
200+
# unexpandable Differential outside.
201+
if isequal(op.x, D.x)
202+
return D(arg)
203+
else
204+
inner = executediff(D, arguments(arg)[1], false)
205+
# if the inner expression is not expandable either, return
206+
if iscall(inner) && operation(inner) isa Differential
207+
return D(arg)
208+
else
209+
# otherwise give the nested Differential another try
210+
return executediff(op, inner, simplify)
211+
end
212+
end
213+
elseif isa(op, Integral)
214+
if isa(op.domain.domain, AbstractInterval)
215+
domain = op.domain.domain
216+
a, b = DomainSets.endpoints(domain)
217+
c = 0
218+
inner_function = arguments(arg)[1]
219+
if iscall(value(a))
220+
t1 = SymbolicUtils.substitute(inner_function, Dict(op.domain.variables => value(a)))
221+
t2 = D(a)
222+
c -= t1*t2
223+
end
224+
if iscall(value(b))
225+
t1 = SymbolicUtils.substitute(inner_function, Dict(op.domain.variables => value(b)))
226+
t2 = D(b)
227+
c += t1*t2
228+
end
229+
inner = executediff(D, arguments(arg)[1])
230+
c += op(inner)
231+
return value(c)
232+
end
233+
end
234+
235+
inner_args = arguments(arg)
236+
l = length(inner_args)
237+
exprs = []
238+
c = 0
239+
240+
for i in 1:l
241+
t2 = executediff(D, inner_args[i],false; occurrences=arguments(occurrences)[i])
242+
243+
x = if _iszero(t2)
244+
t2
245+
elseif _isone(t2)
246+
d = derivative_idx(arg, i)
247+
d isa NoDeriv ? D(arg) : d
248+
else
249+
t1 = derivative_idx(arg, i)
250+
t1 = t1 isa NoDeriv ? D(arg) : t1
251+
t1 * t2
252+
end
253+
254+
if _iszero(x)
255+
continue
256+
elseif x isa Symbolic
257+
push!(exprs, x)
258+
else
259+
c += x
260+
end
261+
end
262+
263+
if isempty(exprs)
264+
return c
265+
elseif length(exprs) == 1
266+
term = (simplify ? SymbolicUtils.simplify(exprs[1]) : exprs[1])
267+
return _iszero(c) ? term : c + term
268+
else
269+
x = +((!_iszero(c) ? vcat(c, exprs) : exprs)...)
270+
return simplify ? SymbolicUtils.simplify(x) : x
271+
end
272+
end
273+
153274
"""
154275
$(SIGNATURES)
155276
@@ -162,9 +283,6 @@ and other derivative rules to expand any derivatives it encounters.
162283
- `O::Symbolic`: The symbolic expression to expand.
163284
- `simplify::Bool=false`: Whether to simplify the resulting expression using
164285
[`SymbolicUtils.simplify`](@ref).
165-
- `occurrences=nothing`: Information about the occurrences of the independent
166-
variable in the argument of the derivative. This is used internally for
167-
optimization purposes.
168286
169287
# Examples
170288
```jldoctest
@@ -180,111 +298,11 @@ julia> dfx = expand_derivatives(Dx(f))
180298
(k*((2abs(x - y)) / y - 2z)*IfElse.ifelse(signbit(x - y), -1, 1)) / y
181299
```
182300
"""
183-
function expand_derivatives(O::Symbolic, simplify=false; occurrences=nothing)
301+
function expand_derivatives(O::Symbolic, simplify=false)
184302
if iscall(O) && isa(operation(O), Differential)
185303
arg = only(arguments(O))
186304
arg = expand_derivatives(arg, false)
187-
188-
if occurrences == nothing
189-
occurrences = occursin_info(operation(O).x, arg)
190-
end
191-
192-
_isfalse(occurrences) && return 0
193-
occurrences isa Bool && return 1 # means it's a `true`
194-
195-
D = operation(O)
196-
197-
if !iscall(arg)
198-
return D(arg) # Cannot expand
199-
elseif (op = operation(arg); issym(op))
200-
inner_args = arguments(arg)
201-
if any(isequal(D.x), inner_args)
202-
return D(arg) # base case if any argument is directly equal to the i.v.
203-
else
204-
return sum(inner_args, init=0) do a
205-
return expand_derivatives(Differential(a)(arg)) *
206-
expand_derivatives(D(a))
207-
end
208-
end
209-
elseif op === (IfElse.ifelse)
210-
args = arguments(arg)
211-
O = op(args[1], D(args[2]), D(args[3]))
212-
return expand_derivatives(O, simplify; occurrences)
213-
elseif isa(op, Differential)
214-
# The recursive expand_derivatives was not able to remove
215-
# a nested Differential. We can attempt to differentiate the
216-
# inner expression wrt to the outer iv. And leave the
217-
# unexpandable Differential outside.
218-
if isequal(op.x, D.x)
219-
return D(arg)
220-
else
221-
inner = expand_derivatives(D(arguments(arg)[1]), false)
222-
# if the inner expression is not expandable either, return
223-
if iscall(inner) && operation(inner) isa Differential
224-
return D(arg)
225-
else
226-
return expand_derivatives(op(inner), simplify)
227-
end
228-
end
229-
elseif isa(op, Integral)
230-
if isa(op.domain.domain, AbstractInterval)
231-
domain = op.domain.domain
232-
a, b = DomainSets.endpoints(domain)
233-
c = 0
234-
inner_function = expand_derivatives(arguments(arg)[1])
235-
if iscall(value(a))
236-
t1 = SymbolicUtils.substitute(inner_function, Dict(op.domain.variables => value(a)))
237-
t2 = D(a)
238-
c -= t1*t2
239-
end
240-
if iscall(value(b))
241-
t1 = SymbolicUtils.substitute(inner_function, Dict(op.domain.variables => value(b)))
242-
t2 = D(b)
243-
c += t1*t2
244-
end
245-
inner = expand_derivatives(D(arguments(arg)[1]))
246-
c += op(inner)
247-
return value(c)
248-
end
249-
end
250-
251-
inner_args = arguments(arg)
252-
l = length(inner_args)
253-
exprs = []
254-
c = 0
255-
256-
for i in 1:l
257-
t2 = expand_derivatives(D(inner_args[i]),false, occurrences=arguments(occurrences)[i])
258-
259-
x = if _iszero(t2)
260-
t2
261-
elseif _isone(t2)
262-
d = derivative_idx(arg, i)
263-
d isa NoDeriv ? D(arg) : d
264-
else
265-
t1 = derivative_idx(arg, i)
266-
t1 = t1 isa NoDeriv ? D(arg) : t1
267-
t1 * t2
268-
end
269-
270-
if _iszero(x)
271-
continue
272-
elseif x isa Symbolic
273-
push!(exprs, x)
274-
else
275-
c += x
276-
end
277-
end
278-
279-
if isempty(exprs)
280-
return c
281-
elseif length(exprs) == 1
282-
term = (simplify ? SymbolicUtils.simplify(exprs[1]) : exprs[1])
283-
return _iszero(c) ? term : c + term
284-
else
285-
x = +((!_iszero(c) ? vcat(c, exprs) : exprs)...)
286-
return simplify ? SymbolicUtils.simplify(x) : x
287-
end
305+
return executediff(operation(O), arg, simplify)
288306
elseif iscall(O) && isa(operation(O), Integral)
289307
return operation(O)(expand_derivatives(arguments(O)[1]))
290308
elseif !hasderiv(O)
@@ -295,14 +313,14 @@ function expand_derivatives(O::Symbolic, simplify=false; occurrences=nothing)
295313
return simplify ? SymbolicUtils.simplify(O1) : O1
296314
end
297315
end
298-
function expand_derivatives(n::Num, simplify=false; occurrences=nothing)
299-
wrap(expand_derivatives(value(n), simplify; occurrences=occurrences))
316+
function expand_derivatives(n::Num, simplify=false)
317+
wrap(expand_derivatives(value(n), simplify))
300318
end
301-
function expand_derivatives(n::Complex{Num}, simplify=false; occurrences=nothing)
302-
wrap(ComplexTerm{Real}(expand_derivatives(real(n), simplify; occurrences=occurrences),
303-
expand_derivatives(imag(n), simplify; occurrences=occurrences)))
319+
function expand_derivatives(n::Complex{Num}, simplify=false)
320+
wrap(ComplexTerm{Real}(expand_derivatives(real(n), simplify),
321+
expand_derivatives(imag(n), simplify)))
304322
end
305-
expand_derivatives(x, simplify=false; occurrences=nothing) = x
323+
expand_derivatives(x, simplify=false) = x
306324

307325
_iszero(x) = false
308326
_isone(x) = false

test/diff.jl

+28
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,34 @@ let
349349
@test isequal(expand_derivatives(Differential(t)(t^2 + im*t)), 2t + im)
350350
end
351351

352+
# 1262
353+
#
354+
let
355+
@variables t b(t)
356+
D = Differential(t)
357+
expr = b - ((D(b))^2) * D(D(b))
358+
expr2 = D(expr)
359+
@test isequal(expand_derivatives(expr), expr)
360+
@test isequal(expand_derivatives(expr2), D(b) - (D(b)^2)*D(D(D(b))) - 2D(b)*(D(D(b))^2))
361+
end
362+
363+
# 1126
364+
#
365+
let
366+
@syms y f(y) g(y) h(y)
367+
D = Differential(y)
368+
369+
expr_gen = (fun) -> D(D(((-D(D(fun))) / g(y))))
370+
371+
expr = expr_gen(g(y))
372+
# just make sure that no errors are thrown in the following, the results are to complicated to compare
373+
expand_derivatives(expr)
374+
expr = expr_gen(h(y))
375+
expand_derivatives(expr)
376+
377+
expr = expr_gen(f(y))
378+
expand_derivatives(expr)
379+
end
352380

353381
# Check `is_derivative` function
354382
let

0 commit comments

Comments
 (0)