Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Draft PR] Graphcore backend support. #1659

Open
wants to merge 77 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
77 commits
Select commit Hold shift + click to select a range
9ea89e7
cpu, gpu basic tests
Sameeranjoshi Jul 18, 2024
46a3c07
add cpu array test
Sameeranjoshi Jul 19, 2024
809048c
add optimization, helper file-check_external_library_used.py, this is…
Sameeranjoshi Jul 20, 2024
5c68cce
understood where is the source generated from, read codegen.py
Sameeranjoshi Jul 22, 2024
9a210f4
make more verbose comment
Sameeranjoshi Jul 24, 2024
8763d54
Tried using a custom codegen following the tutorial guide on dace web…
Sameeranjoshi Jul 24, 2024
e5ae4ee
make cpu, gpu, fpga tests to the most smallest and all doing vector a…
Sameeranjoshi Jul 24, 2024
a727cb9
add debug comments to understand the SDFG
Sameeranjoshi Jul 24, 2024
a575ef8
basic structure is dumped, using node as of now, build fails as well
Sameeranjoshi Jul 25, 2024
5934537
IPUTransformSDFG commented in python code, probably missing registrat…
Sameeranjoshi Jul 25, 2024
01c8bf5
MPI basic test
Sameeranjoshi Jul 25, 2024
429737d
Implement the LoopyLoop custom codegen on Map, will revert in the nex…
Sameeranjoshi Jul 25, 2024
77a7388
Revert "Implement the LoopyLoop custom codegen on Map, will revert in…
Sameeranjoshi Jul 25, 2024
fa78938
Debug: Find what are different types of nodes and how they are organized
Sameeranjoshi Jul 25, 2024
d1af971
Debug: make output more verbose from last commit
Sameeranjoshi Jul 25, 2024
c3171fb
print states(if-else, for)
Sameeranjoshi Jul 26, 2024
ed6f63e
convert from vector add to saclar add, name might be confusing
Sameeranjoshi Jul 26, 2024
c30f1f2
some debug comments, found control_flow_tree code, ipu.py has a lot o…
Sameeranjoshi Jul 26, 2024
78f19af
1. Fix IPUCodegen {used_targets}-{frame} error.
Sameeranjoshi Jul 27, 2024
af38f47
mpi_scalar.py, some debug comments, now move on to cpu only, don't lo…
Sameeranjoshi Jul 30, 2024
908a0f9
partial code works, read cpu.py and generate_{node, state}
Sameeranjoshi Jul 30, 2024
73fc0bb
add mapping GC program to dace
Sameeranjoshi Jul 31, 2024
c997147
[WIP] Register array, copy, and add some code for generating the head…
Sameeranjoshi Aug 2, 2024
01d0658
use dace.DeviceType.IPU to check and emit headers in framecode, not t…
Sameeranjoshi Aug 4, 2024
afbbdd1
learn sdfg by using the APIs and writing tests
Sameeranjoshi Aug 4, 2024
cff27f4
Add new test case, simple codes to understand writing SDFG by hand
Sameeranjoshi Aug 4, 2024
3715101
Add a new library, poplar
Sameeranjoshi Aug 5, 2024
c9c64dd
Fix the issue with checking if the device targer is IPU.
Sameeranjoshi Aug 6, 2024
30178f0
Comment the IPU type doesn't work as needs frontend support probably,…
Sameeranjoshi Aug 13, 2024
82d193d
Copied codegen from cpu.py, tweaked it and understood the structure o…
Sameeranjoshi Aug 13, 2024
8416d82
Add IPU in dtypes.py
Sameeranjoshi Aug 13, 2024
6a254c3
Use IPU from StorageType in sdfg.add_scalar
Sameeranjoshi Aug 13, 2024
3d6e96f
Revert "Add IPU in dtypes.py"
Sameeranjoshi Aug 13, 2024
6efc81e
Revert "Use IPU from StorageType in sdfg.add_scalar"
Sameeranjoshi Aug 13, 2024
cc02d9b
Replace pre_tasklet with generate_read, former comes from cpu.py late…
Sameeranjoshi Aug 13, 2024
0356005
create gpu_vector_add and cpu_* version
Sameeranjoshi Aug 23, 2024
5596ef8
Created the most simplest code for codegen of allocate_array and disp…
Sameeranjoshi Aug 24, 2024
2a424f3
Add IPU_Memory as new data type
Sameeranjoshi Aug 24, 2024
cbfdc42
Added
Sameeranjoshi Aug 24, 2024
623ab3b
generate addVariable() API
Sameeranjoshi Aug 24, 2024
6cd24ea
Add support for IPU types.
Sameeranjoshi Aug 25, 2024
e6610da
Fix shape issue, next add streams support
Sameeranjoshi Aug 25, 2024
a426141
Make more readable for humans.
Sameeranjoshi Aug 25, 2024
89722c3
Implement setTileMapping using 'mapdataontile'
Sameeranjoshi Aug 25, 2024
06d5dc3
Add poplar as a library
Sameeranjoshi Aug 28, 2024
2961353
1. Fix compilations issues from the previous commit.
Sameeranjoshi Aug 29, 2024
75821d7
Add missing ;
Sameeranjoshi Aug 29, 2024
1bb6ead
Fix missing directory paths
Sameeranjoshi Aug 29, 2024
3b47aae
1. Remove buggy/extra IPUModel and fix some more bug
Sameeranjoshi Aug 29, 2024
7ceb866
Fix curly braces not found in code as python f{} strings don't interp…
Sameeranjoshi Aug 29, 2024
3fa7a55
Fix 1. Compilation error (C++14/C++11). 2. Fix headers issue
Sameeranjoshi Aug 31, 2024
63fad92
Add arguments to the library, still we are not able to connect the in…
Sameeranjoshi Sep 1, 2024
3bae655
Fix link time libraries vs compile time libraries issue
Sameeranjoshi Sep 2, 2024
15d8023
changes to test, ipu_test is now the new base, added state dump, next…
Sameeranjoshi Sep 13, 2024
5207d2d
1. Insert all the golden file code from Poplar example.
Sameeranjoshi Sep 13, 2024
f6ac62f
fix bug where dace_init_target_ was missing
Sameeranjoshi Sep 13, 2024
e563d47
Revert "Add arguments to the library, still we are not able to connec…
Sameeranjoshi Sep 13, 2024
940f6bc
Add library node, register it, modify test for the same, goal is to h…
Sameeranjoshi Sep 15, 2024
2ffc0c9
Supress the building process
Sameeranjoshi Sep 15, 2024
ad13cfc
Attempt to add Node dispatcher
Sameeranjoshi Sep 16, 2024
6d5189f
Turn off the node dispatcher and generate a state using some code fro…
Sameeranjoshi Sep 19, 2024
e3193c4
Revert "Supress the building process"
Sameeranjoshi Sep 20, 2024
6ff38a6
Move the headers to a common runtime include/ folder dace/runtime/inc…
Sameeranjoshi Sep 30, 2024
ba32abe
some temporary changes
Sep 30, 2024
eeaa7d1
Resolve errors in compilation when using includes from runtime libraries
Sep 30, 2024
1a21f8e
Fix bug - wasn't generating proper kernel names, was not generic, tes…
Sep 30, 2024
9290acd
Support addVariables() and mapLinearlyOnTiles(), currently works only…
Oct 3, 2024
d1971bd
cosmetic changes, remove Dead code, iondent
Sameeranjoshi Oct 10, 2024
835bd81
Fix bug in is_ipu_kernel, was failing for tests where the first acces…
Sameeranjoshi Oct 10, 2024
0e43bb6
Fix mapping and variable allocation, remove Dead code
Sameeranjoshi Oct 11, 2024
1164d69
Try adding generate_node() - fails as the predicate fails, as there i…
Sameeranjoshi Oct 11, 2024
530e298
Add vector add test for dace and poplar
Sameeranjoshi Oct 17, 2024
1334015
Add scalar code using vector of size 1
Sameeranjoshi Oct 18, 2024
7e54d70
Remove prints
Sameeranjoshi Oct 18, 2024
3b1f4c7
new tests 1.copy a -> b on both IPU and dace test
Sameeranjoshi Oct 20, 2024
7537466
Add IPU_Memory to accessNode
Sameeranjoshi Oct 21, 2024
6db5588
Most of the codegen is correct, generate_node() doesn't trigger, copy…
Sameeranjoshi Oct 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions check_external_library_used.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from dace.libraries import blas

