Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for single attention layer #7

Merged
merged 7 commits into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 40 additions & 2 deletions .gitlab-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,33 +46,52 @@ run_sim:
F: 64
activation: gelu
no_stalls: 0
single_attention: 0
- S: 64
E: 64
P: 64
F: 64
activation: gelu
no_stalls: 1
single_attention: 0
- S: 128
E: 192
P: 256
F: 256
activation: gelu
no_stalls: 0
single_attention: 0
- S: 128
E: 192
P: 256
F: 256
activation: gelu
no_stalls: 1
single_attention: 0
- S: 192
E: 256
P: 128
F: 128
activation: relu
no_stalls: 1
single_attention: 0
- S: 128
E: 192
P: 256
F: 256
activation: gelu
no_stalls: 0
single_attention: 1
- S: 192
E: 256
P: 128
F: 128
activation: relu
no_stalls: 0
single_attention: 1
script:
- make bender
- make sim VSIM_FLAGS=-c s=$S e=$E p=$P f=$F bias=1 activation=$activation no_stalls=$no_stalls
- make sim VSIM_FLAGS=-c s=$S e=$E p=$P f=$F bias=1 activation=$activation no_stalls=$no_stalls single_attention=$single_attention
- ./modelsim/return_status.sh modelsim/build/transcript $S $E $P $F ita_tb

run_hwpe_sim:
Expand All @@ -87,31 +106,50 @@ run_hwpe_sim:
F: 64
activation: gelu
no_stalls: 0
single_attention: 0
- S: 64
E: 64
P: 64
F: 64
activation: gelu
no_stalls: 1
single_attention: 0
- S: 128
E: 192
P: 256
F: 256
activation: gelu
no_stalls: 0
single_attention: 0
- S: 128
E: 192
P: 256
F: 256
activation: gelu
no_stalls: 1
single_attention: 0
- S: 192
E: 256
P: 128
F: 128
activation: relu
no_stalls: 1
single_attention: 0
- S: 128
E: 192
P: 256
F: 256
activation: gelu
no_stalls: 0
single_attention: 1
- S: 192
E: 256
P: 128
F: 128
activation: relu
no_stalls: 0
single_attention: 1
script:
- make bender
- make sim VSIM_FLAGS=-c DEBUG=OFF target=sim_ita_hwpe_tb s=$S e=$E p=$P f=$F bias=1 activation=$activation no_stalls=$no_stalls
- make sim VSIM_FLAGS=-c DEBUG=OFF target=sim_ita_hwpe_tb s=$S e=$E p=$P f=$F bias=1 activation=$activation no_stalls=$no_stalls single_attention=$single_attention
- ./modelsim/return_status.sh modelsim/build/transcript $S $E $P $F hwpe_tb
5 changes: 3 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ BENDER_TARGETS = -t rtl -t test
target ?= sim_ita_tb

no_stalls ?= 0
single_attention ?= 0
s ?= 64
e ?= 128
p ?= 192
Expand All @@ -33,7 +34,7 @@ else ifeq ($(activation), relu)
else
activation_int = 0
endif
vlog_defs += -DNO_STALLS=$(no_stalls) -DSEQ_LENGTH=$(s) -DEMBED_SIZE=$(e) -DPROJ_SPACE=$(p) -DFF_SIZE=$(f) -DBIAS=$(bias) -DACTIVATION=$(activation_int)
vlog_defs += -DNO_STALLS=$(no_stalls) -DSINGLE_ATTENTION=$(single_attention) -DSEQ_LENGTH=$(s) -DEMBED_SIZE=$(e) -DPROJ_SPACE=$(p) -DFF_SIZE=$(f) -DBIAS=$(bias) -DACTIVATION=$(activation_int)

ifeq ($(target), sim_ita_hwpe_tb)
BENDER_TARGETS += -t ita_hwpe -t ita_hwpe_test
Expand All @@ -60,7 +61,7 @@ sim-script: clean-sim

sim: sim-script
cd modelsim && \
$(MAKE) $(target)
$(MAKE) $(target) buildpath=$(ROOT_DIR)/$(SIM_PATH)

synopsys-script:
rm ../ita-gf22/$(SYNTH_PATH)/scripts/analyze.tcl
Expand Down
9 changes: 7 additions & 2 deletions src/hwpe/ita_hwpe_input_bias_buffer.sv
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ module ita_hwpe_input_bias_buffer #(
logic write_enable;
logic [OUTPUT_DATA_WIDTH-1:0] read_data;

