diff --git a/pyproject.toml b/pyproject.toml index 25ac2f6..a4d9b16 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "finch-tensor" -version = "0.1.28" +version = "0.1.29" description = "" authors = ["Willow Ahrens "] readme = "README.md" diff --git a/src/finch/__init__.py b/src/finch/__init__.py index 87ab562..e794812 100644 --- a/src/finch/__init__.py +++ b/src/finch/__init__.py @@ -103,6 +103,10 @@ real, imag, conj, + empty, + empty_like, + arange, + linspace, ) from .compiled import ( lazy, @@ -258,6 +262,10 @@ "conj", "read", "write", + "empty", + "empty_like", + "arange", + "linspace", ] __array_api_version__: str = "2023.12" diff --git a/src/finch/tensor.py b/src/finch/tensor.py index 667aa9f..80c5646 100644 --- a/src/finch/tensor.py +++ b/src/finch/tensor.py @@ -785,6 +785,57 @@ def zeros_like( return zeros(x.shape, dtype=dtype, format=format, device=device) +def empty( + shape: int | tuple[int, ...], + *, + dtype: DType | None = None, + format: str = "coo", + device: Device = None, +) -> Tensor: + return full(shape, np.float64(0), dtype=dtype, format=format, device=device) + + +def empty_like( + x: Tensor, + /, + *, + dtype: DType | None = None, + format: str = "coo", + device: Device = None, +) -> Tensor: + dtype = x.dtype if dtype is None else dtype + return empty(x.shape, dtype=dtype, format=format, device=device) + + +def arange( + start: int | float, + /, + stop: int | float | None = None, + step: int | float = 1, + *, + dtype: DType | None = None, + device: Device = None +) -> Tensor: + _validate_device(device) + return Tensor(np.arange(start, stop, step, jl_dtypes.jl_to_np_dtype[dtype])) + + +def linspace( + start: int | float | complex, + stop: int | float | complex, + /, + num: int, + *, + dtype: DType | None = None, + device: Device = None, + endpoint: bool = True, +) -> Tensor: + _validate_device(device) + return Tensor( + np.linspace(start, stop, num=num, dtype=jl_dtypes.jl_to_np_dtype[dtype], endpoint=endpoint) + ) + + def permute_dims(x: Tensor, axes: tuple[int, ...]): return x.permute_dims(axes) diff --git a/tests/test_sparse.py b/tests/test_sparse.py index dae91f2..7e766d6 100644 --- a/tests/test_sparse.py +++ b/tests/test_sparse.py @@ -227,7 +227,7 @@ def test_reshape(arr, new_shape, order): @pytest.mark.parametrize("shape", [10, (3, 3), (2, 1, 5)]) @pytest.mark.parametrize("dtype_name", [None, "int64", "float64"]) @pytest.mark.parametrize("format", ["coo", "dense"]) -def test_full_ones_zeros(shape, dtype_name, format): +def test_full_ones_zeros_empty(shape, dtype_name, format): jl_dtype = getattr(finch, dtype_name) if dtype_name is not None else None np_dtype = getattr(np, dtype_name) if dtype_name is not None else None @@ -246,6 +246,11 @@ def test_full_ones_zeros(shape, dtype_name, format): res = finch.zeros_like(res, dtype=jl_dtype, format=format) assert_equal(res.todense(), np.zeros(shape, np_dtype)) + res = finch.empty(shape, dtype=jl_dtype, format=format) + assert_equal(res.todense(), np.empty(shape, np_dtype)) + res = finch.empty_like(res, dtype=jl_dtype, format=format) + assert_equal(res.todense(), np.empty(shape, np_dtype)) + @pytest.mark.parametrize("func,arg", [(finch.asarray, np.zeros(3)), (finch.zeros, 3)]) def test_device_keyword(func, arg): @@ -335,3 +340,20 @@ def test_to_scalar(): ValueError, match=" can be computed for one-element tensors only." ): tns.__int__() + + +@pytest.mark.parametrize("dtype_name", [None, "int16", "float64"]) +def test_arange_linspace(dtype_name): + if dtype_name is not None: + finch_dtype = getattr(finch, dtype_name) + np_dtype = getattr(np, dtype_name) + else: + finch_dtype = np_dtype = None + + result = finch.arange(10, 100, 5, dtype=finch_dtype) + expected = np.arange(10, 100, 5, dtype=np_dtype) + assert_equal(result.todense(), expected) + + result = finch.linspace(20, 80, 10, dtype=finch_dtype) + expected = np.linspace(20, 80, 10, dtype=np_dtype) + assert_equal(result.todense(), expected)