Skip to content

Commit

Permalink
Docs/word addressable memory (#519)
Browse files Browse the repository at this point in the history
* Update docs for word-addressable memory

* Add code comments

* Update diagrams
  • Loading branch information
moodlezoup authored Dec 4, 2024
1 parent abb0f07 commit 1c5fad8
Show file tree
Hide file tree
Showing 12 changed files with 151 additions and 18 deletions.
2 changes: 1 addition & 1 deletion book/src/SUMMARY.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
- [GKR](./background/gkr.md)
- [Binius](./background/binius.md)
- [Multiplicative Generator](./background/binius/multiplicative-generator.md)
- [RISC-V](./background/riskv.md)
- [RISC-V](./background/risc-v.md)
- [Dev](./dev/README.md)
- [Install](./dev/install.md)
- [Tools](./dev/tools.md)
Expand Down
File renamed without changes.
99 changes: 98 additions & 1 deletion book/src/how/read_write_memory.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,101 @@ The verification of $\text{read\_timestamp} \leq \text{global\_timestamp}$ is eq

The process of ensuring that both $\text{read\_timestamp}$ and $(\text{global\_timestamp} - \text{read\_timestamp})$ lie within the specified range is known as range-checking. This is the procedure implemented in [`timestamp_range_check.rs`](https://github.com/a16z/jolt/blob/main/jolt-core/src/jolt/vm/timestamp_range_check.rs), using a modified version of Lasso.

Intuitively, checking that each read timestamp does not exceed the global timestamp prevents an attacker from answering all read operations to a given cell with "the right set of values, but out of order". Such an attack requires the attacker to "jump forward and backward in time". That is, for this attack to succeed, at some timestamp t when the cell is read, the attacker would have to return a value that will be written to that cell in the future (and at some later timestamp t' when the same cell is read the attacker would have to return a value that was written to that cell much earlier). This attack is prevented by confirming that all values returned have a timestamp that does not exceed the current global timestamp.
Intuitively, checking that each read timestamp does not exceed the global timestamp prevents an attacker from answering all read operations to a given cell with "the right set of values, but out of order". Such an attack requires the attacker to "jump forward and backward in time". That is, for this attack to succeed, at some timestamp $t$ when the cell is read, the attacker would have to return a value that will be written to that cell in the future (and at some later timestamp t' when the same cell is read the attacker would have to return a value that was written to that cell much earlier). This attack is prevented by confirming that all values returned have a timestamp that does not exceed the current global timestamp.

## Word-addressable memory

According to the RISC-V specification, the RISC-V memory is **byte-addressable**,
i.e. load and store instructions can access any memory address, with no restrictions on whether the address is aligned to a 4-byte word.
However, the specification caveats that unaligned accesses may incur performance penalties on hardware,
and some RISC-V implementations may choose to disallow unaligned memory accesses entirely.
This means that `LW` and `SW` must access memory addresses that are word-aligned (i.e. multiples of 4),
and `LH`, `LHU`, and `SH` must access memory addresses that are halfword-aligned (i.e. multiples of 2).
There are no such restrictions on `LB`, `LBU` and `SB`, as they only access a single byte.

Jolt disallows unaligned memory accesses for prover performance reasons;
one cannot generate a valid Jolt proof for an execution trace that includes unaligned memory accesses.
Enforcing aligned accesses allows Jolt to treat memory as **word-addressable**,
which reduces the number of committed polynomials and instances of offline memory-checking.

In more detail, Jolt
1. constrains all `LW` and `SW` instructions to only allow word-aligned accesses, and
constrains all `LH`, `LHU`, and `SH` instructions to only allow halfword-aligned accesses.
This is accomplished using virtual `ASSERT` instructions (see Section 6.1.1 of the Jolt paper).
2. replaces all `LH`, `LHU`, `SH`, `LB`, `LBU`, and `SB` instructions with [virtual sequences](./m-extension.md)
that only perform word-aligned accesses.

### `LB` virtual sequence

1. `ADDI` `rs1`, --, `imm`, $v_0$ `// Compute the memory address being accessed`
1. `ANDI` $v_0$, --, `(1 << 32) - 4`, $v_1$ `// Mask out the lower bits to obtain the word-aligned address`
1. `LW` $v_1$, --, 0, $v_2$ `// Load the full word`
1. `XORI` $v_2$, --, `0b11`, $v_3$ `// Compute the number of bytes to shift the word (in the lower 2 bits)`
1. `SLLI` $v_3$, --, 3, $v_3$ `// Compute the number of bits to shift the word (in the lower 5 bits)`
1. `SLL` $v_2$, $v_3$, --, `rd` `// Shift the word so that the desired byte is left-aligned`
1. `SRAI` `rd`, --, 24, `rd` `// Right arithmetic shift to sign-extend and right-align the byte`

### `LBU` virtual sequence

1. `ADDI` `rs1`, --, `imm`, $v_0$ `// Compute the memory address being accessed`
1. `ANDI` $v_0$, --, `(1 << 32) - 4`, $v_1$ `// Mask out the lower bits to obtain the word-aligned address`
1. `LW` $v_1$, --, 0, $v_2$ `// Load the full word`
1. `XORI` $v_2$, --, `0b11`, $v_3$ `// Compute the number of bytes to shift the word (in the lower 2 bits)`
1. `SLLI` $v_3$, --, 3, $v_3$ `// Compute the number of bits to shift the word (in the lower 5 bits)`
1. `SLL` $v_2$, $v_3$, --, `rd` `// Shift the word so that the desired byte is left-aligned`
1. `SRLI` `rd`, --, 24, `rd` `// Right logical shift to zero-extend and right-align the byte`

### `LH` virtual sequence

1. `ASSERT_HALFWORD_ALIGNMENT` `rs1`, --, `imm`, -- `// Virtual instruction to enforce aligned memory access`
1. `ADDI` `rs1`, --, `imm`, $v_0$ `// Compute the memory address being accessed`
1. `ANDI` $v_0$, --, `(1 << 32) - 4`, $v_1$ `// Mask out the lower bits to obtain the word-aligned address`
1. `LW` $v_1$, --, 0, $v_2$ `// Load the full word`
1. `XORI` $v_2$, --, `0b10`, $v_3$ `// Compute the number of bytes to shift the word (in the lower 2 bits)`
1. `SLLI` $v_3$, --, 3, $v_3$ `// Compute the number of bits to shift the word (in the lower 5 bits)`
1. `SLL` $v_2$, $v_3$, --, `rd` `// Shift the word so that the desired halfword is left-aligned`
1. `SRAI` `rd`, --, 16, `rd` `// Right arithmetic shift to sign-extend and right-align the halfword`

### `LHU` virtual sequence

1. `ASSERT_HALFWORD_ALIGNMENT` `rs1`, --, `imm`, -- `// Virtual instruction to enforce aligned memory access`
1. `ADDI` `rs1`, --, `imm`, $v_0$ `// Compute the memory address being accessed`
1. `ANDI` $v_0$, --, `(1 << 32) - 4`, $v_1$ `// Mask out the lower bits to obtain the word-aligned address`
1. `LW` $v_1$, --, 0, $v_2$ `// Load the full word`
1. `XORI` $v_2$, --, `0b10`, $v_3$ `// Compute the number of bytes to shift the word (in the lower 2 bits)`
1. `SLLI` $v_3$, --, 3, $v_3$ `// Compute the number of bits to shift the word (in the lower 5 bits)`
1. `SLL` $v_2$, $v_3$, --, `rd` `// Shift the word so that the desired halfword is left-aligned`
1. `SRLI` `rd`, --, 16, `rd` `// Right logical shift to zero-extend and right-align the halfword`

### `SB` virtual sequence

1. `ADDI` `rs1`, --, `imm`, $v_0$ `// Compute the memory address being accessed`
1. `ANDI` $v_0$, --, `(1 << 32) - 4`, $v_1$ `// Mask out the lower bits to obtain the word-aligned address`
1. `LW` $v_1$, --, 0, $v_2$ `// Load the full word`
1. `SLLI` $v_0$, --, 3, $v_3$ `// Compute the number of bits to shift the byte (in the lower 5 bits)`
1. `LUI` --, --, `0xff`, $v_4$ `// Load the bitmask into a virtual register`
1. `SLL` $v_4$, $v_3$, --, $v_4$ `// Shift the bitmask into the position of the desired byte`
1. `SLL` `rs2`, $v_3$, --, $v_5$ `// Shift the value being stored to align with the bitmask`
1. `XOR` $v_2$, $v_5$, --, $v_5$
1. `AND` $v_5$, $v_4$, --, $v_5$
1. `XOR` $v_2$, $v_4$, --, $v_2$
1. `SW` $v_1$, $v_2$, 0, -- `// Store the updated word`

Instructions 8-10 use a [bit-twiddling hack](https://graphics.stanford.edu/~seander/bithacks.html#MaskedMerge) to mask the stored byte into the word.

### `SH` virtual sequence

1. `ASSERT_HALFWORD_ALIGNMENT` `rs1`, --, `imm`, -- `// Virtual instruction to enforce aligned memory access`
1. `ADDI` `rs1`, --, `imm`, $v_0$ `// Compute the memory address being accessed`
1. `ANDI` $v_0$, --, `(1 << 32) - 4`, $v_1$ `// Mask out the lower bits to obtain the word-aligned address`
1. `LW` $v_1$, --, 0, $v_2$ `// Load the full word`
1. `SLLI` $v_0$, --, 3, $v_3$ `// Compute the number of bits to shift the halfword (in the lower 5 bits)`
1. `LUI` --, --, `0xffff`, $v_4$ `// Load the bitmask into a virtual register`
1. `SLL` $v_4$, $v_3$, --, $v_4$ `// Shift the bitmask into the position of the desired halfword`
1. `SLL` `rs2`, $v_3$, --, $v_5$ `// Shift the value being stored to align with the bitmask`
1. `XOR` $v_2$, $v_5$, --, $v_5$
1. `AND` $v_5$, $v_4$, --, $v_5$
1. `XOR` $v_2$, $v_4$, --, $v_2$
1. `SW` $v_1$, $v_2$, 0, -- `// Store the updated word`

Instructions 9-11 use a [bit-twiddling hack](https://graphics.stanford.edu/~seander/bithacks.html#MaskedMerge) to mask the stored byte into the word.
Binary file modified book/src/imgs/final_memory_state.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified book/src/imgs/memory_layout.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion book/src/people.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
- Michael Zhu
- Sam Ragsdale
- Arasu Arun
- Noah Cintron
- Noah Citron
- Justin Thaler
- Srinath Setty

Expand Down
8 changes: 4 additions & 4 deletions jolt-core/src/jolt/instruction/lh.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ impl<const WORD_SIZE: usize> VirtualInstructionSequence for LHInstruction<WORD_S
advice_value: None,
});

let left_aligned_byte = SLLInstruction::<WORD_SIZE>(word, bit_shift).lookup_entry();
let left_aligned_halfword = SLLInstruction::<WORD_SIZE>(word, bit_shift).lookup_entry();
virtual_trace.push(RVTraceRow {
instruction: ELFInstruction {
address: trace_row.instruction.address,
Expand All @@ -193,14 +193,14 @@ impl<const WORD_SIZE: usize> VirtualInstructionSequence for LHInstruction<WORD_S
register_state: RegisterState {
rs1_val: Some(word),
rs2_val: Some(bit_shift),
rd_post_val: Some(left_aligned_byte),
rd_post_val: Some(left_aligned_halfword),
},
memory_state: None,
advice_value: None,
});

let sign_extended_halfword =
SRAInstruction::<WORD_SIZE>(left_aligned_byte, 16).lookup_entry();
SRAInstruction::<WORD_SIZE>(left_aligned_halfword, 16).lookup_entry();
assert_eq!(sign_extended_halfword, expected_rd_post_val);
virtual_trace.push(RVTraceRow {
instruction: ELFInstruction {
Expand All @@ -213,7 +213,7 @@ impl<const WORD_SIZE: usize> VirtualInstructionSequence for LHInstruction<WORD_S
virtual_sequence_remaining: Some(Self::SEQUENCE_LENGTH - virtual_trace.len() - 1),
},
register_state: RegisterState {
rs1_val: Some(left_aligned_byte),
rs1_val: Some(left_aligned_halfword),
rs2_val: None,
rd_post_val: Some(sign_extended_halfword),
},
Expand Down
2 changes: 0 additions & 2 deletions jolt-core/src/jolt/instruction/sb.rs
Original file line number Diff line number Diff line change
Expand Up @@ -296,8 +296,6 @@ impl<const WORD_SIZE: usize> VirtualInstructionSequence for SBInstruction<WORD_S
advice_value: None,
});

// println!("After: {:?}", virtual_trace.last().unwrap());

virtual_trace
}

Expand Down
31 changes: 27 additions & 4 deletions jolt-core/src/jolt/instruction/sh.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@ use tracer::{ELFInstruction, MemoryState, RVTraceRow, RegisterState, RV32IM};

use super::VirtualInstructionSequence;
use crate::jolt::instruction::{
add::ADDInstruction, and::ANDInstruction, sll::SLLInstruction, xor::XORInstruction,
JoltInstruction,
add::ADDInstruction, and::ANDInstruction, sll::SLLInstruction,
virtual_assert_aligned_memory_access::AssertAlignedMemoryAccessInstruction,
xor::XORInstruction, JoltInstruction,
};
/// Stores a halfword in memory
pub struct SHInstruction<const WORD_SIZE: usize>;

impl<const WORD_SIZE: usize> VirtualInstructionSequence for SHInstruction<WORD_SIZE> {
const SEQUENCE_LENGTH: usize = 11;
const SEQUENCE_LENGTH: usize = 12;

fn virtual_trace(trace_row: RVTraceRow) -> Vec<RVTraceRow> {
assert_eq!(trace_row.instruction.opcode, RV32IM::SH);
Expand All @@ -37,6 +38,29 @@ impl<const WORD_SIZE: usize> VirtualInstructionSequence for SHInstruction<WORD_S
_ => panic!("Unsupported WORD_SIZE: {}", WORD_SIZE),
};

let is_aligned =
AssertAlignedMemoryAccessInstruction::<WORD_SIZE, 2>(dest, offset_unsigned)
.lookup_entry();
debug_assert_eq!(is_aligned, 1);
virtual_trace.push(RVTraceRow {
instruction: ELFInstruction {
address: trace_row.instruction.address,
opcode: RV32IM::VIRTUAL_ASSERT_HALFWORD_ALIGNMENT,
rs1: r_dest,
rs2: None,
rd: None,
imm: Some(offset),
virtual_sequence_remaining: Some(Self::SEQUENCE_LENGTH - virtual_trace.len() - 1),
},
register_state: RegisterState {
rs1_val: Some(dest),
rs2_val: None,
rd_post_val: None,
},
memory_state: None,
advice_value: None,
});

let ram_address = ADDInstruction::<WORD_SIZE>(dest, offset_unsigned).lookup_entry();
assert!(ram_address % 2 == 0);
virtual_trace.push(RVTraceRow {
Expand All @@ -57,7 +81,6 @@ impl<const WORD_SIZE: usize> VirtualInstructionSequence for SHInstruction<WORD_S
memory_state: None,
advice_value: None,
});
// TODO(moodlezoup): Assert aligned memory access

let word_address_bitmask = ((1u128 << WORD_SIZE) - 4) as u64;
let word_address =
Expand Down
12 changes: 7 additions & 5 deletions jolt-core/src/jolt/vm/read_write_memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ impl ReadWriteMemoryPreprocessing {

let num_words = max_bytecode_address.next_multiple_of(4) / 4 - min_bytecode_address / 4 + 1;
let mut bytecode_words = vec![0u32; num_words as usize];
// Convert bytes into words and populate `bytecode_words`
for chunk in
memory_init.chunk_by(|(address_a, _), (address_b, _)| address_a / 4 == address_b / 4)
{
Expand Down Expand Up @@ -282,6 +283,7 @@ impl<F: JoltField> ReadWriteMemoryPolynomials<F> {
program_io.memory_layout.input_start,
&program_io.memory_layout,
);
// Convert input bytes into words and populate `v_init`
for chunk in program_io.inputs.chunks(4) {
let mut word = [0u8; 4];
for (i, byte) in chunk.iter().enumerate() {
Expand Down Expand Up @@ -718,8 +720,8 @@ where
v_init[v_init_index] = *word as u64;
v_init_index += 1;
}
// Copy input bytes
v_init_index = memory_address_to_witness_index(memory_layout.input_start, memory_layout);
// Convert input bytes into words and populate `v_init`
for chunk in preprocessing.program_io.as_ref().unwrap().inputs.chunks(4) {
let mut word = [0u8; 4];
for (i, byte) in chunk.iter().enumerate() {
Expand Down Expand Up @@ -860,11 +862,11 @@ where
.collect();

let mut v_io: Vec<u64> = vec![0; memory_size];
// Copy input bytes
let mut input_index = memory_address_to_witness_index(
program_io.memory_layout.input_start,
&program_io.memory_layout,
);
// Convert input bytes into words and populate `v_io`
for chunk in program_io.inputs.chunks(4) {
let mut word = [0u8; 4];
for (i, byte) in chunk.iter().enumerate() {
Expand All @@ -874,11 +876,11 @@ where
v_io[input_index] = word as u64;
input_index += 1;
}
// Copy output bytes
let mut output_index = memory_address_to_witness_index(
program_io.memory_layout.output_start,
&program_io.memory_layout,
);
// Convert output bytes into words and populate `v_io`
for chunk in program_io.outputs.chunks(4) {
let mut word = [0u8; 4];
for (i, byte) in chunk.iter().enumerate() {
Expand Down Expand Up @@ -988,9 +990,9 @@ where
io_witness_range_eval *= r_prod;

let mut v_io: Vec<u64> = vec![0; io_memory_size];
// Copy input bytes
let mut input_index =
memory_address_to_witness_index(memory_layout.input_start, memory_layout);
// Convert input bytes into words and populate `v_io`
for chunk in program_io.inputs.chunks(4) {
let mut word = [0u8; 4];
for (i, byte) in chunk.iter().enumerate() {
Expand All @@ -1000,9 +1002,9 @@ where
v_io[input_index] = word as u64;
input_index += 1;
}
// Copy output bytes
let mut output_index =
memory_address_to_witness_index(memory_layout.output_start, memory_layout);
// Convert output bytes into words and populate `v_io`
for chunk in program_io.outputs.chunks(4) {
let mut word = [0u8; 4];
for (i, byte) in chunk.iter().enumerate() {
Expand Down
2 changes: 2 additions & 0 deletions jolt-core/src/r1cs/constraints.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ impl<const C: usize, F: JoltField> R1CSConstraints<C, F> for JoltRV32IMConstrain
let packed_query =
R1CSBuilder::<C, F, JoltR1CSInputs>::pack_be(query_chunks.clone(), LOG_M);

// For the `AssertAlignedMemoryAccessInstruction` lookups, we add the `rs1` and `imm` values
// to obtain the memory address being accessed.
let add_operands = JoltR1CSInputs::InstructionFlags(ADDInstruction::default().into())
+ JoltR1CSInputs::InstructionFlags(
AssertAlignedMemoryAccessInstruction::<32, 2>::default().into(),
Expand Down
11 changes: 11 additions & 0 deletions tracer/src/emulator/mmu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,8 @@ impl Mmu {
}
}

/// Records the memory word being accessed by a load instruction. The memory
/// state is used in Jolt to construct the witnesses in `read_write_memory.rs`.
fn trace_load(&mut self, effective_address: u64) {
let word_address = (effective_address >> 2) << 2;
let bytes = match self.xlen {
Expand Down Expand Up @@ -582,6 +584,9 @@ impl Mmu {
}
}

/// Records the state of the memory word containing the accessed byte
/// before and after the store instruction. The memory state is used in Jolt to
/// construct the witnesses in `read_write_memory.rs`.
fn trace_store_byte(&mut self, effective_address: u64, value: u64) {
self.assert_effective_address(effective_address);
let bytes = match self.xlen {
Expand Down Expand Up @@ -620,6 +625,9 @@ impl Mmu {
});
}

/// Records the state of the memory word containing the accessed halfword
/// before and after the store instruction. The memory state is used in Jolt to
/// construct the witnesses in `read_write_memory.rs`.
fn trace_store_halfword(&mut self, effective_address: u64, value: u64) {
self.assert_effective_address(effective_address);
let bytes = match self.xlen {
Expand Down Expand Up @@ -658,6 +666,9 @@ impl Mmu {
});
}

/// Records the state of the accessed memory word before and after the store
/// instruction. The memory state is used in Jolt to construct the witnesses
/// in `read_write_memory.rs`.
fn trace_store(&mut self, effective_address: u64, value: u64) {
self.assert_effective_address(effective_address);
let bytes = match self.xlen {
Expand Down

0 comments on commit 1c5fad8

Please sign in to comment.