Skip to content

Commit

Permalink
Merge pull request #101 from marc-vdm/technosphere_mat_format
Browse files Browse the repository at this point in the history
Force correct `self.technosphere_matrix` format for solver
  • Loading branch information
cmutel committed Jul 24, 2024
2 parents 8942c42 + b2ba342 commit fea496c
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 3 deletions.
22 changes: 19 additions & 3 deletions bw2calc/lca_base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import warnings
from collections.abc import Iterator
from functools import partial
from scipy.sparse import csc_matrix, csr_matrix
from typing import Optional, Tuple

import matrix_utils as mu
Expand Down Expand Up @@ -45,7 +46,13 @@ def load_lci_data(self, nonsquare_ok=False) -> None:
use_distributions=use_distributions,
seed_override=self.seed_override,
)
self.technosphere_matrix = self.technosphere_mm.matrix

# explicitly set technosphere format to optimal CSR/CSC format depending on solver
# see this link for discussion https://github.com/haasad/PyPardiso/issues/75#issuecomment-2186825609
if PYPARDISO:
self.technosphere_matrix = self.technosphere_mm.matrix.tocsr()
else:
self.technosphere_matrix = self.technosphere_mm.matrix.tocsc()
self.dicts.product = partial(self.technosphere_mm.row_mapper.to_dict)
self.dicts.activity = partial(self.technosphere_mm.col_mapper.to_dict)

Expand Down Expand Up @@ -119,7 +126,9 @@ def decompose_technosphere(self) -> None:
if PYPARDISO:
warnings.warn("PARDISO installed; this is a no-op")
else:
self.solver = factorized(self.technosphere_matrix.tocsc())
if isinstance(self.technosphere_matrix, csr_matrix):
self.technosphere_matrix.tocsc()
self.solver = factorized(self.technosphere_matrix)

def solve_linear_system(self, demand: Optional[np.ndarray] = None) -> None:
"""
Expand All @@ -145,7 +154,14 @@ def solve_linear_system(self, demand: Optional[np.ndarray] = None) -> None:
if hasattr(self, "solver"):
return self.solver(demand)
else:
return spsolve(self.technosphere_matrix, demand)
if ((PYPARDISO and isinstance(self.technosphere_matrix, csr_matrix)) or
(not PYPARDISO and isinstance(self.technosphere_matrix, csc_matrix))):
return spsolve(self.technosphere_matrix, demand)
elif PYPARDISO:
return spsolve(self.technosphere_matrix.tocsr(), demand)
else:
return spsolve(self.technosphere_matrix.tocsc(), demand)


def lci(self, demand: Optional[dict] = None, factorize: bool = False) -> None:
"""
Expand Down
1 change: 1 addition & 0 deletions bw2calc/multi_lca.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import warnings
from pathlib import Path
from scipy.sparse import csc_matrix, csr_matrix
from typing import Iterable, Optional, Union

import bw_processing as bwp
Expand Down

0 comments on commit fea496c

Please sign in to comment.