From 5487e666443364bdd52ccd3d880ad95b7617fe1e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?H=C3=A9ctor=20Masip?= Date: Tue, 22 Oct 2024 17:09:33 +0000 Subject: [PATCH] Separating shift and sign extend --- pil/zisk.pil | 17 ++- state-machines/binary/pil/binary.pil | 88 ++++++++------ .../binary/pil/binary_extension.pil | 111 ------------------ state-machines/binary/pil/shift.pil | 99 ++++++++++++++++ ...ry_extension_table.pil => shift_table.pil} | 54 ++------- state-machines/binary/pil/sign_extension.pil | 88 ++++++++++++++ 6 files changed, 259 insertions(+), 198 deletions(-) delete mode 100644 state-machines/binary/pil/binary_extension.pil create mode 100644 state-machines/binary/pil/shift.pil rename state-machines/binary/pil/{binary_extension_table.pil => shift_table.pil} (66%) create mode 100644 state-machines/binary/pil/sign_extension.pil diff --git a/pil/zisk.pil b/pil/zisk.pil index 59903524..edeffeee 100644 --- a/pil/zisk.pil +++ b/pil/zisk.pil @@ -3,8 +3,9 @@ require "constants.pil" require "main/pil/main.pil" require "binary/pil/binary.pil" require "binary/pil/binary_table.pil" -require "binary/pil/binary_extension.pil" -require "binary/pil/binary_extension_table.pil" +require "binary/pil/shift.pil" +require "binary/pil/shift_table.pil" +require "binary/pil/sign_extension.pil" // require "mem/pil/mem.pil" const int OPERATION_BUS_ID = 5000; @@ -24,10 +25,14 @@ airgroup BinaryTable { BinaryTable(disable_fixed: 0); } -airgroup BinaryExtension { - BinaryExtension(N: 2**21, operation_bus_id: OPERATION_BUS_ID); +airgroup Shift { + Shift(N: 2**21, operation_bus_id: OPERATION_BUS_ID); } -airgroup BinaryExtensionTable { - BinaryExtensionTable(disable_fixed: 0); +airgroup ShiftTable { + ShiftTable(disable_fixed: 0); +} + +airgroup SignExtension { + SignExtension(N: 2**21, operation_bus_id: OPERATION_BUS_ID); } diff --git a/state-machines/binary/pil/binary.pil b/state-machines/binary/pil/binary.pil index b6c48f73..77c4f3f5 100644 --- a/state-machines/binary/pil/binary.pil +++ b/state-machines/binary/pil/binary.pil @@ -1,8 +1,7 @@ require "std_lookup.pil" - -// Coprocessor in charge of performing standard RISCV binary operations - /* + Coprocessor in charge of performing the following binary operations: + List 64-bit operations: name │ op │ m_op │ carry │ use_last_carry │ NOTES ────────┼──────────┼──────────┼───────┼────────────────┼─────────────────────────────────── @@ -65,6 +64,7 @@ airtemplate Binary(const int N = 2**21, const int operation_bus_id = BINARY_ID) // Default values const int bits = 64; const int bytes = bits / 8; + const int half_bytes = bytes / 2; // Main values const int input_chunks = 2; @@ -80,51 +80,69 @@ airtemplate Binary(const int N = 2**21, const int operation_bus_id = BINARY_ID) // Secondary columns col witness use_last_carry; // 1 if the operation uses the last carry as its result - col witness op_is_min_max; // 1 if op ∈ {MINU,MIN,MAXU,MAX} + col witness op_is_min_max; // 1 if the operation is any of the MIN/MAX operations - const expr cout32 = carry[bytes/2-1]; + const expr mode64 = 1 - mode32; + const expr cout32 = carry[half_bytes-1]; const expr cout64 = carry[bytes-1]; - expr cout = (1-mode32) * (cout64 - cout32) + cout32; use_last_carry * (1 - use_last_carry) === 0; op_is_min_max * (1 - op_is_min_max) === 0; cout32*(1 - cout32) === 0; cout64*(1 - cout64) === 0; - // Constraints to check the correctness of each binary operation + // Auxiliary columns (primarily used to optimize lookups, but can be substituted with expressions) + col witness cout; + col witness result_is_a; + col witness use_last_carry_mode32; + col witness use_last_carry_mode64; + cout === mode64 * (cout64 - cout32) + cout32; + result_is_a === op_is_min_max * cout; + use_last_carry_mode32 === mode32 * use_last_carry; + use_last_carry_mode64 === mode64 * use_last_carry; + /* - opid last a b c cin cout - ─────────────────────────────────────────────────────────────── - m_op 0 a0 b0 c0 0 carry0 - m_op 0 a1 b1 c1 carry0 carry1 - m_op 0 a2 b2 c2 carry1 carry2 - m_op 0 a3 b3 c3 carry2 carry3 + 2*use_last_carry - m_op|EXT_32 0 a4|c3 b4|0 c4 carry3 carry4 - m_op|EXT_32 0 a5|c3 b5|0 c5 carry4 carry5 - m_op|EXT_32 0 a6|c3 b6|0 c6 carry5 carry6 - m_op|EXT_32 1 a7|c3 b7|0 c7 carry6 carry7 + 2*use_last_carry + Constraints to check the correctness of each binary operation + opid last a b c cin cout + flags + ───────────────────────────────────────────────────────────────------------------------------------------------- + m_op 0 a0 b0 c0 0 carry0 + 2*op_is_min_max + 4*result_is_a + m_op 0 a1 b1 c1 carry0 carry1 + 2*op_is_min_max + 4*result_is_a + m_op 0 a2 b2 c2 carry1 carry2 + 2*op_is_min_max + 4*result_is_a + m_op 0|1 a3 b3 c3 carry2 carry3 + 2*op_is_min_max + 4*result_is_a + 8*use_last_carry_mode32 + m_op|EXT_32 0 a4|c3 b4|0 c4 carry3 carry4 + 2*op_is_min_max + 4*result_is_a + m_op|EXT_32 0 a5|c3 b5|0 c5 carry4 carry5 + 2*op_is_min_max + 4*result_is_a + m_op|EXT_32 0 a6|c3 b6|0 c6 carry5 carry6 + 2*op_is_min_max + 4*result_is_a + m_op|EXT_32 0|1 a7|c3 b7|0 c7 carry6 carry7 + 2*op_is_min_max + 4*result_is_a + 8*use_last_carry_mode64 + ───────────────────────────────────────────────────────────────------------------------------------------------- + Perform, at the byte level, lookups against the binary table on inputs: + [last, m_op, a, b, cin, c, cout + flags] + where last indicates whether the byte is the last one in the operation */ - // Perform, at the byte level, lookups against the binary table on inputs: - // [last, m_op, a, b, cin, c, cout + flags] - // where last indicates whether the byte is the last one in the operation + lookup_assumes(BINARY_TABLE_ID, [0, m_op, free_in_a[0], free_in_b[0], 0, free_in_c[0], carry[0] + 2*op_is_min_max + 4*result_is_a]); - lookup_assumes(BINARY_TABLE_ID, [0, m_op, free_in_a[0], free_in_b[0], 0, free_in_c[0], carry[0] + 2*op_is_min_max + 4*op_is_min_max*cout]); - - expr _m_op = (1-mode32) * (m_op - EXT_32_OP) + EXT_32_OP; + // More auxiliary columns + col witness m_op_or_ext; + col witness free_in_a_or_c[half_bytes]; + col witness free_in_b_or_zero[half_bytes]; + m_op_or_ext === mode64 * (m_op - EXT_32_OP) + EXT_32_OP; + int index = 0; for (int i = 1; i < bytes; i++) { - expr _free_in_a = (1-mode32) * (free_in_a[i] - free_in_c[bytes/2-1]) + free_in_c[bytes/2-1]; - expr _free_in_b = (1-mode32) * free_in_b[i]; - - if (i < bytes/2 - 1) { - lookup_assumes(BINARY_TABLE_ID, [0, m_op, free_in_a[i], free_in_b[i], carry[i-1], free_in_c[i], carry[i] + 2*op_is_min_max + 4*op_is_min_max*cout]); - } else if (i == bytes/2 - 1) { - lookup_assumes(BINARY_TABLE_ID, [mode32, m_op, free_in_a[i], free_in_b[i], carry[i-1], free_in_c[i], cout32 + 2*op_is_min_max + 4*op_is_min_max*cout + 8*use_last_carry*mode32]); - } else if (i < bytes - 1) { - lookup_assumes(BINARY_TABLE_ID, [0, _m_op, _free_in_a, _free_in_b, carry[i-1], free_in_c[i], carry[i] + 2*op_is_min_max + 4*op_is_min_max*cout]); - } else { - lookup_assumes(BINARY_TABLE_ID, [1-mode32, _m_op, _free_in_a, _free_in_b, carry[i-1], free_in_c[i], cout64 + 2*op_is_min_max + 4*op_is_min_max*cout + 8*use_last_carry*(1-mode32)]); - } + if (i >= half_bytes) { + index = i - half_bytes; + free_in_a_or_c[index] === mode64 * (free_in_a[i] - free_in_c[half_bytes-1]) + free_in_c[half_bytes-1]; + free_in_b_or_zero[index] === mode64 * free_in_b[i]; + } + + if (i < half_bytes - 1) { + lookup_assumes(BINARY_TABLE_ID, [0, m_op, free_in_a[i], free_in_b[i], carry[i-1], free_in_c[i], carry[i] + 2*op_is_min_max + 4*result_is_a]); + } else if (i == half_bytes - 1) { + lookup_assumes(BINARY_TABLE_ID, [mode32, m_op, free_in_a[i], free_in_b[i], carry[i-1], free_in_c[i], cout32 + 2*op_is_min_max + 4*result_is_a + 8*use_last_carry_mode32]); + } else if (i < bytes - 1) { + lookup_assumes(BINARY_TABLE_ID, [0, m_op_or_ext, free_in_a_or_c[index], free_in_b_or_zero[index], carry[i-1], free_in_c[i], carry[i] + 2*op_is_min_max + 4*result_is_a]); + } else { + lookup_assumes(BINARY_TABLE_ID, [mode64, m_op_or_ext, free_in_a_or_c[index], free_in_b_or_zero[index], carry[i-1], free_in_c[i], cout64 + 2*op_is_min_max + 4*result_is_a + 8*use_last_carry_mode64]); + } } // Constraints to make sure that this component is called from the main component diff --git a/state-machines/binary/pil/binary_extension.pil b/state-machines/binary/pil/binary_extension.pil deleted file mode 100644 index 2e1587f4..00000000 --- a/state-machines/binary/pil/binary_extension.pil +++ /dev/null @@ -1,111 +0,0 @@ -require "std_permutation.pil" -require "std_lookup.pil" -require "std_range_check.pil" - -// Coprocessor in charge of performing standard RISCV binary operations - -/* -List: - ┼────────┼────────┼──────────┼ - │ name │ bits │ op │ - ┼────────┼────────┼──────────┼ - │ SLL │ 64 │ 0x0d │ - │ SRL │ 64 │ 0x0e │ - │ SRA │ 64 │ 0x0f │ - │ SLL_W │ 32 │ 0x1d │ - │ SRL_W │ 32 │ 0x1e │ - │ SRA_W │ 32 │ 0x1f │ - │ SE_B │ 32 │ 0x23 │ - │ SE_H │ 32 │ 0x24 │ - │ SE_W │ 32 │ 0x25 │ - ┼────────┼────────┼──────────┼ - -Examples: -======================================= - -SLL 28 -x in1[x] out[x][0] out[x][1] ---------------------------------------- -0 0x11 0x10000000 0x00000001 -1 0x22 0x00000000 0x00000220 -2 0x33 0x00000000 0x00033000 -3 0x44 0x00000000 0x04400000 -4 0x55 0x00000000 0x50000000 -5 0x66 0x00000000 0x00000000 -6 0x77 0x00000000 0x00000000 -7 0x88 0x00000000 0x00000000 ---------------------------------------- -Result: 0x10000000 0x54433221 - -SLL_W 8 -x in1[x] out[x][0] out[x][1] ---------------------------------------- -0 0x11 0x00001100 0x00000000 -1 0x22 0x00220000 0x00000000 -2 0x33 0x33000000 0x00000000 -3 0x44 0x00000000 0x00000044 -4 0x55 0x00000000 0x00000000 (since 0x44 & 0x80 = 0, we stop here and set the remaining bytes to 0x00) -5 0x66 0x00000000 0x00000000 (bytes of in1 are ignored from here) -6 0x77 0x00000000 0x00000000 -7 0x88 0x00000000 0x00000000 ---------------------------------------- -Result: 0x33221100 0x00000000 - -SE_H -x in2[x] out[x][0] out[x][1] ---------------------------------------- -0 0xbc 0x000000bc 0x00000000 -1 0x8a 0xFFFF8a00 0xFFFFFFFF (since 0x8a & 0x80 = 0x80, we stop here and set the remaining bytes to 0xFF) -2 0x33 0x00000000 0x00000000 (bytes of in2 are ignored from here) -3 0x44 0x00000000 0x00000000 -4 0x55 0x00000000 0x00000000 -5 0x66 0x00000000 0x00000000 -6 0x77 0x00000000 0x00000000 -7 0x88 0x00000000 0x00000000 ---------------------------------------- -Result: 0xFFFF8abc 0xFFFFFFFF -*/ - -const int BINARY_EXTENSION_ID = 21; - -airtemplate BinaryExtension(const int N = 2**18, const int operation_bus_id = BINARY_EXTENSION_ID) { - const int bits = 64; - const int bytes = bits / 8; - - col witness op; - col witness in1[bytes]; - col witness in2_low; // Note: if in2_low∊[0,2^5-1], else in2_low∊[0,2^6-1] (checked by the table) - col witness out[bytes][2]; - col witness op_is_shift; // 1 if op is shift, 0 otherwise - - // Constraints to check the correctness of each binary operation - for (int j = 0; j < bytes; j++) { - lookup_assumes(BINARY_EXTENSION_TABLE_ID, [op, j, in1[j], in2_low, out[j][0], out[j][1], op_is_shift]); - } - - // Constraints to make sure that this component is called from the main component - col witness in2[2]; - - expr in1_low = in1[0] + in1[1]*2**8 + in1[2]*2**16 + in1[3]*2**24; - expr in1_high = in1[4] + in1[5]*2**8 + in1[6]*2**16 + in1[7]*2**24; - - col witness main_step; - col witness multiplicity; - lookup_proves( - operation_bus_id, - [ - main_step, - op, - op_is_shift * (in1_low - in2[0]) + in2[0], - op_is_shift * (in1_high - in2[1]) + in2[1], - op_is_shift * (in2_low + 256 * in2[0] - in1_low) + in1_low, - op_is_shift * (in2[1] - in1_high) + in1_high, - out[0][0] + out[1][0] + out[2][0] + out[3][0] + out[4][0] + out[5][0] + out[6][0] + out[7][0], - out[0][1] + out[1][1] + out[2][1] + out[3][1] + out[4][1] + out[5][1] + out[6][1] + out[7][1], - 0 - ], - multiplicity - ); - - range_check(colu: in2[0], min: 0, max: 2**24-1, sel: op_is_shift); -} \ No newline at end of file diff --git a/state-machines/binary/pil/shift.pil b/state-machines/binary/pil/shift.pil new file mode 100644 index 00000000..3ab3575e --- /dev/null +++ b/state-machines/binary/pil/shift.pil @@ -0,0 +1,99 @@ +require "std_lookup.pil" +require "std_range_check.pil" + +/* + Coprocessor in charge of performing shift operations: + + ┼────────┼────────┼──────────┼ + │ name │ bits │ op │ + ┼────────┼────────┼──────────┼ + │ SLL │ 64 │ 0x0d │ + │ SRL │ 64 │ 0x0e │ + │ SRA │ 64 │ 0x0f │ + │ SLL_W │ 32 │ 0x1d │ + │ SRL_W │ 32 │ 0x1e │ + │ SRA_W │ 32 │ 0x1f │ + ┼────────┼────────┼──────────┼ + + Examples: + ======================================= + + SLL 28 + x in1[x] out[x][0] out[x][1] + --------------------------------------- + 0 0x11 0x10000000 0x00000001 + 1 0x22 0x00000000 0x00000220 + 2 0x33 0x00000000 0x00033000 + 3 0x44 0x00000000 0x04400000 + 4 0x55 0x00000000 0x50000000 + 5 0x66 0x00000000 0x00000000 + 6 0x77 0x00000000 0x00000000 + 7 0x88 0x00000000 0x00000000 + --------------------------------------- + Result: 0x10000000 0x54433221 + + SLL_W 8 + x in1[x] out[x][0] out[x][1] + --------------------------------------- + 0 0x11 0x00001100 0x00000000 + 1 0x22 0x00220000 0x00000000 + 2 0x33 0x33000000 0x00000000 + 3 0x44 0x00000000 0x00000044 + 4 0x55 0x00000000 0x00000000 (since 0x44 & 0x80 = 0, we stop here and set the remaining bytes to 0x00) + 5 0x66 0x00000000 0x00000000 (bytes of in1 are ignored from here) + 6 0x77 0x00000000 0x00000000 + 7 0x88 0x00000000 0x00000000 + --------------------------------------- + Result: 0x33221100 0x00000000 +*/ + +const int SHIFT_ID = 21; + +airtemplate Shift(const int N = 2**18, const int operation_bus_id = SHIFT_ID) { + const int bits = 64; + const int bytes = bits / 8; + const int half_bytes = bytes / 2; + + col witness op; + col witness in1[bytes]; + col witness in2_low; // Note: if in2_low∊[0,2^5-1], else in2_low∊[0,2^6-1] (checked by the table) + col witness out[bytes][2]; + + // Constraints to check the correctness of each shift operation + for (int j = 0; j < bytes; j++) { + lookup_assumes(SHIFT_TABLE_ID, [op, j, in1[j], in2_low, out[j][0], out[j][1]]); + } + + // Constraints to make sure that this component is called from the main component + expr in1_low = 0; + expr in1_high = 0; + expr out_low = 0; + expr out_high = 0; + for (int i = 0; i < half_bytes; i++) { + in1_low += in1[i] * (0xFF ** i); + in1_high += in1[i + half_bytes] * (0xFF ** i); + out_low += out[i][0] + out[i + half_bytes][0]; + out_high += out[i][1] + out[i + half_bytes][1]; + } + + col witness in2[2]; + col witness main_step; + col witness multiplicity; + lookup_proves( + operation_bus_id, + [ + main_step, + op, + in1_low, + in1_high, + in2_low + 256 * in2[0], + in2[1], + out_low, + out_high, + 0 + ], + multiplicity + ); + + range_check(in2[0], 0, 2**24-1); +} \ No newline at end of file diff --git a/state-machines/binary/pil/binary_extension_table.pil b/state-machines/binary/pil/shift_table.pil similarity index 66% rename from state-machines/binary/pil/binary_extension_table.pil rename to state-machines/binary/pil/shift_table.pil index 871debe4..3d4a4f34 100644 --- a/state-machines/binary/pil/binary_extension_table.pil +++ b/state-machines/binary/pil/shift_table.pil @@ -8,20 +8,15 @@ require "constants.pil" // SRA (OP:0x0f) 2^8 (A) * 2^3 (OFFSET) * 2^8 (B) = 2^19 | 2^20 + 2^19 // SLL_W (OP:0x1d) 2^8 (A) * 2^3 (OFFSET) * 2^8 (B) = 2^19 | 2^21 // SRL_W (OP:0x1e) 2^8 (A) * 2^3 (OFFSET) * 2^8 (B) = 2^19 | 2^21 + 2^19 -// SRA_W (OP:0x1f) 2^8 (A) * 2^3 (OFFSET) * 2^8 (B) = 2^19 | 2^21 + 2^20 -// SE_B (OP:0x23) 2^8 (A) * 2^3 (OFFSET) = 2^11 | 2^21 + 2^20 + 2^11 -// SE_H (OP:0x24) 2^8 (A) * 2^3 (OFFSET) = 2^11 | 2^21 + 2^20 2^12 -// SE_W (OP:0x25) 2^8 (A) * 2^3 (OFFSET) = 2^11 | 2^21 + 2^20 2^12 + 2^11 => 2^22 +// SRA_W (OP:0x1f) 2^8 (A) * 2^3 (OFFSET) * 2^8 (B) = 2^19 | 2^21 + 2^20 => 2^22 -const int BINARY_EXTENSION_TABLE_ID = 124; +const int SHIFT_TABLE_ID = 124; -airtemplate BinaryExtensionTable(const int N = 2**22, const int disable_fixed = 0) { +airtemplate ShiftTable(const int N = 2**22, const int disable_fixed = 0) { #pragma memory m1 start const int SE_MASK_32 = 0xFFFFFFFF00000000; - const int SE_MASK_16 = 0xFFFFFFFFFFFF0000; - const int SE_MASK_8 = 0xFFFFFFFFFFFFFF00; const int SIGN_32_BIT = 0x80000000; const int SIGN_BYTE = 0x80; @@ -57,10 +52,7 @@ airtemplate BinaryExtensionTable(const int N = 2**22, const int disable_fixed = 0:(P2_11*3)]...; col fixed OP = [0x0d:P2_19, 0x0e:P2_19, 0x0f:P2_19, // SLL, SRL, SRA - 0x1d:P2_19, 0x1e:P2_19, 0x1f:P2_19, // SLL_W, SRL_W, SRA_W - 0x23:P2_11, 0x24:P2_11, 0x25:P2_11]...; // SE_B, SE_H, SE_W - - col fixed OP_IS_SHIFT = [1:(P2_19*6), 0:(P2_11*3)]...; + 0x1d:P2_19, 0x1e:P2_19, 0x1f:P2_19]...; // SLL_W, SRL_W, SRA_W #pragma timer t1 end #pragma timer t2 start @@ -72,7 +64,7 @@ airtemplate BinaryExtensionTable(const int N = 2**22, const int disable_fixed = #pragma transpile for (int i = 0; i < N; i++) { - int [op, offset, a, b, op_is_shift] = [OP[i], OFFSET[i], A[i], B[i], OP_IS_SHIFT[i]]; + int [op, offset, a, b] = [OP[i], OFFSET[i], A[i], B[i]]; int _out = 0; const int _a = a << (8*offset); switch (op) { @@ -93,6 +85,7 @@ airtemplate BinaryExtensionTable(const int N = 2**22, const int disable_fixed = } } } + case 0x1d: // SLL_W if (offset >= 4) { // last most significant bytes are ignored because it's 32-bit operation @@ -131,37 +124,6 @@ airtemplate BinaryExtensionTable(const int N = 2**22, const int disable_fixed = } } - case 0x23: // SE_B - if (offset == 0) { - // the most significant bit of first byte determines the sign extend - _out = (a & SIGN_BYTE) ? a | SE_MASK_8 : a - } else { - // the rest of the bytes are ignored - _out = 0; - } - - case 0x24: // SE_H - if (offset == 0) { - // fist byte not define the sign extend, but participate of result - _out = a; - } else if (offset == 1) { - // the most significant bit of second byte determines the sign extend - _out = (a & SIGN_BYTE) ? _a | SE_MASK_16 : _a - } else { - // the rest of the bytes are ignored - _out = 0; - } - - case 0x25: // SE_W - if (offset <= 3) { - _out = _a; - if (offset == 3) { - if (a & SIGN_BYTE) { - // the most significant bit of fourth byte determines the sign extend - _out = _out | SE_MASK_32 - } - } - } default: error(`Invalid operation ${op}`); } @@ -170,11 +132,11 @@ airtemplate BinaryExtensionTable(const int N = 2**22, const int disable_fixed = const int _c1 = (_out >> 32) & MASK_32; C0[i] = _c0; C1[i] = _c1; - log(`T ${op},${offset},${a},${b},${_c0},${_c1},${op_is_shift},${i},${_out}`) + log(`T ${i},${op},${offset},${a},${b},${_c0},${_c1},${_out}`) } #pragma timer t2 end #pragma timer tt end - lookup_proves(BINARY_EXTENSION_TABLE_ID, [OP, OFFSET, A, B, C0, C1, OP_IS_SHIFT], multiplicity); + lookup_proves(SHIFT_TABLE_ID, [OP, OFFSET, A, B, C0, C1], multiplicity); #pragma memory m1 end } \ No newline at end of file diff --git a/state-machines/binary/pil/sign_extension.pil b/state-machines/binary/pil/sign_extension.pil new file mode 100644 index 00000000..c093409d --- /dev/null +++ b/state-machines/binary/pil/sign_extension.pil @@ -0,0 +1,88 @@ +require "std_lookup.pil" + +/* + Coprocessor in charge of performing sign extension operations: + + ┼────────┼────────┼──────────┼ + │ name │ bits │ op │ + ┼────────┼────────┼──────────┼ + │ SE_B │ 32 │ 0x23 │ + │ SE_H │ 32 │ 0x24 │ + │ SE_W │ 32 │ 0x25 │ + ┼────────┼────────┼──────────┼ + + Examples: + ======================================= + + SE_B + input input[0] sign_bit input[1] output + ------------------------------------------------------------ + 0xd72678a7 0x27 1 0xd72678 0xFFFFFFa7 (0xa7 & 0x80 = 0x80) + ------------------------------------------------------------ + + SE_H + input input[0] sign_bit input[1] output + ------------------------------------------------------------ + 0x443370bc 0x70bc 0 0x4433 0x000070bc (0x70 & 0x80 = 0x00) + ------------------------------------------------------------ + + SE_W + input input[0] sign_bit input[1] output + ------------------------------------------------------------ + 0x8a3f7a40 0x7a40 1 0xa3f 0x8a3f7a40 (0x8a & 0x80 = 0x80) + ------------------------------------------------------------ +*/ + +const int SIGN_EXTENSION_ID = 22; + +airtemplate SignExtension(const int N = 2**18, const int operation_bus_id = SIGN_EXTENSION_ID) { + + col witness sel_8; + col witness sel_16; + expr sel_32 = 1 - sel_8 - sel_16; + col witness input[2]; + col witness sign_bit; + + // Constraints to check the correctness of each sign extension operation + sel_8 * (1 - sel_8) === 0; + sel_16 * (1 - sel_16) === 0; + sel_8 * sel_16 === 0; + sign_bit * (1 - sign_bit) === 0; + + expr bit_offset = sel_8 * 2**7 + sel_16 * 2**15 + sel_32 * 2**31; + expr padding = sel_8 * 0xFFFFFF00 + sel_16 * 0xFFFF0000; + expr factor_32 = sel_32 * 2**16; + expr factor = sel_8 * 2**8 + sel_16 * 2**16 + factor_32; + expr input_low = input[0] + sign_bit * bit_offset + + input[1] * factor; + expr output_low = input[0] + sign_bit * (bit_offset + padding) + + input[1] * factor_32; + + int id_range_24 = range_check_id(0, 2**24-1); + int id_range_16 = range_check_id(0, 2**16-1); + int id_range_15 = range_check_id(0, 2**15-1); + int id_range_7 = range_check_id(0, 2**7-1); + + range_check_dynamic(input[0], id_range_7 * sel_8 + id_range_15 * sel_16 + id_range_16 * sel_32); + range_check_dynamic(input[1], id_range_24 * sel_8 + id_range_16 * sel_16 + id_range_15 * sel_32); + + // Constraints to make sure that this component is called from the main component + col witness in2_low; + col witness main_step; + col witness multiplicity; + lookup_proves( + operation_bus_id, + [ + main_step, + sel_8 * 0x23 + sel_16 * 0x24 + sel_32 * 0x25, + input_low, + 0, + in2_low, + 0, + output_low, + sign_bit * 0xFFFFFFFF, + 0 + ], + multiplicity + ); +} \ No newline at end of file