Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rework of coo2csr conversion #174

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
233 changes: 231 additions & 2 deletions resolve/matrix/Utilities.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include <algorithm>
#include <list>
#include <resolve/Common.hpp>
#include <resolve/matrix/Coo.hpp>
#include <resolve/matrix/Csr.hpp>
Expand All @@ -7,6 +8,8 @@

namespace ReSolve
{
using out = io::Logger;

/// @brief Helper class for COO matrix sorting
class IndexValuePair
{
Expand Down Expand Up @@ -36,16 +39,242 @@ namespace ReSolve
bool operator < (const IndexValuePair& str) const
{
return (idx_ < str.idx_);
}
}

private:
index_type idx_;
real_type value_;
};

using out = io::Logger;
/// @brief Helper class for COO matrix sorting
class CooTriplet
{
public:
CooTriplet() : rowidx_(0), colidx_(0), value_(0.0)
{}
~CooTriplet()
{}
void setColIdx (index_type new_idx)
{
colidx_ = new_idx;
}
void setValue (real_type new_value)
{
value_ = new_value;
}
void set(index_type rowidx, index_type colidx, real_type value)
{
rowidx_ = rowidx;
colidx_ = colidx;
value_ = value;
}

index_type getRowIdx()
{
return rowidx_;
}
index_type getColIdx()
{
return colidx_;
}
real_type getValue()
{
return value_;
}

bool operator < (const CooTriplet& str) const
{
if (rowidx_ < str.rowidx_)
return true;

if ((rowidx_ == str.rowidx_) && (colidx_ < str.colidx_))
return true;

return false;
}

bool operator == (const CooTriplet& str) const
{
return (rowidx_ == str.rowidx_) && (colidx_ == str.colidx_);
}

CooTriplet& operator += (const CooTriplet t)
{
if ((rowidx_ != t.rowidx_) || (colidx_ != t.colidx_)) {
out::error() << "Adding values into non-matching triplet.\n";
}
value_ += t.value_;
return *this;
}

void print() const
{
// Add 1 to indices to restore indexing from MM format
std::cout << rowidx_ << " " << colidx_ << " " << value_ << "\n";
}

private:
index_type rowidx_{0};
index_type colidx_{0};
real_type value_{0.0};
};

void print_list(std::list<CooTriplet>& l)
{
// Print out the list
std::cout << "tmp list:\n";
for (CooTriplet& n : l)
n.print();
std::cout << "\n";
}