print('BLAS calls will expand by default to', blas.default_implementation)

if blas.IntelMKL.is_installed():
blas.default_implementation = 'MKL'
elif blas.cuBLAS.is_installed():
blas.default_implementation = 'cuBLAS'
elif blas.OpenBLAS.is_installed():
blas.default_implementation = 'OpenBLAS'
elif not blas.BLAS.is_installed():
# No BLAS library found, use the unoptimized native SDFG fallback
blas.default_implementation = 'pure'
20 changes: 20 additions & 0 deletions cpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved.
import dace
import numpy as np

@dace.program
def cpu_vector_add(A: dace.int32[20], B: dace.int32[20], C: dace.int32[20]):
for i in dace.map[0:20]: # parallelization construct
C[i] = A[i] + B[i]

if __name__ == '__main__':
sdfg = cpu_vector_add.to_sdfg(simplify=False) # compiled SDFG

# call with values
A = np.ones((20), dtype=np.int32) # 1,1,1,1,...
B = np.ones((20), dtype=np.int32) # 1,1,1,1,...
C = np.zeros((20), dtype=np.int32) # 0,0,0,0,...
sdfg(A, B, C)

# ref = np.full(20, 2, dtype=np.int32) # 2,2,2,2,...
# assert np.array_equal(ref, C)
25 changes: 25 additions & 0 deletions cpu_array_optimize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import dace
import numpy as np
from dace.transformation.optimizer import SDFGOptimizer


