From 598b424ad29068904878d9d4ca2ec51b923405d2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gamze=20=C4=B0slamo=C4=9Flu?= <54476562+gamzeisl@users.noreply.github.com> Date: Wed, 30 Oct 2024 14:15:35 +0100 Subject: [PATCH] Support for single attention layer (#7) * [Makefile] Pass simulation path * [change] Alter time format in tb * [change] Add trigger task to tb * [feature] Add single attention layer * [fix] Forward correct biases if transposed * [fix] Correct typo * [feature] Add single attention tests --- .gitlab-ci.yml | 42 ++++++- Makefile | 5 +- src/hwpe/ita_hwpe_input_bias_buffer.sv | 9 +- src/hwpe/tb/ita_hwpe_tb.sv | 81 +++++++++++- src/ita.sv | 16 +-- src/ita_controller.sv | 8 +- src/ita_package.sv | 2 +- src/tb/ita_tb.sv | 168 +++++++++++++++++++------ 8 files changed, 274 insertions(+), 57 deletions(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index a8a7dcd..7b4c3ce 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -46,33 +46,52 @@ run_sim: F: 64 activation: gelu no_stalls: 0 + single_attention: 0 - S: 64 E: 64 P: 64 F: 64 activation: gelu no_stalls: 1 + single_attention: 0 - S: 128 E: 192 P: 256 F: 256 activation: gelu no_stalls: 0 + single_attention: 0 - S: 128 E: 192 P: 256 F: 256 activation: gelu no_stalls: 1 + single_attention: 0 - S: 192 E: 256 P: 128 F: 128 activation: relu no_stalls: 1 + single_attention: 0 + - S: 128 + E: 192 + P: 256 + F: 256 + activation: gelu + no_stalls: 0 + single_attention: 1 + - S: 192 + E: 256 + P: 128 + F: 128 + activation: relu + no_stalls: 0 + single_attention: 1 script: - make bender - - make sim VSIM_FLAGS=-c s=$S e=$E p=$P f=$F bias=1 activation=$activation no_stalls=$no_stalls + - make sim VSIM_FLAGS=-c s=$S e=$E p=$P f=$F bias=1 activation=$activation no_stalls=$no_stalls single_attention=$single_attention - ./modelsim/return_status.sh modelsim/build/transcript $S $E $P $F ita_tb run_hwpe_sim: @@ -87,31 +106,50 @@ run_hwpe_sim: F: 64 activation: gelu no_stalls: 0 + single_attention: 0 - S: 64 E: 64 P: 64 F: 64 activation: gelu no_stalls: 1 + single_attention: 0 - S: 128 E: 192 P: 256 F: 256 activation: gelu no_stalls: 0 + single_attention: 0 - S: 128 E: 192 P: 256 F: 256 activation: gelu no_stalls: 1 + single_attention: 0 - S: 192 E: 256 P: 128 F: 128 activation: relu no_stalls: 1 + single_attention: 0 + - S: 128 + E: 192 + P: 256 + F: 256 + activation: gelu + no_stalls: 0 + single_attention: 1 + - S: 192 + E: 256 + P: 128 + F: 128 + activation: relu + no_stalls: 0 + single_attention: 1 script: - make bender - - make sim VSIM_FLAGS=-c DEBUG=OFF target=sim_ita_hwpe_tb s=$S e=$E p=$P f=$F bias=1 activation=$activation no_stalls=$no_stalls + - make sim VSIM_FLAGS=-c DEBUG=OFF target=sim_ita_hwpe_tb s=$S e=$E p=$P f=$F bias=1 activation=$activation no_stalls=$no_stalls single_attention=$single_attention - ./modelsim/return_status.sh modelsim/build/transcript $S $E $P $F hwpe_tb diff --git a/Makefile b/Makefile index 5996908..3359ca7 100644 --- a/Makefile +++ b/Makefile @@ -20,6 +20,7 @@ BENDER_TARGETS = -t rtl -t test target ?= sim_ita_tb no_stalls ?= 0 +single_attention ?= 0 s ?= 64 e ?= 128 p ?= 192 @@ -33,7 +34,7 @@ else ifeq ($(activation), relu) else activation_int = 0 endif -vlog_defs += -DNO_STALLS=$(no_stalls) -DSEQ_LENGTH=$(s) -DEMBED_SIZE=$(e) -DPROJ_SPACE=$(p) -DFF_SIZE=$(f) -DBIAS=$(bias) -DACTIVATION=$(activation_int) +vlog_defs += -DNO_STALLS=$(no_stalls) -DSINGLE_ATTENTION=$(single_attention) -DSEQ_LENGTH=$(s) -DEMBED_SIZE=$(e) -DPROJ_SPACE=$(p) -DFF_SIZE=$(f) -DBIAS=$(bias) -DACTIVATION=$(activation_int) ifeq ($(target), sim_ita_hwpe_tb) BENDER_TARGETS += -t ita_hwpe -t ita_hwpe_test @@ -60,7 +61,7 @@ sim-script: clean-sim sim: sim-script cd modelsim && \ - $(MAKE) $(target) + $(MAKE) $(target) buildpath=$(ROOT_DIR)/$(SIM_PATH) synopsys-script: rm ../ita-gf22/$(SYNTH_PATH)/scripts/analyze.tcl diff --git a/src/hwpe/ita_hwpe_input_bias_buffer.sv b/src/hwpe/ita_hwpe_input_bias_buffer.sv index 6def5c0..eaf9b2f 100644 --- a/src/hwpe/ita_hwpe_input_bias_buffer.sv +++ b/src/hwpe/ita_hwpe_input_bias_buffer.sv @@ -30,6 +30,8 @@ module ita_hwpe_input_bias_buffer #( logic write_enable; logic [OUTPUT_DATA_WIDTH-1:0] read_data; + bias_t bias_reshape; + always_comb begin // Default assignments state_d = state_q; @@ -43,6 +45,7 @@ module ita_hwpe_input_bias_buffer #( data_o.valid = 0; data_o.strb = 48'hFFFFFFFFFFFF; data_o.data = '0; + bias_reshape = '0; case(state_q) Write: begin @@ -60,8 +63,10 @@ module ita_hwpe_input_bias_buffer #( data_o.valid = read_enable_q; if (read_enable_q) begin data_o.data = read_data; - if (bias_dir_i) - data_o.data = read_data >> read_cnt_q[3:0] * 24; + if (bias_dir_i) begin + bias_reshape = read_data >> read_cnt_q[3:0] * 24; + data_o.data = {N {bias_reshape[0]}}; + end end read_enable = 1; if(data_o.valid && data_o.ready) begin diff --git a/src/hwpe/tb/ita_hwpe_tb.sv b/src/hwpe/tb/ita_hwpe_tb.sv index 712e779..e24e70d 100644 --- a/src/hwpe/tb/ita_hwpe_tb.sv +++ b/src/hwpe/tb/ita_hwpe_tb.sv @@ -27,6 +27,7 @@ module ita_hwpe_tb; parameter integer EMBEDDING_SIZE = `ifdef EMBED_SIZE `EMBED_SIZE `else M_TILE_LEN `endif; parameter integer FEEDFORWARD_SIZE = `ifdef FF_SIZE `FF_SIZE `else M_TILE_LEN `endif; parameter activation_e ACTIVATION = `ifdef ACTIVATION `ACTIVATION `else Identity `endif; + parameter integer SINGLE_ATTENTION = `ifdef SINGLE_ATTENTION `SINGLE_ATTENTION `else 0 `endif; integer N_TILES_SEQUENCE_DIM, N_TILES_EMBEDDING_DIM, N_TILES_PROJECTION_DIM, N_TILES_FEEDFORWARD_DIM; integer N_ELEMENTS_PER_TILE; @@ -118,6 +119,8 @@ module ita_hwpe_tb; `HCI_INTF_ARRAY(tcdm_mem, clk_i, MP-1:0); initial begin + $timeformat(-9, 1, " ns", 11); + simdir = { "../../simvectors/data_S", $sformatf("%0d", SEQUENCE_LEN), @@ -356,11 +359,27 @@ endfunction ita_compute_step(Q, ita_reg_tiles_val, ita_reg_rqs_val, ita_reg_gelu_b_c_val, ita_reg_activation_rqs_val, clk); // 2: Step K + if (SINGLE_ATTENTION == 1) begin + // move corresponding ita_reg_rqs_val because linear layers use array[0] + ita_reg_rqs_val[0] = ita_reg_rqs_val[0] >> 8; + ita_reg_rqs_val[2] = ita_reg_rqs_val[2] >> 8; + ita_reg_rqs_val[4] = ita_reg_rqs_val[4] >> 8; + end ita_compute_step(K, ita_reg_tiles_val, ita_reg_rqs_val, ita_reg_gelu_b_c_val, ita_reg_activation_rqs_val, clk); // 3: Step V + if (SINGLE_ATTENTION == 1) begin + // move corresponding ita_reg_rqs_val because linear layers use array[0] + ita_reg_rqs_val[0] = ita_reg_rqs_val[0] >> 8; + ita_reg_rqs_val[2] = ita_reg_rqs_val[2] >> 8; + ita_reg_rqs_val[4] = ita_reg_rqs_val[4] >> 8; + end ita_compute_step(V, ita_reg_tiles_val, ita_reg_rqs_val, ita_reg_gelu_b_c_val, ita_reg_activation_rqs_val, clk); + if (SINGLE_ATTENTION == 1) begin + // Reset the RQS values + ita_reg_eps_mult_val_compute(ita_reg_rqs_val); + end for (int group = 0; group < N_TILES_SEQUENCE_DIM; group++) begin BASE_PTR_INPUT[QK] = BASE_PTR[15] + group * N_TILES_INNER_DIM[QK] * N_ELEMENTS_PER_TILE; @@ -382,12 +401,36 @@ endfunction end // 6: Step OW + if (SINGLE_ATTENTION == 1) begin + // Change order of P and E + ita_reg_tiles_val_compute(N_TILES_SEQUENCE_DIM, N_TILES_PROJECTION_DIM, N_TILES_EMBEDDING_DIM, N_TILES_FEEDFORWARD_DIM, ita_reg_tiles_val); + // move corresponding ita_reg_rqs_val because linear layers use array[0] + ita_reg_rqs_val[0] = ita_reg_rqs_val[1] >> 8; + ita_reg_rqs_val[2] = ita_reg_rqs_val[3] >> 8; + ita_reg_rqs_val[4] = ita_reg_rqs_val[5] >> 8; + end ita_compute_step(OW, ita_reg_tiles_val, ita_reg_rqs_val, ita_reg_gelu_b_c_val, ita_reg_activation_rqs_val, clk); // 7: Step FF1 + if (SINGLE_ATTENTION == 1) begin + // Change order of P and F + ita_reg_tiles_val_compute(N_TILES_SEQUENCE_DIM, N_TILES_EMBEDDING_DIM, N_TILES_FEEDFORWARD_DIM, N_TILES_PROJECTION_DIM, ita_reg_tiles_val); + // move corresponding ita_reg_rqs_val because linear layers use array[0] + ita_reg_rqs_val[0] = ita_reg_rqs_val[1] >> 16; + ita_reg_rqs_val[2] = ita_reg_rqs_val[3] >> 16; + ita_reg_rqs_val[4] = ita_reg_rqs_val[5] >> 16; + end ita_compute_step(F1, ita_reg_tiles_val, ita_reg_rqs_val, ita_reg_gelu_b_c_val, ita_reg_activation_rqs_val, clk); - // 8: Step FF1 + // 8: Step FF2 + if (SINGLE_ATTENTION == 1) begin + // Change order of E and F + ita_reg_tiles_val_compute(N_TILES_SEQUENCE_DIM, N_TILES_FEEDFORWARD_DIM, N_TILES_EMBEDDING_DIM, N_TILES_PROJECTION_DIM, ita_reg_tiles_val); + // move corresponding ita_reg_rqs_val because linear layers use array[0] + ita_reg_rqs_val[0] = ita_reg_rqs_val[1] >> 24; + ita_reg_rqs_val[2] = ita_reg_rqs_val[3] >> 24; + ita_reg_rqs_val[4] = ita_reg_rqs_val[5] >> 24; + end ita_compute_step(F2, ita_reg_tiles_val, ita_reg_rqs_val, ita_reg_gelu_b_c_val, ita_reg_activation_rqs_val, clk); // Wait for the last step to finish @@ -452,8 +495,14 @@ endfunction // Calculate input_ptr, weight_ptr0, weight_ptr1, bias_ptr, and output_ptr ita_ptrs_compute(input_base_ptr, weight_base_ptr0, weight_base_ptr1, bias_base_ptr, output_base_ptr, step, tile, tile_x, tile_y, tile_inner, input_ptr, weight_ptr0, weight_ptr1, bias_ptr, output_ptr); - // Calculate ita_reg_en - ita_reg_en_compute(step, tile, ita_reg_en); + if (SINGLE_ATTENTION == 1) begin + // Enable ita_reg_en + ita_reg_en = 1'b1; + end else begin + // Calculate ita_reg_en + ita_reg_en_compute(step, tile, ita_reg_en); + end + // Calculate ctrl_stream_val, weight_ptr_en, and bias_ptr_en ctrl_val_compute(step, tile, ctrl_engine_val, ctrl_stream_val, weight_ptr_en, bias_ptr_en); @@ -564,7 +613,13 @@ endfunction ctrl_stream_val = 32'h0; reg_weight_en = 1'b0; reg_bias_en = 1'b0; - layer_type = Attention; + + if (SINGLE_ATTENTION == 1) begin + layer_type = Linear; + end else begin + layer_type = Attention; + end + activation_function = Identity; ctrl_engine_val = layer_type | activation_function << 2; @@ -598,11 +653,17 @@ endfunction reg_bias_en = 1'b1; end QK : begin + if (SINGLE_ATTENTION == 1) begin + ctrl_engine_val = SingleAttention | Identity << 2; + end ctrl_stream_val = {28'b0, 4'b0110}; // weight nextload and disable bias reg_weight_en = 1'b1; reg_bias_en = 1'b0; end AV : begin + if (SINGLE_ATTENTION == 1) begin + ctrl_engine_val = SingleAttention | Identity << 2; + end ctrl_stream_val = {28'b0, 4'b0110}; // weight nextload and disable bias reg_weight_en = 1'b1; reg_bias_en = 1'b0; @@ -618,7 +679,11 @@ endfunction reg_bias_en = 1'b1; end F1 : begin - ctrl_engine_val = Feedforward | ACTIVATION << 2; + if (SINGLE_ATTENTION == 1) begin + ctrl_engine_val = Linear | ACTIVATION << 2; + end else begin + ctrl_engine_val = Feedforward | ACTIVATION << 2; + end if (tile == 0) begin ctrl_stream_val = {28'b0, 4'b0011}; // weight preload and weight nextload end else begin @@ -628,7 +693,11 @@ endfunction reg_bias_en = 1'b1; end F2 : begin - ctrl_engine_val = Feedforward | Identity << 2; + if (SINGLE_ATTENTION == 1) begin + ctrl_engine_val = Linear | Identity << 2; + end else begin + ctrl_engine_val = Feedforward | Identity << 2; + end if (tile == (N_TILES_OUTER_X[F2]*N_TILES_OUTER_Y[F2]*N_TILES_INNER_DIM[F2])-1) begin ctrl_stream_val = {28'b0, 4'b0000}; reg_weight_en = 1'b0; diff --git a/src/ita.sv b/src/ita.sv index 9bdb1bb..2dad263 100644 --- a/src/ita.sv +++ b/src/ita.sv @@ -194,14 +194,14 @@ module ita ); ita_input_sampler i_input_sampler ( - .clk_i (clk_i ), - .rst_ni (rst_ni ), - .valid_i (inp_valid_i ), - .ready_i (inp_ready_o ), - .inp_i (inp_i ), - .inp_bias_i ((step == QK || step == AV) ? '0 : ((step == V) ? {N {inp_bias_i[0]}} : inp_bias_i) ), // TODO: temporary fix - .inp_o (inp ), - .inp_bias_o (inp_bias ) + .clk_i (clk_i ), + .rst_ni (rst_ni ), + .valid_i (inp_valid_i ), + .ready_i (inp_ready_o ), + .inp_i (inp_i ), + .inp_bias_i (inp_bias_i ), + .inp_o (inp ), + .inp_bias_o (inp_bias ) ); ita_inp1_mux i_inp1_mux ( diff --git a/src/ita_controller.sv b/src/ita_controller.sv index b4156f5..0fa8034 100644 --- a/src/ita_controller.sv +++ b/src/ita_controller.sv @@ -119,6 +119,8 @@ module ita_controller step_d = F1; end else if (ctrl_i.layer == Linear) begin step_d = MatMul; + end else if (ctrl_i.layer == SingleAttention) begin + step_d = QK; end end end @@ -187,7 +189,11 @@ module ita_controller softmax_tile_d = softmax_tile_q + 1; if (softmax_tile_d == ctrl_i.tile_s) begin softmax_tile_d = '0; - step_d = OW; + if (ctrl_i.layer == Attention) begin + step_d = OW; + end else if (ctrl_i.layer == SingleAttention) begin + step_d = Idle; + end end else begin step_d = QK; end diff --git a/src/ita_package.sv b/src/ita_package.sv index 3d602d8..335e173 100644 --- a/src/ita_package.sv +++ b/src/ita_package.sv @@ -35,7 +35,7 @@ package ita_package; parameter int unsigned N_WRITE_EN = `ifdef TARGET_ITA_HWPE 8 `else M `endif; // Feedforward - typedef enum bit [1:0] {Attention=0, Feedforward=1, Linear=2} layer_e; + typedef enum bit [1:0] {Attention=0, Feedforward=1, Linear=2, SingleAttention=3} layer_e; typedef enum bit [1:0] {Identity=0, Gelu=1, Relu=2} activation_e; typedef logic signed [GELU_CONSTANTS_WIDTH-1:0] gelu_const_t; typedef logic signed [GELU_OUT_WIDTH-1:0] gelu_out_t; diff --git a/src/tb/ita_tb.sv b/src/tb/ita_tb.sv index d721247..e8f84a6 100644 --- a/src/tb/ita_tb.sv +++ b/src/tb/ita_tb.sv @@ -15,6 +15,8 @@ module ita_tb; localparam unsigned RST_CLK_CYCLES = 10; // Set to 1 to run the simulation without stalls localparam unsigned CONT = `ifdef NO_STALLS `NO_STALLS `else 0 `endif; + // Set to 1 to run the simulation with single attention layer and linear layers + localparam unsigned SINGLE_ATTENTION = `ifdef SINGLE_ATTENTION `SINGLE_ATTENTION `else 0 `endif; localparam unsigned ITERS = 1; localparam unsigned N_PHASES = 7; @@ -56,12 +58,17 @@ module ita_tb; requant_oup_t requant_oup ; requant_oup_t exp_res; logic oup_valid, oup_ready; + requant_const_array_t stim_eps_mult; + requant_const_array_t stim_right_shift; + requant_array_t stim_add; // Variables string simdir; integer stim_applied; initial begin + $timeformat(-9, 1, " ns", 11); + N_PE = `ifdef ITA_N `ITA_N `else 16 `endif; M_TILE_LEN = `ifdef ITA_M `ITA_M `else 64 `endif; SEQUENCE_LEN = `ifdef SEQ_LENGTH `SEQ_LENGTH `else M_TILE_LEN `endif; @@ -285,6 +292,16 @@ task automatic read_activation_constants( $fclose(add_fd); endtask +task automatic trigger_ITA (); + @(posedge clk); + #(APPL_DELAY); + ita_ctrl.start = 1'b1; + + @(posedge clk); + #(APPL_DELAY); + ita_ctrl.start = 1'b0; +endtask + task automatic apply_ITA_inputs(input integer phase); integer stim_fd_inp_attn[2]; bit input_file_index = 0; @@ -298,7 +315,7 @@ task automatic apply_ITA_inputs(input integer phase); integer stim_fd_inp; integer stim_fd_bias; - $display("[TB] ITA: Applying inputs in phase %0d at %0t.", phase, $time); + $display("[TB] ITA: Applying inputs in phase %0d at %t.", phase, $time); group = 0; tile = 0; @@ -323,7 +340,7 @@ task automatic apply_ITA_inputs(input integer phase); if(successful_handshake(inp_valid, inp_ready)) begin tile_entry += 1; if (should_toggle_input(tile_entry, group) && phase == 3) begin - $display("[TB] ITA: Input Switch: tile_entry: %d, group: %d at %t.", tile_entry, group, $time); + $display("[TB] ITA: Input Switch: tile_entry: %0d, group: %0d at %t.", tile_entry, group, $time); toggle_input(tile_entry, group, input_file_index); end if (is_end_of_tile(tile_entry) && phase != 3) @@ -351,7 +368,7 @@ task automatic apply_ITA_weights(input integer phase); integer group; integer stim_fd_weight; - $display("[TB] ITA: Applying weights in phase %0d at %0t.", phase, $time); + $display("[TB] ITA: Applying weights in phase %0d at %t.", phase, $time); group = 0; tile = 0; @@ -375,7 +392,7 @@ task automatic apply_ITA_weights(input integer phase); if (successful_handshake(inp_weight_valid, inp_weight_ready)) begin tile_entry += 1; if (should_toggle_input(tile_entry, group) && phase == 3) begin - $display("[TB] ITA: Weight Switch: tile_entry: %0d, group: %0d at %0t.", tile_entry, group, $time); + $display("[TB] ITA: Weight Switch: tile_entry: %0d, group: %0d at %t.", tile_entry, group, $time); toggle_input(tile_entry, group, input_file_index); end stim_fd_weight = stim_fd_weight_attn[input_file_index]; @@ -394,9 +411,9 @@ task automatic apply_ITA_weights(input integer phase); stim_fd_add = open_stim_file("RQS_ATTN_ADD.txt"); for (int j = 0; j < N_ATTENTION_STEPS; j++) begin - ret_code = $fscanf(stim_fd_mul, "%d\n", ita_ctrl.eps_mult[j]); - ret_code = $fscanf(stim_fd_shift, "%d\n", ita_ctrl.right_shift[j]); - ret_code = $fscanf(stim_fd_add, "%d\n", ita_ctrl.add[j]); + ret_code = $fscanf(stim_fd_mul, "%d\n", stim_eps_mult[j]); + ret_code = $fscanf(stim_fd_shift, "%d\n", stim_right_shift[j]); + ret_code = $fscanf(stim_fd_add, "%d\n", stim_add[j]); end stim_fd_mul = open_stim_file("RQS_FFN_MUL.txt"); @@ -404,9 +421,9 @@ task automatic apply_ITA_weights(input integer phase); stim_fd_add = open_stim_file("RQS_FFN_ADD.txt"); for (int j = 0; j < N_FEEDFORWARD_STEPS; j++) begin - ret_code = $fscanf(stim_fd_mul, "%d\n", ita_ctrl.eps_mult[j+N_ATTENTION_STEPS]); - ret_code = $fscanf(stim_fd_shift, "%d\n", ita_ctrl.right_shift[j+N_ATTENTION_STEPS]); - ret_code = $fscanf(stim_fd_add, "%d\n", ita_ctrl.add[j+N_ATTENTION_STEPS]); + ret_code = $fscanf(stim_fd_mul, "%d\n", stim_eps_mult[j+N_ATTENTION_STEPS]); + ret_code = $fscanf(stim_fd_shift, "%d\n", stim_right_shift[j+N_ATTENTION_STEPS]); + ret_code = $fscanf(stim_fd_add, "%d\n", stim_add[j+N_ATTENTION_STEPS]); end $fclose(stim_fd_mul); @@ -425,7 +442,7 @@ task automatic apply_ITA_weights(input integer phase); integer group; integer exp_resp_fd; - $display("[TB] ITA: Checking outputs in phase %d at %t.", phase, $time); + $display("[TB] ITA: Checking outputs in phase %0d at %t.", phase, $time); group = 0; tile_entry = 0; @@ -447,11 +464,11 @@ task automatic apply_ITA_weights(input integer phase); if (successful_handshake(oup_valid, oup_ready)) begin tile_entry += 1; if (requant_oup !== exp_res) begin - $display("[TB] ITA: Wrong value received %x, instead of %x at %0t. (phase: %d)", requant_oup, exp_res, $time, phase); + $display("[TB] ITA: Wrong value received %x, instead of %x at %t. (phase: %0d)", requant_oup, exp_res, $time, phase); end if (!is_last_group(group) && phase == 3 && should_toggle_output(input_file_index, tile_entry)) begin $display("[TB] ITA: %0d outputs were checked in phase %0d.",tile_entry, phase); - $display("[TB] ITA: Output Switch: tile_entry: %0d, group: %0d at %0t.", tile_entry, group, $time); + $display("[TB] ITA: Output Switch: tile_entry: %0d, group: %0d at %t.", tile_entry, group, $time); toggle_input(tile_entry, group, input_file_index); end exp_resp_fd = exp_resp_fd_attn[input_file_index]; @@ -485,34 +502,117 @@ task automatic apply_ITA_weights(input integer phase); for (int i = 0; i < ITERS; i++) begin @(posedge clk); #(APPL_DELAY); - ita_ctrl.start = 1'b1; - ita_ctrl.layer = Attention; + if (SINGLE_ATTENTION == 1) begin + ita_ctrl.layer = Linear; + end else begin + ita_ctrl.layer = Attention; + end ita_ctrl.activation = Identity; stim_applied = 1; - @(posedge clk); - #(APPL_DELAY); - ita_ctrl.start = 1'b0; + if (SINGLE_ATTENTION == 1) begin + // QKV Generation + for (int phase = 0; phase < 3; phase++) begin + @(posedge clk); + #(APPL_DELAY); + ita_ctrl.eps_mult[0] = stim_eps_mult[phase]; + ita_ctrl.right_shift[0] = stim_right_shift[phase]; + ita_ctrl.add[0] = stim_add[phase]; - for (int phase = 0; phase < 5; phase++) begin - apply_ITA_inputs(phase); - end + trigger_ITA(); - @(posedge clk); - #(APPL_DELAY); - ita_ctrl.start = 1'b1; - ita_ctrl.layer = Feedforward; - ita_ctrl.activation = ACTIVATION; + apply_ITA_inputs(phase); - @(posedge clk); - #(APPL_DELAY); - ita_ctrl.start = 1'b0; + #(10*CLK_PERIOD); + end - apply_ITA_inputs(5); + // Attention + @(posedge clk); + #(APPL_DELAY); + ita_ctrl.layer = SingleAttention; + ita_ctrl.eps_mult[3] = stim_eps_mult[3]; + ita_ctrl.right_shift[3] = stim_right_shift[3]; + ita_ctrl.add[3] = stim_add[3]; + ita_ctrl.eps_mult[4] = stim_eps_mult[4]; + ita_ctrl.right_shift[4] = stim_right_shift[4]; + ita_ctrl.add[4] = stim_add[4]; - ita_ctrl.activation = Identity; - - apply_ITA_inputs(6); + trigger_ITA(); + + apply_ITA_inputs(3); + + #(10*CLK_PERIOD); + + // OW Generation + @(posedge clk); + #(APPL_DELAY); + ita_ctrl.layer = Linear; + ita_ctrl.eps_mult[0] = stim_eps_mult[5]; + ita_ctrl.right_shift[0] = stim_right_shift[5]; + ita_ctrl.add[0] = stim_add[5]; + ita_ctrl.tile_e = N_TILES_PROJECTION_DIM; + ita_ctrl.tile_p = N_TILES_EMBEDDING_DIM; + + trigger_ITA(); + + apply_ITA_inputs(4); + + #(10*CLK_PERIOD); + + // FF1 + @(posedge clk); + #(APPL_DELAY); + ita_ctrl.layer = Linear; + ita_ctrl.activation = ACTIVATION; + ita_ctrl.tile_e = N_TILES_EMBEDDING_DIM; + ita_ctrl.tile_p = N_TILES_FEEDFORWARD; + ita_ctrl.eps_mult[0] = stim_eps_mult[6]; + ita_ctrl.right_shift[0] = stim_right_shift[6]; + ita_ctrl.add[0] = stim_add[6]; + + trigger_ITA(); + + apply_ITA_inputs(5); + + #(10*CLK_PERIOD); + + // FF2 + @(posedge clk); + #(APPL_DELAY); + ita_ctrl.activation = Identity; + ita_ctrl.tile_e = N_TILES_FEEDFORWARD; + ita_ctrl.tile_p = N_TILES_EMBEDDING_DIM; + ita_ctrl.eps_mult[0] = stim_eps_mult[7]; + ita_ctrl.right_shift[0] = stim_right_shift[7]; + ita_ctrl.add[0] = stim_add[7]; + + trigger_ITA(); + + apply_ITA_inputs(6); + end else begin + ita_ctrl.eps_mult = stim_eps_mult; + ita_ctrl.right_shift = stim_right_shift; + ita_ctrl.add = stim_add; + + trigger_ITA(); + + for (int phase = 0; phase < 5; phase++) begin + apply_ITA_inputs(phase); + end + + @(posedge clk); + #(APPL_DELAY); + ita_ctrl.layer = Feedforward; + ita_ctrl.activation = ACTIVATION; + + trigger_ITA(); + + apply_ITA_inputs(5); + + ita_ctrl.activation = Identity; + + apply_ITA_inputs(6); + end @(posedge clk); #(APPL_DELAY); @@ -548,8 +648,6 @@ task automatic apply_ITA_weights(input integer phase); end initial begin: rqs_application_block - wait (rst_n); - for (int i = 0; i < ITERS; i++) begin @(posedge clk); #(APPL_DELAY);