bias_t bias_reshape;

always_comb begin
// Default assignments
state_d = state_q;
Expand All @@ -43,6 +45,7 @@ module ita_hwpe_input_bias_buffer #(
data_o.valid = 0;
data_o.strb = 48'hFFFFFFFFFFFF;
data_o.data = '0;
bias_reshape = '0;

case(state_q)
Write: begin
Expand All @@ -60,8 +63,10 @@ module ita_hwpe_input_bias_buffer #(
data_o.valid = read_enable_q;
if (read_enable_q) begin
data_o.data = read_data;
if (bias_dir_i)
data_o.data = read_data >> read_cnt_q[3:0] * 24;
if (bias_dir_i) begin
bias_reshape = read_data >> read_cnt_q[3:0] * 24;
data_o.data = {N {bias_reshape[0]}};
end
Comment on lines +66 to +69
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this is related to

Transposition of biases for V generation is moved from ITA engine to HWPE side.

However, I do not see how this is now handled on the HWPE side. Can you give me a quick pointer?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, we fetch one set of biases, find the required one with counter and replicate it N times. Previously, ITA was replicating them internally.

end
read_enable = 1;
if(data_o.valid && data_o.ready) begin
Expand Down
81 changes: 75 additions & 6 deletions src/hwpe/tb/ita_hwpe_tb.sv
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ module ita_hwpe_tb;
parameter integer EMBEDDING_SIZE = `ifdef EMBED_SIZE `EMBED_SIZE `else M_TILE_LEN `endif;
parameter integer FEEDFORWARD_SIZE = `ifdef FF_SIZE `FF_SIZE `else M_TILE_LEN `endif;
parameter activation_e ACTIVATION = `ifdef ACTIVATION `ACTIVATION `else Identity `endif;
parameter integer SINGLE_ATTENTION = `ifdef SINGLE_ATTENTION `SINGLE_ATTENTION `else 0 `endif;

