-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
test: Added an Extensive Set of Tests #21
base: main
Are you sure you want to change the base?
test: Added an Extensive Set of Tests #21
Conversation
…o tests for them yet. `device_put` is actually a very powerfull operation, such as Memlets, where source and destination are on different devices. However, we do not support something like that yet, so we will keep it down yet.
It is not yet dynamic sclice, but soon.
…n value. This restirction was removed commit `05e4a885441c0cd`. However, I realized that it made some higher level translators impossible to write, such as `select_n`. Thus I removed this restriction again. Another solution would be to add another layer.
I also observe random failures for the `test_iota_broadcast()` test if I run all tests. However, if I only run it, then nothing happens, I have no idea why.
… the class description.
This should bring the development branch to the newest stage.
…Jax stuff about the primitives.
Jax adjust the start indexes if the window overruns, however, this is not done, instead an out of bound error happens.
During that work I also detected some [issue](spcl/dace#1579) in DaCe's simplification pipeline.
The cleaning was not correct. The function essentially created some deatached caches.
The cleaning was not correct. The function essentially created some deatached caches.
I just copied them and did not do a merge, which is not so nice. Furthermore, the tests are not yet there, in my view it makes sense to first have something that can be checked.
Before it was implementewd as a switch, for the case JAX would use the bool overload of XLA. The check for this was now essentially moved inside the function.
However, a similar issue (#1644) in DaCe is still open.
It is now better confiugured.
Now let's test if it works.
This is basically for testing them.
Before the `order` argument was a `Literal` but this caused more truble now it is a string.
I enabled the simplify pass in commit `411bd7bd` and it worked locally. However, this was because I was not running it inside nox and using my own version of DaCe. The bug in simplify was fixed in [PR#1603](spcl/dace#1603) which was merged _after_ 16.1 was released, thus the fix is not avaliable.
It seems JAX has updated the `make_jaxpr()` function and now that thing caches itself. This is now accounted for.
But the new function really tests if the loweing works, if teh strides are honored and infered.
71f9f86
to
5687f40
Compare
Co-authored-by: Enrique González Paredes <enriqueg@cscs.ch>
This PR introduces a series of primitive translators, most of them are based on the prototype, with some improvements. I just copied them over from the development branch, which is not so nice, but was the simplest thing to to without also introducing the other stuff. It is important that the tests from the development branch were not added, to keep the PR small. Furthermore, we need something to test, so this PR must go first. For organizational reasons, the development history of this PR happened to be contained in [PR#21](#21). --------- Co-authored-by: Enrique González Paredes <enriqueg@cscs.ch>
Codecov ReportAll modified and coverable lines are covered by tests ✅
❗ Your organization needs to install the Codecov GitHub app to enable full functionality. Additional details and impacted files@@ Coverage Diff @@
## main #21 +/- ##
=======================================
Coverage ? 88.66%
=======================================
Files ? 31
Lines ? 1235
Branches ? 251
=======================================
Hits ? 1095
Misses ? 82
Partials ? 58 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
First round of review only looking at the test infrastructure.
|
||
- JaCe always traces with enabled `x64` mode. | ||
This is a restriction that might be lifted in the future. | ||
- JAX returns scalars as zero-dimensional arrays, JaCe returns them as array with shape `(1, )`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The main issue is that DaCe does not have a real concept of zero dimensional arrays, as far as I know.
Consider the following two functions
@dace.program
def bar(a: dace.float64):
return a + 1
@dace.program
def baz(a: dace.float64[1]):
return a + 1
if you pass a zero dimensional array to bar()
then it will be casted to an scalar, if you pass it to baz()
an error will happen.
Furthermore, the binary interface of the SDFG can not return scalars, return values have to be arrays there is no way around that without patching the code generator and making a lot of changes to handle special cases.
So I decided to follow PEP20 and decided that this case is not special enough to change the rule.
If you want this feature then please open an issue.
- JaCe always traces with enabled `x64` mode. | ||
This is a restriction that might be lifted in the future. | ||
- JAX returns scalars as zero-dimensional arrays, JaCe returns them as array with shape `(1, )`. | ||
- In JAX parts of the computation runs on CPU parts on GPU, in JaCe everything runs (currently) either on CPU or GPU. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you mean here? Which parts run on CPU/GPU in JAX?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The JAX compiler, i.e. XLA can decide to do this.
The question is, if it really does it.
- JaCe does not return `jax.Array` instances, but NumPy/CuPy arrays. | ||
- The execution is not asynchronous. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These two points could be also fixed in the future, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Return Type: Return
jax.Array
#22 - Asynchronous: Asynchronous Execution #23
tests/common_fixture.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Move fixtures from conftest.py
here (in which case should be renamed to commom_fixtures.py
) or remove this empty file.
from jace import optimization, stages | ||
from jace.util import translation_cache as tcache | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For all fixtures in this module:
- remove the underscore of the name (all fixtures defined here are meant to be used outside of this module)
remove theautouse=true
and follow the same approach we used ingt4py-dace
(https://github.com/GridTools/gt4py/pull/1594/files#diff-4a5190762a7529af16fcf3d4f1726d2168726659f60f9d58495611d56eeb3008): an explicitpytestmark
definition at a module level with the fixtures that should be automatically used for all tests in the module.
Additionally, I would define another fixture requesting all other fixtures expected in the standard case and use this group
fixtures in the pytestmark
at module level. Example:
# This file
@pytest.fixture
def standard_jace_test_settings(enable_x64_mode_in_jax, disable_jit, ...) -> ...:
....
# Other test files
pytestmark = pytest.mark.usefixtures("standard_jace_test_settings")
Finally, I'd create a simpler type alias for the return type of generator fixtures as suggested here :
T = TypeVar("T")
YieldFixture = Generator[T, None, None]
@pytest.fixture
def foo() -> YieldFixture[str]:
yield "foo"
def make_array( | ||
shape: Sequence[int] | int, | ||
dtype: type = np.float64, | ||
order: str = "C", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
order: str = "C", | |
order: Literal["C", "F"] = "C", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was the original, however, it has serious problems.
Whenever the order
came from an argument (fixture to test if it works in both C
and F
order) then MyPy complained, and I had to add brainless # type: ignore[call-overload]
marks (see commit 090c3a2).
I do not think that cluttering the code with these annotations just to get some tiny bit of type security is not worth it.
low: Any = None, | ||
high: Any = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not?
low: Any = None, | |
high: Any = None, | |
low: int | float | np.number = 0, | |
high: int | float | np.number = 1, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yours is a bit better.
But it must be still accept None
, otherwise you will only generate 0
or 1
in the integer case.
Furthermore, technically it must also accept complex, but they are an edge case that we ignore.
|
||
|
||
def make_array( | ||
shape: Sequence[int] | int, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shape: Sequence[int] | int, | |
shape: Sequence[int], |
This is optional, but I would stick to one single way to define the shape to increase the readability of the function and help the type checker to catch errors here. In my opinion, there is not too much value in providing automatic tuple conversion for ints.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I do think that this change would increase readability.
For my these ..., shape=(some_scalar_variable,), ...
expression are much less readable than ..., shape=some_scalar_variable, ...
.
This typedef allows you to select the style that is most appropriate to the function.
Thus you do not have to replicate the "logic" that distinguish scalars from tuples everywhere you use this thing, but only at one central place.
Furthermore, it is the NumPy behaviour.
__all__ = ["make_array"] | ||
|
||
|
||
def make_array( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a question: this utility function is dealing with the generation of random array values with a certain shape, dtype and range, but what about packing them in the the correct ndarray type (numpy, cupy, Jax)? Isn't it needed ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At the end of the day, DaCe generates C code that operates on pointers to memory regions of a certain size.
Thus in the majority of cases you care about the data type and the size of the memory and not about the type of container that is used to store manage said memory it in the Python world.
However, you need to have tests to ensure that you can extract that memory from JAX, NumPy and CuPy arrays.
For JAX array we have:
tests/unit_tests/test_caching.py:test_caching_jax_numpy_array()
tests/unit_tests/test_jax_api.py:test_jax_array_as_input()
So in the end is composition at work.
This PR adds a large set of tests for various translators, especially the ones that where added by PR#20 but also other parts of the code are tested.
This is a very basic PR that does not add any functionality, except tests for the one that is already there.
This PR can only be merged after PR#20 has been merged.
Furthermore, for certain reasons this PR contains the whole development history of PR#20.