Skip to content

Commit

Permalink
Creation of explicit forwarding rules in the protect_payload.rs
Browse files Browse the repository at this point in the history
This commit creates a small framework similar to IPRules that allows to explicitly forwards registers from the payload to the firmware and in the opposite direction. Last but not least, it creates and test a forwarding rule for ecall from S mode.
  • Loading branch information
francois141 committed Oct 25, 2024
1 parent 8d767b6 commit 21b01eb
Show file tree
Hide file tree
Showing 6 changed files with 144 additions and 62 deletions.
16 changes: 16 additions & 0 deletions config/test/qemu-virt-test-protect-payload.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,5 +38,21 @@ start_address = 0x80200000
# Default to 0x8000
stack_size = 0x8000

[target.payload]
# Name or path to the payload binary
name = "test_protect_payload_payload"

# Build profile for the payload (dev profile is set by default)
profile = "dev"

# Payload binary will be compiled with this value as a start address
# Default to "0x80400000"
start_address = 0x80400000

# Size of the payload stack for each hart (i.e. core)
# Default to 0x8000
stack_size = 0x8000


[policy]
name = "protect_payload"
File renamed without changes.
51 changes: 28 additions & 23 deletions firmware/test_protect_payload_firmware/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,32 +53,37 @@ global_asm!(
.align 4
.global _raw_trap_handler
_raw_trap_handler:
// Skip illegal instruction (pc += 4)
// Advance PC by 4 (skip instruction)
csrrw x5, mepc, x5
addi x5, x5, 4
addi x5, x5, 4
csrrw x5, mepc, x5
// Set mscratch to 1
csrrw x5, mscratch, x5
addi x5, x0, 1
csrrw x5, mscratch, x5
// Set t6 to 0 - for test_same_registers_after_trap
li t6, 0
// If mcause is 7 or 2, it might be triggered by the instruction sd t5, 0(t6) | csrr t0, mcause
// Therefore we don't want to load it a second time
csrr t0, mcause
li t1, 7
beq t0, t1, skip
csrr t0, mcause
li t1, 2
beq t0, t1, skip
// Make sure we get an access fault and we can't to that
li t6, 0x80400000
li t5, 60
sd t5, 0(t6)
skip:
// Return back to miralis
// Verify if all input registers are equal to 60
li t2, 60
bne a0, t2, infinite_loop
bne a1, t2, infinite_loop
bne a2, t2, infinite_loop
bne a3, t2, infinite_loop
bne a4, t2, infinite_loop
bne a5, t2, infinite_loop
// Set return values for successful ecall
li a0, 0xdeadbeef
li a1, 0xdeadbeef
j done
infinite_loop:
// Infinite loop in case of failure
wfi
j infinite_loop
done:
// Make sure we can't read this 0xdeadbeef in the payload
li s2, 0xdeadbeef
// Return to Miralis
mret
#"
"
);

