Skip to content

Commit

Permalink
elsi csc mat reader
Browse files Browse the repository at this point in the history
  • Loading branch information
minyez committed Jun 6, 2024
1 parent da82ef5 commit 410507b
Showing 1 changed file with 79 additions and 0 deletions.
79 changes: 79 additions & 0 deletions mushroom/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,82 @@ def write(self, output_path: Union[str, os.PathLike] = None, format: str = None)
else:
self._cell.write(format, filename=output_path)


def read_elsi_to_csc(fn, verbose=False):
"""Read ELSI CSC file format.
Adapted from the same function in FHI-aims utitlities
Args:
fn (str) : path to the ELSI CSC file
verbose (bool) : verbositiy control
Retunrs:
scipy sparse matrix
"""
import struct
import scipy.sparse as sp
import numpy as np

mat = open(fn, "rb")
data = mat.read()
mat.close()
i8 = "l"
i4 = "i"

# Get header
start = 0
end = 128
header = struct.unpack(i8 * 16, data[start:end])
if verbose:
print(header)

# Number of basis functions (matrix size)
n_basis = header[3]

# Total number of non-zero elements
nnz = header[5]

# Get column pointer
start = end
end = start + n_basis * 8
col_ptr = struct.unpack(i8 * n_basis, data[start:end])
# print(col_ptr)
col_ptr += (nnz + 1, )
col_ptr = np.array(col_ptr)

# Get row index
start = end
end = start + nnz * 4
row_idx = struct.unpack(i4 * nnz, data[start:end])
row_idx = np.array(row_idx)

# Get non-zero value
start = end

if header[2] == 0:
if verbose:
print("Reading real matrix")
# Real case
end = start + nnz * 8
nnz_val = struct.unpack("d" * nnz, data[start:end])
else:
if verbose:
print("Reading complex matrix")
# Complex case
end = start + nnz * 16
nnz_val = struct.unpack("d" * nnz * 2, data[start:end])
nnz_val_real = np.array(nnz_val[0::2])
nnz_val_imag = np.array(nnz_val[1::2])
nnz_val = nnz_val_real + 1j * nnz_val_imag

nnz_val = np.array(nnz_val)

# Change convention to starting index from 0
for i_val in range(nnz):
row_idx[i_val] -= 1

for i_col in range(n_basis + 1):
col_ptr[i_col] -= 1

return sp.csc_matrix((nnz_val, row_idx, col_ptr), shape=(n_basis, n_basis))

0 comments on commit 410507b

Please sign in to comment.