diff --git a/src/TimeModeling/LinearOperators/lazy.jl b/src/TimeModeling/LinearOperators/lazy.jl index a590d9476..ced81df02 100644 --- a/src/TimeModeling/LinearOperators/lazy.jl +++ b/src/TimeModeling/LinearOperators/lazy.jl @@ -130,6 +130,7 @@ end size(jA::jAdjoint) = (jA.op.n, jA.op.m) display(P::jAdjoint) = println("Adjoint($(P.op))") display(P::judiProjection{D}) where D = println("JUDI projection operator $(repr(P.n)) -> $(repr(P.m))") +display(P::judiWavelet{T}) where T = println("JUDI wavelet") ############################################################################################################################ # Indexing diff --git a/src/TimeModeling/Types/abstract.jl b/src/TimeModeling/Types/abstract.jl index 8c7cf12ed..be848d848 100644 --- a/src/TimeModeling/Types/abstract.jl +++ b/src/TimeModeling/Types/abstract.jl @@ -73,7 +73,9 @@ vec(x::judiMultiSourceVector) = vcat(vec.(x.data)...) time_sampling(ms::judiMultiSourceVector) = [1 for i=1:ms.nsrc] +reshape(ms::judiMultiSourceVector, dims::Dims{1}) = ms ### during AD, size(ms::judiVector) = ms.nsrc reshape(ms::judiMultiSourceVector, dims::Dims{N}) where N = reshape(vec(ms), dims) + ############################################################################################################################ # Linear algebra `*` (msv::judiMultiSourceVector{mT})(x::AbstractVector{T}) where {mT, T<:Number} = x diff --git a/src/rrules.jl b/src/rrules.jl index b8f16907f..b496f6b1b 100644 --- a/src/rrules.jl +++ b/src/rrules.jl @@ -18,15 +18,20 @@ Parameters * `F`: the JUDI propgator * `q`: The source to compute F*q """ -struct LazyPropagation +mutable struct LazyPropagation post::Function F::judiPropagator q + val # store F * q end -eval_prop(F::LazyPropagation) = F.post(F.F * F.q) +function eval_prop(F::LazyPropagation) + isnothing(F.val) && (F.val = F.F * F.q) + return F.post(F.val) +end Base.collect(F::LazyPropagation) = eval_prop(F) -LazyPropagation(F::judiPropagator, q) = LazyPropagation(identity, F, q) +LazyPropagation(post::Function, F::judiPropagator, q) = LazyPropagation(post, F, q, nothing) +LazyPropagation(F::judiPropagator, q) = LazyPropagation(identity, F, q, nothing) # Only a few arithmetic operation are supported @@ -45,10 +50,29 @@ for op in [:+, :-, :*, :/] end end +for op in [:*, :/] + @eval begin + $(op)(F::LazyPropagation, y::T) where T <: Number = LazyPropagation(F.post, F.F, $(op)(F.q, y), isnothing(F.val) ? nothing : $(op)(F.val, y)) + $(op)(y::T, F::LazyPropagation) where T <: Number = LazyPropagation(F.post, F.F, $(op)(y, F.q), isnothing(F.val) ? nothing : $(op)(y, F.val)) + broadcasted(::typeof($op), F::LazyPropagation, y::T) where T <: Number = LazyPropagation(F.post, F.F, broadcasted($(op), F.q, y), isnothing(F.val) ? nothing : broadcasted($(op), F.val, y)) + broadcasted(::typeof($op), y::T, F::LazyPropagation) where T <: Number = LazyPropagation(F.post, F.F, broadcasted($(op), y, F.q), isnothing(F.val) ? nothing : broadcasted($(op), y, F.val)) + end +end + +for op in [:+, :-] + @eval begin + $(op)(F::LazyPropagation, y::T) where T <: Number = $(op)(eval_prop(F), y) + $(op)(y::T, F::LazyPropagation) where T <: Number = $(op)(y, eval_prop(F)) + broadcasted(::typeof($op), F::LazyPropagation, y::T) where T <: Number = broadcasted($(op), eval_prop(F), y) + broadcasted(::typeof($op), y::T, F::LazyPropagation) where T <: Number = broadcasted($(op), y, eval_prop(F)) + end +end + broadcasted(::typeof(^), y::LazyPropagation, p::Real) = eval_prop(y).^(p) *(F::judiPropagator, q::LazyPropagation) = F*eval_prop(q) -reshape(F::LazyPropagation, dims...) = LazyPropagation(x->reshape(x, dims...), F.F, Q.q) +reshape(F::LazyPropagation, dims...) = LazyPropagation(x->reshape(x, dims...), F.F, F.q, F.val) +vec(F::LazyPropagation) = LazyPropagation(vec, F.F, F.q, F.val) copyto!(x::AbstractArray, F::LazyPropagation) = copyto!(x, eval_prop(F)) dot(x::AbstractArray, F::LazyPropagation) = dot(x, eval_prop(F)) dot(F::LazyPropagation, x::AbstractArray) = dot(x, F) diff --git a/test/test_rrules.jl b/test/test_rrules.jl index 0a4e76a68..404bdc000 100644 --- a/test/test_rrules.jl +++ b/test/test_rrules.jl @@ -32,7 +32,7 @@ perturb(x::judiVector) = judiVector(x.geometry, [randx(x.data[i]) for i=1:x.nsrc reverse(x::judiVector) = judiVector(x.geometry, [x.data[i][end:-1:1, :] for i=1:x.nsrc]) misfit_objective_2p(d_obs, q0, m0, F) = .5f0*norm(F(m0, q0) - d_obs)^2 -misfit_objective_1p(d_obs, q0, m0, F) = .5f0*norm(F(m0)*q0 - d_obs)^2 +misfit_objective_1p(d_obs, q0, m0, F) = .5f0*norm(F(1f0.*m0)*q0 - d_obs)^2 function loss(misfit, d_obs, q0, m0, F) local ϕ