diff --git a/src/ita_controller.sv b/src/ita_controller.sv index d3c2f82..9817f34 100644 --- a/src/ita_controller.sv +++ b/src/ita_controller.sv @@ -505,25 +505,47 @@ module ita_controller end end Strided: begin - mask_col_offset_d = '0; + mask_col_pos_d = (step_q == QK || step_q == AV) ? mask_col_pos_q : ctrl_i.mask_start_index; + mask_row_pos_d = (step_q == QK || step_q == AV) ? mask_row_pos_q : ctrl_i.mask_start_index; + //mask_col_offset_d = (step_q == QK || step_q == AV) ? mask_col_offset_q : ((mask_col_pos_q) & (N-1)); + //Should be mask_col_pos mask_tile_x_pos_d = '0; - mask_tile_y_pos_d = '0; - mask_pos_d = '0; - mask_d = '0; + //Should be mask_row_index + mask_tile_y_pos_d = (step_q == QK || step_q == AV) ? mask_tile_y_pos_q : ctrl_i.mask_start_index; + //Should be mask_col_pos_actual + mask_pos_d = (step_q == QK || step_q == AV) ? mask_pos_q : '0; + mask_d = '1; if (step_q == QK) begin if (last_inner_tile_o == 1'b1) begin - for (int i = 0; i < N; i++) begin - //col_pos = count_q/M + i + mask_tile_x_pos_q * M - //row_pos = count_q & (M-1) + mask_tile_y_pos_q * M - if ((((((count_q / M) * N) + i + (tile_x_q * M)) - ((count_q & (M-1)) + (tile_y_q * M))) & (ctrl_i.mask_start_index-1)) == 0) begin - mask_d[i] = 1'b0; - end else begin - mask_d[i] = 1'b1; + if ((count_q & (M-1)) == 0) begin + for (int i = 0; i < N; i++) begin + if (i == 0 && ((count_q / M) * N) == 0) begin + mask_d[i] = 0; + end else if (i == (mask_col_pos_q & (N-1)) && ((((count_q / M) * N) + (tile_x_q * M)) == (mask_col_pos_d / N))) begin + mask_col_pos_d = mask_col_pos_q + ctrl_i.mask_start_index; + mask_d[i] = 0; + end else begin + mask_d[i] = 1; + end + end + end else if (count_q >= mask_pos_q && count_q < (mask_pos_q + N)) begin + //Circular shift + mask_d = (mask_q << (N-1)) | (mask_q >> 1) + if ((count_q & (M-1)) == mask_row_pos_q) begin + mask_row_pos_d = mask_row_pos_q + ctrl_i.mask_start_index; + end + if ((count_q & (N-1)) == (N-1)) begin + if (ctrl_i.mask_start_index < N) begin + mask_pos_d = mask_pos_q + N; + end else begin + mask_pos_d = mask_pos_q + ctrl_i.mask_start_index; + end end + $display("Circular shift", mask_d); end end - end + end end endcase diff --git a/src/ita_package.sv b/src/ita_package.sv index f63a299..7c6249b 100644 --- a/src/ita_package.sv +++ b/src/ita_package.sv @@ -49,10 +49,10 @@ package ita_package; typedef logic [N_REQUANT_CONSTS-1:0][EMS-1:0] requant_const_array_t; typedef logic signed [WI-1:0] requant_t; typedef logic signed [N_REQUANT_CONSTS-1:0][WI-1:0] requant_array_t; - typedef logic [WO-WI*2-2:0] seq_length_t; - typedef logic [WO-WI*2-2:0] proj_space_t; - typedef logic [WO-WI*2-2:0] embed_size_t; - typedef logic [WO-WI*2-2:0] ff_size_t; + typedef logic [WO-WI*2-1:0] seq_length_t; + typedef logic [WO-WI*2-1:0] proj_space_t; + typedef logic [WO-WI*2-1:0] embed_size_t; + typedef logic [WO-WI*2-1:0] ff_size_t; typedef logic [ 32-1:0] tile_t; typedef struct packed { logic start ; diff --git a/src/ita_softmax.sv b/src/ita_softmax.sv index e99653b..b452f76 100644 --- a/src/ita_softmax.sv +++ b/src/ita_softmax.sv @@ -187,7 +187,8 @@ module ita_softmax write_max_addr_o = count_q3; write_max_data_o = max_q; for (int i = 0; i < N; i++) begin - if (shift_q[i] != 4'hF) + // MARCELK: This is most likely not required + // if (shift_q[i] != 4'hF) exp_sum_d += unsigned'(9'h100)>>shift_q[i]; end if (tile_q3 != '0 || count_q3>=M) begin // If not first part of the first row