extern "C" {
Expand Down
4 changes: 3 additions & 1 deletion justfile
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ spike_virt_benchmark := "./config/test/spike-virt-benchmark.toml"
spike_latency_benchmark := "./config/test/spike-latency-benchmark.toml"
qemu_virt_release := "./config/test/qemu-virt-release.toml"
qemu_virt_hello_world_payload := "./config/test/qemu-virt-hello-world-payload.toml"
qemu_virt_payload_lock := "./config/test/qemu-virt-payload-lock.toml"
qemu_virt_test_protect_paylod := "./config/test/qemu-virt-test-protect-payload.toml"
qemu_virt_hello_world_payload_spike := "./config/test/qemu-virt-hello-world-payload-spike.toml"
qemu_virt_u_boot_payload := "./config/test/qemu-virt-u-boot-payload.toml"
Expand Down Expand Up @@ -76,7 +77,8 @@ test:
cargo run -- run --config {{qemu_virt_u_boot_payload}} --firmware opensbi-jump

# Testing with protect payload policy
cargo run -- run --config {{qemu_virt_test_protect_paylod}} --firmware linux
cargo run -- run --config {{qemu_virt_payload_lock}} --firmware linux
cargo run -- run --config {{qemu_virt_test_protect_paylod}} --firmware test_protect_payload_firmware

# Testing benchmark code
cargo run -- run --config {{qemu_virt_benchmark}} --firmware csr_write
Expand Down
53 changes: 26 additions & 27 deletions payload/test_protect_payload_payload/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,39 +21,38 @@ fn main() -> ! {
// Lock payload to firmware
lock_payload();

// Make sure the firmware can't overwrite the set of registers
assert!(
test_same_registers_after_trap(),
"Test same register after trap failed"
);

// Make sure the firmware can't overwrite a memory region
assert!(
test_same_region_after_trap(),
"Test same region after trap failed"
);
// Make sure the ecall parameters goes through
assert!(test_ecall_rule(), "Ecall test failed");

// and exit
success();
}

fn test_same_registers_after_trap() -> bool {
let x3: usize;
fn test_ecall_rule() -> bool {
let ret_value_1: usize;
let ret_value_2: usize;
let s2_value: usize;
unsafe {
asm!("li t6, 60", "csrw mscratch, zero");
asm!("", out("t6") x3);
asm!(
"li a0, 60",
"li a1, 60",
"li a2, 60",
"li a3, 60",
"li a4, 60",
"li a5, 60",
"li x17, 0x08475bd0",
"ecall",
out("x10") ret_value_1,
out("x11") ret_value_2,
out("x12") _,
out("x13") _,
out("x14") _,
out("x15") _,
out("x16") _,
out("x17") _,
out("x18") s2_value,
);
}

x3 == 60
}

fn test_same_region_after_trap() -> bool {
let address: *const usize = 0x80400000 as *const usize;
let value: usize;

unsafe {
value = *address;
}

value != 60
ret_value_1 == 0xdeadbeef && ret_value_2 == 0xdeadbeef && s2_value != 0xdeadbeef
}
82 changes: 71 additions & 11 deletions src/policy/protect_payload.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
use miralis_core::abi_protect_payload;

use crate::arch::pmp::pmpcfg;
use crate::arch::MCause::EcallFromSMode;
use crate::arch::{MCause, Register};
use crate::host::MiralisContext;
use crate::policy::{PolicyHookResult, PolicyModule};
Expand All @@ -12,6 +13,7 @@ use crate::virt::{RegisterContextGetter, VirtContext};
pub struct ProtectPayloadPolicy {
protected: bool,
general_register: [usize; 32],
rules: [ForwardingRule; ForwardingRule::NB_RULES],
last_cause: MCause,
}

Expand All @@ -20,6 +22,7 @@ impl PolicyModule for ProtectPayloadPolicy {
ProtectPayloadPolicy {
protected: false,
general_register: [0; 32],
rules: ForwardingRule::build_forwarding_rules(),
// It is important to let the first mode be EcallFromSMode as the firmware passes some information to the OS.
// Setting this last_cause allows to pass the arguments during the first call.
last_cause: MCause::EcallFromSMode,
Expand Down Expand Up @@ -72,11 +75,14 @@ impl PolicyModule for ProtectPayloadPolicy {
ctx: &mut VirtContext,
mctx: &mut MiralisContext,
) {
// Restore general purpose registers
for i in 0..32 {
// Step 1: Restore general purpose registers
let trap_cause = MCause::try_from(ctx.trap_info.mcause).unwrap();
let filter_rule = ForwardingRule::match_rule(trap_cause, &mut self.rules);

for i in 0..self.general_register.len() {
self.general_register[i] = ctx.regs[i];
// We don't clear ecall registers
if self.clear_register(i, ctx) {
if !filter_rule.allow_in[i] {
ctx.regs[i] = 0;
}
}
Expand All @@ -86,18 +92,19 @@ impl PolicyModule for ProtectPayloadPolicy {
.set_from_policy(0, 0x80400000 / 4, pmpcfg::INACTIVE);
mctx.pmp.set_from_policy(1, usize::MAX / 4, pmpcfg::TOR);

self.last_cause = ctx.trap_info.get_cause();
self.last_cause = trap_cause;
}

fn switch_from_firmware_to_payload(
&mut self,
ctx: &mut VirtContext,
mctx: &mut MiralisContext,
) {
let register_filter = ForwardingRule::match_rule(self.last_cause, &mut self.rules);

// Restore general purpose registers
for i in 0..32 {
// 10 & 11 are return registers
if self.restore_register(i, ctx) {
for i in 0..self.general_register.len() {
if !register_filter.allow_out[i] {
ctx.regs[i] = self.general_register[i];
}
}
Expand All @@ -120,12 +127,65 @@ impl ProtectPayloadPolicy {
fn is_policy_call(&mut self, ctx: &VirtContext) -> bool {
ctx.get(Register::X17) == abi_protect_payload::MIRALIS_PROTECT_PAYLOAD_EID
}
}

// ———————————————————————————————— Explicit Forwarding Rules ———————————————————————————————— //

#[derive(Clone)]
pub struct ForwardingRule {
mcause: MCause,
allow_in: [bool; 32],
allow_out: [bool; 32],
}

impl ForwardingRule {
pub const NB_RULES: usize = 1;

fn match_rule(trap_cause: MCause, rules: &mut [ForwardingRule; 1]) -> ForwardingRule {
for rule in rules {
if trap_cause == rule.mcause {
return rule.clone();
}
}

Self::new_allow_nothing(trap_cause)
}

fn build_forwarding_rules() -> [ForwardingRule; Self::NB_RULES] {
let mut rules = [Self::new_allow_nothing(EcallFromSMode); 1];

// Build Ecall rule
rules[0]
.allow_register_in(Register::X10)
.allow_register_in(Register::X11)
.allow_register_in(Register::X12)
.allow_register_in(Register::X13)
.allow_register_in(Register::X14)
.allow_register_in(Register::X15)
.allow_register_in(Register::X16)
.allow_register_in(Register::X17)
.allow_register_out(Register::X10)
.allow_register_out(Register::X11);

rules
}

#[allow(unused)]
fn new_allow_nothing(mcause: MCause) -> Self {
ForwardingRule {
mcause,
allow_in: [false; 32],
allow_out: [false; 32],
}
}

fn clear_register(&mut self, idx: usize, ctx: &mut VirtContext) -> bool {
!(10..18).contains(&idx) || ctx.trap_info.get_cause() != MCause::EcallFromSMode
fn allow_register_in(&mut self, reg: Register) -> &mut Self {
self.allow_in[reg as usize] = true;
self
}

fn restore_register(&mut self, idx: usize, _ctx: &mut VirtContext) -> bool {
!(10..12).contains(&idx) || self.last_cause != MCause::EcallFromSMode
fn allow_register_out(&mut self, reg: Register) -> &mut Self {
self.allow_out[reg as usize] = true;
self
}
}

0 comments on commit 21b01eb

Please sign in to comment.