Skip to content

Commit

Permalink
Small fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
mtsch committed Aug 4, 2023
1 parent f0e5642 commit e3730f8
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 14 deletions.
32 changes: 20 additions & 12 deletions src/DictVectors/pdvec.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ end
Dictionary-based vector-like data structure for use with FCIQMC and
[KrylovKit.jl](https://github.com/Jutho/KrylovKit.jl). While mostly behaving like a `Dict`,
it supports various linear algebra operations such as `norm` and `dot`.
it supports various linear algebra operations such as `norm` and `dot`, and the interface defined in [VectorInterface](https://github.com/Jutho/VectorInterface.jl).
The P in `PDVec` stands for parallel. `PDVec`s perform `mapreduce`, `foreach`, and various
linear algebra operations in a threaded manner. If MPI is available, these operations are
Expand Down Expand Up @@ -44,9 +44,6 @@ dictionary a key-value pair is mapped to is determined by the hash of the key. T
of this segmentation is to allow parallel processing - functions such as `mapreduce`, `add!`
or `dot` (full list below) process each subdictionary on a separate thread.
For parallel binary operations, the numbers of segments in both vectors must match. To
ensure this, it is best to leave the number of segments at its default value.
### Example
```julia
Expand Down Expand Up @@ -144,7 +141,7 @@ julia> results[1][1:4]
The following functions are threaded MPI-compatible:
* From Base: `mapreduce` and derivatives (`sum`, `prod`, `reduce`...), `all`,
`any`,`map!` (on values only), `+`, `-`, `*`
`any`,`map!` (on `values` only), `+`, `-`, `*`
* From LinearAlgebra: `rmul!`, `lmul!`, `mul!`, `axpy!`, `axpby!`, `dot`, `norm`,
`normalize`, `normalize!`
Expand Down Expand Up @@ -284,16 +281,16 @@ function check_compatibility(t, u)
end
end

function Base.isequal(t::PDVec, u::PDVec)
check_compatibility(t, u)
if length(localpart(t)) == length(localpart(u))
result = Folds.all(zip(t.segments, u.segments)) do (t_seg, u_seg)
isequal(t_seg, u_seg)
function Base.isequal(l::PDVec, r::PDVec)
check_compatibility(l, r)
if length(localpart(l)) == length(localpart(r))
result = Folds.all(zip(l.segments, r.segments)) do (l_seg, r_seg)
isequal(l_seg, r_seg)
end
else
result = false
end
return merge_remote_reductions(t.communicator, &, result)
return merge_remote_reductions(l.communicator, &, result)
end

"""
Expand Down Expand Up @@ -324,6 +321,8 @@ end
function Base.setindex!(t::PDVec{K,V}, val, k::K) where {K,V}
v = V(val)
segment_id, is_local = target_segment(t, k)
# Adding a key that is not local is supported. This is done to allow easy construction
# of vectors even when using MPI.
if is_local
if iszero(v)
delete!(t.segments[segment_id], k)
Expand Down Expand Up @@ -660,7 +659,16 @@ Perform `y = A * x` in-place. The working memory `w` is required to facilitate
threaded/distributed operations. If not passed a new instance will be allocated. `y` and `x`
may be the same vector.
"""
function LinearAlgebra.mul!(y::PDVec, op::AbstractHamiltonian, x::PDVec, w=PDWorkingMemory(y; style=IsDeterministic()))
function LinearAlgebra.mul!(
y::PDVec, op::AbstractHamiltonian, x::PDVec,
w=PDWorkingMemory(y; style=IsDeterministic()),
)
if w.style IsDeterministic()
throw(ArgumentError(
"Attempted to use `mul!` with non-deterministic working memory. " *
"Use `apply_operator!` instead."
))
end
_, _, wm, y = apply_operator!(w, y, x, op)
return y
end
Expand Down
2 changes: 1 addition & 1 deletion src/DictVectors/pdworkingmemory.jl
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ function perform_spawns!(w::PDWorkingMemory, t::PDVec, ham, boost)
error("working memory incompatible with vector")
end
stat_names, init_stats = step_stats(w.style)
stats = Folds.sum(zip(w.columns, t.segments)) do (column, segment)
stats = Folds.sum(zip(w.columns, t.segments); init=init_stats) do (column, segment)
_spawn_column!(ham, column, segment, boost)
end ::typeof(init_stats)
return stat_names, stats
Expand Down
2 changes: 1 addition & 1 deletion src/DictVectors/projectors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ Base.valtype(::FrozenDVec{<:Any,V}) where {V} = V
Base.eltype(::FrozenDVec{K,V}) where {K,V} = Pair{K,V}
Base.pairs(fd::FrozenDVec) = fd.pairs

freeze(dv) = FrozenDVec(collect(pairs(localpart(dv))))
freeze(dv::AbstractDVec) = FrozenDVec(collect(pairs(localpart(dv))))

freeze(p::AbstractProjector) = p

Expand Down

0 comments on commit e3730f8

Please sign in to comment.