namespace matrix
{
/**
* @brief
*
* @param A_coo
* @param A_csr
* @param memspace
* @return int
*
* @pre A_coo and A_csr matrix sizes must match.
*/
int coo2csr_new(matrix::Coo* A_coo, matrix::Csr* A_csr, memory::MemorySpace memspace)
{
index_type n = A_coo->getNumRows();
index_type m = A_coo->getNumColumns();
index_type nnz = A_coo->getNnz();
const index_type* rows_coo = A_coo->getRowData(memory::HOST);
const index_type* cols_coo = A_coo->getColData(memory::HOST);
const real_type* vals_coo = A_coo->getValues(memory::HOST);
index_type n_diagonal = 0;
bool is_symmetric = A_coo->symmetric();
bool is_expanded = A_coo->expanded();
bool is_upper_triangular = false;
bool is_lower_triangular = false;

// Compute size of the symmetric matrix when expanded as general.
// Check if matrix is upper- or lower-triangular, if it is
// defined as symmetric and not expanded.
// Complexity O(NNZ)
index_type nnz_expanded = nnz;
if (is_symmetric && !is_expanded) {
for (index_type i = 0; i <nnz; ++i) {
if (rows_coo[i] == cols_coo[i]) {
++n_diagonal;
}

is_lower_triangular += (rows_coo[i] > cols_coo[i]);
is_upper_triangular += (rows_coo[i] < cols_coo[i]);
if (is_lower_triangular && is_upper_triangular) {
out::error() << "Input COO matrix supposed to be storred as symmetric "
<< "but is neither upper- nor lower-triangular.\n"
<< "Now exiting coo2csr ...\n";
return 1;
}
}
nnz_expanded = 2*nnz - n_diagonal;
}

// Create temporary workspace for COO to CSR conversion
// Store COO data in the workspace and expand, if needed.
// Complexity O(NNZ)
std::list<CooTriplet> tmp(nnz_expanded);
std::list<CooTriplet>::iterator it = tmp.begin();
for (index_type i = 0; i < nnz; ++i) {
index_type row = rows_coo[i];
index_type col = cols_coo[i];
real_type val = vals_coo[i];
it->set(row, col, val);
it++;

if (is_symmetric && !is_expanded) {
if (row != col) {
it->set(col, row, val);
it++;
}
}
}
if (it != tmp.end()) {
out::error() << "NNZ computed inaccurately!\n";
}
// print_list(tmp);
// std::cout << "Size of tmp list = " << tmp.size() << "\n\n";

// Sort tmp
// Complexity NNZ*log(NNZ)
tmp.sort();
// print_list(tmp);
// std::cout << "Size of tmp list = " << tmp.size() << "\n\n";

// Deduplicate tmp
// Complexity O(NNZ)
it = tmp.begin();
while (it != tmp.end())
{
std::list<CooTriplet>::iterator it_tmp = it;
it++;
if (*it == *it_tmp) {
*it += *it_tmp;
tmp.erase(it_tmp);
}
}
// print_list(tmp);
index_type nnz_expanded_no_duplicates = tmp.size();
// std::cout << "Size of tmp list = " << tmp.size() << "\n\n";


// Convert to general CSR
// Complexity O(NNZ)
A_csr->setExpanded(true);
A_csr->setNnz(nnz_expanded_no_duplicates);
A_csr->setNnzExpanded(nnz_expanded_no_duplicates);
A_csr->allocateMatrixData(memory::HOST);
index_type* rows_csr = A_csr->getRowData(memory::HOST);
index_type* cols_csr = A_csr->getColData(memory::HOST);
real_type* vals_csr = A_csr->getValues(memory::HOST);

index_type csr_row_idx = 0;
index_type csr_row_counter = 0;
index_type csr_val_counter = 0;

// Set first row pointer here
rows_csr[0] = csr_row_idx;
++csr_row_counter;

// Loop throught the list of COO triplets
it = tmp.begin();
while (it != tmp.end())
{
// Set column indices and matrix values
cols_csr[csr_val_counter] = it->getColIdx();
vals_csr[csr_val_counter] = it->getValue();

// When row index changes, set next row pointer
if (csr_row_idx != it->getRowIdx()) {
csr_row_idx = it->getRowIdx();
++csr_row_counter;
rows_csr[csr_row_idx] = csr_val_counter;
}

csr_val_counter++;
it++;
}
// Set last row pointer
rows_csr[csr_row_counter] = nnz_expanded_no_duplicates;

// std::cout << "Size of tmp list = " << tmp.size() << "\n\n";
// std::cout << "Rows counted = " << csr_row_counter << "\n";
// std::cout << "NNZs counted = " << csr_val_counter << "\n\n";
// A_csr->print();

return 0;
}



/**
* @brief Creates a CSR from a COO matrix.
*
Expand Down
3 changes: 3 additions & 0 deletions resolve/matrix/Utilities.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,8 @@ namespace ReSolve

/// @brief Converts symmetric or general COO to general CSR matrix
int coo2csr(matrix::Coo* A_coo, matrix::Csr* A_csr, memory::MemorySpace memspace);

/// @brief Converts symmetric or general COO to general CSR matrix
int coo2csr_new(matrix::Coo* A_coo, matrix::Csr* A_csr, memory::MemorySpace memspace);
}
}
7 changes: 6 additions & 1 deletion tests/unit/matrix/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
add_executable(runMatrixIoTests.exe runMatrixIoTests.cpp)
target_link_libraries(runMatrixIoTests.exe PRIVATE ReSolve resolve_matrix)

# Build matrix Conversion tests
add_executable(runMatrixConversionTests.exe runMatrixConversionTests.cpp)
target_link_libraries(runMatrixConversionTests.exe PRIVATE ReSolve resolve_matrix)

# Build matrix handler tests
add_executable(runMatrixHandlerTests.exe runMatrixHandlerTests.cpp)
target_link_libraries(runMatrixHandlerTests.exe PRIVATE ReSolve resolve_matrix)
Expand All @@ -25,14 +29,15 @@ if(RESOLVE_USE_LUSOL)
endif()

# Install tests
set(installable_tests runMatrixIoTests.exe runMatrixHandlerTests.exe runMatrixFactorizationTests.exe)
set(installable_tests runMatrixIoTests.exe runMatrixConversionTests.exe runMatrixHandlerTests.exe runMatrixFactorizationTests.exe)
if(RESOLVE_USE_LUSOL)
list(APPEND installable_tests runLUSOLTests.exe)
endif()
install(TARGETS ${installable_tests}
RUNTIME DESTINATION bin/resolve/tests/unit)

add_test(NAME matrix_test COMMAND $<TARGET_FILE:runMatrixIoTests.exe>)
add_test(NAME matrix_conversion_test COMMAND $<TARGET_FILE:runMatrixConversionTests.exe>)
add_test(NAME matrix_handler_test COMMAND $<TARGET_FILE:runMatrixHandlerTests.exe>)
add_test(NAME matrix_factorization_test COMMAND $<TARGET_FILE:runMatrixFactorizationTests.exe>)
if(RESOLVE_USE_LUSOL)
Expand Down
Loading