Skip to content

Commit

Permalink
Merge pull request #72 from willow-ahrens/last-creation-funcs
Browse files Browse the repository at this point in the history
Add `empty`, `empty_like`, `arange`, and `linspace`
  • Loading branch information
mtsokol authored Jun 17, 2024
2 parents aa2f8fb + 0bfed32 commit 5f9a222
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 2 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "finch-tensor"
version = "0.1.28"
version = "0.1.29"
description = ""
authors = ["Willow Ahrens <willow.marie.ahrens@gmail.com>"]
readme = "README.md"
Expand Down
8 changes: 8 additions & 0 deletions src/finch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,10 @@
real,
imag,
conj,
empty,
empty_like,
arange,
linspace,
)
from .compiled import (
lazy,
Expand Down Expand Up @@ -258,6 +262,10 @@
"conj",
"read",
"write",
"empty",
"empty_like",
"arange",
"linspace",
]

__array_api_version__: str = "2023.12"
51 changes: 51 additions & 0 deletions src/finch/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -785,6 +785,57 @@ def zeros_like(
return zeros(x.shape, dtype=dtype, format=format, device=device)


def empty(
shape: int | tuple[int, ...],
*,
dtype: DType | None = None,
format: str = "coo",
device: Device = None,
) -> Tensor:
return full(shape, np.float64(0), dtype=dtype, format=format, device=device)


def empty_like(
x: Tensor,
/,
*,
dtype: DType | None = None,
format: str = "coo",
device: Device = None,
) -> Tensor:
dtype = x.dtype if dtype is None else dtype
return empty(x.shape, dtype=dtype, format=format, device=device)


def arange(
start: int | float,
/,
stop: int | float | None = None,
step: int | float = 1,
*,
dtype: DType | None = None,
device: Device = None
) -> Tensor:
_validate_device(device)
return Tensor(np.arange(start, stop, step, jl_dtypes.jl_to_np_dtype[dtype]))


def linspace(
start: int | float | complex,
stop: int | float | complex,
/,
num: int,
*,
dtype: DType | None = None,
device: Device = None,
endpoint: bool = True,
) -> Tensor:
_validate_device(device)
return Tensor(
np.linspace(start, stop, num=num, dtype=jl_dtypes.jl_to_np_dtype[dtype], endpoint=endpoint)
)


def permute_dims(x: Tensor, axes: tuple[int, ...]):
return x.permute_dims(axes)

Expand Down
24 changes: 23 additions & 1 deletion tests/test_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def test_reshape(arr, new_shape, order):
@pytest.mark.parametrize("shape", [10, (3, 3), (2, 1, 5)])
@pytest.mark.parametrize("dtype_name", [None, "int64", "float64"])
@pytest.mark.parametrize("format", ["coo", "dense"])
def test_full_ones_zeros(shape, dtype_name, format):
def test_full_ones_zeros_empty(shape, dtype_name, format):
jl_dtype = getattr(finch, dtype_name) if dtype_name is not None else None
np_dtype = getattr(np, dtype_name) if dtype_name is not None else None

Expand All @@ -246,6 +246,11 @@ def test_full_ones_zeros(shape, dtype_name, format):
res = finch.zeros_like(res, dtype=jl_dtype, format=format)
assert_equal(res.todense(), np.zeros(shape, np_dtype))

res = finch.empty(shape, dtype=jl_dtype, format=format)
assert_equal(res.todense(), np.empty(shape, np_dtype))
res = finch.empty_like(res, dtype=jl_dtype, format=format)
assert_equal(res.todense(), np.empty(shape, np_dtype))


@pytest.mark.parametrize("func,arg", [(finch.asarray, np.zeros(3)), (finch.zeros, 3)])
def test_device_keyword(func, arg):
Expand Down Expand Up @@ -335,3 +340,20 @@ def test_to_scalar():
ValueError, match="<class 'int'> can be computed for one-element tensors only."
):
tns.__int__()


@pytest.mark.parametrize("dtype_name", [None, "int16", "float64"])
def test_arange_linspace(dtype_name):
if dtype_name is not None:
finch_dtype = getattr(finch, dtype_name)
np_dtype = getattr(np, dtype_name)
else:
finch_dtype = np_dtype = None

result = finch.arange(10, 100, 5, dtype=finch_dtype)
expected = np.arange(10, 100, 5, dtype=np_dtype)
assert_equal(result.todense(), expected)

result = finch.linspace(20, 80, 10, dtype=finch_dtype)
expected = np.linspace(20, 80, 10, dtype=np_dtype)
assert_equal(result.todense(), expected)

0 comments on commit 5f9a222

Please sign in to comment.