Skip to content

Commit

Permalink
Tweak copy keyword logic
Browse files Browse the repository at this point in the history
  • Loading branch information
mtsokol committed Jun 6, 2024
1 parent 45571d5 commit 282b8ae
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 10 deletions.
11 changes: 7 additions & 4 deletions src/finch/tensor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import builtins
from typing import Any, Callable, Optional, Iterable, Literal
import warnings

import numpy as np
from numpy.core.numeric import normalize_axis_index, normalize_axis_tuple
Expand Down Expand Up @@ -91,7 +92,9 @@ def __init__(
copy: bool | None = None,
):
if isinstance(obj, (int, float, complex, bool, list)):
obj = np.array(obj, copy=copy)
if copy is False:
raise ValueError("copy=False isn't supported for scalar inputs and Python lists")
obj = np.asarray(obj)
if fill_value is None:
fill_value = 0.0

Expand Down Expand Up @@ -424,12 +427,12 @@ def _from_scipy_sparse(
fill_value: np.number | None = None,
copy: bool | None = None,
) -> JuliaObj:
if copy is False and not (x.has_canonical_format and x.format in ("coo", "csr", "csc")):
if copy is False and not (x.format in ("coo", "csr", "csc") and x.has_canonical_format):
raise ValueError("Unable to avoid copy while creating an array as requested.")
if copy or not x.has_canonical_format:
x = x.copy()
if x.format not in ("coo", "csr", "csc"):
x = x.asformat("coo")
if copy:
x = x.copy()
if not x.has_canonical_format:
x.sum_duplicates()
assert x.has_canonical_format
Expand Down
7 changes: 4 additions & 3 deletions tests/test_scipy_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,10 @@ def test_from_scipy_sparse(format_with_pattern, fill_value):
def test_non_canonical_format(format):
sp_arr = sp.random(3, 4, density=0.5, format=format)

with pytest.warns(
UserWarning, match="SciPy sparse input must be in a canonical format."
with pytest.raises(
ValueError, match="Unable to avoid copy while creating an array"
):
finch_arr = finch.asarray(sp_arr)
finch.asarray(sp_arr, copy=False)

finch_arr = finch.asarray(sp_arr)
assert_equal(finch_arr.todense(), sp_arr.toarray())
8 changes: 5 additions & 3 deletions tests/test_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,16 @@ def test_wrappers(dtype, jl_dtype, order):
@pytest.mark.parametrize("dtype", [np.int64, np.float64, np.complex128])
@pytest.mark.parametrize("order", ["C", "F", None])
@pytest.mark.parametrize("copy", [True, False, None])
def test_no_copy_fully_dense(dtype, order, copy, arr3d):
def test_copy_fully_dense(dtype, order, copy, arr3d):
arr = np.array(arr3d, dtype=dtype, order=order)
arr_finch = finch.Tensor(arr, copy=copy)
arr_todense = arr_finch.todense()

assert_equal(arr_todense, arr)
assert np.shares_memory(arr_todense, arr)

if copy:
assert not np.shares_memory(arr_todense, arr)
else:
assert np.shares_memory(arr_todense, arr)

def test_coo(rng):
coords = (
Expand Down

0 comments on commit 282b8ae

Please sign in to comment.