From 2e6e8d8ecc13acce656c2a9b3e59ec1694ead5af Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gamze=20=C4=B0slamo=C4=9Flu?= <54476562+gamzeisl@users.noreply.github.com> Date: Tue, 24 Sep 2024 18:43:41 +0200 Subject: [PATCH] Activation unit (#1) * build(requirements): add pytorch * build(requirements): add pytest * test(py-ita): add feedforward size parameter * test(test-generator): add feedforward parameter and adjust output path * feat(py-ita): extend print_properties with feedforward size * feat(py-ita). implement floating point i_poly * feat(py-ita): add i_poly_wrapper function * build(requirements): add pytest-check * feat(py-ita): add symmetric (de-)quantization functions * feat(py-ita): add functions for floating point poly and erf * test(py-ita): add test for quantization, poly and erf functions * fix(py-ita): fix q_c typo from definition * feat(py-ita): implement floating point i_gelu * test(py-ita): add i_gelu tests * fix(py-ita): fix ipoly types to avoid overflow * refactor(py-ita): type i-gelu and i-erf * test(py-ita): clean up and pretty print * refactor(py-ita): fix typo * test(py-ita): make scaling factor output more readable * test(py-ita): verify i-gelu domain interval [b, -b] * test(py-ita): add error plotting * refactor(py-ita): rename variable * feat(py-ita): add missing imports * feat(py-ita): add gelu golden model based on random preactivation values * feat(ita-tb): adjustments for feedforward parameter * feat(return-status): adjust parameters and check number of outputs * build(makefile): add feedforward parameter * build(gitlab-ci): adjust for new parameters * fix(test-generator): typo in simvector path * fix(test-generator): fix and refactor simvector paths * feat(py-ita): write out gelu constants * feat(gelu-tb): add testbench skeleton * feat(accel-pkg): add gelu bit width parameters * feat(gelu): add initial scalar gelu * build(bender): add gelu and gelu testbench * feat(modelsim): add gelu simulation script * build(modelsim): add gelu simulation target * test(gelu-tb): validate scalar gelu for all activations from precomputed stimuli files * chore(gitignore): ignore vscode settings.json * fix(py-ita): clip i-gelu within valid range -127, 127 * test(py-ita): add tests for i-gelu edge cases * refactor(gelu): explicitly sign-extend input to output bit width * fix(gelu): clip input within valid bounds * stlye(gelu): auto-format * feat(py-ita): implement i-gelu with requantization * test(py-ita): add tests for i-gelu with requantization * feat(py-ita): use 8 bit eps_mul for requantizer i-gelu * test(py-ita): auto-format * feat(py-ita): write out i-gelu requantization constants * test(gelu): read and apply gelu requantization constants * feat(gelu): implement requantization * test(gelu): remove redundant print statements * refactor(gelu): rename types * refactor(gelu): use correct type for add * test(gelu): refactor to test higher-level activation module instead of just GELU * feat(relu): add RELU activation module * feat(activation): add module which lifts GELU and RELU to match output dimension of requantizer * feat(accel-pkg): add enum to select activation function * build: adjust simulation and build files for activation testbench * test(activation): read activations in blocks of N_PE size * test(activation): also verify ReLU and identity activations * test(activation-tb): extract validation function to reduce redandancies * test(activation-tb): rename variables holding gelu constants for clarity * refactor(py-ita): rename files for GELU constants for clarity * test(activation-tb): extract function which reads all GELU constants * test(activation-tb): fix equality check + refactor reading of GELU constants * test(activation-tb): parametrize function for reading GELU constants * refactor(activation): reorder inputs * feat(accel-pkg): extend controller interface with constants for GELU and activation mode * feat(ita): insert activation module in datapath between requantizer and output FIFO * test(ita-tb): supply ITA with additional activation-related control signals and constants * feat(py-ita): implement almost symmetric quantization * test(py-ita): add simple test for almost symmetric quantization * test(py-ita): fix last test case of quantize * test(py-ita): use almost symmetric quantization instead of symmetric quantization * fix(ita): compute scaling factor based on almost symmetric quantization * fix(gelu): remove edge case treatment of q=-128, no longer necessary with almost symmetric quantization * refactor(activation): rename requant variables and extract typedefs * refactor(gelu): extract type for gelu output * fix(py-ita): make sure to round before clipping and properly apply half-way rounding from zero * feat(py-ita): ensure requantization constants are unsigned and compensate for eps_mul sign flip * test(py-ita): reduce error tolerances * feat(activation): apply requantization using existing vector module instead of inside GELU * test(activation-tb): account for requantization latency * refactor(activation): reorder condition for clarity * test(activation-tb): make sure to use the correct input file when expected values for RELU * refactor(fifo-controller): rename port for clarity * fix(fifo-controller): delay insertion by two cycles to account for activation requantizer latency * feat(activation): don't introduce unnecessary delay for RELU and IDENTITY * test(activation-tb): extend and refactor for requantized GELU * feat(fifo-controller): let fifo insertion condition depend on activation latency * refactor(ita): rename fifo controller port * refactor(relu): use typedef for clarity * feat(py-ita): allow choosing activation function for step 6 * feat(ita): delay activation signal to keep in sync with activation input * test(ita-tb): make value mismatch info more verbose * test(ita-tb): use RELU activation by default * feat(activation): fix activation latency to two cycles * fix(fifo-controller): account for fixed latency of activation module * fix(ita): account for fixed latency of activation module * test(ita-tb): use GELU activation by default * feat(py-ita): use GELU activation by default * test(ita-tb): introduce activation parameter and set to identity by default * feat(py-ita): allow setting activation function using command line arguments * test(ita-tb): expose activation function as parameter * build(makefile): allow configuring activation function * test(activation-tb): fix latency * refactor(accel_pkg): clean up typedefs * refactor(ita): remove unused first_inner_tile signal * feat(activation): requantize relu output * feat(py-ita): requantize relu output * test(activation-tb): read requantized relu output from generated file * refactor(test-ITA): move plot files into subdir * chore(gitignore): ignore pytest cache * test(activation-tb): rewrite debug msg * refactor(ita): rename gelu requant constants to more general activation requant constants * build(requirements): set up pre-commit hooks for auto-formatting python files using yapf * style: auto-format files with yapf * build(pre-commit-config): add some linters like checking for trailing whitespace and verifying python ast correctness * style: fix trailing whitespace * test(py-ita): fix plot dir * refactor(ita): remove unused signals * refactor(accel-pkg): reorganize for clarity * feat(ita): add requantization controller which decouples state value from constants index using indirection * refactor(requantizer): incorporate new requant_mode type * test(activation-tb): use new requant_mode type * refactor: make enums camel case * feat(accel-pkg): add feedforward layer type * refactor(py-ita): extract apply_activation function * refactor(py-ita): vectorize apply_activation * test(activation-tb): update file names of activation requant constants * feat(py-ita): update file names of activation requant constants * test(ita-tb): rename files of activation and gelu constants * test(ita-tb): extend testbench to run a single feedforward layer with GELU activation after attention layer * feat(py-ita): generate testvectors for feedforward layer * feat(requantization-controller): reuse requantization constants at index 0 for feedforward layer * feat(accel-pkg): extend interface with layer mode and number of feedforward tiles * feat(controller): extend FSM for feedforward layer * feat(py-ita): execute arbitrary activation in feedforward layer * test(ita-tb): allow executing arbitrary activation function for feedforward layer * feat(py-ita): add second stage with identity activation to feedforward layer according to transformer architecture * test(ita-tb): execute second stage with identity activation in feedforward layer * fix(ita): use correct imported function * fix(py-ita): use instance field instead of removed local * fix(py-ita): call correct functions * refactor(makefile): remove duplicate variable definitions * refactor(ita-tb): rename projection size to projection space * fix(py-ita): use RQS constants at index 0 for 2nd FF layer * fix(ita-tb): add missing required finish_number argument for $fatal call * fix(hwpe-ita-tb): add missing required finish_number argument for $fatal call * refactor(accel-pkg): rename layer typedef * refactor(accel-pkg): reorder ctrl fields * test(ita-tb): use matching types * refactor(ita-tb): rename activation constant variables * refactor(ita-tb): rename index variable since it has no relationship to phase * fix(hwpe-ita-tb): fix simdir path * fix(ita-package): update control engine to match control struct in accel_pkg * test(hwpe-ita-tb): print correct base ptrs * fix(hwpe-ita-tb): create arrays of correct size so that indexing with the step state does not go out-of-bounds when number of states changes * feat(ita-package): reorder control engine type * refactor(ita-package): remove redundant definition of control engine structure * feat(ita-package): extend register mapping for feedforward layer * refactor(accel-pkg): explicitly state enum type to allow direct assignment without cast from implicit int * style(ita-package): formatting * feat(hwpe-ita-tb): load constants and prepare registers for feedforward layer * feat(ita-ctrl): pass constants and control signals for feedforward layer to control engine * refactor(accel-pkg): explicitely compute N_STATES parameter * refactor(py-ita): extract memfile constant * fix(py-ita): include all files in file list for hwpe * perf(gelu): apply strength reduction to remove signed mult when computing gelu_erf * perf(gelu): reduce intermediate bitwidths * perf(gelu): use lower bitwidth for poly_d * feat(py-ita): export mempool and snitch cluster on demand * chore(gitlab-ci): use gelu activation for FF layer by default * feat(activation): pipeline gelu * feat(accel-pkg): increase FIFO depth to account for gelu pipelining * test(activation-tb): adjust for increased latency * test(ita-tb): fix input, weight and output timing * test(activation-tb): fix latency for identity * refactor(gelu): split up combinational block by stages * perf(activation): reduce flop resources for RELU buffering by 70% by performing sign extension when reading out buffers * fix(ita): delay activation control signal until end of activation computations * perf(gelu): use calc_en signal to only compute during valid cycles * test(activation): add extra calc_en input * refactor(accel-pkg): removed unused fields in control_t * test(ita-tb): don't reference unused control_t fields * refactor(gelu): merge gelu one and gelu b constants * test(activation): removed unused signal * fix(py-ita): correctly compute L2 error * refactor(ita-package): remove unused gelu one constant * chore(accel-pkg): increase fifo depth * build(ci): fix mismatch of generated testvectors * feat(return-status): add ff checks * change(hwpe-pkg): do not reuse regs for activations * fix(ita_tb): lower input valid signal after handshake * feat: add support for two layer ffn * fix(PyITA): correct random vector gen for ffn * feat(PyITA): write hwpe files for ffn * feat(hwpe_tb): extend to test ffn * fix(PyITA): correct typecast in gelu * change(ci): add activation to hwpe sim * feat(PyITA): add separate requantization params for ffn * feat(hw): add separate ffn requant params * [PyITA] Move GELU functions * [ci] Add tests with relu * [PyITA] Modify license headers * Remove config yaml file * Add header * Fix python format * Add relu test vectors to ci --------- Co-authored-by: Timon Fercho --- .gitignore | 4 +- .gitlab-ci.yml | 31 ++- Bender.yml | 5 + Makefile | 15 +- PyITA/ITA.py | 224 +++++++++++++++-- PyITA/gelu.py | 103 ++++++++ PyITA/test_gelu.py | 262 ++++++++++++++++++++ PyITA/util.py | 78 +++++- README.md | 5 + modelsim/Makefile | 3 + modelsim/return_status.sh | 94 ++++--- modelsim/sim_activation_tb.tcl | 26 ++ requirements.txt | 4 + src/hwpe/ita_hwpe_ctrl.sv | 16 ++ src/hwpe/ita_hwpe_package.sv | 32 +-- src/hwpe/tb/ita_hwpe_tb.sv | 355 ++++++++++++++++++--------- src/ita.sv | 113 ++++++--- src/ita_activation.sv | 119 +++++++++ src/ita_controller.sv | 50 +++- src/ita_fifo_controller.sv | 5 +- src/ita_gelu.sv | 67 +++++ src/ita_package.sv | 71 +++--- src/ita_relu.sv | 14 ++ src/ita_requantization_controller.sv | 44 ++++ src/ita_requantizer.sv | 8 +- src/ita_softmax.sv | 4 +- src/tb/activation_tb.sv | 283 +++++++++++++++++++++ src/tb/ita_tb.sv | 163 ++++++++---- testGenerator.py | 15 +- tests/run_loop.sh | 2 +- 30 files changed, 1888 insertions(+), 327 deletions(-) create mode 100644 PyITA/gelu.py create mode 100644 PyITA/test_gelu.py create mode 100644 modelsim/sim_activation_tb.tcl create mode 100644 src/ita_activation.sv create mode 100644 src/ita_gelu.sv create mode 100644 src/ita_relu.sv create mode 100644 src/ita_requantization_controller.sv create mode 100644 src/tb/activation_tb.sv diff --git a/.gitignore b/.gitignore index 9a445c2..7b65df9 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,7 @@ simvectors/ # Byte-compiled / optimized / DLL files __pycache__/ +.pytest_cache/ *.py[cod] *$py.class @@ -22,10 +23,9 @@ dist # Jupyter Notebook .ipynb_checkpoints -# Ignore everything in .vscode except launch.json +# Ignore everything in .vscode with some exceptions .vscode/* !.vscode/launch.json -!.vscode/settings.json !.vscode/c_cpp_properties.json simvectors diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 226a5eb..d87f8f2 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -26,8 +26,9 @@ generate_testvectors: stage: test script: - !reference [.setup_test, script] - - python testGenerator.py -H 1 -S 64 -E 64 -P 64 - - python testGenerator.py -H 1 -S 128 -E 192 -P 256 + - python testGenerator.py -H 1 -S 64 -E 64 -P 64 -F 64 --activation gelu + - python testGenerator.py -H 1 -S 128 -E 192 -P 256 -F 256 --activation gelu + - python testGenerator.py -H 1 -S 192 -E 256 -P 128 -F 128 --activation relu artifacts: paths: - simvectors @@ -42,13 +43,22 @@ run_sim: - S: 64 E: 64 P: 64 + F: 64 + activation: gelu - S: 128 E: 192 P: 256 + F: 256 + activation: gelu + - S: 192 + E: 256 + P: 128 + F: 128 + activation: relu script: - make bender - - make sim VSIM_FLAGS=-c s=$S e=$E p=$P bias=1 - - ./modelsim/return_status.sh modelsim/build/transcript $S $E ita_tb + - make sim VSIM_FLAGS=-c s=$S e=$E p=$P f=$F bias=1 activation=$activation + - ./modelsim/return_status.sh modelsim/build/transcript $S $E $P $F ita_tb run_hwpe_sim: stage: sim @@ -59,10 +69,19 @@ run_hwpe_sim: - S: 64 E: 64 P: 64 + F: 64 + activation: gelu - S: 128 E: 192 P: 256 + F: 256 + activation: gelu + - S: 192 + E: 256 + P: 128 + F: 128 + activation: relu script: - make bender - - make sim VSIM_FLAGS=-c DEBUG=OFF target=sim_ita_hwpe_tb s=$S e=$E p=$P bias=1 - - ./modelsim/return_status.sh modelsim/build/transcript $S $E hwpe_tb + - make sim VSIM_FLAGS=-c DEBUG=OFF target=sim_ita_hwpe_tb s=$S e=$E p=$P f=$F bias=1 activation=$activation + - ./modelsim/return_status.sh modelsim/build/transcript $S $E $P $F hwpe_tb diff --git a/Bender.yml b/Bender.yml index 27050c7..f0497c7 100644 --- a/Bender.yml +++ b/Bender.yml @@ -50,6 +50,10 @@ sources: - src/ita_weight_controller.sv - src/ita.sv - src/ita_max_finder.sv + - src/ita_gelu.sv + - src/ita_relu.sv + - src/ita_activation.sv + - src/ita_requantization_controller.sv # HWPE sources - target: ita_hwpe @@ -72,6 +76,7 @@ sources: - src/tb/ita_tb.sv - src/tb/clk_rst_gen.sv - src/tb/rst_gen.sv + - src/tb/activation_tb.sv # HWPE TB sources - target: ita_hwpe_test diff --git a/Makefile b/Makefile index 44d7a37..5996908 100644 --- a/Makefile +++ b/Makefile @@ -21,10 +21,19 @@ target ?= sim_ita_tb no_stalls ?= 0 s ?= 64 -p ?= 64 -e ?= 64 +e ?= 128 +p ?= 192 +f ?= 256 bias ?= 0 -vlog_defs += -DNO_STALLS=$(no_stalls) -DSEQ_LENGTH=$(s) -DPROJ_SPACE=$(p) -DEMBED_SIZE=$(e) -DBIAS=$(bias) +activation ?= identity +ifeq ($(activation), gelu) + activation_int = 1 +else ifeq ($(activation), relu) + activation_int = 2 +else + activation_int = 0 +endif +vlog_defs += -DNO_STALLS=$(no_stalls) -DSEQ_LENGTH=$(s) -DEMBED_SIZE=$(e) -DPROJ_SPACE=$(p) -DFF_SIZE=$(f) -DBIAS=$(bias) -DACTIVATION=$(activation_int) ifeq ($(target), sim_ita_hwpe_tb) BENDER_TARGETS += -t ita_hwpe -t ita_hwpe_test diff --git a/PyITA/ITA.py b/PyITA/ITA.py index bd81153..039771a 100644 --- a/PyITA/ITA.py +++ b/PyITA/ITA.py @@ -23,9 +23,10 @@ from numpy.typing import ArrayLike, DTypeLike from .softmax import fastSoftmax, realSoftmax, streamingPartialSoftmax +from .gelu import gelu_requantize, i_gelu_requantized, get_i_gelu_constants, get_i_gelu_requantized_constants from .util import (generate_matrix_mem, pack_8b_to_word, pack_array_8b_to_word, pack_hex_24b, pack_multihead_8b_to_word, pack_multihead_24b_to_word, random_shuffled_tensor, requantize, split_matrix, to_hex, write_matrix, - write_matrix_mem, write_matrix_mem_hex, write_vector_mem_hex) + write_matrix_mem, write_matrix_mem_hex, write_vector_mem_hex, get_almost_symmetric_scaling_factor) class Transformer: @@ -36,9 +37,11 @@ def __init__(self, S: int, P: int, E: int, + F: int, H: int, path: Union[str, os.PathLike], bias: bool = True, + activation: str = "identity", Q: ArrayLike = None, K: ArrayLike = None, V: ArrayLike = None, @@ -49,7 +52,12 @@ def __init__(self, Bq: ArrayLike = None, Bk: ArrayLike = None, Bv: ArrayLike = None, - Bo: ArrayLike = None): + Bo: ArrayLike = None, + FF_in: ArrayLike = None, + Wff: ArrayLike = None, + Wff2: ArrayLike = None, + Bff: ArrayLike = None, + Bff2: ArrayLike = None): self.ITA_N = 16 self.ITA_M = 64 @@ -63,14 +71,17 @@ def __init__(self, self.S_ITA = max(64, S) self.P_ITA = max(64, P) self.E_ITA = max(64, E) + self.F_ITA = max(64, F) self.H_ITA = 4 self.split = self.ITA_M // self.ITA_N self.S = S self.P = P self.E = E + self.F = F self.H = H self.bias = bias + self.activation = activation # Setup transformation functions self.split_m_m = partial(split_matrix, block_shape = (self.ITA_M, self.ITA_M)) @@ -78,7 +89,8 @@ def __init__(self, self._validate_matrix_constraints(K, V) self._initialize_quantization_parameters() - self._initialize_tensors(Q, V, Wq, Wk, Wv, Wo, Bq, Bk, Bv, Bo) + self._init_gelu_constants() + self._initialize_tensors(Q, V, Wq, Wk, Wv, Wo, Bq, Bk, Bv, Bo, FF_in, Wff, Wff2, Bff, Bff2) def split_multihead_m_m(self, multihead_array: np.ndarray): """ @@ -100,6 +112,7 @@ def _validate_matrix_constraints(self, K: ArrayLike, V: ArrayLike): assert (self.S % self.ITA_M == 0), "Sequence length must be divisible by ITA_M" assert (self.P % self.ITA_M == 0), "Projection space must be divisible by ITA_M" assert (self.E % self.ITA_M == 0), "Embedding size must be divisible by ITA_M" + assert (self.F % self.ITA_M == 0), "Feedforward size must be divisible by ITA_M" assert ( self.E <= 512 @@ -110,10 +123,13 @@ def _validate_matrix_constraints(self, K: ArrayLike, V: ArrayLike): assert ( self.S <= 512 ), f"Sequence length must be less than {int(2**(self.WO-17))} because the internal bit width is {self.WO} bits" + assert ( + self.F <= 512 + ), f"Feedforward size must be less than {int(2**(self.WO-17))} because the internal bit width is {self.WO} bits" # assert (self.H % self.H_ITA == 0 or self.H == 1), "Number of heads must be one or divisible by H_ITA" - def _initialize_tensors(self, Q, V, Wq, Wk, Wv, Wo, Bq, Bk, Bv, Bo): + def _initialize_tensors(self, Q, V, Wq, Wk, Wv, Wo, Bq, Bk, Bv, Bo, FF_in, Wff, Wff2, Bff, Bff2): self.exp_sum = np.zeros(self.S, dtype = np.int32) @@ -127,6 +143,9 @@ def _initialize_tensors(self, Q, V, Wq, Wk, Wv, Wo, Bq, Bk, Bv, Bo): self.K_in = self.V_in self.K = self.V + self.FF_in = random_shuffled_tensor((self.S, self.E), self.WI - 1) if FF_in is None else FF_in + self.FF = np.pad(self.FF_in, ((0, self.S_ITA - self.S), (0, self.E_ITA - self.E))) + #### Weight matrices #### self.Wq_in = random_shuffled_tensor((self.H, self.E, self.P), self.WI - 1) if Wq is None else Wq self.Wq = np.pad(self.Wq_in, ((0, 0), (0, self.E_ITA - self.E), (0, self.P_ITA - self.P))) @@ -140,6 +159,11 @@ def _initialize_tensors(self, Q, V, Wq, Wk, Wv, Wo, Bq, Bk, Bv, Bo): self.Wo_in = random_shuffled_tensor((self.H, self.P, self.E), self.WI - 1) if Wo is None else Wo self.Wo = np.pad(self.Wo_in, ((0, 0), (0, self.P_ITA - self.P), (0, self.E_ITA - self.E))) + self.Wff_in = random_shuffled_tensor((1, self.E, self.F), self.WI - 1) if Wff is None else Wff + self.Wff = np.pad(self.Wff_in, ((0, 0), (0, self.E_ITA - self.E), (0, self.F_ITA - self.F))) + self.Wff2_in = random_shuffled_tensor((1, self.F, self.E), self.WI - 1) if Wff2 is None else Wff2 + self.Wff2 = np.pad(self.Wff2_in, ((0, 0), (0, self.F_ITA - self.F), (0, self.E_ITA - self.E))) + #### Bias matrices #### if self.bias: self.Bq_in = random_shuffled_tensor( @@ -173,6 +197,21 @@ def _initialize_tensors(self, Q, V, Wq, Wk, Wv, Wo, Bq, Bk, Bv, Bo): self.Bo = np.pad(self.Bo_in, ((0, 0), (0, self.E_ITA - self.E))) self.Bo_broadcast = np.reshape(np.repeat(self.Bo, self.S, axis = 0), (self.H, self.S, self.E)) + if self.bias: + self.Bff_in = random_shuffled_tensor( + (1, self.F), int(np.log2(self.F)) + 8, type = np.int32) if Bff is None else Bff + else: + self.Bff_in = np.zeros((1, self.F), dtype = np.int8) + self.Bff = np.pad(self.Bff_in, ((0, 0), (0, self.F_ITA - self.F))) + self.Bff_broadcast = np.reshape(np.repeat(self.Bff, self.S, axis = 0), (1, self.S, self.F)) + if self.bias: + self.Bff2_in = random_shuffled_tensor( + (1, self.E), int(np.log2(self.E)) + 8, type = np.int32) if Bff2 is None else Bff2 + else: + self.Bff2_in = np.zeros((1, self.E), dtype = np.int8) + self.Bff2 = np.pad(self.Bff2_in, ((0, 0), (0, self.E_ITA - self.E))) + self.Bff2_broadcast = np.reshape(np.repeat(self.Bff2, self.S, axis = 0), (1, self.S, self.E)) + #### Intermediate tensors #### self.Qp = None @@ -181,6 +220,10 @@ def _initialize_tensors(self, Q, V, Wq, Wk, Wv, Wo, Bq, Bk, Bv, Bo): self.Kp_requant = None self.Vp = None self.Vp_requant = None + self.FFp = None + self.FFp_requant = None + self.FF2p = None + self.FF2p_requant = None self.A = None self.A_requant = None @@ -196,8 +239,11 @@ def _initialize_tensors(self, Q, V, Wq, Wk, Wv, Wo, Bq, Bk, Bv, Bo): self.Out_soft_sum = None self.Out_soft_sum_requant = None + self.preactivation = np.random.randint(-128, 127, size = (self.S, self.F), dtype = np.int8) + self.postactivation = None + def _initialize_quantization_parameters(self): - # WIESEP: 6 steps for attention layer and one to requantize the accumulated output + # WIESEP: 6 steps for attention layer and one to requantize the accumulated output, 2 for feedforward self.requant_eps_mult = np.zeros((7, self.H), dtype = np.uint8) self.requant_right_shift = np.zeros((7, self.H), dtype = np.uint8) @@ -224,9 +270,44 @@ def _initialize_quantization_parameters(self): else: self.requant_right_shift[i, :] = max_bit_width - 8 + 2 - write_matrix([self.requant_eps_mult.T], "RQS_MUL", self.paths["base"]) - write_matrix([self.requant_right_shift.T], "RQS_SHIFT", self.paths["base"]) - write_matrix([self.requant_add.T], "RQS_ADD", self.paths["base"]) + write_matrix([self.requant_eps_mult.T], "RQS_ATTN_MUL", self.paths["base"]) + write_matrix([self.requant_right_shift.T], "RQS_ATTN_SHIFT", self.paths["base"]) + write_matrix([self.requant_add.T], "RQS_ATTN_ADD", self.paths["base"]) + + self.requant_eps_mult_ffn = np.zeros((2, 1), dtype = np.uint8) + self.requant_right_shift_ffn = np.zeros((2, 1), dtype = np.uint8) + self.requant_add_ffn = np.zeros((2, 1), dtype = np.int8) + + for i in range(2): + self.requant_eps_mult_ffn[i, :] = np.random.randint(64, 127, size = (1, 1), dtype = np.uint8) + + if i == 0: + max_bit_width = np.log2(self.requant_eps_mult_ffn[i, :].astype(np.uint32) * self.E * 2**9).astype( + np.uint32) + elif i == 1: + max_bit_width = np.log2(self.requant_eps_mult_ffn[i, :].astype(np.uint32) * self.F * 2**9).astype( + np.uint32) + + self.requant_right_shift_ffn[i, :] = max_bit_width - 8 + 2 + + write_matrix([self.requant_eps_mult_ffn.T], "RQS_FFN_MUL", self.paths["base"]) + write_matrix([self.requant_right_shift_ffn.T], "RQS_FFN_SHIFT", self.paths["base"]) + write_matrix([self.requant_add_ffn.T], "RQS_FFN_ADD", self.paths["base"]) + + def _init_gelu_constants(self): + CLIP_LO = -4 + D = 2**20 + + gelu_eps_mult, _ = get_almost_symmetric_scaling_factor(CLIP_LO, n_bits = 8) + self.q_1, self.q_b, self.q_c, _, _, _, self.gelu_rqs_mul, self.gelu_rqs_shift, self.gelu_rqs_add, S_out = get_i_gelu_requantized_constants( + gelu_eps_mult, D) + + write_matrix([[self.q_1]], "GELU_ONE", self.paths["base"]) + write_matrix([[self.q_b]], "GELU_B", self.paths["base"]) + write_matrix([[self.q_c]], "GELU_C", self.paths["base"]) + write_matrix([[self.gelu_rqs_mul]], "activation_requant_mult", self.paths["base"]) + write_matrix([[self.gelu_rqs_shift]], "activation_requant_shift", self.paths["base"]) + write_matrix([[self.gelu_rqs_add]], "activation_requant_add", self.paths["base"]) def _init_paths(self, base_path: Union[str, os.PathLike]): self.paths = { @@ -248,11 +329,15 @@ def print_properties(self, verbose: int, text_align = 30): print(f"{'Matrix Sequence Length ' :<{text_align}}: {self.S}") print(f"{'Matrix Projection Space' :<{text_align}}: {self.P}") print(f"{'Matrix Embedding Size ' :<{text_align}}: {self.E}") + print(f"{'Matrix Feedforward Size' :<{text_align}}: {self.F}") print(f"{'Matrix Number of Heads ' :<{text_align}}: {self.H}") print(f"{'Bias ' :<{text_align}}: {bool(self.bias)}") - print(f"{'Requant Mult ' :<{text_align}}: {list(self.requant_eps_mult)}") - print(f"{'Requant Shift ' :<{text_align}}: {list(self.requant_right_shift)}") - print(f"{'Requant Add ' :<{text_align}}: {list(self.requant_add)}") + print(f"{'Requant Mult Attention ' :<{text_align}}: {list(self.requant_eps_mult)}") + print(f"{'Requant Shift Attention ' :<{text_align}}: {list(self.requant_right_shift)}") + print(f"{'Requant Add Attention ' :<{text_align}}: {list(self.requant_add)}") + print(f"{'Requant Mult FFN ' :<{text_align}}: {list(self.requant_eps_mult_ffn)}") + print(f"{'Requant Shift FFN ' :<{text_align}}: {list(self.requant_right_shift_ffn)}") + print(f"{'Requant Add FFN ' :<{text_align}}: {list(self.requant_add_ffn)}") def tiler_QK(self, qk: np.ndarray, weight: np.ndarray, bias: np.ndarray, output: np.ndarray, input_file: str, weight_file: str, bias_file: str, output_file: str): @@ -479,20 +564,67 @@ def step5_AV(self): self.tiler_AV(self.A_requant, np.transpose(self.Vp_requant, (0, 2, 1)), self.O_soft_requant, "A_stream_soft_in", "Vp_in", "O_soft") + def apply_activation(self, preactivation, activation): + if activation not in ["gelu", "relu", "identity"]: + raise ValueError("Activation function not supported") + + if activation == "gelu": + vectorized_gelu = np.vectorize(i_gelu_requantized) + postactivation = vectorized_gelu(preactivation, self.q_1, self.q_b, self.q_c, self.gelu_rqs_mul, + self.gelu_rqs_shift, self.gelu_rqs_add) + elif activation == "relu": + postactivation = np.maximum(preactivation, 0) + vectorized_requantize = np.vectorize(gelu_requantize) + postactivation = vectorized_requantize(postactivation, self.gelu_rqs_mul, self.gelu_rqs_shift, + self.gelu_rqs_add) + elif activation == "identity": + postactivation = preactivation.copy() + + return postactivation + def step6_O(self): self.Out_soft = np.matmul(self.O_soft_requant, self.Wo, dtype = np.int32) + self.Bo_broadcast self.Out_soft = np.clip(self.Out_soft, -2**(self.WO - 1), 2**(self.WO - 1) - 1) self.Out_soft_requant = requantize(self.Out_soft, self.requant_eps_mult[5], self.requant_right_shift[5], self.requant_add[5]) - self.tiler_Out(self.O_soft_requant, self.Wo, self.Bo, self.Out_soft_requant, "O_soft_in", "Wo", "Bo", "Out_soft") + def feedforward_layer(self): + self.FFp = np.matmul(self.FF, self.Wff, dtype = np.int32) + self.Bff_broadcast + self.FFp = np.clip(self.FFp, -2**(self.WO - 1), 2**(self.WO - 1) - 1) + self.FFp_requant = requantize(self.FFp, self.requant_eps_mult_ffn[0], self.requant_right_shift_ffn[0], + self.requant_add_ffn[0]) + self.FFp_requant = self.apply_activation(self.FFp_requant, self.activation) + + self.tiler_QK(self.FF, self.Wff, self.Bff, self.FFp_requant, "FF", "Wff", "Bff", "FFp") + + self.FF2p = np.matmul(self.FFp_requant, self.Wff2, dtype = np.int32) + self.Bff2_broadcast + self.FF2p = np.clip(self.FF2p, -2**(self.WO - 1), 2**(self.WO - 1) - 1) + self.FF2p_requant = requantize(self.FF2p, self.requant_eps_mult_ffn[1], self.requant_right_shift_ffn[1], + self.requant_add_ffn[1]) + + self.tiler_Out(self.FFp_requant, self.Wff2, self.Bff2, self.FF2p_requant, "FFp_in", "Wff2", "Bff2", "FF2p") + def step7_Osum(self): self.Out_soft_sum = np.sum(self.Out_soft_requant, axis = 0, dtype = np.int32, keepdims = True) self.Out_soft_sum_requant = requantize(self.Out_soft_sum, self.requant_eps_mult[6], self.requant_right_shift[6], self.requant_add[6]) + def test_activations(self): + write_matrix(self.preactivation, "preactivation", self.paths["standalone"]) + gelu = np.zeros(self.preactivation.shape, dtype = np.int8) + relu = np.zeros(self.preactivation.shape, dtype = np.int8) + for i in range(self.preactivation.shape[0]): + for j in range(self.preactivation.shape[1]): + gelu[i, j] = i_gelu_requantized(self.preactivation[i, j], self.q_1, self.q_b, self.q_c, + self.gelu_rqs_mul, self.gelu_rqs_shift, self.gelu_rqs_add) + relu[i, j] = self.preactivation[i, j] if self.preactivation[i, j] > 0 else 0 + relu[i, j] = gelu_requantize(relu[i, j], self.gelu_rqs_mul, self.gelu_rqs_shift, self.gelu_rqs_add) + + write_matrix(gelu, "gelu", self.paths["standalone"]) + write_matrix(relu, "relu", self.paths["standalone"]) + def export_hwpe(self): path = self.paths["hwpe"] @@ -501,49 +633,54 @@ def remove_if_exists(file_name): os.remove(file_name) # WIESEP: Delete the old file otherwise it will lead to mismatches during RTL simulations as the files are memory mapped - files = ["mem.txt", "Output.txt", "Q.txt", "K.txt", "V.txt", "QK.txt", "A.txt", "AV.txt", "OW.txt"] + mem_file = "mem" + files = [ + f"{mem_file}.txt", "Output.txt", "Q.txt", "K.txt", "V.txt", "QK.txt", "A.txt", "AV.txt", "OW.txt", "F1.txt", + "F2.txt" + ] for file in files: remove_if_exists(f"{path}/{file}") # Write the new mem file + # Layer: Attention for h in range(self.H): q = split_matrix(self.Q, (self.ITA_M, self.ITA_M)) - write_matrix_mem_hex(pack_array_8b_to_word(q, hex_string = False), "mem", path) + write_matrix_mem_hex(pack_array_8b_to_word(q, hex_string = False), mem_file, path) k = split_matrix(self.K, (self.ITA_M, self.ITA_M)) - write_matrix_mem_hex(pack_array_8b_to_word(k, hex_string = False), "mem", path) + write_matrix_mem_hex(pack_array_8b_to_word(k, hex_string = False), mem_file, path) w1 = split_matrix(np.transpose(self.Wq[h]), (self.ITA_M, self.ITA_M)) - write_matrix_mem_hex(pack_array_8b_to_word(w1, hex_string = False), "mem", path) + write_matrix_mem_hex(pack_array_8b_to_word(w1, hex_string = False), mem_file, path) w2 = split_matrix(np.transpose(self.Wk[h]), (self.ITA_M, self.ITA_M)) - write_matrix_mem_hex(pack_array_8b_to_word(w2, hex_string = False), "mem", path) + write_matrix_mem_hex(pack_array_8b_to_word(w2, hex_string = False), mem_file, path) w3 = split_matrix(np.transpose(self.Wv[h]), (self.ITA_M, self.ITA_M)) - write_matrix_mem_hex(pack_array_8b_to_word(w3, hex_string = False), "mem", path) + write_matrix_mem_hex(pack_array_8b_to_word(w3, hex_string = False), mem_file, path) w4 = split_matrix(np.transpose(self.Wo[h]), (self.ITA_M, self.ITA_M)) - write_matrix_mem_hex(pack_array_8b_to_word(w4, hex_string = False), "mem", path) + write_matrix_mem_hex(pack_array_8b_to_word(w4, hex_string = False), mem_file, path) b1_hex = np.vectorize(lambda val: to_hex(val, bit_size = 24))(self.Bq[h]) # pack 24-bit values into 32-bit words packed_b1_hex = np.array(pack_hex_24b(b1_hex)) - write_vector_mem_hex(packed_b1_hex, "mem", path) + write_vector_mem_hex(packed_b1_hex, mem_file, path) b2_hex = np.vectorize(lambda val: to_hex(val, bit_size = 24))(self.Bk[h]) # pack 24-bit values into 32-bit words packed_b2_hex = np.array(pack_hex_24b(b2_hex)) - write_vector_mem_hex(packed_b2_hex, "mem", path) + write_vector_mem_hex(packed_b2_hex, mem_file, path) b3_hex = np.vectorize(lambda val: to_hex(val, bit_size = 24))(self.Bv[h]) # pack 24-bit values into 32-bit words packed_b3_hex = np.array(pack_hex_24b(b3_hex)) - write_vector_mem_hex(packed_b3_hex, "mem", path) + write_vector_mem_hex(packed_b3_hex, mem_file, path) b4_hex = np.vectorize(lambda val: to_hex(val, bit_size = 24))(self.Bo[h]) # pack 24-bit values into 32-bit words packed_b4_hex = np.array(pack_hex_24b(b4_hex)) - write_vector_mem_hex(packed_b4_hex, "mem", path) + write_vector_mem_hex(packed_b4_hex, mem_file, path) # Write output qp = split_matrix(self.Qp_requant[h], (self.ITA_M, self.ITA_M)) @@ -567,6 +704,33 @@ def remove_if_exists(file_name): out = split_matrix(self.Out_soft_requant[h], (self.ITA_M, self.ITA_M)) write_matrix_mem_hex(pack_array_8b_to_word(out, hex_string = False), "OW", path) + # Layer: Feedforward + ff = split_matrix(self.FF, (self.ITA_M, self.ITA_M)) + write_matrix_mem_hex(pack_array_8b_to_word(ff, hex_string = False), mem_file, path) + + wff = split_matrix(np.transpose(self.Wff[0]), (self.ITA_M, self.ITA_M)) + write_matrix_mem_hex(pack_array_8b_to_word(wff, hex_string = False), mem_file, path) + + wff2 = split_matrix(np.transpose(self.Wff2[0]), (self.ITA_M, self.ITA_M)) + write_matrix_mem_hex(pack_array_8b_to_word(wff2, hex_string = False), mem_file, path) + + bff_hex = np.vectorize(lambda val: to_hex(val, bit_size = 24))(self.Bff[0]) + # pack 24-bit values into 32-bit words + packed_bff_hex = np.array(pack_hex_24b(bff_hex)) + write_vector_mem_hex(packed_bff_hex, mem_file, path) + + bff2_hex = np.vectorize(lambda val: to_hex(val, bit_size = 24))(self.Bff2[0]) + # pack 24-bit values into 32-bit words + packed_bff2_hex = np.array(pack_hex_24b(bff2_hex)) + write_vector_mem_hex(packed_bff2_hex, mem_file, path) + + # Write output + ff = split_matrix(self.FFp_requant[0], (self.ITA_M, self.ITA_M)) + write_matrix_mem_hex(pack_array_8b_to_word(ff, hex_string = False), "F1", path) + + ff2 = split_matrix(self.FF2p_requant[0], (self.ITA_M, self.ITA_M)) + write_matrix_mem_hex(pack_array_8b_to_word(ff2, hex_string = False), "F2", path) + def generate_snitch_cluster(self) -> str: """ This function generates a header file for ITA integrated into the the Snitch cluster. @@ -801,10 +965,14 @@ def generateTestVectors(path, **kwargs): s = kwargs['S'] p = kwargs['P'] e = kwargs['E'] + f = kwargs['F'] h = kwargs['H'] + activation = kwargs['activation'] bias = int(not kwargs['no_bias']) + export_snitch_cluster = kwargs['export_snitch_cluster'] + export_mempool = kwargs['export_mempool'] - acc1 = Transformer(s, p, e, h, bias = bias, path = path) + acc1 = Transformer(s, p, e, f, h, bias = bias, path = path, activation = activation) if kwargs['verbose']: print("=> Generating test vectors...") @@ -816,9 +984,13 @@ def generateTestVectors(path, **kwargs): acc1.step5_AV() acc1.step6_O() acc1.step7_Osum() + acc1.feedforward_layer() + acc1.test_activations() - acc1.export_mempool(kwargs['mem_path']) - acc1.export_snitch_cluster(kwargs['mem_path']) + if export_mempool: + acc1.export_mempool(kwargs['mem_path']) + if export_snitch_cluster: + acc1.export_snitch_cluster(kwargs['mem_path']) acc1.export_hwpe() acc1.export_numpy() diff --git a/PyITA/gelu.py b/PyITA/gelu.py new file mode 100644 index 0000000..7648a5a --- /dev/null +++ b/PyITA/gelu.py @@ -0,0 +1,103 @@ +# Copyright 2024 ETH Zurich and University of Bologna. +# Licensed under the Apache License, Version 2.0, see LICENSE for details. +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np + +from .util import (round_to_i8, round_to_u8, round_to_i16) +from typing import Tuple +from numpy import int8 as i8, int16 as i16, int32 as i32, float32 as f32, uint8 as u8, uint16 as u16 + + +def i_gelu(q: i8, q_1: i16, q_b: i16, q_c: i16) -> i32: + q_clipped = max(q, -2**7 + 1) + q_erf: i32 = i_erf(q_clipped, q_b, q_c) + q_out: i32 = q_clipped * (q_erf + q_1) + return q_out + + +def gelu_requantize(q: i32, eps_mul: i8, eps_shift: u8, eps_add: u8) -> i8: + q_mul: i64 = eps_mul * q + shifted: f32 = q_mul / 2**float(eps_shift) + eps_add + q_req: i8 = round_to_i8(shifted) + return q_req + + +def i_gelu_requantized(q: i8, q_1: i16, q_b: i16, q_c: i16, eps_mul: u8, eps_shift: u8, eps_add: u8) -> i8: + q_out: i32 = i_gelu(q, q_1, q_b, q_c) + q_req: i8 = gelu_requantize(q_out, eps_mul, eps_shift, eps_add) + return q_req + + +def get_i_gelu_constants(S: f32) -> Tuple[i16, i16, i16, float, float, float]: + a: float = -0.2888 + b: float = -1.769 + c: float = 1 + S_2: f32 = S / np.sqrt(2) + q_1: i16 = round_to_i16(1 / (a * S_2**2)) + q_b: i16 = round_to_i16(b / S_2) + q_c: i16 = round_to_i16(c / (a * S_2**2)) + return q_1, q_b, q_c, a, b, c + + +def get_i_gelu_requantized_constants(S: f32, D: i32) -> Tuple[i16, i16, i16, float, float, float, u8, u8, u8, f32]: + q_1, q_b, q_c, a, b, c = get_i_gelu_constants(S) + S_2: f32 = S / np.sqrt(2) + S_out: f32 = S * a * S_2**2 / 2 + # Flip sign of eps_mul to ensure its positive + eps_mul: u8 = round_to_u8(-S_out / S * D) + eps_shift: u8 = round_to_i8(np.log2(D)) + eps_add: u8 = 0 + # Compensate for the sign flip in eps_mul by negating S + return q_1, q_b, q_c, a, b, c, eps_mul, eps_shift, eps_add, -S + + +def i_gelu_wrapper(q: i8, S: f32) -> Tuple[i32, f32]: + S_2: f32 = S / np.sqrt(2) + q_1, q_b, q_c, a, _, _ = get_i_gelu_constants(S) + q_out: i32 = i_gelu(q, q_1, q_b, q_c) + S_out: f32 = S * a * S_2**2 / 2 + return q_out, S_out + + +def i_gelu_wrapper_requantized(q: i8, S: f32, D: i32) -> Tuple[i8, f32]: + q_1, q_b, q_c, a, _, _, eps_mul, eps_shift, eps_add, S_out = get_i_gelu_requantized_constants(S, D) + q_out: i32 = i_gelu_requantized(q, q_1, q_b, q_c, eps_mul, eps_shift, eps_add) + return q_out, S_out + + +def i_erf(q: i8, q_b: i16, q_c: i16) -> i32: + q_sgn: i8 = np.sign(q) + q_abs: i8 = np.abs(q) + q_clipped: i8 = np.clip(q_abs, 0, -q_b) + q_L: i32 = i_poly(q_clipped, q_b, q_c) + q_out: i32 = q_sgn * q_L + return q_out + + +def i_erf_wrapper(q: i8, S: i8) -> Tuple[i32, f32]: + a: float = -0.2888 + b: float = -1.769 + c: float = 1 + q_b: i16 = round_to_i16(b / S) + q_c: i16 = round_to_i16(c / (a * S**2)) + S_out: f32 = a * S**2 + q_out: i32 = i_erf(q, q_b, q_c) + return q_out, S_out + + +def i_poly(q: i8, q_b: i16, q_c: i16) -> i32: + q16: i16 = q.astype(i16) + q_c32: i32 = q_c.astype(i32) + d: i16 = q16 + q_b + d_sq: i16 = d**2 + q_out: i32 = d_sq + q_c32 + return q_out.astype(i32) + + +def i_poly_wrapper(q: i8, S: f32, a: f32, b: f32, c: f32) -> Tuple[i32, f32]: + q_b: i16 = round_to_i16(b / S) + q_c: i16 = round_to_i16(c / (a * S**2)) + S_out: f32 = a * S**2 + q_out: i32 = i_poly(q, q_b, q_c) + return q_out, S_out diff --git a/PyITA/test_gelu.py b/PyITA/test_gelu.py new file mode 100644 index 0000000..934ba88 --- /dev/null +++ b/PyITA/test_gelu.py @@ -0,0 +1,262 @@ +# Copyright 2024 ETH Zurich and University of Bologna. +# Licensed under the Apache License, Version 2.0, see LICENSE for details. +# SPDX-License-Identifier: Apache-2.0 + +# This test file is used to check the integer quantization of the GELU function. + +import pytest +import torch +import numpy as np +import pytest_check as check +import pandas as pd +import matplotlib.pyplot as plt +import seaborn as sns + +from .util import * +from .gelu import * +from .ITA import * + +N_SAMPLES = 75 + + +def pretty_print(x, x_q, S, res_q, res_S, deq_res, exp_res): + print( + f"x={x:>10.2f}, x_q={x_q:>10}, S={S:>10.1g}, res_q={res_q:>10}, res_S={res_S:>10.1g}, deq_res={deq_res:>10.2f}, exp_res={exp_res:>10.2f}, abs_err={(np.abs(deq_res - exp_res)):>10.3f}" + ) + + +file_dir = os.path.dirname(os.path.abspath(__file__)) +plot_dir = os.path.join(file_dir, 'plots') + + +def plot(data: pd.DataFrame, title: str, quantized_y_label: str, expected_y_label: str, alpha: float): + l2_error = np.mean(np.sqrt((data['deq_res'] - data['exp_res'])**2)) + l_inf_error = np.max(np.abs(data['deq_res'] - data['exp_res'])) + print(f'alpha: {alpha}, average L2 error: {l2_error:.4f}, Linf error: {l_inf_error:.3f}') + sns.set_theme() + fig, ax = plt.subplots(1, 1, figsize = (10, 6)) + sns.lineplot(data = data, + x = 'x', + y = 'deq_res', + label = quantized_y_label, + ax = ax, + marker = 'o', + linestyle = '--') + sns.lineplot(data = data, x = 'x', y = 'exp_res', label = expected_y_label, ax = ax) + ax.set_title( + f'{title}\n($\\alpha$: {alpha}, average $L_2$ error: {l2_error:.4f}, $L_{{\\infty}}$ error: {l_inf_error:.3f})') + + ax.set_xlabel('$x$') + ax.set_ylabel('Value') + filename = os.path.join(plot_dir, f'{title}.png') + if not os.path.exists(plot_dir): + os.makedirs(plot_dir) + plt.savefig(filename) + + +def test_i_gelu_requant(): + n_bits = 8 + D = 2**20 + xs = np.linspace(-4, 4, N_SAMPLES) + clip_lo = -np.abs(xs).max() + qs, S = almost_symmetric_quantize(xs, clip_lo, n_bits) + data = [] + for x, q in zip(xs, qs): + res_q, res_S = i_gelu_wrapper_requantized(q, S, D) + deq_res = res_q * res_S + exp_res = torch.nn.functional.gelu(torch.tensor(x, dtype = torch.float32)).item() + pretty_print(x, q, S, res_q, res_S, deq_res, exp_res) + data.append({'x': x, 'x_q': q, 'S': S, 'res_q': res_q, 'res_S': res_S, 'deq_res': deq_res, 'exp_res': exp_res}) + check.almost_equal(deq_res, exp_res, abs = 62e-2) + plot(pd.DataFrame(data), + quantized_y_label = 'I-GELU(x)', + expected_y_label = 'GELU(x)', + title = 'I-GELU with 8-bit almost symmetric quantization (output requantized to 8 bit)', + alpha = -clip_lo) + + +def test_i_gelu_edge_cases(): + n_bits = 8 + qs = np.array([-128, -127, -64, 0, 64, 127], dtype = np.int8) + clip_lo = -4 + S, _ = get_almost_symmetric_scaling_factor(clip_lo, n_bits) + xs = qs * S + for q, x in zip(qs, xs): + res_q, res_S = i_gelu_wrapper(q, S) + deq_res = res_q * res_S + exp_res = torch.nn.functional.gelu(torch.tensor(x, dtype = torch.float32)).item() + pretty_print(x, q, S, res_q, res_S, deq_res, exp_res) + check.almost_equal(deq_res, exp_res, abs = 1e-2) + + +def test_gelu(): + n_bits = 8 + xs = np.linspace(-4, 4, N_SAMPLES) + # xs = np.linspace(-1.769, 1.769, 25) + clip_lo = -np.abs(xs).max() + # alpha = 4 + x_qs, S = almost_symmetric_quantize(xs, clip_lo, n_bits) + data = [] + + for x, x_q in zip(xs, x_qs): + res_q, res_S = i_gelu_wrapper(x_q, S) + deq_res = res_q * res_S + exp_res = torch.nn.functional.gelu(torch.tensor(x, dtype = torch.float32)).item() + pretty_print(x, x_q, S, res_q, res_S, deq_res, exp_res) + data.append({ + 'x': x, + 'x_q': x_q, + 'S': S, + 'res_q': res_q, + 'res_S': res_S, + 'deq_res': deq_res, + 'exp_res': exp_res + }) + check.almost_equal(deq_res, exp_res, abs = 33e-3) + plot(pd.DataFrame(data), + quantized_y_label = 'I-GELU(x)', + expected_y_label = 'GELU(x)', + title = 'I-GELU with 8-bit almost symmetric quantization', + alpha = -clip_lo) + + +def test_gelu_simple(): + xs = np.array([-20, -10, -3, -2, -1, 0, 1, 2, 3, 10, 20]) * 0.1 + n_bits = 8 + clip_lo = -np.abs(xs).max() + x_qs, S = almost_symmetric_quantize(xs, clip_lo, n_bits) + + for x, x_q in zip(xs, x_qs): + res_q, res_S = i_gelu_wrapper(x_q, S) + deq_res = res_q * res_S + exp_res = torch.nn.functional.gelu(torch.tensor(x, dtype = torch.float32)).item() + pretty_print(x, x_q, S, res_q, res_S, deq_res, exp_res) + check.almost_equal(deq_res, exp_res, abs = 1e-1) + + +def test_erf(): + xs = np.linspace(-4, 4, N_SAMPLES) + # xs = np.linspace(-1.769, 1.769, 25) + n_bits = 8 + clip_lo = -np.abs(xs).max() + # alpha = 4 + x_qs, S = almost_symmetric_quantize(xs, clip_lo, n_bits) + data = [] + + for x, x_q in zip(xs, x_qs): + res_q, res_S = i_erf_wrapper(x_q, S) + deq_res = res_q * res_S + exp_res = torch.erf(torch.tensor(x, dtype = torch.float32)).item() + pretty_print(x, x_q, S, res_q, res_S, deq_res, exp_res) + data.append({ + 'x': x, + 'x_q': x_q, + 'S': S, + 'res_q': res_q, + 'res_S': res_S, + 'deq_res': deq_res, + 'exp_res': exp_res + }) + check.almost_equal(deq_res, exp_res, abs = 1e-1) + plot(pd.DataFrame(data), + quantized_y_label = 'I-ERF(x)', + expected_y_label = 'ERF(x)', + title = 'I-ERF with 8-bit almost symmetric quantization', + alpha = -clip_lo) + + +def test_erf_simple(): + xs = np.array([-20, -10, -3, -2, -1, 0, 1, 2, 3, 10, 20]) * 0.1 + n_bits = 8 + clip_lo = -np.abs(xs).max() + x_qs, S = almost_symmetric_quantize(xs, clip_lo, n_bits) + + for x, x_q in zip(xs, x_qs): + res_q, res_S = i_erf_wrapper(x_q, S) + deq_res = res_q * res_S + exp_res = torch.erf(torch.tensor(x, dtype = torch.float32)).item() + pretty_print(x, x_q, S, res_q, res_S, deq_res, exp_res) + check.almost_equal(deq_res, exp_res, abs = 9e-2) + + +def test_i_poly(): + n_bits = 8 + xs = np.linspace(0, 1.769, 10) + a, b, c = -0.2888, -1.769, 1 + clip_lo = -np.abs(xs).max() + x_qs, S = almost_symmetric_quantize(xs, clip_lo, n_bits) + for x, x_q in zip(xs, x_qs): + res_q, res_S = i_poly_wrapper(x_q, S, a, b, c) + deq_res = res_q * res_S + exp_res = a * (x + b)**2 + c + pretty_print(x, x_q, S, res_q, res_S, deq_res, exp_res) + check.almost_equal(deq_res, exp_res, abs = 5e-3) + + +def test_i_poly_simple(): + n_bits = 8 + a, b, c = 2, 1, 1 + xs = np.array([-3, -1, 0, 1, 2], dtype = np.int8) + clip_lo = xs.min() + x_qs, S = almost_symmetric_quantize(xs, clip_lo, n_bits) + for x, x_q in zip(xs, x_qs): + res_q, res_S = i_poly_wrapper(x_q, S, a, b, c) + deq_res = res_q * res_S + exp_res = a * (x + b)**2 + c + pretty_print(x, x_q, S, res_q, res_S, deq_res, exp_res) + check.almost_equal(deq_res, exp_res, abs = 13e-2) + + +def test_quantize(): + activations = np.array([-2, -1, 0, 1, 2, 3]) + alpha = 3 + n_bits = 3 + expected_output = np.array([-2, -1, 0, 1, 2, 3], dtype = np.int8) + x_q, _ = quantize(activations, alpha, n_bits) + assert np.array_equal(x_q, expected_output) + + activations = np.array([-4, -2, 0, 2, 4, 6]) + alpha = 3 + n_bits = 3 + expected_output = np.array([-3, -2, 0, 2, 3, 3], dtype = np.int8) + x_q, _ = quantize(activations, alpha, n_bits) + assert np.array_equal(x_q, expected_output) + + activations = np.array([-4, -2, 0, 2, 4, 6]) + alpha = 3 + n_bits = 2 + expected_output = np.array([-1, -1, 0, 1, 1, 1], dtype = np.int8) + output, _ = quantize(activations, alpha, n_bits) + assert np.array_equal(output, expected_output) + + activations = np.array([-4, -2, 0, 2, 4, 6]) + alpha = 4 + n_bits = 8 + expected_output = np.array([-127, -63, 0, 64, 127, 127], dtype = np.int8) + output, _ = quantize(activations, alpha, n_bits) + assert np.array_equal(output, expected_output) + + +def test_almost_symmetric_quantize(): + activations = np.array([-4, -2, 0, 2, 127 / 32, 4]) + clip_lo = -4 + n_bits = 8 + expected_S = 1 / 32 + S, _ = get_almost_symmetric_scaling_factor(clip_lo, n_bits) + assert np.isclose(S, expected_S) + expected_output = np.array([-128, -64, 0, 64, 127, 127], dtype = np.int8) + x_q, _ = almost_symmetric_quantize(activations, clip_lo, n_bits) + assert np.array_equal(x_q, expected_output) + + +def test_dequantize(): + quantized_activations = np.array([-2, -1, 0, 1, 2, 3], dtype = np.int8) + alpha = 3 + n_bits = 3 + expected_output = np.array([-2, -1, 0, 1, 2, 3]) + output = dequantize(quantized_activations, alpha, n_bits) + assert np.allclose(output, expected_output) + + +if __name__ == '__main__': + pytest.main(['-v', __file__]) diff --git a/PyITA/util.py b/PyITA/util.py index 9056658..958b99d 100644 --- a/PyITA/util.py +++ b/PyITA/util.py @@ -16,11 +16,13 @@ # ---------------------------------------------------------------------- import os -from typing import SupportsIndex, Tuple, Union +from typing import Optional, SupportsIndex, Tuple, Union import numpy as np from numpy.typing import DTypeLike +from numpy import int8 as i8, int16 as i16, int32 as i32, float32 as f32, uint8 as u8, uint16 as u16 + def random_shuffled_tensor(shape, bitwidth: int, type: DTypeLike = np.int8, scaling = 1 / 4) -> np.ndarray: """ @@ -441,3 +443,77 @@ def split_matrix(m: np.ndarray, block_shape: Tuple[SupportsIndex, SupportsIndex] return res else: raise ValueError("Matrix must be 2D") + + +def round(x: f32, n_bits: int = 8): + x_clip = np.clip(x, -2**(n_bits - 1), 2**(n_bits - 1) - 1) + return np.floor(x_clip + 0.5 + np.finfo(f32).eps).astype(int) + + +def clip(x: f32, n_bits: int = 8) -> f32: + return np.clip(x, -2**(n_bits - 1), 2**(n_bits - 1) - 1) + + +def round_and_clip(x: f32, n_bits: int = 8) -> f32: + x_rounded = np.floor(x + 0.5 + np.finfo(f32).eps) + x_clipped = clip(x_rounded, n_bits) + return x_clipped + + +def round_to_i8(x: f32) -> i8: + x_rounded_clipped: f32 = round_and_clip(x, 8) + return x_rounded_clipped.astype(i8) + + +def round_to_u8(x: f32) -> u8: + x_rounded_clipped: f32 = round_and_clip(x, 8) + return x_rounded_clipped.astype(u8) + + +def round_to_i16(x: f32) -> i16: + x_rounded_clipped: f32 = round_and_clip(x, 16) + return x_rounded_clipped.astype(i16) + + +def get_scaling_factor(alpha: f32, n_bits: int = 8) -> f32: + S: f32 = alpha / (2**(n_bits - 1) - 1) + return S + + +def quantize(activations: np.ndarray, alpha: f32, n_bits: int = 8, S: Optional[f32] = None) -> Tuple[np.ndarray, f32]: + x_q = np.clip(activations, -alpha, alpha) + if S is None: + S = get_scaling_factor(alpha, n_bits) + x_q = x_q / S + x_q = np.array(list(map(round, x_q))) + return x_q, S + + +def dequantize(quantized_activations: np.ndarray, alpha: f32, n_bits: int = 8) -> np.ndarray: + S = get_scaling_factor(alpha, n_bits) + activations = quantized_activations * S + return activations + + +def get_almost_symmetric_scaling_factor(clip_lo: f32, n_bits: int = 8) -> Tuple[f32, f32]: + if 2**n_bits == 2: + return 1 + n_levels = 2**n_bits + scale = (-n_levels + 2) / n_levels + clip_hi = clip_lo * scale + S = clip_hi / (n_levels / 2 - 1) + return S, clip_hi + + +def almost_symmetric_quantize(activations: np.ndarray, clip_lo: f32, n_bits: int = 8) -> Tuple[np.ndarray, f32]: + S, clip_hi = get_almost_symmetric_scaling_factor(clip_lo, n_bits) + x_q = np.clip(activations, clip_lo, clip_hi) + x_q = x_q / S + x_q = np.array(list(map(round, x_q))) + return x_q, S + + +def almost_symmetric_dequantize(quantized_activations: np.ndarray, clip_lo: f32, n_bits: int = 8) -> np.ndarray: + S, _ = get_almost_symmetric_scaling_factor(clip_lo, n_bits) + activations = quantized_activations * S + return activations diff --git a/README.md b/README.md index ba7de26..3a29e23 100644 --- a/README.md +++ b/README.md @@ -56,6 +56,11 @@ $> source venv/bin/activate $> pip install -r requirements.txt ``` +If you want to enable pre-commit hooks, which perform code formatting and linting, run the following command: +```sh +$> pre-commit install +``` + In case you want to compare the softmax implementation with the QuantLib implementation, you need to install the QuantLib library and additional dependencies. To do so, create a virtual environment: ```sh diff --git a/modelsim/Makefile b/modelsim/Makefile index 77ce325..7d181aa 100644 --- a/modelsim/Makefile +++ b/modelsim/Makefile @@ -25,3 +25,6 @@ sim_ita_tb: lib build sim_ita_hwpe_tb: lib build cd $(buildpath) && $(VSIM) $(VSIM_FLAGS) -do 'set DEBUG $(DEBUG); source ../sim_ita_hwpe_tb.tcl' + +sim_activation_tb: lib build + cd $(buildpath) && $(VSIM) $(VSIM_FLAGS) -do 'source ../sim_activation_tb.tcl' diff --git a/modelsim/return_status.sh b/modelsim/return_status.sh index 6a3e3e0..7990c23 100755 --- a/modelsim/return_status.sh +++ b/modelsim/return_status.sh @@ -4,59 +4,77 @@ # Licensed under the Apache License, Version 2.0, see LICENSE for details. # SPDX-License-Identifier: Apache-2.0 +export SIM_LOG=$1 +export SEQUENCE_LEN=$2 +export EMBEDDING_SIZE=$3 +export PROJECTION_SIZE=$4 +export FEEDFORWARD_SIZE=$5 +export TEST_BENCH=$6 + # Number of dot product units export ITA_N=16 export ITA_M=64 +# Round to the clostest of multiple of $ITA_M +ita_s=$(( ITA_M * ( (( ${SEQUENCE_LEN} - 1) / ITA_M) + 1) )) +ita_e=$(( ITA_M * ( (( ${EMBEDDING_SIZE} - 1) / ITA_M) + 1) )) +ita_p=$(( ITA_M * ( (( ${PROJECTION_SIZE} - 1) / ITA_M) + 1) )) +ita_f=$(( ITA_M * ( (( ${FEEDFORWARD_SIZE} - 1) / ITA_M) + 1) )) + +exp_n_outputs_Q=$((ita_s * ita_p / ITA_N)) +exp_n_outputs_K=${exp_n_outputs_Q} +exp_n_outputs_V=${exp_n_outputs_Q} +exp_n_outputs_OW=$((ita_s * ita_e / ITA_N)) +exp_n_outputs_FF1=$((ita_s * ita_f / ITA_N)) +exp_n_outputs_FF2=$((ita_s * ita_e / ITA_N)) + # Check if the simulation log exists -if [ ! -f $1 ] -then - echo "❗ Simulation log not found." - exit 1 +if [[ ! -f ${SIM_LOG} ]]; then + echo "❗ Simulation log not found." + exit 1 fi +check_n_outputs() { + local n_outputs=$1 + local phase=$2 + + if ! grep -q "${n_outputs} outputs were checked in phase ${phase}." "${SIM_LOG}"; then + echo "❌ Simulation did not finish successfully. Expected ${n_outputs} outputs in phase ${phase}." + exit 1 + fi + + echo "✅ Checked ${n_outputs} outputs in phase ${phase}." +} + # ITA TB # Check if the simulation log has finished successfully by checking the number of outputs -if [ $4 == "ita_tb" ] +if [ ${TEST_BENCH} == "ita_tb" ] then - # Round $2 and $3 to the clostest of multiple of $ITA_M - ita_s=$(( ITA_M * ( (( $2 - 1) / ITA_M) + 1) )) - ita_e=$(( ITA_M * ( (( $3 - 1) / ITA_M) + 1) )) - num_outputs=$((ita_s * ita_e / ITA_N)) - if grep -q "${num_outputs} outputs were checked in phase 4." $1 - then - count=$(grep -c "Wrong value" $1) - - if [ $count -gt 0 ]; - then - echo "❌ Found ${count} errors in the simulation log." - exit 1 - else - echo "✅ No errors found in the simulation log." - exit 0 - fi - else - echo "❗ Simulation did not finish successfully." - exit 1 - fi + check_n_outputs "${exp_n_outputs_Q}" 0 + check_n_outputs "${exp_n_outputs_K}" 1 + check_n_outputs "${exp_n_outputs_V}" 2 + check_n_outputs "${exp_n_outputs_OW}" 4 + check_n_outputs "${exp_n_outputs_FF1}" 5 + check_n_outputs "${exp_n_outputs_FF2}" 6 + + n_error_lines=$(grep -c "Wrong value" "${SIM_LOG}") # HWPE TB # Check if the simulation log has finished successfully -elif [ $4 == "hwpe_tb" ] +elif [ ${TEST_BENCH} == "hwpe_tb" ] then - if grep -q "Comparing output" $1 && grep -q "\$finish" $1 + if ! grep -q "Comparing output" $1 && grep -q "\$finish" $1 then - count=$(grep -c "mismatch" $1) - - if [ $count -gt 0 ]; - then - echo "❌ Found ${count} errors in the simulation log." - exit 1 - else - echo "✅ No errors found in the simulation log." - exit 0 - fi - else echo "❗ Simulation did not finish successfully." exit 1 fi + + n_error_lines=$(grep -c "mismatch" "${SIM_LOG}") +fi + +if [[ ${n_error_lines} -gt 0 ]]; then + echo "❌ Found ${n_error_lines} errors in the simulation log." + exit 1 fi + +echo "✅ No errors found in the simulation log." +exit 0 \ No newline at end of file diff --git a/modelsim/sim_activation_tb.tcl b/modelsim/sim_activation_tb.tcl new file mode 100644 index 0000000..1155ec0 --- /dev/null +++ b/modelsim/sim_activation_tb.tcl @@ -0,0 +1,26 @@ +# Copyright 2024 ETH Zurich and University of Bologna. +# Licensed under the Apache License, Version 2.0, see LICENSE for details. +# SPDX-License-Identifier: Apache-2.0 + +set DEBUG ON + +# Set working library. +set LIB work + +if {$DEBUG == "ON"} { + set VOPT_ARG "+acc" + echo $VOPT_ARG + set DB_SW "-debugdb" +} else { + set DB_SW "" +} + +quit -sim + +vsim -voptargs=$VOPT_ARG $DB_SW -pedanticerrors -lib $LIB activation_tb + +if {$DEBUG == "ON"} { + add log -r /* +} + +run -a diff --git a/requirements.txt b/requirements.txt index 31ec016..4bfbc47 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,3 +8,7 @@ onnxruntime netron seaborn matplotlib +torch +pytest +pytest-check +pre-commit diff --git a/src/hwpe/ita_hwpe_ctrl.sv b/src/hwpe/ita_hwpe_ctrl.sv index ea06083..3b371bf 100644 --- a/src/hwpe/ita_hwpe_ctrl.sv +++ b/src/hwpe/ita_hwpe_ctrl.sv @@ -9,6 +9,8 @@ import ita_package::M; import ita_package::N; import hwpe_ctrl_package::*; import hwpe_stream_package::*; +import ita_package::layer_e; +import ita_package::activation_e; module ita_hwpe_ctrl ( @@ -65,24 +67,38 @@ module ita_hwpe_ctrl ctrl_engine_o.tile_s = reg_file.hwpe_params[ITA_REG_TILES][3:0]; ctrl_engine_o.tile_e = reg_file.hwpe_params[ITA_REG_TILES][7:4]; ctrl_engine_o.tile_p = reg_file.hwpe_params[ITA_REG_TILES][11:8]; + ctrl_engine_o.tile_f = reg_file.hwpe_params[ITA_REG_TILES][15:12]; ctrl_engine_o.eps_mult[0] = reg_file.hwpe_params[ITA_REG_EPS_MULT0][7:0]; ctrl_engine_o.eps_mult[1] = reg_file.hwpe_params[ITA_REG_EPS_MULT0][15:8]; ctrl_engine_o.eps_mult[2] = reg_file.hwpe_params[ITA_REG_EPS_MULT0][23:16]; ctrl_engine_o.eps_mult[3] = reg_file.hwpe_params[ITA_REG_EPS_MULT0][31:24]; ctrl_engine_o.eps_mult[4] = reg_file.hwpe_params[ITA_REG_EPS_MULT1][7:0]; ctrl_engine_o.eps_mult[5] = reg_file.hwpe_params[ITA_REG_EPS_MULT1][15:8]; + ctrl_engine_o.eps_mult[6] = reg_file.hwpe_params[ITA_REG_EPS_MULT1][23:16]; + ctrl_engine_o.eps_mult[7] = reg_file.hwpe_params[ITA_REG_EPS_MULT1][31:24]; ctrl_engine_o.right_shift[0] = reg_file.hwpe_params[ITA_REG_RIGHT_SHIFT0][7:0]; ctrl_engine_o.right_shift[1] = reg_file.hwpe_params[ITA_REG_RIGHT_SHIFT0][15:8]; ctrl_engine_o.right_shift[2] = reg_file.hwpe_params[ITA_REG_RIGHT_SHIFT0][23:16]; ctrl_engine_o.right_shift[3] = reg_file.hwpe_params[ITA_REG_RIGHT_SHIFT0][31:24]; ctrl_engine_o.right_shift[4] = reg_file.hwpe_params[ITA_REG_RIGHT_SHIFT1][7:0]; ctrl_engine_o.right_shift[5] = reg_file.hwpe_params[ITA_REG_RIGHT_SHIFT1][15:8]; + ctrl_engine_o.right_shift[6] = reg_file.hwpe_params[ITA_REG_RIGHT_SHIFT1][23:16]; + ctrl_engine_o.right_shift[7] = reg_file.hwpe_params[ITA_REG_RIGHT_SHIFT1][31:24]; ctrl_engine_o.add[0] = reg_file.hwpe_params[ITA_REG_ADD0][7:0]; ctrl_engine_o.add[1] = reg_file.hwpe_params[ITA_REG_ADD0][15:8]; ctrl_engine_o.add[2] = reg_file.hwpe_params[ITA_REG_ADD0][23:16]; ctrl_engine_o.add[3] = reg_file.hwpe_params[ITA_REG_ADD0][31:24]; ctrl_engine_o.add[4] = reg_file.hwpe_params[ITA_REG_ADD1][7:0]; ctrl_engine_o.add[5] = reg_file.hwpe_params[ITA_REG_ADD1][15:8]; + ctrl_engine_o.add[6] = reg_file.hwpe_params[ITA_REG_ADD1][23:16]; + ctrl_engine_o.add[7] = reg_file.hwpe_params[ITA_REG_ADD1][31:24]; + ctrl_engine_o.gelu_b = reg_file.hwpe_params[ITA_REG_GELU_B_C][15:0]; + ctrl_engine_o.gelu_c = reg_file.hwpe_params[ITA_REG_GELU_B_C][31:16]; + ctrl_engine_o.activation_requant_mult = reg_file.hwpe_params[ITA_REG_ACTIVATION_REQUANT][7:0]; + ctrl_engine_o.activation_requant_shift = reg_file.hwpe_params[ITA_REG_ACTIVATION_REQUANT][15:8]; + ctrl_engine_o.activation_requant_add = reg_file.hwpe_params[ITA_REG_ACTIVATION_REQUANT][23:16]; + ctrl_engine_o.layer = layer_e'(reg_file.hwpe_params[ITA_REG_CTRL_ENGINE][1:0]); + ctrl_engine_o.activation = activation_e'(reg_file.hwpe_params[ITA_REG_CTRL_ENGINE][3:2]); ctrl_stream_o.weight_preload = reg_file.hwpe_params[ITA_REG_CTRL_STREAM][0]; ctrl_stream_o.weight_nextload = reg_file.hwpe_params[ITA_REG_CTRL_STREAM][1]; ctrl_stream_o.bias_disable = reg_file.hwpe_params[ITA_REG_CTRL_STREAM][2]; diff --git a/src/hwpe/ita_hwpe_package.sv b/src/hwpe/ita_hwpe_package.sv index 313f80e..984685f 100644 --- a/src/hwpe/ita_hwpe_package.sv +++ b/src/hwpe/ita_hwpe_package.sv @@ -12,7 +12,7 @@ package ita_hwpe_package; parameter int unsigned N_CORES = 9; parameter int unsigned N_CONTEXT = 2; parameter int unsigned ID_WIDTH = 2; - parameter int unsigned ITA_IO_REGS = 14; // 5 address + 8 parameters + 1 sync + parameter int unsigned ITA_IO_REGS = 17; // 5 address + 11 parameters + 1 sync parameter int unsigned ITA_TCDM_DW = 1024; parameter int unsigned ITA_INPUT_DW = M*WI; @@ -27,14 +27,18 @@ package ita_hwpe_package; parameter int unsigned ITA_REG_BIAS_PTR = 3; parameter int unsigned ITA_REG_OUTPUT_PTR = 4; parameter int unsigned ITA_REG_SEQ_LENGTH = 5; - parameter int unsigned ITA_REG_TILES = 6; // tile_s [3:0], tile_e [7:4], tile_p [11:8] + parameter int unsigned ITA_REG_TILES = 6; // tile_s [3:0], tile_e [7:4], tile_p [11:8], tile_f [15:12] parameter int unsigned ITA_REG_EPS_MULT0 = 7; // eps_mult[0] [7:0], eps_mult[1] [15:8], eps_mult[2] [23:16], eps_mult[3] [31:24] - parameter int unsigned ITA_REG_EPS_MULT1 = 8; // eps_mult[4] [7:0], eps_mult[5] [15:8] + parameter int unsigned ITA_REG_EPS_MULT1 = 8; // eps_mult[4] [7:0], eps_mult[5] [15:8], eps_mult[6] [23:16], eps_mult[7] [31:24] parameter int unsigned ITA_REG_RIGHT_SHIFT0 = 9; // right_shift[0] [7:0], right_shift[1] [15:8], right_shift[2] [23:16], right_shift[3] [31:24] - parameter int unsigned ITA_REG_RIGHT_SHIFT1 = 10; // right_shift[4] [7:0], right_shift[5] [15:8] + parameter int unsigned ITA_REG_RIGHT_SHIFT1 = 10; // right_shift[4] [7:0], right_shift[5] [15:8], right_shift[6] [23:16], right_shift[7] [31:24] parameter int unsigned ITA_REG_ADD0 = 11; // add[0] [7:0], add[1] [15:8], add[2] [23:16], add[3] [31:24] - parameter int unsigned ITA_REG_ADD1 = 12; // add[4] [7:0], add[5] [15:8] - parameter int unsigned ITA_REG_CTRL_STREAM = 13; // ctrl_stream [0]: weight preload, ctrl_stream [1]: weight nextload, ctrl_stream [2]: bias disable, ctrl_stream [3]: bias direction, ctrl_stream [4]: output disable + parameter int unsigned ITA_REG_ADD1 = 12; // add[4] [7:0], add[5] [15:8], add[6] [23:16], add[7] [31:24] + parameter int unsigned ITA_REG_CTRL_ENGINE = 13; // layer [1:0], activation [3:2] + parameter int unsigned ITA_REG_CTRL_STREAM = 14; // ctrl_stream [0]: weight preload, ctrl_stream [1]: weight nextload, ctrl_stream [2]: bias disable, ctrl_stream [3]: bias direction, ctrl_stream [4]: output disable + parameter int unsigned ITA_REG_GELU_B_C = 15; // gelu_b [15:0], gelu_c [31:16] + parameter int unsigned ITA_REG_ACTIVATION_REQUANT = 16; // activation_requant_mult [7:0], activation_requant_shift [15:8], activation_requant_add [23:16] + typedef struct packed { hci_package::hci_streamer_ctrl_t input_source_ctrl; @@ -50,21 +54,7 @@ package ita_hwpe_package; hci_package::hci_streamer_flags_t output_sink_flags; } flags_streamer_t; - typedef struct packed { - logic start ; - seq_length_t seq_length ; - proj_space_t proj_space ; - embed_size_t embed_size ; - n_heads_t n_heads ; - logic [5:0][EMS-1:0] eps_mult ; - logic [5:0][EMS-1:0] right_shift ; - logic [5:0][WI-1:0] add ; - logic [32-1:0] lin_tiles ; - logic [32-1:0] attn_tiles ; - logic [32-1:0] tile_s; - logic [32-1:0] tile_e; - logic [32-1:0] tile_p; - } ctrl_engine_t; + typedef ctrl_t ctrl_engine_t; typedef struct packed { logic busy; diff --git a/src/hwpe/tb/ita_hwpe_tb.sv b/src/hwpe/tb/ita_hwpe_tb.sv index 949bbe9..c31ecbc 100644 --- a/src/hwpe/tb/ita_hwpe_tb.sv +++ b/src/hwpe/tb/ita_hwpe_tb.sv @@ -25,10 +25,12 @@ module ita_hwpe_tb; parameter integer SEQUENCE_LEN = `ifdef SEQ_LENGTH `SEQ_LENGTH `else M_TILE_LEN `endif; parameter integer PROJECTION_SPACE = `ifdef PROJ_SPACE `PROJ_SPACE `else M_TILE_LEN `endif; parameter integer EMBEDDING_SIZE = `ifdef EMBED_SIZE `EMBED_SIZE `else M_TILE_LEN `endif; + parameter integer FEEDFORWARD_SIZE = `ifdef FF_SIZE `FF_SIZE `else M_TILE_LEN `endif; + parameter activation_e ACTIVATION = `ifdef ACTIVATION `ACTIVATION `else Identity `endif; - integer N_TILES_SEQUENCE_DIM, N_TILES_EMBEDDING_DIM, N_TILES_PROJECTION_DIM; + integer N_TILES_SEQUENCE_DIM, N_TILES_EMBEDDING_DIM, N_TILES_PROJECTION_DIM, N_TILES_FEEDFORWARD_DIM; integer N_ELEMENTS_PER_TILE; - integer N_TILES_OUTER_X[6], N_TILES_OUTER_Y [6], N_TILES_INNER_DIM[6]; + integer N_TILES_OUTER_X[N_STATES], N_TILES_OUTER_Y [N_STATES], N_TILES_INNER_DIM[N_STATES]; // Memory Map with // 0: q (SxE bytes) @@ -41,24 +43,32 @@ module ita_hwpe_tb; // 7: Bk (P*3 bytes) (four 24bit values per 32bit word) // 8: Bv (P*3 bytes) (four 24bit values per 32bit word) // 9: Bo (E*3 bytes) (four 24bit values per 32bit word) - // 10: Q (SxP bytes) - // 11: K (SxP bytes) - // 12: V (SxP bytes) - // 13: QK (SxS bytes) - // 14: AV (SxP bytes) - // 15: OW (SxE bytes) - integer BASE_PTR[16]; - - logic [6][31:0] BASE_PTR_INPUT; - logic [6][31:0] BASE_PTR_WEIGHT0; - logic [6][31:0] BASE_PTR_WEIGHT1; - logic [6][31:0] BASE_PTR_BIAS; - logic [6][31:0] BASE_PTR_OUTPUT; + // 10: ff (SxE bytes) + // 11: Wf1 (ExF bytes) + // 12: Wf2 (FxE bytes) + // 13: Bf1 (F*3 bytes) (four 24bit values per 32bit word) + // 14: Bf2 (E*3 bytes) (four 24bit values per 32bit word) + // 15: Q (SxP bytes) + // 16: K (SxP bytes) + // 17: V (SxP bytes) + // 18: QK (SxS bytes) + // 19: AV (SxP bytes) + // 20: OW (SxE bytes) + // 21: F1 (SxF bytes) + // 22: F2 (SxE bytes) + + integer BASE_PTR[23]; + + logic [N_STATES][31:0] BASE_PTR_INPUT; + logic [N_STATES][31:0] BASE_PTR_WEIGHT0; + logic [N_STATES][31:0] BASE_PTR_WEIGHT1; + logic [N_STATES][31:0] BASE_PTR_BIAS; + logic [N_STATES][31:0] BASE_PTR_OUTPUT; // HWPE Parameters localparam unsigned ITA_REG_OFFSET = 32'h20; parameter real PROB_STALL = 0.1; - parameter MEMORY_SIZE = SEQUENCE_LEN*EMBEDDING_SIZE*3+EMBEDDING_SIZE*PROJECTION_SPACE*4+PROJECTION_SPACE*3*3+EMBEDDING_SIZE*3+SEQUENCE_LEN*PROJECTION_SPACE*4+SEQUENCE_LEN*SEQUENCE_LEN; + parameter MEMORY_SIZE = SEQUENCE_LEN*EMBEDDING_SIZE*4+EMBEDDING_SIZE*PROJECTION_SPACE*4+PROJECTION_SPACE*3*3+EMBEDDING_SIZE*3+SEQUENCE_LEN*PROJECTION_SPACE*4+SEQUENCE_LEN*SEQUENCE_LEN+EMBEDDING_SIZE*FEEDFORWARD_SIZE*2+FEEDFORWARD_SIZE*3+EMBEDDING_SIZE*3; parameter int unsigned AccDataWidth = ITA_TCDM_DW; parameter int unsigned IdWidth = 8; @@ -69,6 +79,11 @@ module ita_hwpe_tb; // Variables string simdir; + string gelu_b_file = "GELU_B.txt"; + string gelu_c_file = "GELU_C.txt"; + string activation_requant_mult_file = "activation_requant_mult.txt"; + string activation_requant_shift_file = "activation_requant_shift.txt"; + string activation_requant_add_file = "activation_requant_add.txt"; // Signals logic clk, rst_n; @@ -85,10 +100,10 @@ module ita_hwpe_tb; logic [MP-1:0] tcdm_r_valid; logic [MP-1:0] tcdm_r_ready; - hwpe_ctrl_intf_periph #( - .ID_WIDTH (IdWidth) - ) periph ( - .clk (clk) + hwpe_ctrl_intf_periph #( + .ID_WIDTH (IdWidth) + ) periph ( + .clk (clk) ); localparam hci_size_parameter_t `HCI_SIZE_PARAM(tcdm_mem) = '{ @@ -110,6 +125,8 @@ module ita_hwpe_tb; $sformatf("%0d", EMBEDDING_SIZE), "_P", $sformatf("%0d", PROJECTION_SPACE), + "_F", + $sformatf("%0d", FEEDFORWARD_SIZE), "_H1_B", $sformatf("%0d", `ifdef BIAS `BIAS `else 0 `endif) }; @@ -119,6 +136,8 @@ module ita_hwpe_tb; N_TILES_EMBEDDING_DIM = EMBEDDING_SIZE / M_TILE_LEN; // Number of tiles in the projection dimension N_TILES_PROJECTION_DIM = PROJECTION_SPACE / M_TILE_LEN; + // Number of tiles in the feedforward dimension + N_TILES_FEEDFORWARD_DIM = FEEDFORWARD_SIZE / M_TILE_LEN; // Number of entries per tile N_ELEMENTS_PER_TILE = M_TILE_LEN * M_TILE_LEN; // Number of output tiles in X direction per step @@ -128,6 +147,8 @@ module ita_hwpe_tb; N_TILES_OUTER_X[QK] = N_TILES_SEQUENCE_DIM; N_TILES_OUTER_X[AV] = N_TILES_PROJECTION_DIM; N_TILES_OUTER_X[OW] = N_TILES_EMBEDDING_DIM; + N_TILES_OUTER_X[F1] = N_TILES_FEEDFORWARD_DIM; + N_TILES_OUTER_X[F2] = N_TILES_EMBEDDING_DIM; // Number of output tiles in Y direction per step N_TILES_OUTER_Y[Q ] = N_TILES_SEQUENCE_DIM; N_TILES_OUTER_Y[K ] = N_TILES_SEQUENCE_DIM; @@ -135,6 +156,8 @@ module ita_hwpe_tb; N_TILES_OUTER_Y[QK] = 1; // Only one tile row is calculated before switching to AV) N_TILES_OUTER_Y[AV] = 1; // Only one tile row is calculated before switching to QK) N_TILES_OUTER_Y[OW] = N_TILES_SEQUENCE_DIM; + N_TILES_OUTER_Y[F1] = N_TILES_SEQUENCE_DIM; + N_TILES_OUTER_Y[F2] = N_TILES_SEQUENCE_DIM; // Number of inner tiles per step N_TILES_INNER_DIM[Q ] = N_TILES_EMBEDDING_DIM; N_TILES_INNER_DIM[K ] = N_TILES_EMBEDDING_DIM; @@ -142,6 +165,8 @@ module ita_hwpe_tb; N_TILES_INNER_DIM[QK] = N_TILES_PROJECTION_DIM; N_TILES_INNER_DIM[AV] = N_TILES_SEQUENCE_DIM; N_TILES_INNER_DIM[OW] = N_TILES_PROJECTION_DIM; + N_TILES_INNER_DIM[F1] = N_TILES_EMBEDDING_DIM; + N_TILES_INNER_DIM[F2] = N_TILES_FEEDFORWARD_DIM; BASE_PTR[0 ] = 0; BASE_PTR[1 ] = BASE_PTR[0 ] + SEQUENCE_LEN * EMBEDDING_SIZE; @@ -154,42 +179,57 @@ module ita_hwpe_tb; BASE_PTR[8 ] = BASE_PTR[7 ] + PROJECTION_SPACE * 3; BASE_PTR[9 ] = BASE_PTR[8 ] + PROJECTION_SPACE * 3; BASE_PTR[10] = BASE_PTR[9 ] + EMBEDDING_SIZE * 3; - BASE_PTR[11] = BASE_PTR[10] + SEQUENCE_LEN * PROJECTION_SPACE; - BASE_PTR[12] = BASE_PTR[11] + SEQUENCE_LEN * PROJECTION_SPACE; - BASE_PTR[13] = BASE_PTR[12] + SEQUENCE_LEN * PROJECTION_SPACE; - BASE_PTR[14] = BASE_PTR[13] + SEQUENCE_LEN * SEQUENCE_LEN; - BASE_PTR[15] = BASE_PTR[14] + SEQUENCE_LEN * PROJECTION_SPACE; + BASE_PTR[11] = BASE_PTR[10] + SEQUENCE_LEN * EMBEDDING_SIZE; + BASE_PTR[12] = BASE_PTR[11] + EMBEDDING_SIZE * FEEDFORWARD_SIZE; + BASE_PTR[13] = BASE_PTR[12] + FEEDFORWARD_SIZE * EMBEDDING_SIZE; + BASE_PTR[14] = BASE_PTR[13] + FEEDFORWARD_SIZE * 3; + BASE_PTR[15] = BASE_PTR[14] + EMBEDDING_SIZE * 3; + BASE_PTR[16] = BASE_PTR[15] + SEQUENCE_LEN * PROJECTION_SPACE; + BASE_PTR[17] = BASE_PTR[16] + SEQUENCE_LEN * PROJECTION_SPACE; + BASE_PTR[18] = BASE_PTR[17] + SEQUENCE_LEN * PROJECTION_SPACE; + BASE_PTR[19] = BASE_PTR[18] + SEQUENCE_LEN * SEQUENCE_LEN; + BASE_PTR[20] = BASE_PTR[19] + SEQUENCE_LEN * PROJECTION_SPACE; + BASE_PTR[21] = BASE_PTR[20] + SEQUENCE_LEN * EMBEDDING_SIZE; + BASE_PTR[22] = BASE_PTR[21] + SEQUENCE_LEN * FEEDFORWARD_SIZE; // Base pointers BASE_PTR_INPUT[Q ] = BASE_PTR[0 ]; // q BASE_PTR_INPUT[K ] = BASE_PTR[1 ]; // k BASE_PTR_INPUT[V ] = BASE_PTR[4 ]; // Wv - BASE_PTR_INPUT[QK] = BASE_PTR[10]; // Q - BASE_PTR_INPUT[AV] = BASE_PTR[13]; // QK - BASE_PTR_INPUT[OW] = BASE_PTR[14]; // AV + BASE_PTR_INPUT[QK] = BASE_PTR[15]; // Q + BASE_PTR_INPUT[AV] = BASE_PTR[18]; // QK + BASE_PTR_INPUT[OW] = BASE_PTR[19]; // AV + BASE_PTR_INPUT[F1] = BASE_PTR[10]; // ff + BASE_PTR_INPUT[F2] = BASE_PTR[21]; // F1 BASE_PTR_WEIGHT0[Q ] = BASE_PTR[2 ]; // Wq BASE_PTR_WEIGHT0[K ] = BASE_PTR[3 ]; // Wk BASE_PTR_WEIGHT0[V ] = BASE_PTR[1 ]; // k - BASE_PTR_WEIGHT0[QK] = BASE_PTR[11]; // K - BASE_PTR_WEIGHT0[AV] = BASE_PTR[12]; // V + BASE_PTR_WEIGHT0[QK] = BASE_PTR[16]; // K + BASE_PTR_WEIGHT0[AV] = BASE_PTR[17]; // V BASE_PTR_WEIGHT0[OW] = BASE_PTR[5 ]; // Wo + BASE_PTR_WEIGHT0[F1] = BASE_PTR[11]; // Wf1 + BASE_PTR_WEIGHT0[F2] = BASE_PTR[12]; // Wf2 BASE_PTR_BIAS[Q ] = BASE_PTR[6 ]; // Bq BASE_PTR_BIAS[K ] = BASE_PTR[7 ]; // Bk BASE_PTR_BIAS[V ] = BASE_PTR[8 ]; // Bv BASE_PTR_BIAS[QK] = 32'hXXXX; BASE_PTR_BIAS[AV] = 32'hXXXX; BASE_PTR_BIAS[OW] = BASE_PTR[9 ]; // Bo - BASE_PTR_OUTPUT[Q ] = BASE_PTR[10]; // Q - BASE_PTR_OUTPUT[K ] = BASE_PTR[11]; // K - BASE_PTR_OUTPUT[V ] = BASE_PTR[12]; // V - BASE_PTR_OUTPUT[QK] = BASE_PTR[13]; // QK - BASE_PTR_OUTPUT[AV] = BASE_PTR[14]; // AV - BASE_PTR_OUTPUT[OW] = BASE_PTR[15]; // OW + BASE_PTR_BIAS[F1] = BASE_PTR[13]; // Bf1 + BASE_PTR_BIAS[F2] = BASE_PTR[14]; // Bf2 + BASE_PTR_OUTPUT[Q ] = BASE_PTR[15]; // Q + BASE_PTR_OUTPUT[K ] = BASE_PTR[16]; // K + BASE_PTR_OUTPUT[V ] = BASE_PTR[17]; // V + BASE_PTR_OUTPUT[QK] = BASE_PTR[18]; // QK + BASE_PTR_OUTPUT[AV] = BASE_PTR[19]; // AV + BASE_PTR_OUTPUT[OW] = BASE_PTR[20]; // OW + BASE_PTR_OUTPUT[F1] = BASE_PTR[21]; // F1 + BASE_PTR_OUTPUT[F2] = BASE_PTR[22]; // F2 for (int i = 0; i < 5; i++) begin BASE_PTR_WEIGHT1[i] = BASE_PTR_WEIGHT0[i+1]; end - BASE_PTR_WEIGHT1[5] = 32'hXXXX; + BASE_PTR_WEIGHT1[7] = BASE_PTR_WEIGHT0[F2]; end @@ -275,7 +315,7 @@ function automatic integer open_stim_file(string filename); return 0; stim_fd = $fopen({simdir,"/",filename}, "r"); if (stim_fd == 0) begin - $fatal("[TB] ITA: Could not open %s stim file!", filename); + $fatal(1, "[TB] ITA: Could not open %s stim file!", filename); end return stim_fd; endfunction @@ -286,6 +326,8 @@ endfunction string STIM_DATA; logic [31:0] ita_reg_tiles_val; logic [5:0][31:0] ita_reg_rqs_val; + logic [31:0] ita_reg_gelu_b_c_val; + logic [31:0] ita_reg_activation_rqs_val; $timeformat(-9, 2, " ns", 10); @@ -296,8 +338,9 @@ endfunction STIM_DATA = {simdir,"/hwpe/mem.txt"}; $readmemh(STIM_DATA, ita_hwpe_tb.i_data_memory.memory); - ita_reg_tiles_val_compute(N_TILES_SEQUENCE_DIM, N_TILES_EMBEDDING_DIM, N_TILES_PROJECTION_DIM, ita_reg_tiles_val); + ita_reg_tiles_val_compute(N_TILES_SEQUENCE_DIM, N_TILES_EMBEDDING_DIM, N_TILES_PROJECTION_DIM, N_TILES_FEEDFORWARD_DIM, ita_reg_tiles_val); ita_reg_eps_mult_val_compute(ita_reg_rqs_val); + ita_reg_activation_constants_compute(ita_reg_gelu_b_c_val, ita_reg_activation_rqs_val); // soft clear PERIPH_WRITE( 32'h14, 32'h0, 32'h0, clk); @@ -308,24 +351,24 @@ endfunction PERIPH_READ( 32'h04, 32'h0, status, clk); // 1: Step Q - ita_compute_step(Q, ita_reg_tiles_val, ita_reg_rqs_val, clk); + ita_compute_step(Q, ita_reg_tiles_val, ita_reg_rqs_val, ita_reg_gelu_b_c_val, ita_reg_activation_rqs_val, clk); // 2: Step K - ita_compute_step(K, ita_reg_tiles_val, ita_reg_rqs_val, clk); + ita_compute_step(K, ita_reg_tiles_val, ita_reg_rqs_val, ita_reg_gelu_b_c_val, ita_reg_activation_rqs_val, clk); // 3: Step V - ita_compute_step(V, ita_reg_tiles_val, ita_reg_rqs_val, clk); + ita_compute_step(V, ita_reg_tiles_val, ita_reg_rqs_val, ita_reg_gelu_b_c_val, ita_reg_activation_rqs_val, clk); for (int group = 0; group < N_TILES_SEQUENCE_DIM; group++) begin - BASE_PTR_INPUT[QK] = BASE_PTR[10] + group * N_TILES_INNER_DIM[QK] * N_ELEMENTS_PER_TILE; - BASE_PTR_OUTPUT[QK] = BASE_PTR[13] + group * N_TILES_OUTER_X[QK] * N_ELEMENTS_PER_TILE; + BASE_PTR_INPUT[QK] = BASE_PTR[15] + group * N_TILES_INNER_DIM[QK] * N_ELEMENTS_PER_TILE; + BASE_PTR_OUTPUT[QK] = BASE_PTR[18] + group * N_TILES_OUTER_X[QK] * N_ELEMENTS_PER_TILE; - BASE_PTR_INPUT[AV] = BASE_PTR[13] + group * N_TILES_INNER_DIM[AV] * N_ELEMENTS_PER_TILE; - BASE_PTR_OUTPUT[AV] = BASE_PTR[14] + group * N_TILES_OUTER_X[AV] * N_ELEMENTS_PER_TILE; + BASE_PTR_INPUT[AV] = BASE_PTR[18] + group * N_TILES_INNER_DIM[AV] * N_ELEMENTS_PER_TILE; + BASE_PTR_OUTPUT[AV] = BASE_PTR[19] + group * N_TILES_OUTER_X[AV] * N_ELEMENTS_PER_TILE; // 4: Step QK - ita_compute_step(QK, ita_reg_tiles_val, ita_reg_rqs_val, clk); + ita_compute_step(QK, ita_reg_tiles_val, ita_reg_rqs_val, ita_reg_gelu_b_c_val, ita_reg_activation_rqs_val, clk); // WIESEP: Hack to ensure that during the last tile of AV, the weight pointer is set correctly if (group == N_TILES_SEQUENCE_DIM-1) begin @@ -333,11 +376,17 @@ endfunction end // 5: Step AV - ita_compute_step(AV, ita_reg_tiles_val, ita_reg_rqs_val, clk); + ita_compute_step(AV, ita_reg_tiles_val, ita_reg_rqs_val, ita_reg_gelu_b_c_val, ita_reg_activation_rqs_val, clk); end // 6: Step OW - ita_compute_step(OW, ita_reg_tiles_val, ita_reg_rqs_val, clk); + ita_compute_step(OW, ita_reg_tiles_val, ita_reg_rqs_val, ita_reg_gelu_b_c_val, ita_reg_activation_rqs_val, clk); + + // 7: Step FF1 + ita_compute_step(F1, ita_reg_tiles_val, ita_reg_rqs_val, ita_reg_gelu_b_c_val, ita_reg_activation_rqs_val, clk); + + // 8: Step FF1 + ita_compute_step(F2, ita_reg_tiles_val, ita_reg_rqs_val, ita_reg_gelu_b_c_val, ita_reg_activation_rqs_val, clk); // Wait for the last step to finish wait(evt); @@ -347,12 +396,14 @@ endfunction #(10ns); - compare_output("hwpe/Q.txt", BASE_PTR[10]); - compare_output("hwpe/K.txt", BASE_PTR[11]); - compare_output("hwpe/V.txt", BASE_PTR[12]); - compare_output("hwpe/QK.txt", BASE_PTR[13]); - compare_output("hwpe/AV.txt", BASE_PTR[14]); - compare_output("hwpe/OW.txt", BASE_PTR[15]); + compare_output("hwpe/Q.txt", BASE_PTR[15]); + compare_output("hwpe/K.txt", BASE_PTR[16]); + compare_output("hwpe/V.txt", BASE_PTR[17]); + compare_output("hwpe/QK.txt", BASE_PTR[18]); + compare_output("hwpe/AV.txt", BASE_PTR[19]); + compare_output("hwpe/OW.txt", BASE_PTR[20]); + compare_output("hwpe/F1.txt", BASE_PTR[21]); + compare_output("hwpe/F2.txt", BASE_PTR[22]); // Finish the simulation $finish; @@ -362,9 +413,12 @@ endfunction input step_e step, input logic [31:0] ita_reg_tiles_val, input logic [5:0][31:0] ita_reg_rqs_val, + input logic [31:0] ita_reg_gelu_b_c_val, + input logic [31:0] ita_reg_activation_rqs_val, ref logic clk_i ); + logic [31:0] ctrl_engine_val; logic [31:0] ctrl_stream_val; logic weight_ptr_en; logic bias_ptr_en; @@ -399,13 +453,13 @@ endfunction // Calculate ita_reg_en ita_reg_en_compute(step, tile, ita_reg_en); // Calculate ctrl_stream_val, weight_ptr_en, and bias_ptr_en - ctrl_stream_val_compute(step, tile, ctrl_stream_val, weight_ptr_en, bias_ptr_en); + ctrl_val_compute(step, tile, ctrl_engine_val, ctrl_stream_val, weight_ptr_en, bias_ptr_en); // $display(" - Input_ptr 0x%0h, Weight_ptr0 0x%0h, Weight_ptr1 0x%0h, Bias_ptr 0x%0h, Output_ptr 0x%0h", input_ptr, weight_ptr0, weight_ptr1, bias_ptr, output_ptr); $display(" - ITA Reg En 0x%0h, Ctrl Stream Val 0x%0h, Weight Ptr En %0d, Bias Ptr En %0d", ita_reg_en, ctrl_stream_val, weight_ptr_en, bias_ptr_en); // Program ITA - PROGRAM_ITA(input_ptr, weight_ptr0, weight_ptr1, weight_ptr_en, bias_ptr, bias_ptr_en, output_ptr, ita_reg_tiles_val, ita_reg_rqs_val, ita_reg_en, ctrl_stream_val, clk_i); + PROGRAM_ITA(input_ptr, weight_ptr0, weight_ptr1, weight_ptr_en, bias_ptr, bias_ptr_en, output_ptr, ita_reg_tiles_val, ita_reg_rqs_val, ita_reg_gelu_b_c_val, ita_reg_activation_rqs_val, ita_reg_en, ctrl_engine_val, ctrl_stream_val, clk_i); // Wait for ITA to finish @(posedge clk_i); @@ -465,7 +519,7 @@ endfunction end $display(" - input_ptr 0x%08h (input_base_ptr 0x%08h)", input_ptr, input_base_ptr); $display(" - weight_ptr0 0x%08h (weight_base_ptr0 0x%08h)", weight_ptr0, weight_base_ptr0); - $display(" - weight_ptr1 0x%08h (weight_base_ptr0 0x%08h)", weight_ptr1, weight_base_ptr0); + $display(" - weight_ptr1 0x%08h (weight_base_ptr1 0x%08h)", weight_ptr1, weight_base_ptr1); $display(" - bias_ptr 0x%08h (bias_base_ptr 0x%08h)", bias_ptr, bias_base_ptr); $display(" - output_ptr 0x%08h (output_base_ptr 0x%08h)", output_ptr, output_base_ptr); endtask @@ -484,20 +538,34 @@ endfunction end else if (step == K && N_TILES_OUTER_X[Q]*N_TILES_OUTER_Y[Q]*N_TILES_INNER_DIM[Q] == 1) begin if (tile == 0) enable = 1'b1; + end else if (step == F1) begin + if (tile == 0 || tile == 1) + enable = 1'b1; + end else if (step == F2 && N_TILES_OUTER_X[F1]*N_TILES_OUTER_Y[F1]*N_TILES_INNER_DIM[F1] == 1) begin + if (tile == 0) + enable = 1'b1; end endtask - task automatic ctrl_stream_val_compute( + task automatic ctrl_val_compute( input step_e step, input integer tile, - output logic [31:0] reg_val, + output logic [31:0] ctrl_engine_val, + output logic [31:0] ctrl_stream_val, output logic reg_weight_en, output logic reg_bias_en ); + layer_e layer_type; + activation_e activation_function; + // Default values - reg_val = 32'h0; + ctrl_stream_val = 32'h0; reg_weight_en = 1'b0; reg_bias_en = 1'b0; + layer_type = Attention; + activation_function = Identity; + + ctrl_engine_val = layer_type | activation_function << 2; // ctrl_stream [0]: weight preload, // ctrl_stream [1]: weight nextload, @@ -510,70 +578,139 @@ endfunction case(step) Q : begin if (tile == 0) begin - reg_val = {28'b0, 4'b0011}; // weight preload and weight nextload + ctrl_stream_val = {28'b0, 4'b0011}; // weight preload and weight nextload end else begin - reg_val = {28'b0, 4'b0010}; // weight nextload + ctrl_stream_val = {28'b0, 4'b0010}; // weight nextload end reg_weight_en = 1'b1; reg_bias_en = 1'b1; end K : begin - reg_val = {28'b0, 4'b0010}; // weight nextload + ctrl_stream_val = {28'b0, 4'b0010}; // weight nextload reg_weight_en = 1'b1; reg_bias_en = 1'b1; end V : begin - reg_val = {28'b0, 4'b1010}; // weight nextload and invert bias direction + ctrl_stream_val = {28'b0, 4'b1010}; // weight nextload and invert bias direction reg_weight_en = 1'b1; reg_bias_en = 1'b1; end QK : begin - reg_val = {28'b0, 4'b0110}; // weight nextload and disable bias + ctrl_stream_val = {28'b0, 4'b0110}; // weight nextload and disable bias reg_weight_en = 1'b1; reg_bias_en = 1'b0; end AV : begin - reg_val = {28'b0, 4'b0110}; // weight nextload and disable bias + ctrl_stream_val = {28'b0, 4'b0110}; // weight nextload and disable bias reg_weight_en = 1'b1; reg_bias_en = 1'b0; end OW : begin if (tile == (N_TILES_OUTER_X[OW]*N_TILES_OUTER_Y[OW]*N_TILES_INNER_DIM[OW])-1) begin - reg_val = {28'b0, 4'b0000}; + ctrl_stream_val = {28'b0, 4'b0000}; + reg_weight_en = 1'b0; + end else begin + ctrl_stream_val = {28'b0, 4'b0010}; // weight nextload + reg_weight_en = 1'b1; + end + reg_bias_en = 1'b1; + end + F1 : begin + ctrl_engine_val = Feedforward | ACTIVATION << 2; + if (tile == 0) begin + ctrl_stream_val = {28'b0, 4'b0011}; // weight preload and weight nextload + end else begin + ctrl_stream_val = {28'b0, 4'b0010}; // weight nextload + end + reg_weight_en = 1'b1; + reg_bias_en = 1'b1; + end + F2 : begin + ctrl_engine_val = Feedforward | Identity << 2; + if (tile == (N_TILES_OUTER_X[F2]*N_TILES_OUTER_Y[F2]*N_TILES_INNER_DIM[F2])-1) begin + ctrl_stream_val = {28'b0, 4'b0000}; reg_weight_en = 1'b0; end else begin - reg_val = {28'b0, 4'b0010}; // weight nextload + ctrl_stream_val = {28'b0, 4'b0010}; // weight nextload reg_weight_en = 1'b1; end reg_bias_en = 1'b1; end endcase - reg_val[4] = ( (tile+1) % N_TILES_INNER_DIM[step] == 0) ? 1'b0 : 1'b1; + ctrl_stream_val[4] = ( (tile+1) % N_TILES_INNER_DIM[step] == 0) ? 1'b0 : 1'b1; endtask task automatic ita_reg_tiles_val_compute( input integer tile_s, input integer tile_e, input integer tile_p, + input integer tile_f, output logic [31:0] reg_val ); - reg_val = tile_s | tile_e << 4 | tile_p << 8; + reg_val = tile_s | tile_e << 4 | tile_p << 8 | tile_f << 12; + endtask + + task automatic ita_reg_activation_constants_compute( + output logic [31:0] gelu_b_c_reg, + output logic [31:0] activation_requant_reg + ); + gelu_const_t gelu_b; + gelu_const_t gelu_c; + requant_const_t activation_requant_mult; + requant_const_t activation_requant_shift; + requant_t activation_requant_add; + read_activation_constants(gelu_b, gelu_c, activation_requant_mult, activation_requant_shift, activation_requant_add); + gelu_b_c_reg = $unsigned(gelu_b) | gelu_c << 16; + activation_requant_reg = activation_requant_mult | activation_requant_shift << 8 | activation_requant_add << 16; + endtask + + task automatic read_activation_constants( + output gelu_const_t gelu_b, + output gelu_const_t gelu_c, + output requant_const_t gelu_eps_mult, + output requant_const_t gelu_right_shift, + output requant_t gelu_add + ); + integer b_fd; + integer c_fd; + integer rqs_mul_fd; + integer rqs_shift_fd; + integer add_fd; + int return_code; + + b_fd = open_stim_file(gelu_b_file); + c_fd = open_stim_file(gelu_c_file); + rqs_mul_fd = open_stim_file(activation_requant_mult_file); + rqs_shift_fd = open_stim_file(activation_requant_shift_file); + add_fd = open_stim_file(activation_requant_add_file); + + return_code = $fscanf(b_fd, "%d", gelu_b); + return_code = $fscanf(c_fd, "%d", gelu_c); + return_code = $fscanf(rqs_mul_fd, "%d", gelu_eps_mult); + return_code = $fscanf(rqs_shift_fd, "%d", gelu_right_shift); + return_code = $fscanf(add_fd, "%d", gelu_add); + + $fclose(b_fd); + $fclose(c_fd); + $fclose(rqs_mul_fd); + $fclose(rqs_shift_fd); + $fclose(add_fd); endtask task automatic ita_reg_eps_mult_val_compute( output logic [5:0][31:0] reg_val ); - logic [5:0][EMS-1:0] eps_mult; - logic [5:0][EMS-1:0] right_shift; - logic [5:0][ WI-1:0] add; + logic [N_REQUANT_CONSTS][EMS-1:0] eps_mult; + logic [N_REQUANT_CONSTS][EMS-1:0] right_shift; + logic [N_REQUANT_CONSTS][ WI-1:0] add; read_ITA_rqs(eps_mult, right_shift, add); reg_val[0] = eps_mult[0] | eps_mult[1] << 8 | eps_mult[2] << 16 | eps_mult[3] << 24; - reg_val[1] = eps_mult[4] | eps_mult[5] << 8; + reg_val[1] = eps_mult[4] | eps_mult[5] << 8 | eps_mult[6] << 16 | eps_mult[7] << 24; reg_val[2] = right_shift[0] | right_shift[1] << 8 | right_shift[2] << 16 | right_shift[3] << 24; - reg_val[3] = right_shift[4] | right_shift[5] << 8; + reg_val[3] = right_shift[4] | right_shift[5] << 8 | right_shift[6] << 16 | right_shift[7] << 24; reg_val[4] = add[0] | add[1] << 8 | add[2] << 16 | add[3] << 24; - reg_val[5] = add[4] | add[5] << 8; + reg_val[5] = add[4] | add[5] << 8 | add[6] << 16 | add[7] << 24; endtask task automatic compare_output(string STIM_DATA, integer address); @@ -599,46 +736,36 @@ endfunction endtask task read_ITA_rqs( - output logic [5:0][EMS-1:0] eps_mult, - output logic [5:0][EMS-1:0] right_shift, - output logic [5:0][ WI-1:0] add + output logic [N_REQUANT_CONSTS][EMS-1:0] eps_mult, + output logic [N_REQUANT_CONSTS][EMS-1:0] right_shift, + output logic [N_REQUANT_CONSTS][ WI-1:0] add ); - integer stim_fd_rqs; + integer stim_fd_mul, stim_fd_shift, stim_fd_add; integer ret_code; - for (int phase = 0; phase < 3; phase++) begin - case(phase) - 0 : begin - stim_fd_rqs = open_stim_file("RQS_MUL.txt"); - end - 1 : begin - stim_fd_rqs = open_stim_file("RQS_SHIFT.txt"); - end - 2 : begin - stim_fd_rqs = open_stim_file("RQS_ADD.txt"); - end - endcase + stim_fd_mul = open_stim_file("RQS_ATTN_MUL.txt"); + stim_fd_shift = open_stim_file("RQS_ATTN_SHIFT.txt"); + stim_fd_add = open_stim_file("RQS_ATTN_ADD.txt"); - case(phase) - 0 : begin - for (int j = 0; j < 6; j++) begin - ret_code = $fscanf(stim_fd_rqs, "%d\n", eps_mult[j]); - end - end - 1 : begin - for (int j = 0; j < 6; j++) begin - ret_code = $fscanf(stim_fd_rqs, "%d\n", right_shift[j]); - end - end - 2 : begin - for (int j = 0; j < 6; j++) begin - ret_code = $fscanf(stim_fd_rqs, "%d\n", add[j]); - end - end - endcase + for (int j = 0; j < N_ATTENTION_STEPS; j++) begin + ret_code = $fscanf(stim_fd_mul, "%d\n", eps_mult[j]); + ret_code = $fscanf(stim_fd_shift, "%d\n", right_shift[j]); + ret_code = $fscanf(stim_fd_add, "%d\n", add[j]); + end - $fclose(stim_fd_rqs); + stim_fd_mul = open_stim_file("RQS_FFN_MUL.txt"); + stim_fd_shift = open_stim_file("RQS_FFN_SHIFT.txt"); + stim_fd_add = open_stim_file("RQS_FFN_ADD.txt"); + + for (int j = 0; j < N_FEEDFORWARD_STEPS; j++) begin + ret_code = $fscanf(stim_fd_mul, "%d\n", eps_mult[j+N_ATTENTION_STEPS]); + ret_code = $fscanf(stim_fd_shift, "%d\n", right_shift[j+N_ATTENTION_STEPS]); + ret_code = $fscanf(stim_fd_add, "%d\n", add[j+N_ATTENTION_STEPS]); end + + $fclose(stim_fd_mul); + $fclose(stim_fd_shift); + $fclose(stim_fd_add); endtask task automatic PROGRAM_ITA( @@ -651,7 +778,10 @@ endfunction input logic [31:0] output_ptr, input logic [31:0] ita_reg_tiles_val, input logic [5:0][31:0] ita_reg_rqs_val, + input logic [31:0] ita_reg_gelu_b_c_val, + input logic [31:0] ita_reg_activation_rqs_val, input logic ita_reg_en, + input logic [31:0] ctrl_engine_val, input logic [31:0] ctrl_stream_val, ref logic clk_i ); @@ -671,8 +801,11 @@ endfunction PERIPH_WRITE( 4*ITA_REG_RIGHT_SHIFT1,ITA_REG_OFFSET, ita_reg_rqs_val[3], clk_i); PERIPH_WRITE( 4*ITA_REG_ADD0, ITA_REG_OFFSET, ita_reg_rqs_val[4], clk_i); PERIPH_WRITE( 4*ITA_REG_ADD1, ITA_REG_OFFSET, ita_reg_rqs_val[5], clk_i); + PERIPH_WRITE( 4*ITA_REG_GELU_B_C, ITA_REG_OFFSET, ita_reg_gelu_b_c_val, clk_i); + PERIPH_WRITE( 4*ITA_REG_ACTIVATION_REQUANT, ITA_REG_OFFSET, ita_reg_activation_rqs_val, clk_i); end + PERIPH_WRITE( 4*ITA_REG_CTRL_ENGINE, ITA_REG_OFFSET, ctrl_engine_val, clk_i); PERIPH_WRITE( 4*ITA_REG_CTRL_STREAM, ITA_REG_OFFSET, ctrl_stream_val, clk_i); endtask : PROGRAM_ITA diff --git a/src/ita.sv b/src/ita.sv index 3d9d712..f164ac9 100644 --- a/src/ita.sv +++ b/src/ita.sv @@ -33,16 +33,20 @@ module ita ); step_e step, step_q1, step_q2, step_q3, step_q4, step_q5, step_q6; - logic calc_en, calc_en_q1, calc_en_q2, calc_en_q3, calc_en_q4, calc_en_q5, calc_en_q6; - logic first_inner_tile, first_inner_tile_q1, first_inner_tile_q2, first_inner_tile_q3, first_inner_tile_q4, first_inner_tile_q5, first_inner_tile_q6; - logic last_inner_tile, last_inner_tile_q1, last_inner_tile_q2, last_inner_tile_q3, last_inner_tile_q4, last_inner_tile_q5, last_inner_tile_q6; + logic calc_en, calc_en_q1, calc_en_q2, calc_en_q3, calc_en_q4, calc_en_q5, calc_en_q6, calc_en_q7, calc_en_q8, calc_en_q9, calc_en_q10; + logic first_inner_tile, first_inner_tile_q1, first_inner_tile_q2, first_inner_tile_q3; + logic last_inner_tile, last_inner_tile_q1, last_inner_tile_q2, last_inner_tile_q3, last_inner_tile_q4, last_inner_tile_q5, last_inner_tile_q6, last_inner_tile_q7, last_inner_tile_q8, last_inner_tile_q9, last_inner_tile_q10; logic weight_valid, weight_ready; inp_t inp, inp_stream_soft; weight_t inp1, inp1_q, inp2, inp2_q; bias_t inp_bias, inp_bias_q1, inp_bias_q2; - oup_t oup, oup_q, result; + oup_t oup, oup_q, accumulator_oup; + requant_const_t requant_mult, requant_shift, activation_requant_mult, activation_requant_shift; requant_oup_t requant_oup; + requant_t requant_add, activation_requant_add; + requant_mode_e requant_mode, activation_requant_mode; + requant_oup_t post_activation; // FIFO signals logic fifo_full, fifo_empty, push_to_fifo, pop_from_fifo; @@ -60,20 +64,28 @@ module ita write_data_t write_data ; write_select_t write_select; + // Activation signals + activation_e activation_q1, activation_q2, activation_q3, activation_q4, activation_q5, activation_q6, activation_q7, activation_q8, activation_q9, activation_q10; + always_ff @(posedge clk_i, negedge rst_ni) begin if (!rst_ni) begin + calc_en_q10 <= 0; + calc_en_q9 <= 0; + calc_en_q8 <= 0; + calc_en_q7 <= 0; calc_en_q6 <= 0; calc_en_q5 <= 0; calc_en_q4 <= 0; calc_en_q3 <= 0; calc_en_q2 <= 0; calc_en_q1 <= 0; - first_inner_tile_q6 <= 0; - first_inner_tile_q5 <= 0; - first_inner_tile_q4 <= 0; first_inner_tile_q3 <= 0; first_inner_tile_q2 <= 0; first_inner_tile_q1 <= 0; + last_inner_tile_q10 <= 0; + last_inner_tile_q9 <= 0; + last_inner_tile_q8 <= 0; + last_inner_tile_q7 <= 0; last_inner_tile_q6 <= 1'b0; last_inner_tile_q5 <= 1'b0; last_inner_tile_q4 <= 1'b0; @@ -86,19 +98,32 @@ module ita step_q3 <= Idle; step_q2 <= Idle; step_q1 <= Idle; + activation_q8 <= Identity; + activation_q7 <= Identity; + activation_q6 <= Identity; + activation_q5 <= Identity; + activation_q4 <= Identity; + activation_q3 <= Identity; + activation_q2 <= Identity; + activation_q1 <= Identity; end else begin + calc_en_q10 <= calc_en_q9; + calc_en_q9 <= calc_en_q8; + calc_en_q8 <= calc_en_q7; + calc_en_q7 <= calc_en_q6; calc_en_q6 <= calc_en_q5; calc_en_q5 <= calc_en_q4; calc_en_q4 <= calc_en_q3; calc_en_q3 <= calc_en_q2; calc_en_q2 <= calc_en_q1; calc_en_q1 <= calc_en; - first_inner_tile_q6 <= first_inner_tile_q5; - first_inner_tile_q5 <= first_inner_tile_q4; - first_inner_tile_q4 <= first_inner_tile_q3; first_inner_tile_q3 <= first_inner_tile_q2; first_inner_tile_q2 <= first_inner_tile_q1; first_inner_tile_q1 <= first_inner_tile; + last_inner_tile_q10 <= last_inner_tile_q9; + last_inner_tile_q9 <= last_inner_tile_q8; + last_inner_tile_q8 <= last_inner_tile_q7; + last_inner_tile_q7 <= last_inner_tile_q6; last_inner_tile_q6 <= last_inner_tile_q5; last_inner_tile_q5 <= last_inner_tile_q4; last_inner_tile_q4 <= last_inner_tile_q3; @@ -111,6 +136,16 @@ module ita step_q3 <= step_q2; step_q2 <= step_q1; step_q1 <= step; + activation_q10 <= activation_q9; + activation_q9 <= activation_q8; + activation_q8 <= activation_q7; + activation_q7 <= activation_q6; + activation_q6 <= activation_q5; + activation_q5 <= activation_q4; + activation_q4 <= activation_q3; + activation_q3 <= activation_q2; + activation_q2 <= activation_q1; + activation_q1 <= ctrl_i.activation; end end @@ -205,7 +240,7 @@ module ita .oup_i (oup_q ), .inp_bias_i (inp_bias_q2 ), - .result_o (result ) + .result_o (accumulator_oup ) ); ita_softmax_top i_softmax_top ( @@ -223,20 +258,19 @@ module ita .inp_stream_soft_o (inp_stream_soft ) ); - oup_t requant_result; - logic requant_mode ; - eps_mult_t requant_mult ; - right_shift_t requant_shift ; - add_t requant_add ; - - assign requant_result = result; - assign requant_mode = 1'b0; - always_comb begin - requant_mult = ctrl_i.eps_mult[step_q4]; - requant_shift = ctrl_i.right_shift[step_q4]; - requant_add = ctrl_i.add[step_q4]; - end + ita_requatization_controller i_requantization_controller ( + .ctrl_i (ctrl_i ), + .requantizer_step_i (step_q4 ), + .requant_mult_o (requant_mult ), + .requant_shift_o (requant_shift ), + .requant_add_o (requant_add ), + .requant_mode_o (requant_mode ), + .activation_requant_mult_o (activation_requant_mult), + .activation_requant_shift_o(activation_requant_shift), + .activation_requant_add_o (activation_requant_add ), + .activation_requant_mode_o (activation_requant_mode ) + ); ita_requantizer i_requantizer ( .clk_i ( clk_i ), @@ -248,18 +282,33 @@ module ita .calc_en_i ( calc_en_q4 && last_inner_tile_q4 ), .calc_en_q_i ( calc_en_q5 && last_inner_tile_q5 ), - .result_i ( requant_result ), + .result_i ( accumulator_oup ), .add_i ( {N {requant_add}} ), .requant_oup_o( requant_oup ) ); - ita_fifo_controller i_fifo_controller ( + ita_activation i_activation ( .clk_i (clk_i ), .rst_ni (rst_ni ), + .activation_i (activation_q10), + .calc_en_i (calc_en_q6 && last_inner_tile_q6 ), + .calc_en_q_i (calc_en_q7 && last_inner_tile_q7 ), + .b_i (ctrl_i.gelu_b ), + .c_i (ctrl_i.gelu_c ), + .requant_mode_i (activation_requant_mode), + .requant_mult_i (activation_requant_mult), + .requant_shift_i (activation_requant_shift), + .requant_add_i (activation_requant_add), + .data_i (requant_oup), + .data_o (post_activation) + ); - .requant_oup_i (requant_oup ), - .ready_i (calc_en_q6 && last_inner_tile_q6 ), + ita_fifo_controller i_fifo_controller ( + .clk_i (clk_i ), + .rst_ni (rst_ni ), + .requant_oup_i (post_activation), + .activation_done_i (calc_en_q10 && last_inner_tile_q10 ), .fifo_full_i (fifo_full ), .push_to_fifo_o(push_to_fifo), .data_to_fifo_o(data_to_fifo) @@ -353,8 +402,12 @@ module ita fifo_usage_max <= FifoDepth; step_q <= step; end - if ((step_q==OW) && (step==Idle)) - $display("[ITA] Max FIFO usage: %d", fifo_usage_max); + if ((step_q==OW) && (step==Idle)) begin + $display("[ITA] Max FIFO usage during Attention: %d", fifo_usage_max); + end + if ((step_q==F2) && (step==Idle)) begin + $display("[ITA] Max FIFO usage during Feedforward: %d", fifo_usage_max); + end end // pragma translate_on endmodule diff --git a/src/ita_activation.sv b/src/ita_activation.sv new file mode 100644 index 0000000..95813f4 --- /dev/null +++ b/src/ita_activation.sv @@ -0,0 +1,119 @@ +// Copyright 2024 ETH Zurich and University of Bologna. +// Solderpad Hardware License, Version 0.51, see LICENSE for details. +// SPDX-License-Identifier: SHL-0.51 + +module ita_activation + import ita_package::*; + ( + input logic clk_i, + input logic rst_ni, + input gelu_const_t b_i, + input gelu_const_t c_i, + input requant_mode_e requant_mode_i, + input requant_const_t requant_mult_i, + input requant_const_t requant_shift_i, + input requant_t requant_add_i, + input activation_e activation_i, + input logic calc_en_i, + input logic calc_en_q_i, + input requant_oup_t data_i, + output requant_oup_t data_o + ); + + requant_oup_t data_q1, data_q2, data_q3, data_q4; + activation_e activation_q1, activation_q2; + oup_t gelu_out, requant_in; + requant_oup_t relu_out_d, relu_out_q1, relu_out_q2, requant_out; + logic calc_en_q2, calc_en_q3; + + ita_requantizer i_requantizer ( + .clk_i(clk_i), + .rst_ni(rst_ni), + .mode_i(requant_mode_i), + .eps_mult_i(requant_mult_i), + .right_shift_i(requant_shift_i), + .add_i({N{requant_add_i}}), + .calc_en_i(calc_en_q2), + .calc_en_q_i(calc_en_q3), + .result_i(requant_in), + .requant_oup_o(requant_out) + ); + + generate + for (genvar i = 0; i < N; i++) begin: relu_instances + ita_relu i_relu ( + .data_i(data_q2[i]), + .data_o(relu_out_d[i]) + ); + end + endgenerate + + generate + for (genvar i = 0; i < N; i++) begin: gelu_instances + ita_gelu i_gelu ( + .clk_i(clk_i), + .rst_ni(rst_ni), + .b_i(b_i), + .c_i(c_i), + .calc_en_i(calc_en_i), + .calc_en_q_i(calc_en_q_i), + .data_i(data_i[i]), + .data_o(gelu_out[i]) + ); + end + endgenerate + + always_comb begin + case (activation_i) + Gelu: begin + requant_in = gelu_out; + end + Relu: begin + for (int i = 0; i < N; i++) begin + requant_in[i] = {{(WO-WI){relu_out_q2[i][WI-1]}}, relu_out_q2[i]}; + end + end + default: begin + requant_in = '0; + end + endcase + end + + + always_comb begin + case (activation_q2) + Gelu, Relu: begin + data_o = requant_out; + end + default: begin + data_o = data_q4; + end + endcase + end + + always_ff @(posedge clk_i) begin + if (rst_ni == 0) begin + activation_q1 <= Identity; + activation_q2 <= Identity; + data_q1 <= '0; + data_q2 <= '0; + data_q3 <= '0; + data_q4 <= '0; + calc_en_q2 <= 0; + calc_en_q3 <= 0; + relu_out_q1 <= '0; + relu_out_q2 <= '0; + end else begin + activation_q1 <= activation_i; + activation_q2 <= activation_q1; + data_q1 <= data_i; + data_q2 <= data_q1; + data_q3 <= data_q2; + data_q4 <= data_q3; + calc_en_q2 <= calc_en_q_i; + calc_en_q3 <= calc_en_q2; + relu_out_q1 <= relu_out_d; + relu_out_q2 <= relu_out_q1; + end + end +endmodule \ No newline at end of file diff --git a/src/ita_controller.sv b/src/ita_controller.sv index 185b3db..b4156f5 100644 --- a/src/ita_controller.sv +++ b/src/ita_controller.sv @@ -113,9 +113,16 @@ module ita_controller softmax_div_done_d = 1'b0; busy_d = 1'b0; if (ctrl_i.start) begin - step_d = Q; + if(ctrl_i.layer == Attention) begin + step_d = Q; + end else if (ctrl_i.layer == Feedforward) begin + step_d = F1; + end else if (ctrl_i.layer == Linear) begin + step_d = MatMul; + end end end + // Attention Q : begin if (inner_tile_q == ctrl_i.tile_e-1) begin last_inner_tile_o = 1'b1; @@ -200,6 +207,47 @@ module ita_controller end end end + // Feedforward + F1: begin + if (inner_tile_q == ctrl_i.tile_e-1) begin + last_inner_tile_o = 1'b1; + end + if (inner_tile_d == ctrl_i.tile_e) begin // end of inner tile + inner_tile_d = '0; + tile_d = tile_q + 1; + if (tile_d == ctrl_i.tile_s*ctrl_i.tile_f) begin + tile_d = '0; + step_d = F2; + end + end + end + F2: begin + if (inner_tile_q == ctrl_i.tile_f-1) begin + last_inner_tile_o = 1'b1; + end + if (inner_tile_d == ctrl_i.tile_f) begin // end of inner tile + inner_tile_d = '0; + tile_d = tile_q + 1; + if (tile_d == ctrl_i.tile_s*ctrl_i.tile_e) begin + tile_d = '0; + step_d = Idle; + end + end + end + // Linear + MatMul: begin + if (inner_tile_q == ctrl_i.tile_e-1) begin + last_inner_tile_o = 1'b1; + end + if (inner_tile_d == ctrl_i.tile_e) begin // end of inner tile + inner_tile_d = '0; + tile_d = tile_q + 1; + if (tile_d == ctrl_i.tile_s*ctrl_i.tile_p) begin + tile_d = '0; + step_d = Idle; + end + end + end endcase if (inp_valid_i && inp_ready_o && oup_valid_i && oup_ready_i && last_inner_tile_o) begin ongoing_d = ongoing_q; diff --git a/src/ita_fifo_controller.sv b/src/ita_fifo_controller.sv index 78a8bdb..7265fc7 100644 --- a/src/ita_fifo_controller.sv +++ b/src/ita_fifo_controller.sv @@ -9,7 +9,7 @@ module ita_fifo_controller input logic clk_i , input logic rst_ni , input requant_oup_t requant_oup_i , - input logic ready_i , + input logic activation_done_i, input logic fifo_full_i , output logic push_to_fifo_o, output fifo_data_t data_to_fifo_o @@ -18,10 +18,9 @@ module ita_fifo_controller always_comb begin push_to_fifo_o = 0; data_to_fifo_o = '0; - if (&ready_i) begin + if (activation_done_i) begin push_to_fifo_o = 1; data_to_fifo_o = {>>WI{requant_oup_i}}; end end - endmodule diff --git a/src/ita_gelu.sv b/src/ita_gelu.sv new file mode 100644 index 0000000..cace920 --- /dev/null +++ b/src/ita_gelu.sv @@ -0,0 +1,67 @@ +// Copyright 2024 ETH Zurich and University of Bologna. +// Solderpad Hardware License, Version 0.51, see LICENSE for details. +// SPDX-License-Identifier: SHL-0.51 + +module ita_gelu + import ita_package::*; + ( + input logic clk_i, + input logic rst_ni, + input gelu_const_t b_i, + input gelu_const_t c_i, + input logic calc_en_i, + input logic calc_en_q_i, + input requant_t data_i, + output gelu_out_t data_o + ); + + logic erf_sgn_d, erf_sgn_q1; + gelu_const_t c_q1; + gelu_const_t data_sign_ext, erf_abs, erf_clipped, poly_d; + gelu_out_t erf_L_q1, gelu_erf_q1, gelu_sum_q1; + gelu_out_t gelu_out_q1, gelu_out_q2; + gelu_out_t poly_sq_d, poly_sq_q1; + requant_t data_q1; + + always_comb begin : first_stage + if (calc_en_i) begin + data_sign_ext = {{GELU_CONSTANTS_WIDTH-WI{data_i[WI-1]}}, data_i}; + + erf_sgn_d = data_i < 0; + erf_abs = erf_sgn_d ? -data_sign_ext : data_sign_ext; + erf_clipped = erf_abs > -b_i ? -b_i : erf_abs; + + poly_d = erf_clipped + b_i; + poly_sq_d = poly_d * poly_d; + end + end + + always_comb begin : second_stage + if (calc_en_q_i) begin + erf_L_q1 = poly_sq_q1 + c_q1; + + gelu_erf_q1 = erf_sgn_q1 ? -erf_L_q1 : erf_L_q1; + gelu_sum_q1 = gelu_erf_q1 + c_q1; + gelu_out_q1 = data_q1 * gelu_sum_q1; + end + end + + always_ff @(posedge clk_i or negedge rst_ni) begin + if (!rst_ni) begin + c_q1 <= '0; + data_q1 <= '0; + erf_sgn_q1 <= '0; + poly_sq_q1 <= '0; + gelu_out_q2 <= '0; + end else begin + c_q1 <= c_i; + data_q1 <= data_i; + erf_sgn_q1 <= erf_sgn_d; + poly_sq_q1 <= poly_sq_d; + gelu_out_q2 <= gelu_out_q1; + end + end + + assign data_o = gelu_out_q2; + +endmodule diff --git a/src/ita_package.sv b/src/ita_package.sv index a699877..3d602d8 100644 --- a/src/ita_package.sv +++ b/src/ita_package.sv @@ -17,63 +17,74 @@ package ita_package; localparam int unsigned WO = 26 ; localparam int unsigned EMS = 8 ; localparam int unsigned Latency = 7 ; + localparam int unsigned GELU_CONSTANTS_WIDTH = 16 ; + localparam int unsigned GELU_OUT_WIDTH = 26 ; + localparam int unsigned N_ATTENTION_STEPS = 6 ; + localparam int unsigned N_FEEDFORWARD_STEPS = 2 ; + localparam int unsigned N_STATES = N_ATTENTION_STEPS + N_FEEDFORWARD_STEPS + 1; + localparam int unsigned N_REQUANT_CONSTS = N_ATTENTION_STEPS + N_FEEDFORWARD_STEPS; + parameter int unsigned InputAddrWidth = idx_width(S) ; parameter int unsigned MAddrWidth = idx_width(H*S) ; parameter int unsigned M3AddrWidth = idx_width(S) ; parameter int unsigned NumReadPorts = N ; parameter int unsigned MNumReadPorts = N ; - parameter int unsigned FifoDepth = `ifdef ITA_OUTPUT_FIFO_DEPTH `ITA_OUTPUT_FIFO_DEPTH `else 8 `endif; + parameter int unsigned FifoDepth = `ifdef ITA_OUTPUT_FIFO_DEPTH `ITA_OUTPUT_FIFO_DEPTH `else 14 `endif; localparam int unsigned SplitFactor = 4 ; - parameter int unsigned N_WRITE_EN = `ifdef TARGET_ITA_HWPE 8 `else M `endif; + // Feedforward + typedef enum bit [1:0] {Attention=0, Feedforward=1, Linear=2} layer_e; + typedef enum bit [1:0] {Identity=0, Gelu=1, Relu=2} activation_e; + typedef logic signed [GELU_CONSTANTS_WIDTH-1:0] gelu_const_t; + typedef logic signed [GELU_OUT_WIDTH-1:0] gelu_out_t; + // IO + typedef logic [EMS-1:0] requant_const_t; + typedef logic [N_REQUANT_CONSTS-1:0][EMS-1:0] requant_const_array_t; + typedef logic signed [WI-1:0] requant_t; + typedef logic signed [N_REQUANT_CONSTS-1:0][WI-1:0] requant_array_t; typedef logic [idx_width(S+1)-1:0] seq_length_t; typedef logic [idx_width(P+1)-1:0] proj_space_t; typedef logic [idx_width(E+1)-1:0] embed_size_t; typedef logic [idx_width(H+1)-1:0] n_heads_t; - typedef logic [ EMS-1:0] eps_mult_t; - typedef logic [ EMS-1:0] right_shift_t; - typedef logic [ WI-1:0] add_t; - + typedef logic [ 32-1:0] tile_t; typedef struct packed { logic start ; - seq_length_t seq_length ; - proj_space_t proj_space ; - embed_size_t embed_size ; - n_heads_t n_heads ; - logic [5:0][EMS-1:0] eps_mult ; - logic [5:0][EMS-1:0] right_shift ; - logic [5:0][WI-1:0] add ; - logic [32-1:0] lin_tiles ; - logic [32-1:0] attn_tiles ; - logic [32-1:0] tile_s; - logic [32-1:0] tile_e; - logic [32-1:0] tile_p; + layer_e layer ; + activation_e activation ; + requant_const_array_t eps_mult ; + requant_const_array_t right_shift ; + requant_array_t add ; + gelu_const_t gelu_b; + gelu_const_t gelu_c; + requant_const_t activation_requant_mult; + requant_const_t activation_requant_shift; + requant_t activation_requant_add; + tile_t tile_s; + tile_t tile_e; + tile_t tile_p; + tile_t tile_f; } ctrl_t; - typedef struct packed { logic [InputAddrWidth-1:0] addr; logic [ E-1:0][WI-1:0] data; } write_port_t; - typedef logic signed [WI-1:0] requant_t; - - typedef logic signed [N-1:0][M-1:0][WI-1:0] weight_t; - typedef logic signed [N-1:0][(WO-2)-1:0] bias_t; - - typedef logic signed [N-1:0][WO-1:0] oup_t; - typedef requant_t [N-1:0] requant_oup_t; - // States - typedef enum {Idle=6, Q=0, K=1, V=2, QK=3, AV=4, OW=5} step_e; + typedef enum {Idle=0, Q=1, K=2, V=3, QK=4, AV=5, OW=6, F1=7, F2=8, MatMul=9} step_e; // Inputs and weights typedef logic signed [M-1:0][ WI-1:0] inp_t; typedef logic [(N*M/N_WRITE_EN)-1:0][ WI-1:0] inp_weight_t; typedef logic [N_WRITE_EN-1:0] write_select_t; typedef logic [N_WRITE_EN-1:0][(N*M*WI/N_WRITE_EN)-1:0] write_data_t; + typedef logic signed [N-1:0][(WO-2)-1:0] bias_t; + typedef logic signed [N-1:0][M-1:0][WI-1:0] weight_t; + + // Accumulator + typedef logic signed [N-1:0][WO-1:0] oup_t; // FIFO typedef logic [ N*WI-1:0] fifo_data_t; @@ -92,4 +103,8 @@ package ita_package; localparam int unsigned DividerWidth = SoftmaxAccDataWidth + 1; localparam int unsigned NumDiv = 5; + // Requantizer + typedef enum {Signed=0, Unsigned=1} requant_mode_e; + localparam requant_mode_e REQUANT_MODE = Signed; + typedef requant_t [N-1:0] requant_oup_t; endpackage : ita_package diff --git a/src/ita_relu.sv b/src/ita_relu.sv new file mode 100644 index 0000000..f639e0c --- /dev/null +++ b/src/ita_relu.sv @@ -0,0 +1,14 @@ +// Copyright 2024 ETH Zurich and University of Bologna. +// Solderpad Hardware License, Version 0.51, see LICENSE for details. +// SPDX-License-Identifier: SHL-0.51 + +module ita_relu + import ita_package::*; + ( + input requant_t data_i, + output requant_t data_o + ); + + assign data_o = data_i > 0 ? data_i : 0; + +endmodule \ No newline at end of file diff --git a/src/ita_requantization_controller.sv b/src/ita_requantization_controller.sv new file mode 100644 index 0000000..3b6c865 --- /dev/null +++ b/src/ita_requantization_controller.sv @@ -0,0 +1,44 @@ +// Copyright 2020 ETH Zurich and University of Bologna. +// Solderpad Hardware License, Version 0.51, see LICENSE for details. +// SPDX-License-Identifier: SHL-0.51 + + +module ita_requatization_controller + import ita_package::*; +( + input step_e requantizer_step_i, + input ctrl_t ctrl_i, + output requant_const_t requant_mult_o, + output requant_const_t requant_shift_o, + output requant_t requant_add_o, + output requant_mode_e requant_mode_o, + output requant_const_t activation_requant_mult_o, + output requant_const_t activation_requant_shift_o, + output requant_t activation_requant_add_o, + output requant_mode_e activation_requant_mode_o +); + logic [$clog2(N_REQUANT_CONSTS)-1:0] constant_idx; + + always_comb begin + case (requantizer_step_i) + Q: constant_idx = 0; + K: constant_idx = 1; + V: constant_idx = 2; + QK: constant_idx = 3; + AV: constant_idx = 4; + OW: constant_idx = 5; + F1: constant_idx = 6; + F2: constant_idx = 7; + default: constant_idx = 0; + endcase + end + + assign requant_mult_o = ctrl_i.eps_mult[constant_idx]; + assign requant_shift_o = ctrl_i.right_shift[constant_idx]; + assign requant_add_o = ctrl_i.add[constant_idx]; + assign activation_requant_mult_o = ctrl_i.activation_requant_mult; + assign activation_requant_shift_o = ctrl_i.activation_requant_shift; + assign activation_requant_add_o = ctrl_i.activation_requant_add; + assign requant_mode_o = requant_mode_e'(REQUANT_MODE); + assign activation_requant_mode_o = requant_mode_e'(REQUANT_MODE); + endmodule diff --git a/src/ita_requantizer.sv b/src/ita_requantizer.sv index 75d2516..6033c09 100644 --- a/src/ita_requantizer.sv +++ b/src/ita_requantizer.sv @@ -8,9 +8,9 @@ module ita_requantizer ( input logic clk_i , input logic rst_ni , - input logic mode_i , - input logic [EMS-1:0] eps_mult_i , - input logic [EMS-1:0] right_shift_i, + input requant_mode_e mode_i , + input requant_const_t eps_mult_i , + input requant_const_t right_shift_i, input logic calc_en_i , input logic calc_en_q_i , input oup_t result_i , @@ -34,7 +34,7 @@ module ita_requantizer shifted_added = '0; for (int i = 0; i < N; i++) begin - if (mode_i) begin + if (mode_i === Unsigned) begin mult_signed = {1'b0, result_i[i]}; end else begin mult_signed = signed'(result_i[i]); diff --git a/src/ita_softmax.sv b/src/ita_softmax.sv index efd93a7..675750c 100644 --- a/src/ita_softmax.sv +++ b/src/ita_softmax.sv @@ -39,7 +39,7 @@ module ita_softmax input requant_t [1:0] read_max_data_i, output logic write_max_en_o, output logic [InputAddrWidth-1:0] write_max_addr_o, - output requant_t write_max_data_o + output requant_t write_max_data_o ); counter_t tile_d, tile_q1, tile_q2, tile_q3, tile_q4; @@ -307,5 +307,5 @@ module ita_softmax .data_o (data_from_fifo), .pop_i (pop_from_fifo ) ); - + endmodule : ita_softmax diff --git a/src/tb/activation_tb.sv b/src/tb/activation_tb.sv new file mode 100644 index 0000000..2fb0f54 --- /dev/null +++ b/src/tb/activation_tb.sv @@ -0,0 +1,283 @@ +// Copyright 2024 ETH Zurich and University of Bologna. +// Solderpad Hardware License, Version 0.51, see LICENSE for details. +// SPDX-License-Identifier: SHL-0.51 + +module activation_tb; + + timeunit 10ps; + timeprecision 1ps; + + import ita_package::*; + + localparam time CLK_PERIOD = 2000ps; + localparam time APPL_DELAY = 400ps; + localparam time ACQ_DELAY = 1600ps; + localparam unsigned RST_CLK_CYCLES = 10; + + string gelu_b_file = "GELU_B.txt"; + string gelu_c_file = "GELU_C.txt"; + string input_file = "standalone/preactivation.txt"; + string gelu_output_file = "standalone/gelu.txt"; + string relu_output_file = "standalone/relu.txt"; + string activation_requant_mult_file = "activation_requant_mult.txt"; + string activation_requant_shift_file = "activation_requant_shift.txt"; + string activation_requant_add_file = "activation_requant_add.txt"; + + integer N_PE, M_TILE_LEN; + integer SEQUENCE_LEN, PROJECTION_SIZE, EMBEDDING_SIZE, FEEDFORWARD_SIZE; + + logic clk, rst_n; + requant_oup_t preactivation_input; + requant_oup_t preactivation_input_check; + requant_oup_t expected_postactivation; + requant_oup_t acquired_postactivation; + gelu_const_t gelu_b; + gelu_const_t gelu_c; + requant_const_t activation_requant_mult; + requant_const_t activation_requant_shift; + requant_t activation_requant_add; + activation_e selected_activation; + + string simdir; + + initial begin + N_PE = `ifdef ITA_N `ITA_N `else 16 `endif; + M_TILE_LEN = `ifdef ITA_M `ITA_M `else 64 `endif; + SEQUENCE_LEN = `ifdef SEQ_LENGTH `SEQ_LENGTH `else M_TILE_LEN `endif; + PROJECTION_SIZE = `ifdef PROJ_SPACE `PROJ_SPACE `else M_TILE_LEN `endif; + EMBEDDING_SIZE = `ifdef EMBED_SIZE `EMBED_SIZE `else M_TILE_LEN `endif; + FEEDFORWARD_SIZE = `ifdef FF_SIZE `FF_SIZE `else M_TILE_LEN `endif; + simdir = { + "../../simvectors/data_S", + $sformatf("%0d", SEQUENCE_LEN), + "_E", + $sformatf("%0d", EMBEDDING_SIZE), + "_P", + $sformatf("%0d", PROJECTION_SIZE), + "_F", + $sformatf("%0d", FEEDFORWARD_SIZE), + "_H1_B", + $sformatf("%0d", `ifdef BIAS `BIAS `else 0 `endif) + }; + end + + clk_rst_gen #( + .CLK_PERIOD (CLK_PERIOD ), + .RST_CLK_CYCLES(RST_CLK_CYCLES) + ) i_clk_rst_gen ( + .clk_o (clk ), + .rst_no(rst_n) + ); + + activation dut ( + .clk_i (clk ), + .rst_ni (rst_n), + .b_i (gelu_b ), + .c_i (gelu_c ), + .data_i (preactivation_input), + .activation_i (selected_activation), + .requant_mode_i (requant_mode_e'(REQUANT_MODE)), + .requant_mult_i (activation_requant_mult), + .requant_shift_i(activation_requant_shift), + .requant_add_i (activation_requant_add ), + .calc_en_i (1'b1), + .calc_en_q_i (1'b1), + .data_o (acquired_postactivation) + ); + + function automatic integer open_stim_file(string filename); + integer stim_fd; + if (filename == "") + return 0; + stim_fd = $fopen({simdir,"/",filename}, "r"); + if (stim_fd == 0) begin + $fatal(1, "[TB] ITA: Could not open %s stim file!", filename); + end + return stim_fd; + endfunction + + function automatic void read_preactivation(integer stim_fd); + int return_code; + for (int i = 0; i < N_PE; i++) begin + return_code = $fscanf(stim_fd, "%d", preactivation_input[i]); + end + endfunction + + function automatic void read_preactivation_check(integer stim_fd); + int return_code; + for (int i = 0; i < N_PE; i++) begin + return_code = $fscanf(stim_fd, "%d", preactivation_input_check[i]); + end + endfunction + + function automatic void read_postactivation(integer gelu_fd, integer relu_fd, input activation_e activation, input requant_oup_t preactivation, output requant_oup_t expected_postactivation); + int return_code; + if (activation == Gelu) begin + for (int i = 0; i < N_PE; i++) begin + return_code = $fscanf(gelu_fd, "%d", expected_postactivation[i]); + end + end else if (activation == Relu) begin + for (int i = 0; i < N_PE; i++) begin + return_code = $fscanf(relu_fd, "%d", expected_postactivation[i]); + end + end else if (activation == Identity) begin + for (int i = 0; i < N_PE; i++) begin + expected_postactivation[i] = preactivation[i]; + end + end + endfunction + + task automatic read_gelu_constants( + output gelu_const_t gelu_b, + output gelu_const_t gelu_c, + output requant_const_t activation_requant_mult, + output requant_const_t activation_requant_shift, + output requant_t activation_requant_add + ); + integer b_fd; + integer c_fd; + integer rqs_mul_fd; + integer rqs_shift_fd; + integer add_fd; + int return_code; + + b_fd = open_stim_file(gelu_b_file); + c_fd = open_stim_file(gelu_c_file); + rqs_mul_fd = open_stim_file(activation_requant_mult_file); + rqs_shift_fd = open_stim_file(activation_requant_shift_file); + add_fd = open_stim_file(activation_requant_add_file); + + return_code = $fscanf(b_fd, "%d", gelu_b); + return_code = $fscanf(c_fd, "%d", gelu_c); + return_code = $fscanf(rqs_mul_fd, "%d", activation_requant_mult); + return_code = $fscanf(rqs_shift_fd, "%d", activation_requant_shift); + return_code = $fscanf(add_fd, "%d", activation_requant_add); + + $fclose(b_fd); + $fclose(c_fd); + $fclose(rqs_mul_fd); + $fclose(rqs_shift_fd); + $fclose(add_fd); + endtask + + task apply_activations(input activation_e activation, int latency); + integer input_fd; + integer is_end_of_file; + + is_end_of_file = 0; + + input_fd = open_stim_file(input_file); + + if (activation == Gelu) begin + read_gelu_constants(gelu_b, gelu_c, activation_requant_mult, activation_requant_shift, activation_requant_add); + end + + $display("Starting to apply activations for %s with latency %0d after %0d", activation, latency, $time); + + while (!is_end_of_file) begin + @(posedge clk); + #(APPL_DELAY); + read_preactivation(input_fd); + is_end_of_file = $feof(input_fd); + selected_activation = activation; + end + + repeat(latency) @(posedge clk); + + $display("Finished applying activations for %s at %0d", activation, $time); + + $fclose(input_fd); + endtask + + initial begin: application_block + integer input_fd; + integer is_end_of_file; + + is_end_of_file = 0; + + wait (rst_n); + + apply_activations(Identity, 4); + apply_activations(Gelu, 4); + apply_activations(Relu, 4); + + @(posedge clk); + end : application_block + + function automatic void validate_postactivation(inout integer n_checks, inout integer n_errors, input activation_e activation); + n_checks += N_PE; + for (int i = 0; i < N_PE; i++) begin + if (acquired_postactivation[i] !== expected_postactivation[i]) begin + n_errors += 1; + if (n_errors <= 30) begin + $display(":=( expected %d, not %d for input %d and activation %s at %0d\n", expected_postactivation[i], acquired_postactivation[i], preactivation_input_check[i], activation, $time); + end + if (n_errors == 31) begin + $display(":=( suppressing further mismatches...\n"); + end + end + end + endfunction + + task check_activations(input activation_e activation, input int latency, inout integer n_checks, inout integer n_errors); + integer input_fd; + integer gelu_output_fd; + integer relu_output_fd; + integer is_end_of_file; + + is_end_of_file = 0; + + input_fd = open_stim_file(input_file); + gelu_output_fd = open_stim_file(gelu_output_file); + relu_output_fd = open_stim_file(relu_output_file); + + repeat(latency) @(posedge clk); + + $display("Starting to check activations for %s with latency %0d after %0d", activation, latency, $time); + + while (!is_end_of_file) begin + @(posedge clk); + #(ACQ_DELAY); + read_preactivation_check(input_fd); + is_end_of_file = $feof(input_fd); + read_postactivation(gelu_output_fd, relu_output_fd, activation, preactivation_input_check, expected_postactivation); + validate_postactivation(n_checks, n_errors, activation); + end + + $display("Finished checking activations for %s at %0d", activation, $time); + endtask + + + initial begin: checker_block + integer is_end_of_file; + integer n_checks; + integer n_errors; + integer input_fd; + integer output_fd; + + is_end_of_file = 0; + n_checks = 0; + n_errors = 0; + + wait (rst_n); + + check_activations(Identity, 4, n_checks, n_errors); + check_activations(Gelu, 4, n_checks, n_errors); + check_activations(Relu, 4, n_checks, n_errors); + + @(posedge clk); + + if (n_errors > 0) begin + $display(":=( Test failed with ", n_errors, " mismatches out of ", n_checks, " checks!"); + end else begin + $display(":=) Test passed with ", n_errors, " mismatches out of ", n_checks, " checks!"); + end + + $fclose(input_fd); + $fclose(output_fd); + + #(100*CLK_PERIOD); + $finish(); + end + +endmodule diff --git a/src/tb/ita_tb.sv b/src/tb/ita_tb.sv index 3a805ae..78cdb17 100644 --- a/src/tb/ita_tb.sv +++ b/src/tb/ita_tb.sv @@ -16,26 +16,34 @@ module ita_tb; // Set to 1 to run the simulation without stalls localparam unsigned CONT = `ifdef NO_STALLS `NO_STALLS `else 0 `endif; localparam unsigned ITERS = 1; + localparam unsigned N_PHASES = 7; // Stimuli files - string INPUT_FILES[5] = {"standalone/Q.txt", "standalone/K.txt", "standalone/Wv_0.txt", "standalone/Qp_in_0.txt", "standalone/O_soft_in_0.txt"}; + string INPUT_FILES[N_PHASES] = {"standalone/Q.txt", "standalone/K.txt", "standalone/Wv_0.txt", "standalone/Qp_in_0.txt", "standalone/O_soft_in_0.txt", "standalone/FF.txt", "standalone/FFp_in_0.txt"}; string ATTENTION_INPUT_FILES[1] = {"standalone/A_stream_soft_in_0.txt"}; - string INPUT_BIAS_FILES[5] = {"standalone/Bq_0.txt", "standalone/Bk_0.txt", "standalone/Bv_0.txt", "", "standalone/Bo_0.txt"}; - string WEIGHT_FILES[5] = {"standalone/Wq_0.txt", "standalone/Wk_0.txt", "standalone/V.txt", "standalone/Kp_in_0.txt", "standalone/Wo_0.txt"}; + string INPUT_BIAS_FILES[N_PHASES] = {"standalone/Bq_0.txt", "standalone/Bk_0.txt", "standalone/Bv_0.txt", "", "standalone/Bo_0.txt", "standalone/Bff_0.txt", "standalone/Bff2_0.txt"}; + string WEIGHT_FILES[N_PHASES] = {"standalone/Wq_0.txt", "standalone/Wk_0.txt", "standalone/V.txt", "standalone/Kp_in_0.txt", "standalone/Wo_0.txt", "standalone/Wff_0.txt", "standalone/Wff2_0.txt"}; string ATTENTION_WEIGHT_FILES[1] = {"standalone/Vp_in_0.txt"}; - string OUTPUT_FILES[5] = {"standalone/Qp_0.txt", "standalone/Kp_0.txt", "standalone/Vp_0.txt", "standalone/A_0.txt", "standalone/Out_soft_0.txt"}; + string OUTPUT_FILES[N_PHASES] = {"standalone/Qp_0.txt", "standalone/Kp_0.txt", "standalone/Vp_0.txt", "standalone/A_0.txt", "standalone/Out_soft_0.txt", "standalone/FFp_0.txt", "standalone/FF2p_0.txt"}; string ATTENTION_OUTPUT_FILES[2] = {"standalone/A_0.txt", "standalone/O_soft_0.txt"}; + string gelu_b_file = "GELU_B.txt"; + string gelu_c_file = "GELU_C.txt"; + string activation_requant_mult_file = "activation_requant_mult.txt"; + string activation_requant_shift_file = "activation_requant_shift.txt"; + string activation_requant_add_file = "activation_requant_add.txt"; // Parameters integer N_PE, M_TILE_LEN; integer N_ENTRIES_PER_TILE; - integer SEQUENCE_LEN, PROJECTION_SPACE, EMBEDDING_SIZE; + integer SEQUENCE_LEN, PROJECTION_SPACE, EMBEDDING_SIZE, FEEDFORWARD_SIZE; integer N_TILES_SEQUENCE_DIM, N_TILES_EMBEDDING_DIM, N_TILES_PROJECTION_DIM; + integer N_TILES_FEEDFORWARD; integer N_TILES_LINEAR_PROJECTION, N_TILES_ATTENTION; integer N_TILES_LINEAR_OUTPUT; integer N_ENTRIES_LINEAR_OUTPUT, N_ENTRIES_PER_PROJECTION_DIM, N_ENTRIES_PER_SEQUENCE_DIM; - integer N_TILES_INNER_DIM_LINEAR_PROJECTION[5]; + integer N_TILES_INNER_DIM_LINEAR_PROJECTION[N_PHASES]; integer N_ATTENTION_TILE_ROWS, N_GROUPS; + activation_e ACTIVATION; // Signals logic clk, rst_n; @@ -59,6 +67,9 @@ module ita_tb; SEQUENCE_LEN = `ifdef SEQ_LENGTH `SEQ_LENGTH `else M_TILE_LEN `endif; PROJECTION_SPACE = `ifdef PROJ_SPACE `PROJ_SPACE `else M_TILE_LEN `endif; EMBEDDING_SIZE = `ifdef EMBED_SIZE `EMBED_SIZE `else M_TILE_LEN `endif; + FEEDFORWARD_SIZE = `ifdef FF_SIZE `FF_SIZE `else M_TILE_LEN `endif; + ACTIVATION = activation_e'(`ifdef ACTIVATION `ACTIVATION `else Identity `endif); + simdir = { "../../simvectors/data_S", $sformatf("%0d", SEQUENCE_LEN), @@ -66,6 +77,8 @@ module ita_tb; $sformatf("%0d", EMBEDDING_SIZE), "_P", $sformatf("%0d", PROJECTION_SPACE), + "_F", + $sformatf("%0d", FEEDFORWARD_SIZE), "_H1_B", $sformatf("%0d", `ifdef BIAS `BIAS `else 0 `endif) }; @@ -81,11 +94,14 @@ module ita_tb; N_ENTRIES_PER_SEQUENCE_DIM = N_ENTRIES_PER_TILE * N_TILES_SEQUENCE_DIM; N_ATTENTION_TILE_ROWS = N_TILES_SEQUENCE_DIM; N_GROUPS = 2 * N_ATTENTION_TILE_ROWS; + N_TILES_FEEDFORWARD = FEEDFORWARD_SIZE / M_TILE_LEN; N_TILES_INNER_DIM_LINEAR_PROJECTION[0] = N_TILES_EMBEDDING_DIM; N_TILES_INNER_DIM_LINEAR_PROJECTION[1] = N_TILES_EMBEDDING_DIM; N_TILES_INNER_DIM_LINEAR_PROJECTION[2] = N_TILES_EMBEDDING_DIM; N_TILES_INNER_DIM_LINEAR_PROJECTION[3] = '0; // Not used, no bias N_TILES_INNER_DIM_LINEAR_PROJECTION[4] = N_TILES_PROJECTION_DIM; + N_TILES_INNER_DIM_LINEAR_PROJECTION[5] = N_TILES_EMBEDDING_DIM; + N_TILES_INNER_DIM_LINEAR_PROJECTION[6] = N_TILES_FEEDFORWARD; end clk_rst_gen #( @@ -121,7 +137,7 @@ function automatic integer open_stim_file(string filename); return 0; stim_fd = $fopen({simdir,"/",filename}, "r"); if (stim_fd == 0) begin - $fatal("[TB] ITA: Could not open %s stim file!", filename); + $fatal(1, "[TB] ITA: Could not open %s stim file!", filename); end return stim_fd; endfunction @@ -193,7 +209,7 @@ task automatic toggle_input(inout integer tile_entry, inout integer group, inout group += 1; endtask -function integer get_random(); +function bit get_random(); logic value; integer ret_code; if (CONT) @@ -222,11 +238,11 @@ function bit did_finish_output_dot_product(input integer tile_entry); return tile_entry >= N_ENTRIES_PER_SEQUENCE_DIM; endfunction -function bit is_last_entry_of_output_group(input integer input_file_index, input integer tile_entry); +function bit is_last_entry_of_output_group(input bit input_file_index, input integer tile_entry); return is_output_group(input_file_index) && did_finish_output_dot_product(tile_entry); endfunction -function bit is_last_entry_of_attention_group(input integer input_file_index, input integer tile_entry); +function bit is_last_entry_of_attention_group(input bit input_file_index, input integer tile_entry); return is_attention_group(input_file_index) && did_finish_attention_dot_product(tile_entry); endfunction @@ -234,6 +250,39 @@ function bit should_toggle_output(input bit input_file_index, input integer tile return is_last_entry_of_output_group(input_file_index, tile_entry) || is_last_entry_of_attention_group(input_file_index, tile_entry); endfunction +task automatic read_activation_constants( + output gelu_const_t gelu_b, + output gelu_const_t gelu_c, + output requant_const_t activation_requant_mult, + output requant_const_t activation_requant_shift, + output requant_t activation_requant_add +); + integer b_fd; + integer c_fd; + integer rqs_mul_fd; + integer rqs_shift_fd; + integer add_fd; + int return_code; + + b_fd = open_stim_file(gelu_b_file); + c_fd = open_stim_file(gelu_c_file); + rqs_mul_fd = open_stim_file(activation_requant_mult_file); + rqs_shift_fd = open_stim_file(activation_requant_shift_file); + add_fd = open_stim_file(activation_requant_add_file); + + return_code = $fscanf(b_fd, "%d", gelu_b); + return_code = $fscanf(c_fd, "%d", gelu_c); + return_code = $fscanf(rqs_mul_fd, "%d", activation_requant_mult); + return_code = $fscanf(rqs_shift_fd, "%d", activation_requant_shift); + return_code = $fscanf(add_fd, "%d", activation_requant_add); + + $fclose(b_fd); + $fclose(c_fd); + $fclose(rqs_mul_fd); + $fclose(rqs_shift_fd); + $fclose(add_fd); +endtask + task automatic apply_ITA_inputs(input integer phase); integer stim_fd_inp_attn[2]; bit input_file_index = 0; @@ -278,9 +327,12 @@ task automatic apply_ITA_inputs(input integer phase); if (is_end_of_tile(tile_entry) && phase != 3) reset_tile(tile, tile_entry); stim_fd_inp = stim_fd_inp_attn[input_file_index]; - is_end_of_input = $feof(stim_fd_inp); + is_end_of_input = $feof(stim_fd_inp) != 0; end end + @(posedge clk); + #(APPL_DELAY); + inp_valid = 1'b0; // Set back to default $fclose(stim_fd_inp); $fclose(stim_fd_bias); endtask @@ -332,42 +384,32 @@ task automatic apply_ITA_weights(input integer phase); endtask task apply_ITA_rqs(); - integer stim_fd_rqs; - integer ret_code, rand_ret_code; + integer stim_fd_mul, stim_fd_shift, stim_fd_add; + integer ret_code; - for (int phase = 0; phase < 3; phase++) begin - case(phase) - 0 : begin - stim_fd_rqs = open_stim_file("RQS_MUL.txt"); - end - 1 : begin - stim_fd_rqs = open_stim_file("RQS_SHIFT.txt"); - end - 2 : begin - stim_fd_rqs = open_stim_file("RQS_ADD.txt"); - end - endcase + stim_fd_mul = open_stim_file("RQS_ATTN_MUL.txt"); + stim_fd_shift = open_stim_file("RQS_ATTN_SHIFT.txt"); + stim_fd_add = open_stim_file("RQS_ATTN_ADD.txt"); - case(phase) - 0 : begin - for (int j = 0; j < 6; j++) begin - ret_code = $fscanf(stim_fd_rqs, "%d\n", ita_ctrl.eps_mult[j]); - end - end - 1 : begin - for (int j = 0; j < 6; j++) begin - ret_code = $fscanf(stim_fd_rqs, "%d\n", ita_ctrl.right_shift[j]); - end - end - 2 : begin - for (int j = 0; j < 6; j++) begin - ret_code = $fscanf(stim_fd_rqs, "%d\n", ita_ctrl.add[j]); - end - end - endcase + for (int j = 0; j < N_ATTENTION_STEPS; j++) begin + ret_code = $fscanf(stim_fd_mul, "%d\n", ita_ctrl.eps_mult[j]); + ret_code = $fscanf(stim_fd_shift, "%d\n", ita_ctrl.right_shift[j]); + ret_code = $fscanf(stim_fd_add, "%d\n", ita_ctrl.add[j]); + end + + stim_fd_mul = open_stim_file("RQS_FFN_MUL.txt"); + stim_fd_shift = open_stim_file("RQS_FFN_SHIFT.txt"); + stim_fd_add = open_stim_file("RQS_FFN_ADD.txt"); - $fclose(stim_fd_rqs); + for (int j = 0; j < N_FEEDFORWARD_STEPS; j++) begin + ret_code = $fscanf(stim_fd_mul, "%d\n", ita_ctrl.eps_mult[j+N_ATTENTION_STEPS]); + ret_code = $fscanf(stim_fd_shift, "%d\n", ita_ctrl.right_shift[j+N_ATTENTION_STEPS]); + ret_code = $fscanf(stim_fd_add, "%d\n", ita_ctrl.add[j+N_ATTENTION_STEPS]); end + + $fclose(stim_fd_mul); + $fclose(stim_fd_shift); + $fclose(stim_fd_add); endtask task automatic check_ITA_outputs(input integer phase); @@ -402,8 +444,9 @@ task automatic apply_ITA_weights(input integer phase); oup_ready_q = oup_ready; if (successful_handshake(oup_valid, oup_ready)) begin tile_entry += 1; - if (requant_oup !== exp_res) - $display("[TB] ITA: Wrong value received %x, instead of %x at %0t.", requant_oup, exp_res, $time); + if (requant_oup !== exp_res) begin + $display("[TB] ITA: Wrong value received %x, instead of %x at %0t. (phase: %d)", requant_oup, exp_res, $time, phase); + end if (!is_last_group(group) && phase == 3 && should_toggle_output(input_file_index, tile_entry)) begin $display("[TB] ITA: %0d outputs were checked in phase %0d.",tile_entry, phase); $display("[TB] ITA: Output Switch: tile_entry: %0d, group: %0d at %0t.", tile_entry, group, $time); @@ -423,11 +466,12 @@ task automatic apply_ITA_weights(input integer phase); ita_ctrl.eps_mult = 1; ita_ctrl.right_shift = 8; ita_ctrl.add = 0; - ita_ctrl.lin_tiles = N_TILES_LINEAR_PROJECTION; - ita_ctrl.attn_tiles = N_TILES_ATTENTION; ita_ctrl.tile_e = N_TILES_EMBEDDING_DIM; ita_ctrl.tile_p = N_TILES_PROJECTION_DIM; ita_ctrl.tile_s = N_TILES_SEQUENCE_DIM; + ita_ctrl.tile_f = N_TILES_FEEDFORWARD; + + read_activation_constants(ita_ctrl.gelu_b, ita_ctrl.gelu_c, ita_ctrl.activation_requant_mult, ita_ctrl.activation_requant_shift, ita_ctrl.activation_requant_add); inp_valid = 1'b0; inp = '0; @@ -440,6 +484,8 @@ task automatic apply_ITA_weights(input integer phase); @(posedge clk); #(APPL_DELAY); ita_ctrl.start = 1'b1; + ita_ctrl.layer = Attention; + ita_ctrl.activation = Identity; stim_applied = 1; @(posedge clk); @@ -450,6 +496,22 @@ task automatic apply_ITA_weights(input integer phase); apply_ITA_inputs(phase); end + @(posedge clk); + #(APPL_DELAY); + ita_ctrl.start = 1'b1; + ita_ctrl.layer = Feedforward; + ita_ctrl.activation = ACTIVATION; + + @(posedge clk); + #(APPL_DELAY); + ita_ctrl.start = 1'b0; + + apply_ITA_inputs(5); + + ita_ctrl.activation = Identity; + + apply_ITA_inputs(6); + @(posedge clk); #(APPL_DELAY); inp = '0; @@ -472,6 +534,10 @@ task automatic apply_ITA_weights(input integer phase); apply_ITA_weights(phase); end + apply_ITA_weights(5); + + apply_ITA_weights(6); + @(posedge clk); #(APPL_DELAY); inp_weight = '0; @@ -496,9 +562,10 @@ task automatic apply_ITA_weights(input integer phase); for (int i = 0; i < ITERS; i++) begin @(posedge clk); - for (int phase = 0; phase < 5; phase++) begin + for (int phase = 0; phase < 7; phase++) begin check_ITA_outputs(phase); end + end #(50*CLK_PERIOD); diff --git a/testGenerator.py b/testGenerator.py index cd27e70..993ddfe 100644 --- a/testGenerator.py +++ b/testGenerator.py @@ -43,14 +43,16 @@ def generateMHA(**args): S = args['S'] P = args['P'] E = args['E'] + F = args['F'] H = args['H'] NO_BIAS = args['no_bias'] NO_PARTIAL_SOFTMAX = args['no_partial_softmax'] + base_path = f'{current_dir}/simvectors/data_S{S}_E{E}_P{P}_F{F}_H{H}_B{int(not NO_BIAS)}' if NO_PARTIAL_SOFTMAX: - path = f'{current_dir}/simvectors/data_S{S}_E{E}_P{P}_H{H}_B{int(not NO_BIAS)}_noPartialSoftmax/' + path = f'{base_path}_noPartialSoftmax/' else: - path = f'{current_dir}/simvectors/data_S{S}_E{E}_P{P}_H{H}_B{int(not NO_BIAS)}/' + path = f'{base_path}/' os.makedirs(path, exist_ok = True) ITA.generateTestVectors(path, **args) @@ -91,12 +93,21 @@ class ArgumentDefaultMetavarTypeFormatter(argparse.ArgumentDefaultsHelpFormatter self.group1.add_argument('-B', default = 1, type = int, help = 'Number of batches') self.group1.add_argument('-S', default = 64, type = int, help = 'Sequence length') self.group1.add_argument('-E', default = 64, type = int, help = 'Embedding size') + self.group1.add_argument('-F', default = 64, type = int, help = 'Feedforward size') self.group1.add_argument('-P', default = 64, type = int, help = 'Projection size') self.group1.add_argument('-H', default = 1, type = int, help = 'Number of heads') + self.group1.add_argument('--activation', + default = 'identity', + type = str, + help = 'Activation function', + choices = ['gelu', 'relu', 'identity']) self.group1.add_argument('--no-partial-softmax', action = 'store_true', help = 'Disable partial softmax calculation') self.group1.add_argument('--no-bias', default = False, action = 'store_true', help = 'Disable bias') + self.group1.add_argument('--export-snitch-cluster', action = 'store_true', help = 'Export for snitch cluster') + self.group1.add_argument('--export-mempool', action = 'store_true', help = 'Export for mempool') + self.group1.add_argument('--export-rom', action = 'store_true', help = 'Export ROM configuration') if __name__ == "__main__": diff --git a/tests/run_loop.sh b/tests/run_loop.sh index 257eea7..bf8b368 100755 --- a/tests/run_loop.sh +++ b/tests/run_loop.sh @@ -58,7 +58,7 @@ do echo "Testing S=$s E=$e P=$p bias=$bias" >> $log_file # Run the test - make sim VSIM_FLAGS=-c no_stalls=$no_stalls s=$s e=$e p=$p bias=$bias + make sim VSIM_FLAGS=-c no_stalls=$no_stalls s=$s e=$e p=$p bias=$bias ./modelsim/return_status.sh modelsim/build/transcript $s $e ita_tb >> $log_file # Remove the test vectors