-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
I added this file, which I called `ROADMAP.md`, but `TODO.md` would be fine as well, to keep track and to help me kind of guide the development.
- Loading branch information
1 parent
09bcb3f
commit 0bc7ced
Showing
1 changed file
with
67 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
# Roadmap | ||
|
||
A kind of roadmap that gives a rough idea about how the project will be continued and which features will be implemented, in the sense of a todo list. | ||
|
||
- [x] Being able to perform _some_ translations [PR#3](https://github.com/GridTools/jace/pull/3). | ||
- [ ] Basic functionalities: | ||
- [ ] Annotation `@jace.jit`. | ||
- [ ] Composable with Jax, i.e. take the Jax derivative of a Jace annotated function. | ||
- [ ] Implementing the `stages` model that is supported by Jax. | ||
- [ ] Handling Jax arrays as native input (only on single host). | ||
- [ ] Cache the compilation and lowering results for later reuse. | ||
In Jax these parts (together with the dispatch) are actually written in C++, thus in the beginning we will use a self made cache. | ||
- [ ] Implementing some basic `PrimitiveTranslators`, that allows us to run some early tests, such as: | ||
- [ ] Backporting the ones from the prototype. | ||
- [ ] Implement the `scatter` primitive (needed for pyhpc). | ||
- [ ] Implement the `scan` primitive (needed for pyhpc). | ||
- [ ] _Initial_ optimization pipeline | ||
In order to do benchmarks, we need to perform optimizations first. | ||
However, the one offered by DaCe were not that well, so we should, for now, backport the ones from the prototype. | ||
- [ ] Support GPU code (relatively simple, but needs some detection logic). | ||
- [ ] Initial benchmark: | ||
In the beginning we will not have the same dispatching performance as Jax. | ||
But passing these benchmarks could give us some better hint of how to proceed in this matter. | ||
- [ ] Passing the [pyhpc-benchmark](https://github.com/dionhaefner/pyhpc-benchmarks) | ||
- [ ] Passing Felix' fluid project; possibility. | ||
- [ ] Support of static arguments. | ||
- [ ] Stop relying on `jax.make_jaxpr()`. | ||
Look at the `jax._src.pjit.make_jit()` function for how to hijack the staging process. | ||
- [ ] Implementing more advanced primitives: | ||
- [ ] Handling pytrees as arguments. | ||
- [ ] Implement random numbers. | ||
- [ ] `jax.numpy`. | ||
- [ ] `jax.scipy`. | ||
- [ ] Passing the single host Jax unit tests. | ||
- [ ] Multi-Device capabilities, i.e. multiple GPUs but all on the same host. | ||
- [ ] Passing the associated Jax unit tests. | ||
- [ ] Multi-Host capabilities, i.e. MPI. | ||
- [ ] Passing the associated Jax unit tests. | ||
|
||
## General | ||
|
||
These are more general topics that should be addressed at one point. | ||
|
||
- [ ] Integrating better with Jax | ||
- [ ] Support its array type (probably implement this in DaCe). | ||
- [ ] Increase the dispatching speed + Cache | ||
Jax does this in C++, which is impossible to beat in Python, thus we have to go that root as well. | ||
- [ ] Debugging information. | ||
- [ ] Dynamic shapes | ||
This could be done by making the inputs fully dynamic, and then use the primitives to simplify. | ||
For example in an addition the shape of the two inputs and the outputs are the same. | ||
That is knowledge that is inherent to the primitives itself. | ||
However, the compiled object must know how to extract the sizes itself. | ||
- [ ] Defining a Logo: | ||
It should be green with a nice curly font. | ||
|
||
# Optimization & Transformations | ||
|
||
The SDFG generated by Jace have a very particular structure, thus we could and probably should write some highly targeted optimization passes for them. | ||
Our experiments with the prototype showed that the most important transformation is Map fusion and the one in DaCe is essentially broken. | ||
|
||
- [ ] Modified state fusion; Because of the structure we have, this could make `Simplify` much more efficient. | ||
- [ ] Trivial Tasklet removal. | ||
Since we will work a lot with Maps that are trivial (probably the best structure for fusing) we will end up with some of trivial Tasklets, i.e. `__out = __in`. | ||
Thus, we should have a good way to get rid of them. | ||
- [ ] Modified Map fusion transformation. | ||
We should still support parallel and serial fusion as the prototype did, but focusing on serial. |