From e3730f8e93b43b66ce964b8f091778cb2cb7c114 Mon Sep 17 00:00:00 2001 From: mtsch Date: Fri, 4 Aug 2023 17:20:12 +0200 Subject: [PATCH] Small fixes --- src/DictVectors/pdvec.jl | 32 +++++++++++++++++++----------- src/DictVectors/pdworkingmemory.jl | 2 +- src/DictVectors/projectors.jl | 2 +- 3 files changed, 22 insertions(+), 14 deletions(-) diff --git a/src/DictVectors/pdvec.jl b/src/DictVectors/pdvec.jl index 8d81a93f0..c7bb4d768 100644 --- a/src/DictVectors/pdvec.jl +++ b/src/DictVectors/pdvec.jl @@ -15,7 +15,7 @@ end Dictionary-based vector-like data structure for use with FCIQMC and [KrylovKit.jl](https://github.com/Jutho/KrylovKit.jl). While mostly behaving like a `Dict`, -it supports various linear algebra operations such as `norm` and `dot`. +it supports various linear algebra operations such as `norm` and `dot`, and the interface defined in [VectorInterface](https://github.com/Jutho/VectorInterface.jl). The P in `PDVec` stands for parallel. `PDVec`s perform `mapreduce`, `foreach`, and various linear algebra operations in a threaded manner. If MPI is available, these operations are @@ -44,9 +44,6 @@ dictionary a key-value pair is mapped to is determined by the hash of the key. T of this segmentation is to allow parallel processing - functions such as `mapreduce`, `add!` or `dot` (full list below) process each subdictionary on a separate thread. -For parallel binary operations, the numbers of segments in both vectors must match. To -ensure this, it is best to leave the number of segments at its default value. - ### Example ```julia @@ -144,7 +141,7 @@ julia> results[1][1:4] The following functions are threaded MPI-compatible: * From Base: `mapreduce` and derivatives (`sum`, `prod`, `reduce`...), `all`, - `any`,`map!` (on values only), `+`, `-`, `*` + `any`,`map!` (on `values` only), `+`, `-`, `*` * From LinearAlgebra: `rmul!`, `lmul!`, `mul!`, `axpy!`, `axpby!`, `dot`, `norm`, `normalize`, `normalize!` @@ -284,16 +281,16 @@ function check_compatibility(t, u) end end -function Base.isequal(t::PDVec, u::PDVec) - check_compatibility(t, u) - if length(localpart(t)) == length(localpart(u)) - result = Folds.all(zip(t.segments, u.segments)) do (t_seg, u_seg) - isequal(t_seg, u_seg) +function Base.isequal(l::PDVec, r::PDVec) + check_compatibility(l, r) + if length(localpart(l)) == length(localpart(r)) + result = Folds.all(zip(l.segments, r.segments)) do (l_seg, r_seg) + isequal(l_seg, r_seg) end else result = false end - return merge_remote_reductions(t.communicator, &, result) + return merge_remote_reductions(l.communicator, &, result) end """ @@ -324,6 +321,8 @@ end function Base.setindex!(t::PDVec{K,V}, val, k::K) where {K,V} v = V(val) segment_id, is_local = target_segment(t, k) + # Adding a key that is not local is supported. This is done to allow easy construction + # of vectors even when using MPI. if is_local if iszero(v) delete!(t.segments[segment_id], k) @@ -660,7 +659,16 @@ Perform `y = A * x` in-place. The working memory `w` is required to facilitate threaded/distributed operations. If not passed a new instance will be allocated. `y` and `x` may be the same vector. """ -function LinearAlgebra.mul!(y::PDVec, op::AbstractHamiltonian, x::PDVec, w=PDWorkingMemory(y; style=IsDeterministic())) +function LinearAlgebra.mul!( + y::PDVec, op::AbstractHamiltonian, x::PDVec, + w=PDWorkingMemory(y; style=IsDeterministic()), +) + if w.style ≢ IsDeterministic() + throw(ArgumentError( + "Attempted to use `mul!` with non-deterministic working memory. " * + "Use `apply_operator!` instead." + )) + end _, _, wm, y = apply_operator!(w, y, x, op) return y end diff --git a/src/DictVectors/pdworkingmemory.jl b/src/DictVectors/pdworkingmemory.jl index 114d36b0b..631d4949d 100644 --- a/src/DictVectors/pdworkingmemory.jl +++ b/src/DictVectors/pdworkingmemory.jl @@ -168,7 +168,7 @@ function perform_spawns!(w::PDWorkingMemory, t::PDVec, ham, boost) error("working memory incompatible with vector") end stat_names, init_stats = step_stats(w.style) - stats = Folds.sum(zip(w.columns, t.segments)) do (column, segment) + stats = Folds.sum(zip(w.columns, t.segments); init=init_stats) do (column, segment) _spawn_column!(ham, column, segment, boost) end ::typeof(init_stats) return stat_names, stats diff --git a/src/DictVectors/projectors.jl b/src/DictVectors/projectors.jl index e6e245ccc..72b22bb27 100644 --- a/src/DictVectors/projectors.jl +++ b/src/DictVectors/projectors.jl @@ -145,7 +145,7 @@ Base.valtype(::FrozenDVec{<:Any,V}) where {V} = V Base.eltype(::FrozenDVec{K,V}) where {K,V} = Pair{K,V} Base.pairs(fd::FrozenDVec) = fd.pairs -freeze(dv) = FrozenDVec(collect(pairs(localpart(dv)))) +freeze(dv::AbstractDVec) = FrozenDVec(collect(pairs(localpart(dv)))) freeze(p::AbstractProjector) = p