@dace.program
def cpu_getstarted_optimize(A, B, C):
C = A + B
return C

if __name__ == "__main__":
#a = np.random.rand(2,3)
# a = 10
# b = 20
# call with values
A = np.ones((20), dtype=np.int32) # 1,1,1,1,...
B = np.ones((20), dtype=np.int32) # 1,1,1,1,...
C = np.zeros((20), dtype=np.int32) # 0,0,0,0,...
print ("before dace(CPU) (a,b)", A, B, C)
print("after dace(CPU)", cpu_getstarted_optimize(A, B, C))
sdfg = cpu_getstarted_optimize.to_sdfg(A, B, C)

# VISUALLY OPTIMIZE
sdfg = SDFGOptimizer(sdfg).optimize()
# sdfg.apply_gpu_transformations()
107 changes: 107 additions & 0 deletions custom_codegen_external.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import dace
from dace import registry
from dace.sdfg.scope import ScopeSubgraphView
from dace.codegen.prettycode import CodeIOStream
from dace.codegen.targets.target import TargetCodeGenerator
from dace.codegen.targets.framecode import DaCeCodeGenerator
from dace.codegen.targets.cpp import sym2cpp

@dace.program
def custom_kernel(A: dace.float64[20, 30]):
for i, j in dace.map[0:20:2, 0:30]:
A[i, j] += A[i, j]



dace.ScheduleType.register('LoopyLoop')
dace.SCOPEDEFAULT_SCHEDULE[dace.ScheduleType.LoopyLoop] = dace.ScheduleType.Sequential
dace.SCOPEDEFAULT_STORAGE[dace.ScheduleType.LoopyLoop] = dace.StorageType.CPU_Heap


@registry.autoregister_params(name='loopy')
class MyCustomLoop(TargetCodeGenerator):
def __init__(self, frame_codegen: DaCeCodeGenerator, sdfg: dace.SDFG):
################################################################
# Define some locals:
# Can be used to call back to the frame-code generator
self.frame = frame_codegen
# Can be used to dispatch other code generators for allocation/nodes
self.dispatcher = frame_codegen.dispatcher

################################################################
# Register handlers/hooks through dispatcher: Can be used for
# nodes, memory copy/allocation, scopes, states, and more.

# In this case, register scopes
self.dispatcher.register_map_dispatcher(dace.ScheduleType.LoopyLoop, self)

# You can similarly use register_{array,copy,node,state}_dispatcher

# A scope dispatcher will trigger a method called generate_scope whenever
# an SDFG has a scope with that schedule
def generate_scope(self, sdfg: dace.SDFG, scope: ScopeSubgraphView,
state_id: int, function_stream: CodeIOStream,
callsite_stream: CodeIOStream):
# The parameters here are:
# sdfg: The SDFG we are currently generating.
# scope: The subgraph of the state containing only the scope (map contents)
# we want to generate the code for.
# state_id: The state in the SDFG the subgraph is taken from (i.e.,
# `sdfg.node(state_id)` is the same as `scope.graph`)
# function_stream: A cursor to the global code (which can be used to define
# functions, hence the name).
# callsite_stream: A cursor to the current location in the code, most of
# the code is generated here.

# We can get the map entry node from the scope graph
entry_node = scope.source_nodes()[0]

# First, generate an opening brace (for instrumentation and dynamic map ranges)
callsite_stream.write('{', sdfg, state_id, entry_node)

################################################################
# Generate specific code: We will generate a reversed loop with a
# comment for each dimension of the map. For the sake of simplicity,
# dynamic map ranges are not supported.

for param, rng in zip(entry_node.map.params, entry_node.map.range):
# We use the sym2cpp function from the cpp support functions
# to convert symbolic expressions to proper C++
begin, end, stride = (sym2cpp(r) for r in rng)

