Skip to content

Commit

Permalink
New implementation of the strided mask
Browse files Browse the repository at this point in the history
  • Loading branch information
Marcel Kant committed Dec 9, 2024
1 parent b3068b1 commit ea6f6ff
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 17 deletions.
46 changes: 34 additions & 12 deletions src/ita_controller.sv
Original file line number Diff line number Diff line change
Expand Up @@ -505,25 +505,47 @@ module ita_controller
end
end
Strided: begin
mask_col_offset_d = '0;
mask_col_pos_d = (step_q == QK || step_q == AV) ? mask_col_pos_q : ctrl_i.mask_start_index;
mask_row_pos_d = (step_q == QK || step_q == AV) ? mask_row_pos_q : ctrl_i.mask_start_index;
//mask_col_offset_d = (step_q == QK || step_q == AV) ? mask_col_offset_q : ((mask_col_pos_q) & (N-1));
//Should be mask_col_pos
mask_tile_x_pos_d = '0;
mask_tile_y_pos_d = '0;
mask_pos_d = '0;
mask_d = '0;
//Should be mask_row_index
mask_tile_y_pos_d = (step_q == QK || step_q == AV) ? mask_tile_y_pos_q : ctrl_i.mask_start_index;
//Should be mask_col_pos_actual
mask_pos_d = (step_q == QK || step_q == AV) ? mask_pos_q : '0;
mask_d = '1;

if (step_q == QK) begin
if (last_inner_tile_o == 1'b1) begin
for (int i = 0; i < N; i++) begin
//col_pos = count_q/M + i + mask_tile_x_pos_q * M
//row_pos = count_q & (M-1) + mask_tile_y_pos_q * M
if ((((((count_q / M) * N) + i + (tile_x_q * M)) - ((count_q & (M-1)) + (tile_y_q * M))) & (ctrl_i.mask_start_index-1)) == 0) begin
mask_d[i] = 1'b0;
end else begin
mask_d[i] = 1'b1;
if ((count_q & (M-1)) == 0) begin
for (int i = 0; i < N; i++) begin
if (i == 0 && ((count_q / M) * N) == 0) begin
mask_d[i] = 0;
end else if (i == (mask_col_pos_q & (N-1)) && ((((count_q / M) * N) + (tile_x_q * M)) == (mask_col_pos_d / N))) begin
mask_col_pos_d = mask_col_pos_q + ctrl_i.mask_start_index;
mask_d[i] = 0;
end else begin
mask_d[i] = 1;
end
end
end else if (count_q >= mask_pos_q && count_q < (mask_pos_q + N)) begin
//Circular shift
mask_d = (mask_q << (N-1)) | (mask_q >> 1)
if ((count_q & (M-1)) == mask_row_pos_q) begin
mask_row_pos_d = mask_row_pos_q + ctrl_i.mask_start_index;
end
if ((count_q & (N-1)) == (N-1)) begin
if (ctrl_i.mask_start_index < N) begin
mask_pos_d = mask_pos_q + N;
end else begin
mask_pos_d = mask_pos_q + ctrl_i.mask_start_index;
end
end
$display("Circular shift", mask_d);
end
end
end
end
end
endcase

Expand Down
8 changes: 4 additions & 4 deletions src/ita_package.sv
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,10 @@ package ita_package;
typedef logic [N_REQUANT_CONSTS-1:0][EMS-1:0] requant_const_array_t;
typedef logic signed [WI-1:0] requant_t;
typedef logic signed [N_REQUANT_CONSTS-1:0][WI-1:0] requant_array_t;
typedef logic [WO-WI*2-2:0] seq_length_t;
typedef logic [WO-WI*2-2:0] proj_space_t;
typedef logic [WO-WI*2-2:0] embed_size_t;
typedef logic [WO-WI*2-2:0] ff_size_t;
typedef logic [WO-WI*2-1:0] seq_length_t;
typedef logic [WO-WI*2-1:0] proj_space_t;
typedef logic [WO-WI*2-1:0] embed_size_t;
typedef logic [WO-WI*2-1:0] ff_size_t;
typedef logic [ 32-1:0] tile_t;
typedef struct packed {
logic start ;
Expand Down
3 changes: 2 additions & 1 deletion src/ita_softmax.sv
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,8 @@ module ita_softmax
write_max_addr_o = count_q3;
write_max_data_o = max_q;
for (int i = 0; i < N; i++) begin
if (shift_q[i] != 4'hF)
// MARCELK: This is most likely not required
// if (shift_q[i] != 4'hF)
exp_sum_d += unsigned'(9'h100)>>shift_q[i];
end
if (tile_q3 != '0 || count_q3>=M) begin // If not first part of the first row
Expand Down

0 comments on commit ea6f6ff

Please sign in to comment.