From 4c9a2766005103cabf1c2aec59c96dd407c6505c Mon Sep 17 00:00:00 2001 From: Ziyi Yin Date: Fri, 6 Jan 2023 14:49:17 -0500 Subject: [PATCH] fix reshape judivector in ad --- src/TimeModeling/Types/abstract.jl | 10 +++++++++- src/rrules.jl | 1 + 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/src/TimeModeling/Types/abstract.jl b/src/TimeModeling/Types/abstract.jl index 8c7cf12ed..7227c00a4 100644 --- a/src/TimeModeling/Types/abstract.jl +++ b/src/TimeModeling/Types/abstract.jl @@ -73,7 +73,15 @@ vec(x::judiMultiSourceVector) = vcat(vec.(x.data)...) time_sampling(ms::judiMultiSourceVector) = [1 for i=1:ms.nsrc] -reshape(ms::judiMultiSourceVector, dims::Dims{N}) where N = reshape(vec(ms), dims) +function reshape(ms::judiMultiSourceVector, dims::Dims{N}) where N + try + return reshape(vec(ms), dims) + catch e + @assert dims[1] == ms.nsrc ### during AD, size(ms::judiVector) = ms.nsrc + return ms + end +end + ############################################################################################################################ # Linear algebra `*` (msv::judiMultiSourceVector{mT})(x::AbstractVector{T}) where {mT, T<:Number} = x diff --git a/src/rrules.jl b/src/rrules.jl index 96729e79f..bcb27e58f 100644 --- a/src/rrules.jl +++ b/src/rrules.jl @@ -72,6 +72,7 @@ broadcasted(::typeof(^), y::LazyPropagation, p::Real) = eval_prop(y).^(p) *(F::judiPropagator, q::LazyPropagation) = F*eval_prop(q) reshape(F::LazyPropagation, dims...) = LazyPropagation(x->reshape(x, dims...), F.F, F.q, F.val) +vec(F::LazyPropagation) = LazyPropagation(vec, F.F, F.q, F.val) copyto!(x::AbstractArray, F::LazyPropagation) = copyto!(x, eval_prop(F)) dot(x::AbstractArray, F::LazyPropagation) = dot(x, eval_prop(F)) dot(F::LazyPropagation, x::AbstractArray) = dot(x, F)