integer N_TILES_SEQUENCE_DIM, N_TILES_EMBEDDING_DIM, N_TILES_PROJECTION_DIM, N_TILES_FEEDFORWARD_DIM;
integer N_ELEMENTS_PER_TILE;
Expand Down Expand Up @@ -118,6 +119,8 @@ module ita_hwpe_tb;
`HCI_INTF_ARRAY(tcdm_mem, clk_i, MP-1:0);

initial begin
$timeformat(-9, 1, " ns", 11);

simdir = {
"../../simvectors/data_S",
$sformatf("%0d", SEQUENCE_LEN),
Expand Down Expand Up @@ -356,11 +359,27 @@ endfunction
ita_compute_step(Q, ita_reg_tiles_val, ita_reg_rqs_val, ita_reg_gelu_b_c_val, ita_reg_activation_rqs_val, clk);

// 2: Step K
if (SINGLE_ATTENTION == 1) begin
// move corresponding ita_reg_rqs_val because linear layers use array[0]
ita_reg_rqs_val[0] = ita_reg_rqs_val[0] >> 8;
ita_reg_rqs_val[2] = ita_reg_rqs_val[2] >> 8;
ita_reg_rqs_val[4] = ita_reg_rqs_val[4] >> 8;
end
ita_compute_step(K, ita_reg_tiles_val, ita_reg_rqs_val, ita_reg_gelu_b_c_val, ita_reg_activation_rqs_val, clk);

// 3: Step V
if (SINGLE_ATTENTION == 1) begin
// move corresponding ita_reg_rqs_val because linear layers use array[0]
ita_reg_rqs_val[0] = ita_reg_rqs_val[0] >> 8;
ita_reg_rqs_val[2] = ita_reg_rqs_val[2] >> 8;
ita_reg_rqs_val[4] = ita_reg_rqs_val[4] >> 8;
end
ita_compute_step(V, ita_reg_tiles_val, ita_reg_rqs_val, ita_reg_gelu_b_c_val, ita_reg_activation_rqs_val, clk);

if (SINGLE_ATTENTION == 1) begin
// Reset the RQS values
ita_reg_eps_mult_val_compute(ita_reg_rqs_val);
end

for (int group = 0; group < N_TILES_SEQUENCE_DIM; group++) begin
BASE_PTR_INPUT[QK] = BASE_PTR[15] + group * N_TILES_INNER_DIM[QK] * N_ELEMENTS_PER_TILE;
Expand All @@ -382,12 +401,36 @@ endfunction
end

// 6: Step OW
if (SINGLE_ATTENTION == 1) begin
// Change order of P and E
ita_reg_tiles_val_compute(N_TILES_SEQUENCE_DIM, N_TILES_PROJECTION_DIM, N_TILES_EMBEDDING_DIM, N_TILES_FEEDFORWARD_DIM, ita_reg_tiles_val);
// move corresponding ita_reg_rqs_val because linear layers use array[0]
ita_reg_rqs_val[0] = ita_reg_rqs_val[1] >> 8;
ita_reg_rqs_val[2] = ita_reg_rqs_val[3] >> 8;
ita_reg_rqs_val[4] = ita_reg_rqs_val[5] >> 8;
end
ita_compute_step(OW, ita_reg_tiles_val, ita_reg_rqs_val, ita_reg_gelu_b_c_val, ita_reg_activation_rqs_val, clk);

// 7: Step FF1
if (SINGLE_ATTENTION == 1) begin
// Change order of P and F
ita_reg_tiles_val_compute(N_TILES_SEQUENCE_DIM, N_TILES_EMBEDDING_DIM, N_TILES_FEEDFORWARD_DIM, N_TILES_PROJECTION_DIM, ita_reg_tiles_val);
// move corresponding ita_reg_rqs_val because linear layers use array[0]
ita_reg_rqs_val[0] = ita_reg_rqs_val[1] >> 16;
ita_reg_rqs_val[2] = ita_reg_rqs_val[3] >> 16;
ita_reg_rqs_val[4] = ita_reg_rqs_val[5] >> 16;
end
ita_compute_step(F1, ita_reg_tiles_val, ita_reg_rqs_val, ita_reg_gelu_b_c_val, ita_reg_activation_rqs_val, clk);

// 8: Step FF1
// 8: Step FF2
if (SINGLE_ATTENTION == 1) begin
// Change order of E and F
ita_reg_tiles_val_compute(N_TILES_SEQUENCE_DIM, N_TILES_FEEDFORWARD_DIM, N_TILES_EMBEDDING_DIM, N_TILES_PROJECTION_DIM, ita_reg_tiles_val);
// move corresponding ita_reg_rqs_val because linear layers use array[0]
ita_reg_rqs_val[0] = ita_reg_rqs_val[1] >> 24;
ita_reg_rqs_val[2] = ita_reg_rqs_val[3] >> 24;
ita_reg_rqs_val[4] = ita_reg_rqs_val[5] >> 24;
end
ita_compute_step(F2, ita_reg_tiles_val, ita_reg_rqs_val, ita_reg_gelu_b_c_val, ita_reg_activation_rqs_val, clk);

// Wait for the last step to finish
Expand Down Expand Up @@ -452,8 +495,14 @@ endfunction
// Calculate input_ptr, weight_ptr0, weight_ptr1, bias_ptr, and output_ptr
ita_ptrs_compute(input_base_ptr, weight_base_ptr0, weight_base_ptr1, bias_base_ptr, output_base_ptr, step, tile, tile_x, tile_y, tile_inner, input_ptr, weight_ptr0, weight_ptr1, bias_ptr, output_ptr);

// Calculate ita_reg_en
ita_reg_en_compute(step, tile, ita_reg_en);
if (SINGLE_ATTENTION == 1) begin
// Enable ita_reg_en
ita_reg_en = 1'b1;
end else begin
// Calculate ita_reg_en
ita_reg_en_compute(step, tile, ita_reg_en);
end

// Calculate ctrl_stream_val, weight_ptr_en, and bias_ptr_en
ctrl_val_compute(step, tile, ctrl_engine_val, ctrl_stream_val, weight_ptr_en, bias_ptr_en);

Expand Down Expand Up @@ -564,7 +613,13 @@ endfunction
ctrl_stream_val = 32'h0;
reg_weight_en = 1'b0;
reg_bias_en = 1'b0;
layer_type = Attention;

if (SINGLE_ATTENTION == 1) begin
layer_type = Linear;
end else begin
layer_type = Attention;
end

activation_function = Identity;

ctrl_engine_val = layer_type | activation_function << 2;
Expand Down Expand Up @@ -598,11 +653,17 @@ endfunction
reg_bias_en = 1'b1;
end
QK : begin
if (SINGLE_ATTENTION == 1) begin
ctrl_engine_val = SingleAttention | Identity << 2;
end
ctrl_stream_val = {28'b0, 4'b0110}; // weight nextload and disable bias
reg_weight_en = 1'b1;
reg_bias_en = 1'b0;
end
AV : begin
if (SINGLE_ATTENTION == 1) begin
ctrl_engine_val = SingleAttention | Identity << 2;
end
ctrl_stream_val = {28'b0, 4'b0110}; // weight nextload and disable bias
reg_weight_en = 1'b1;
reg_bias_en = 1'b0;
Expand All @@ -618,7 +679,11 @@ endfunction
reg_bias_en = 1'b1;
end
F1 : begin
ctrl_engine_val = Feedforward | ACTIVATION << 2;
if (SINGLE_ATTENTION == 1) begin
ctrl_engine_val = Linear | ACTIVATION << 2;
end else begin
ctrl_engine_val = Feedforward | ACTIVATION << 2;
end
if (tile == 0) begin
ctrl_stream_val = {28'b0, 4'b0011}; // weight preload and weight nextload
end else begin
Expand All @@ -628,7 +693,11 @@ endfunction
reg_bias_en = 1'b1;
end
F2 : begin
ctrl_engine_val = Feedforward | Identity << 2;
if (SINGLE_ATTENTION == 1) begin
ctrl_engine_val = Linear | Identity << 2;
end else begin
ctrl_engine_val = Feedforward | Identity << 2;
end
if (tile == (N_TILES_OUTER_X[F2]*N_TILES_OUTER_Y[F2]*N_TILES_INNER_DIM[F2])-1) begin
ctrl_stream_val = {28'b0, 4'b0000};
reg_weight_en = 1'b0;
Expand Down
16 changes: 8 additions & 8 deletions src/ita.sv
Original file line number Diff line number Diff line change
Expand Up @@ -194,14 +194,14 @@ module ita
);

ita_input_sampler i_input_sampler (
.clk_i (clk_i ),
.rst_ni (rst_ni ),
.valid_i (inp_valid_i ),
.ready_i (inp_ready_o ),
.inp_i (inp_i ),
.inp_bias_i ((step == QK || step == AV) ? '0 : ((step == V) ? {N {inp_bias_i[0]}} : inp_bias_i) ), // TODO: temporary fix
.inp_o (inp ),
.inp_bias_o (inp_bias )
.clk_i (clk_i ),
.rst_ni (rst_ni ),
.valid_i (inp_valid_i ),
.ready_i (inp_ready_o ),
.inp_i (inp_i ),
.inp_bias_i (inp_bias_i ),
.inp_o (inp ),
.inp_bias_o (inp_bias )
);

ita_inp1_mux i_inp1_mux (
Expand Down
8 changes: 7 additions & 1 deletion src/ita_controller.sv
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ module ita_controller
step_d = F1;
end else if (ctrl_i.layer == Linear) begin
step_d = MatMul;
end else if (ctrl_i.layer == SingleAttention) begin
step_d = QK;
end
end
end
Expand Down Expand Up @@ -187,7 +189,11 @@ module ita_controller
softmax_tile_d = softmax_tile_q + 1;
if (softmax_tile_d == ctrl_i.tile_s) begin
softmax_tile_d = '0;
step_d = OW;
if (ctrl_i.layer == Attention) begin
step_d = OW;
end else if (ctrl_i.layer == SingleAttention) begin
step_d = Idle;
end
end else begin
step_d = QK;
end
Expand Down
2 changes: 1 addition & 1 deletion src/ita_package.sv
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ package ita_package;
parameter int unsigned N_WRITE_EN = `ifdef TARGET_ITA_HWPE 8 `else M `endif;

// Feedforward
typedef enum bit [1:0] {Attention=0, Feedforward=1, Linear=2} layer_e;
typedef enum bit [1:0] {Attention=0, Feedforward=1, Linear=2, SingleAttention=3} layer_e;
typedef enum bit [1:0] {Identity=0, Gelu=1, Relu=2} activation_e;
typedef logic signed [GELU_CONSTANTS_WIDTH-1:0] gelu_const_t;
typedef logic signed [GELU_OUT_WIDTH-1:0] gelu_out_t;
Expand Down
Loading
Loading