Skip to content

Commit

Permalink
Support for single attention layer (#7)
Browse files Browse the repository at this point in the history
* [Makefile] Pass simulation path

* [change] Alter time format in tb

* [change] Add trigger task to tb

* [feature] Add single attention layer

* [fix] Forward correct biases if transposed

* [fix] Correct typo

* [feature] Add single attention tests
  • Loading branch information
gamzeisl authored Oct 30, 2024
1 parent ebd365e commit 598b424
Show file tree
Hide file tree
Showing 8 changed files with 274 additions and 57 deletions.
42 changes: 40 additions & 2 deletions .gitlab-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,33 +46,52 @@ run_sim:
F: 64
activation: gelu
no_stalls: 0
single_attention: 0
- S: 64
E: 64
P: 64
F: 64
activation: gelu
no_stalls: 1
single_attention: 0
- S: 128
E: 192
P: 256
F: 256
activation: gelu
no_stalls: 0
single_attention: 0
- S: 128
E: 192
P: 256
F: 256
activation: gelu
no_stalls: 1
single_attention: 0
- S: 192
E: 256
P: 128
F: 128
activation: relu
no_stalls: 1
single_attention: 0
- S: 128
E: 192
P: 256
F: 256
activation: gelu
no_stalls: 0
single_attention: 1
- S: 192
E: 256
P: 128
F: 128
activation: relu
no_stalls: 0
single_attention: 1
script:
- make bender
- make sim VSIM_FLAGS=-c s=$S e=$E p=$P f=$F bias=1 activation=$activation no_stalls=$no_stalls
- make sim VSIM_FLAGS=-c s=$S e=$E p=$P f=$F bias=1 activation=$activation no_stalls=$no_stalls single_attention=$single_attention
- ./modelsim/return_status.sh modelsim/build/transcript $S $E $P $F ita_tb

run_hwpe_sim:
Expand All @@ -87,31 +106,50 @@ run_hwpe_sim:
F: 64
activation: gelu
no_stalls: 0
single_attention: 0
- S: 64
E: 64
P: 64
F: 64
activation: gelu
no_stalls: 1
single_attention: 0
- S: 128
E: 192
P: 256
F: 256
activation: gelu
no_stalls: 0
single_attention: 0
- S: 128
E: 192
P: 256
F: 256
activation: gelu
no_stalls: 1
single_attention: 0
- S: 192
E: 256
P: 128
F: 128
activation: relu
no_stalls: 1
single_attention: 0
- S: 128
E: 192
P: 256
F: 256
activation: gelu
no_stalls: 0
single_attention: 1
- S: 192
E: 256
P: 128
F: 128
activation: relu
no_stalls: 0
single_attention: 1
script:
- make bender
- make sim VSIM_FLAGS=-c DEBUG=OFF target=sim_ita_hwpe_tb s=$S e=$E p=$P f=$F bias=1 activation=$activation no_stalls=$no_stalls
- make sim VSIM_FLAGS=-c DEBUG=OFF target=sim_ita_hwpe_tb s=$S e=$E p=$P f=$F bias=1 activation=$activation no_stalls=$no_stalls single_attention=$single_attention
- ./modelsim/return_status.sh modelsim/build/transcript $S $E $P $F hwpe_tb
5 changes: 3 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ BENDER_TARGETS = -t rtl -t test
target ?= sim_ita_tb

no_stalls ?= 0
single_attention ?= 0
s ?= 64
e ?= 128
p ?= 192
Expand All @@ -33,7 +34,7 @@ else ifeq ($(activation), relu)
else
activation_int = 0
endif
vlog_defs += -DNO_STALLS=$(no_stalls) -DSEQ_LENGTH=$(s) -DEMBED_SIZE=$(e) -DPROJ_SPACE=$(p) -DFF_SIZE=$(f) -DBIAS=$(bias) -DACTIVATION=$(activation_int)
vlog_defs += -DNO_STALLS=$(no_stalls) -DSINGLE_ATTENTION=$(single_attention) -DSEQ_LENGTH=$(s) -DEMBED_SIZE=$(e) -DPROJ_SPACE=$(p) -DFF_SIZE=$(f) -DBIAS=$(bias) -DACTIVATION=$(activation_int)

ifeq ($(target), sim_ita_hwpe_tb)
BENDER_TARGETS += -t ita_hwpe -t ita_hwpe_test
Expand All @@ -60,7 +61,7 @@ sim-script: clean-sim

sim: sim-script
cd modelsim && \
$(MAKE) $(target)
$(MAKE) $(target) buildpath=$(ROOT_DIR)/$(SIM_PATH)

synopsys-script:
rm ../ita-gf22/$(SYNTH_PATH)/scripts/analyze.tcl
Expand Down
9 changes: 7 additions & 2 deletions src/hwpe/ita_hwpe_input_bias_buffer.sv
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ module ita_hwpe_input_bias_buffer #(
logic write_enable;
logic [OUTPUT_DATA_WIDTH-1:0] read_data;

bias_t bias_reshape;

always_comb begin
// Default assignments
state_d = state_q;
Expand All @@ -43,6 +45,7 @@ module ita_hwpe_input_bias_buffer #(
data_o.valid = 0;
data_o.strb = 48'hFFFFFFFFFFFF;
data_o.data = '0;
bias_reshape = '0;

case(state_q)
Write: begin
Expand All @@ -60,8 +63,10 @@ module ita_hwpe_input_bias_buffer #(
data_o.valid = read_enable_q;
if (read_enable_q) begin
data_o.data = read_data;
if (bias_dir_i)
data_o.data = read_data >> read_cnt_q[3:0] * 24;
if (bias_dir_i) begin
bias_reshape = read_data >> read_cnt_q[3:0] * 24;
data_o.data = {N {bias_reshape[0]}};
end
end
read_enable = 1;
if(data_o.valid && data_o.ready) begin
Expand Down
81 changes: 75 additions & 6 deletions src/hwpe/tb/ita_hwpe_tb.sv
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ module ita_hwpe_tb;
parameter integer EMBEDDING_SIZE = `ifdef EMBED_SIZE `EMBED_SIZE `else M_TILE_LEN `endif;
parameter integer FEEDFORWARD_SIZE = `ifdef FF_SIZE `FF_SIZE `else M_TILE_LEN `endif;
parameter activation_e ACTIVATION = `ifdef ACTIVATION `ACTIVATION `else Identity `endif;
parameter integer SINGLE_ATTENTION = `ifdef SINGLE_ATTENTION `SINGLE_ATTENTION `else 0 `endif;

integer N_TILES_SEQUENCE_DIM, N_TILES_EMBEDDING_DIM, N_TILES_PROJECTION_DIM, N_TILES_FEEDFORWARD_DIM;
integer N_ELEMENTS_PER_TILE;
Expand Down Expand Up @@ -118,6 +119,8 @@ module ita_hwpe_tb;
`HCI_INTF_ARRAY(tcdm_mem, clk_i, MP-1:0);

initial begin
$timeformat(-9, 1, " ns", 11);

simdir = {
"../../simvectors/data_S",
$sformatf("%0d", SEQUENCE_LEN),
Expand Down Expand Up @@ -356,11 +359,27 @@ endfunction
ita_compute_step(Q, ita_reg_tiles_val, ita_reg_rqs_val, ita_reg_gelu_b_c_val, ita_reg_activation_rqs_val, clk);

// 2: Step K
if (SINGLE_ATTENTION == 1) begin
// move corresponding ita_reg_rqs_val because linear layers use array[0]
ita_reg_rqs_val[0] = ita_reg_rqs_val[0] >> 8;
ita_reg_rqs_val[2] = ita_reg_rqs_val[2] >> 8;
ita_reg_rqs_val[4] = ita_reg_rqs_val[4] >> 8;
end
ita_compute_step(K, ita_reg_tiles_val, ita_reg_rqs_val, ita_reg_gelu_b_c_val, ita_reg_activation_rqs_val, clk);

// 3: Step V
if (SINGLE_ATTENTION == 1) begin
// move corresponding ita_reg_rqs_val because linear layers use array[0]
ita_reg_rqs_val[0] = ita_reg_rqs_val[0] >> 8;
ita_reg_rqs_val[2] = ita_reg_rqs_val[2] >> 8;
ita_reg_rqs_val[4] = ita_reg_rqs_val[4] >> 8;
end
ita_compute_step(V, ita_reg_tiles_val, ita_reg_rqs_val, ita_reg_gelu_b_c_val, ita_reg_activation_rqs_val, clk);

if (SINGLE_ATTENTION == 1) begin
// Reset the RQS values
ita_reg_eps_mult_val_compute(ita_reg_rqs_val);
end

for (int group = 0; group < N_TILES_SEQUENCE_DIM; group++) begin
BASE_PTR_INPUT[QK] = BASE_PTR[15] + group * N_TILES_INNER_DIM[QK] * N_ELEMENTS_PER_TILE;
Expand All @@ -382,12 +401,36 @@ endfunction
end

// 6: Step OW
if (SINGLE_ATTENTION == 1) begin
// Change order of P and E
ita_reg_tiles_val_compute(N_TILES_SEQUENCE_DIM, N_TILES_PROJECTION_DIM, N_TILES_EMBEDDING_DIM, N_TILES_FEEDFORWARD_DIM, ita_reg_tiles_val);
// move corresponding ita_reg_rqs_val because linear layers use array[0]
ita_reg_rqs_val[0] = ita_reg_rqs_val[1] >> 8;
ita_reg_rqs_val[2] = ita_reg_rqs_val[3] >> 8;
ita_reg_rqs_val[4] = ita_reg_rqs_val[5] >> 8;
end
ita_compute_step(OW, ita_reg_tiles_val, ita_reg_rqs_val, ita_reg_gelu_b_c_val, ita_reg_activation_rqs_val, clk);

// 7: Step FF1
if (SINGLE_ATTENTION == 1) begin
// Change order of P and F
ita_reg_tiles_val_compute(N_TILES_SEQUENCE_DIM, N_TILES_EMBEDDING_DIM, N_TILES_FEEDFORWARD_DIM, N_TILES_PROJECTION_DIM, ita_reg_tiles_val);
// move corresponding ita_reg_rqs_val because linear layers use array[0]
ita_reg_rqs_val[0] = ita_reg_rqs_val[1] >> 16;
ita_reg_rqs_val[2] = ita_reg_rqs_val[3] >> 16;
ita_reg_rqs_val[4] = ita_reg_rqs_val[5] >> 16;
end
ita_compute_step(F1, ita_reg_tiles_val, ita_reg_rqs_val, ita_reg_gelu_b_c_val, ita_reg_activation_rqs_val, clk);

// 8: Step FF1
// 8: Step FF2
if (SINGLE_ATTENTION == 1) begin
// Change order of E and F
ita_reg_tiles_val_compute(N_TILES_SEQUENCE_DIM, N_TILES_FEEDFORWARD_DIM, N_TILES_EMBEDDING_DIM, N_TILES_PROJECTION_DIM, ita_reg_tiles_val);
// move corresponding ita_reg_rqs_val because linear layers use array[0]
ita_reg_rqs_val[0] = ita_reg_rqs_val[1] >> 24;
ita_reg_rqs_val[2] = ita_reg_rqs_val[3] >> 24;
ita_reg_rqs_val[4] = ita_reg_rqs_val[5] >> 24;
end
ita_compute_step(F2, ita_reg_tiles_val, ita_reg_rqs_val, ita_reg_gelu_b_c_val, ita_reg_activation_rqs_val, clk);

// Wait for the last step to finish
Expand Down Expand Up @@ -452,8 +495,14 @@ endfunction
// Calculate input_ptr, weight_ptr0, weight_ptr1, bias_ptr, and output_ptr
ita_ptrs_compute(input_base_ptr, weight_base_ptr0, weight_base_ptr1, bias_base_ptr, output_base_ptr, step, tile, tile_x, tile_y, tile_inner, input_ptr, weight_ptr0, weight_ptr1, bias_ptr, output_ptr);

// Calculate ita_reg_en
ita_reg_en_compute(step, tile, ita_reg_en);
if (SINGLE_ATTENTION == 1) begin
// Enable ita_reg_en
ita_reg_en = 1'b1;
end else begin
// Calculate ita_reg_en
ita_reg_en_compute(step, tile, ita_reg_en);
end

// Calculate ctrl_stream_val, weight_ptr_en, and bias_ptr_en
ctrl_val_compute(step, tile, ctrl_engine_val, ctrl_stream_val, weight_ptr_en, bias_ptr_en);

Expand Down Expand Up @@ -564,7 +613,13 @@ endfunction
ctrl_stream_val = 32'h0;
reg_weight_en = 1'b0;
reg_bias_en = 1'b0;
layer_type = Attention;

if (SINGLE_ATTENTION == 1) begin
layer_type = Linear;
end else begin
layer_type = Attention;
end

activation_function = Identity;

ctrl_engine_val = layer_type | activation_function << 2;
Expand Down Expand Up @@ -598,11 +653,17 @@ endfunction
reg_bias_en = 1'b1;
end
QK : begin
if (SINGLE_ATTENTION == 1) begin
ctrl_engine_val = SingleAttention | Identity << 2;
end
ctrl_stream_val = {28'b0, 4'b0110}; // weight nextload and disable bias
reg_weight_en = 1'b1;
reg_bias_en = 1'b0;
end
AV : begin
if (SINGLE_ATTENTION == 1) begin
ctrl_engine_val = SingleAttention | Identity << 2;
end
ctrl_stream_val = {28'b0, 4'b0110}; // weight nextload and disable bias
reg_weight_en = 1'b1;
reg_bias_en = 1'b0;
Expand All @@ -618,7 +679,11 @@ endfunction
reg_bias_en = 1'b1;
end
F1 : begin
ctrl_engine_val = Feedforward | ACTIVATION << 2;
if (SINGLE_ATTENTION == 1) begin
ctrl_engine_val = Linear | ACTIVATION << 2;
end else begin
ctrl_engine_val = Feedforward | ACTIVATION << 2;
end
if (tile == 0) begin
ctrl_stream_val = {28'b0, 4'b0011}; // weight preload and weight nextload
end else begin
Expand All @@ -628,7 +693,11 @@ endfunction
reg_bias_en = 1'b1;
end
F2 : begin
ctrl_engine_val = Feedforward | Identity << 2;
if (SINGLE_ATTENTION == 1) begin
ctrl_engine_val = Linear | Identity << 2;
end else begin
ctrl_engine_val = Feedforward | Identity << 2;
end
if (tile == (N_TILES_OUTER_X[F2]*N_TILES_OUTER_Y[F2]*N_TILES_INNER_DIM[F2])-1) begin
ctrl_stream_val = {28'b0, 4'b0000};
reg_weight_en = 1'b0;
Expand Down
16 changes: 8 additions & 8 deletions src/ita.sv
Original file line number Diff line number Diff line change
Expand Up @@ -194,14 +194,14 @@ module ita
);

ita_input_sampler i_input_sampler (
.clk_i (clk_i ),
.rst_ni (rst_ni ),
.valid_i (inp_valid_i ),
.ready_i (inp_ready_o ),
.inp_i (inp_i ),
.inp_bias_i ((step == QK || step == AV) ? '0 : ((step == V) ? {N {inp_bias_i[0]}} : inp_bias_i) ), // TODO: temporary fix
.inp_o (inp ),
.inp_bias_o (inp_bias )
.clk_i (clk_i ),
.rst_ni (rst_ni ),
.valid_i (inp_valid_i ),
.ready_i (inp_ready_o ),
.inp_i (inp_i ),
.inp_bias_i (inp_bias_i ),
.inp_o (inp ),
.inp_bias_o (inp_bias )
);

ita_inp1_mux i_inp1_mux (
Expand Down
8 changes: 7 additions & 1 deletion src/ita_controller.sv
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ module ita_controller
step_d = F1;
end else if (ctrl_i.layer == Linear) begin
step_d = MatMul;
end else if (ctrl_i.layer == SingleAttention) begin
step_d = QK;
end
end
end
Expand Down Expand Up @@ -187,7 +189,11 @@ module ita_controller
softmax_tile_d = softmax_tile_q + 1;
if (softmax_tile_d == ctrl_i.tile_s) begin
softmax_tile_d = '0;
step_d = OW;
if (ctrl_i.layer == Attention) begin
step_d = OW;
end else if (ctrl_i.layer == SingleAttention) begin
step_d = Idle;
end
end else begin
step_d = QK;
end
Expand Down
2 changes: 1 addition & 1 deletion src/ita_package.sv
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ package ita_package;
parameter int unsigned N_WRITE_EN = `ifdef TARGET_ITA_HWPE 8 `else M `endif;

// Feedforward
typedef enum bit [1:0] {Attention=0, Feedforward=1, Linear=2} layer_e;
typedef enum bit [1:0] {Attention=0, Feedforward=1, Linear=2, SingleAttention=3} layer_e;
typedef enum bit [1:0] {Identity=0, Gelu=1, Relu=2} activation_e;
typedef logic signed [GELU_CONSTANTS_WIDTH-1:0] gelu_const_t;
typedef logic signed [GELU_OUT_WIDTH-1:0] gelu_out_t;
Expand Down
Loading

0 comments on commit 598b424

Please sign in to comment.