Skip to content

Commit

Permalink
HWPE Settings (#10)
Browse files Browse the repository at this point in the history
* [change] Make default tile values 1

* Increase HWPE contexts to 4
  • Loading branch information
gamzeisl authored Dec 5, 2024
1 parent 39236c9 commit 8ef386a
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 39 deletions.
2 changes: 1 addition & 1 deletion Bender.lock
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ packages:
- hwpe-stream
- l2_tcdm_hybrid_interco
hwpe-ctrl:
revision: 2926867cafb3fb518a1ae849675f281b79ecab8a
revision: 7ba707d837697c2c7c6ea1396ec4e4ab094054a2
version: null
source:
Git: https://github.com/pulp-platform/hwpe-ctrl
Expand Down
2 changes: 1 addition & 1 deletion Bender.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ dependencies:
common_cells: { git: https://github.com/pulp-platform/common_cells, version: 1.23.0 }
hwpe-stream: { git: https://github.com/pulp-platform/hwpe-stream, rev: a20f35e62fe2842904797079dc7881e490ff7117 }
hci: { git: https://github.com/pulp-platform/hci, rev: 066c7ce7d24b61587e245decb592054669d7a2d1 }
hwpe-ctrl: { git: https://github.com/pulp-platform/hwpe-ctrl, rev: 2926867cafb3fb518a1ae849675f281b79ecab8a }
hwpe-ctrl: { git: https://github.com/pulp-platform/hwpe-ctrl, rev: 7ba707d837697c2c7c6ea1396ec4e4ab094054a2 }
scm: { git: https://github.com/pulp-platform/scm, rev: 998466d2a3c2d7d572e43d2666d93c4f767d8d60 }
tech_cells_generic: { git: https://github.com/pulp-platform/tech_cells_generic, version: 0.2.11 }

Expand Down
8 changes: 4 additions & 4 deletions src/hwpe/ita_hwpe_ctrl.sv
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,10 @@ module ita_hwpe_ctrl
always_comb begin
ctrl_engine_o = '0;
ctrl_engine_o.start = slave_flags.start;
ctrl_engine_o.tile_s = reg_file.hwpe_params[ITA_REG_TILES][3:0];
ctrl_engine_o.tile_e = reg_file.hwpe_params[ITA_REG_TILES][7:4];
ctrl_engine_o.tile_p = reg_file.hwpe_params[ITA_REG_TILES][11:8];
ctrl_engine_o.tile_f = reg_file.hwpe_params[ITA_REG_TILES][15:12];
ctrl_engine_o.tile_s = reg_file.hwpe_params[ITA_REG_TILES][3:0] == 0 ? 1 : reg_file.hwpe_params[ITA_REG_TILES][3:0];
ctrl_engine_o.tile_e = reg_file.hwpe_params[ITA_REG_TILES][7:4] == 0 ? 1 : reg_file.hwpe_params[ITA_REG_TILES][7:4];
ctrl_engine_o.tile_p = reg_file.hwpe_params[ITA_REG_TILES][11:8] == 0 ? 1 : reg_file.hwpe_params[ITA_REG_TILES][11:8];
ctrl_engine_o.tile_f = reg_file.hwpe_params[ITA_REG_TILES][15:12] == 0 ? 1 : reg_file.hwpe_params[ITA_REG_TILES][15:12];
ctrl_engine_o.eps_mult[0] = reg_file.hwpe_params[ITA_REG_EPS_MULT0][7:0];
ctrl_engine_o.eps_mult[1] = reg_file.hwpe_params[ITA_REG_EPS_MULT0][15:8];
ctrl_engine_o.eps_mult[2] = reg_file.hwpe_params[ITA_REG_EPS_MULT0][23:16];
Expand Down
2 changes: 1 addition & 1 deletion src/hwpe/ita_hwpe_package.sv
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ package ita_hwpe_package;

// HWPE Configuration
parameter int unsigned N_CORES = 9;
parameter int unsigned N_CONTEXT = 2;
parameter int unsigned N_CONTEXT = 4;
parameter int unsigned ID_WIDTH = 2;
parameter int unsigned ITA_IO_REGS = 17; // 5 address + 11 parameters + 1 sync

Expand Down
51 changes: 19 additions & 32 deletions src/hwpe/tb/ita_hwpe_tb.sv
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,7 @@ endfunction
// Signals
logic [31:0] status;
string STIM_DATA;
int ita_reg_cnt;
logic [31:0] ita_reg_tiles_val;
logic [5:0][31:0] ita_reg_rqs_val;
logic [31:0] ita_reg_gelu_b_c_val;
Expand All @@ -338,6 +339,7 @@ endfunction

// Wait for reset to be released
wait (rst_n);
ita_reg_cnt = 0;

// Load memory
STIM_DATA = {simdir,"/hwpe/mem.txt"};
Expand All @@ -356,7 +358,7 @@ endfunction
PERIPH_READ( 32'h04, 32'h0, status, clk);

// 1: Step Q
ita_compute_step(Q, ita_reg_tiles_val, ita_reg_rqs_val, ita_reg_gelu_b_c_val, ita_reg_activation_rqs_val, clk);
ita_compute_step(Q, ita_reg_cnt, ita_reg_tiles_val, ita_reg_rqs_val, ita_reg_gelu_b_c_val, ita_reg_activation_rqs_val, clk);

// 2: Step K
if (SINGLE_ATTENTION == 1) begin
Expand All @@ -365,7 +367,7 @@ endfunction
ita_reg_rqs_val[2] = ita_reg_rqs_val[2] >> 8;
ita_reg_rqs_val[4] = ita_reg_rqs_val[4] >> 8;
end
ita_compute_step(K, ita_reg_tiles_val, ita_reg_rqs_val, ita_reg_gelu_b_c_val, ita_reg_activation_rqs_val, clk);
ita_compute_step(K, ita_reg_cnt, ita_reg_tiles_val, ita_reg_rqs_val, ita_reg_gelu_b_c_val, ita_reg_activation_rqs_val, clk);

// 3: Step V
if (SINGLE_ATTENTION == 1) begin
Expand All @@ -374,7 +376,7 @@ endfunction
ita_reg_rqs_val[2] = ita_reg_rqs_val[2] >> 8;
ita_reg_rqs_val[4] = ita_reg_rqs_val[4] >> 8;
end
ita_compute_step(V, ita_reg_tiles_val, ita_reg_rqs_val, ita_reg_gelu_b_c_val, ita_reg_activation_rqs_val, clk);
ita_compute_step(V, ita_reg_cnt, ita_reg_tiles_val, ita_reg_rqs_val, ita_reg_gelu_b_c_val, ita_reg_activation_rqs_val, clk);

if (SINGLE_ATTENTION == 1) begin
// Reset the RQS values
Expand All @@ -389,15 +391,15 @@ endfunction
BASE_PTR_OUTPUT[AV] = BASE_PTR[19] + group * N_TILES_OUTER_X[AV] * N_ELEMENTS_PER_TILE;

// 4: Step QK
ita_compute_step(QK, ita_reg_tiles_val, ita_reg_rqs_val, ita_reg_gelu_b_c_val, ita_reg_activation_rqs_val, clk);
ita_compute_step(QK, ita_reg_cnt, ita_reg_tiles_val, ita_reg_rqs_val, ita_reg_gelu_b_c_val, ita_reg_activation_rqs_val, clk);

// WIESEP: Hack to ensure that during the last tile of AV, the weight pointer is set correctly
if (group == N_TILES_SEQUENCE_DIM-1) begin
BASE_PTR_WEIGHT0[QK] = BASE_PTR_WEIGHT0[OW];
end

// 5: Step AV
ita_compute_step(AV, ita_reg_tiles_val, ita_reg_rqs_val, ita_reg_gelu_b_c_val, ita_reg_activation_rqs_val, clk);
ita_compute_step(AV, ita_reg_cnt, ita_reg_tiles_val, ita_reg_rqs_val, ita_reg_gelu_b_c_val, ita_reg_activation_rqs_val, clk);
end

// 6: Step OW
Expand All @@ -409,7 +411,9 @@ endfunction
ita_reg_rqs_val[2] = ita_reg_rqs_val[3] >> 8;
ita_reg_rqs_val[4] = ita_reg_rqs_val[5] >> 8;
end
ita_compute_step(OW, ita_reg_tiles_val, ita_reg_rqs_val, ita_reg_gelu_b_c_val, ita_reg_activation_rqs_val, clk);
ita_compute_step(OW, ita_reg_cnt, ita_reg_tiles_val, ita_reg_rqs_val, ita_reg_gelu_b_c_val, ita_reg_activation_rqs_val, clk);

ita_reg_cnt = 0;

// 7: Step FF1
if (SINGLE_ATTENTION == 1) begin
Expand All @@ -420,7 +424,7 @@ endfunction
ita_reg_rqs_val[2] = ita_reg_rqs_val[3] >> 16;
ita_reg_rqs_val[4] = ita_reg_rqs_val[5] >> 16;
end
ita_compute_step(F1, ita_reg_tiles_val, ita_reg_rqs_val, ita_reg_gelu_b_c_val, ita_reg_activation_rqs_val, clk);
ita_compute_step(F1, ita_reg_cnt, ita_reg_tiles_val, ita_reg_rqs_val, ita_reg_gelu_b_c_val, ita_reg_activation_rqs_val, clk);

// 8: Step FF2
if (SINGLE_ATTENTION == 1) begin
Expand All @@ -431,7 +435,7 @@ endfunction
ita_reg_rqs_val[2] = ita_reg_rqs_val[3] >> 24;
ita_reg_rqs_val[4] = ita_reg_rqs_val[5] >> 24;
end
ita_compute_step(F2, ita_reg_tiles_val, ita_reg_rqs_val, ita_reg_gelu_b_c_val, ita_reg_activation_rqs_val, clk);
ita_compute_step(F2, ita_reg_cnt, ita_reg_tiles_val, ita_reg_rqs_val, ita_reg_gelu_b_c_val, ita_reg_activation_rqs_val, clk);

// Wait for the last step to finish
wait(evt);
Expand All @@ -456,6 +460,7 @@ endfunction

task automatic ita_compute_step(
input step_e step,
inout integer ita_reg_cnt,
input logic [31:0] ita_reg_tiles_val,
input logic [5:0][31:0] ita_reg_rqs_val,
input logic [31:0] ita_reg_gelu_b_c_val,
Expand Down Expand Up @@ -500,7 +505,12 @@ endfunction
ita_reg_en = 1'b1;
end else begin
// Calculate ita_reg_en
ita_reg_en_compute(step, tile, ita_reg_en);
if (ita_reg_cnt < N_CONTEXT) begin
ita_reg_en = 1'b1;
ita_reg_cnt++;
end else begin
ita_reg_en = 1'b0;
end
end

// Calculate ctrl_stream_val, weight_ptr_en, and bias_ptr_en
Expand Down Expand Up @@ -575,29 +585,6 @@ endfunction
$display(" - output_ptr 0x%08h (output_base_ptr 0x%08h)", output_ptr, output_base_ptr);
endtask


task automatic ita_reg_en_compute(
input step_e step,
input integer tile,
output logic enable
);
enable = 1'b0;
// Write requantization parameters only in first two programming phases
if (step == Q) begin
if (tile == 0 || tile == 1)
enable = 1'b1;
end else if (step == K && N_TILES_OUTER_X[Q]*N_TILES_OUTER_Y[Q]*N_TILES_INNER_DIM[Q] == 1) begin
if (tile == 0)
enable = 1'b1;
end else if (step == F1) begin
if (tile == 0 || tile == 1)
enable = 1'b1;
end else if (step == F2 && N_TILES_OUTER_X[F1]*N_TILES_OUTER_Y[F1]*N_TILES_INNER_DIM[F1] == 1) begin
if (tile == 0)
enable = 1'b1;
end
endtask

task automatic ctrl_val_compute(
input step_e step,
input integer tile,
Expand Down

0 comments on commit 8ef386a

Please sign in to comment.