diff --git a/pyscf/lib/numpy_helper.py b/pyscf/lib/numpy_helper.py index dcc21b312f..8c7866cc3d 100644 --- a/pyscf/lib/numpy_helper.py +++ b/pyscf/lib/numpy_helper.py @@ -641,7 +641,7 @@ def hermi_sum(a, axes=None, hermi=HERMITIAN, inplace=False, out=None): if (not a.flags.c_contiguous or (a.dtype != numpy.double and a.dtype != numpy.complex128)): - if a.ndim == 2: + def _hermi_sum_2d(a, out): na = a.shape[0] for c0, c1 in misc.prange(0, na, BLOCK_DIM): for r0, r1 in misc.prange(0, c0, BLOCK_DIM): @@ -652,6 +652,14 @@ def hermi_sum(a, axes=None, hermi=HERMITIAN, inplace=False, out=None): tmp = a[c0:c1,c0:c1] + a[c0:c1,c0:c1].conj().T out[c0:c1,c0:c1] = tmp return out + + if a.ndim == 2: + return _hermi_sum_2d(a, out) + elif a.ndim == 3 and axes == (0,2,1): + assert a.shape[1] == a.shape[2] + for i in range(a.shape[0]): + _hermi_sum_2d(a[i], out[i]) + return out else: raise NotImplementedError('input array is not C-contiguous')