# Every write is optionally (but recommended to be) tagged with
# 1-3 extra arguments, serving as line information to match
# SDFG, state, and graph nodes/edges to written code.
callsite_stream.write(f'''// Loopy-loop {param}
for (auto {param} = {end}; {param} >= {begin}; {param} -= {stride}) {{''',
sdfg, state_id, entry_node
)

# NOTE: CodeIOStream will automatically take care of indentation for us.


# Now that the loops have been defined, use the dispatcher to invoke any
# code generator (including this one) that is registered to deal with
# the internal nodes in the subgraph. We skip the MapEntry node.
self.dispatcher.dispatch_subgraph(sdfg, scope, state_id,
function_stream, callsite_stream,
skip_entry_node=True)

# NOTE: Since skip_exit_node above is set to False, closing braces will
# be automatically generated

# Preview SDFG
sdfg = custom_kernel.to_sdfg()

# Change schedule
for node, _ in sdfg.all_nodes_recursive():
if isinstance(node, dace.nodes.MapEntry):
node.schedule = dace.ScheduleType.LoopyLoop

# Code(sdfg.generate_code()[0].clean_code, language='cpp')


# display
from IPython.display import Code
from IPython.display import display
display(Code(sdfg.generate_code()[0].clean_code, language='cpp'))
11 changes: 8 additions & 3 deletions dace/codegen/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from dace.codegen.instrumentation import InstrumentationProvider
from dace.sdfg.state import SDFGState


# include/* files, containing the signature header code.
def generate_headers(sdfg: SDFG, frame: framecode.DaCeCodeGenerator) -> str:
""" Generate a header file for the SDFG """
proto = ""
Expand All @@ -34,7 +34,7 @@ def generate_headers(sdfg: SDFG, frame: framecode.DaCeCodeGenerator) -> str:
proto += 'extern "C" void __program_%s(%sHandle_t handle%s);\n' % params
return proto


# sample/* files - contains the main() function.
def generate_dummy(sdfg: SDFG, frame: framecode.DaCeCodeGenerator) -> str:
""" Generates a C program calling this SDFG. Since we do not
know the purpose/semantics of the program, we allocate
Expand Down Expand Up @@ -147,7 +147,10 @@ def _get_codegen_targets(sdfg: SDFG, frame: framecode.DaCeCodeGenerator):
if sdfg.instrument != dtypes.InstrumentationType.No_Instrumentation:
disp.instrumentation[sdfg.instrument] = provider_mapping[sdfg.instrument]


# 3 step process
# 1. Generate the code for the SDFG(.cpp file)(generate_code)(sdfg.generate_code()[0])
# 2. Generate the header file for the SDFG(.h file)(generate_headers)(sdfg.generate_code()[1])
# 3. Generate the main function to call the SDFG(.main file)(generate_dummy)(sdfg.generate_code()[2])
def generate_code(sdfg: SDFG, validate=True) -> List[CodeObject]:
"""
Generates code as a list of code objects for a given SDFG.
Expand Down Expand Up @@ -230,6 +233,7 @@ def generate_code(sdfg: SDFG, validate=True) -> List[CodeObject]:
# NOTE: THE SDFG IS ASSUMED TO BE FROZEN (not change) FROM THIS POINT ONWARDS

# Generate frame code (and the rest of the code)
# (, generated_code/clean_code, ...))
(global_code, frame_code, used_targets, used_environments) = frame.generate_code(sdfg, None)
target_objects = [
CodeObject(sdfg.name,
Expand All @@ -246,6 +250,7 @@ def generate_code(sdfg: SDFG, validate=True) -> List[CodeObject]:
target_objects.extend(tgt.get_generated_codeobjects())

# Ensure that no new targets were dynamically added

assert frame._dispatcher.used_targets == (frame.targets - {frame})

# add a header file for calling the SDFG
Expand Down
2 changes: 2 additions & 0 deletions dace/codegen/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,8 @@ def dispatch_node(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg: StateSubgraphVi
state = cfg.state(state_id)
disp = self.get_node_dispatcher(sdfg, state, node)
self._used_targets.add(disp)
# print debugging for the dispatcher
print("SJJ: Dispatching node", node, "to", disp)
disp.generate_node(sdfg, cfg, dfg, state_id, node, function_stream, callsite_stream)

def get_scope_dispatcher(self, schedule: dtypes.ScheduleType) -> target.TargetCodeGenerator:
Expand Down
1 change: 1 addition & 0 deletions dace/codegen/targets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@
from .mlir.mlir import MLIRCodeGen
from .sve.codegen import SVECodeGen
from .snitch import SnitchCodeGen
from .ipu import IPUCodeGen
Loading