Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

v2.0 rust #179

Merged
merged 2 commits into from
Oct 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions .github/workflows/rust-codestyle.yml
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ jobs:
profile: minimal
toolchain: stable
override: true

- name: Rust build binding
run: bash copy.sh && cargo build --verbose
working-directory: binding/rust

- name: Run clippy
run: cargo clippy -- -D warnings
Expand All @@ -85,6 +89,10 @@ jobs:
profile: minimal
toolchain: stable
override: true

- name: Rust build binding
run: bash copy.sh && cargo build --verbose
working-directory: binding/rust

- name: Run clippy
run: cargo clippy -- -D warnings
Expand Down
8 changes: 8 additions & 0 deletions .github/workflows/rust-demos.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ jobs:
profile: minimal
toolchain: stable
override: true

- name: Rust build binding
run: bash copy.sh && cargo build --verbose
working-directory: binding/rust

- name: Rust build micdemo
run: cargo build --verbose
Expand Down Expand Up @@ -72,6 +76,10 @@ jobs:
with:
toolchain: nightly
override: true

- name: Rust build binding
run: bash copy.sh && cargo build --verbose
working-directory: binding/rust

- name: Rust build micdemo
run: cargo build --verbose
Expand Down
2 changes: 1 addition & 1 deletion binding/rust/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion binding/rust/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "pv_cobra"
version = "1.2.1"
version = "2.0.0"
edition = "2018"
description = "The Rust bindings for Picovoice's Cobra library"
license = "Apache-2.0"
Expand Down
160 changes: 112 additions & 48 deletions binding/rust/src/cobra.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ type PvCobraVersionFn = unsafe extern "C" fn() -> *mut c_char;
type PvCobraProcessFn =
unsafe extern "C" fn(object: *mut CCobra, pcm: *const i16, is_voiced: *mut f32) -> PvStatus;
type PvCobraDeleteFn = unsafe extern "C" fn(object: *mut CCobra);
type PvGetErrorStackFn =
unsafe extern "C" fn(message_stack: *mut *mut *mut c_char, message_stack_depth: *mut i32);
type PvFreeErrorStackFn = unsafe extern "C" fn(message_stack: *mut *mut c_char);
type PvSetSdkFn = unsafe extern "C" fn(sdk: *const c_char);

#[derive(Clone, Debug)]
pub enum CobraErrorStatus {
Expand All @@ -66,24 +70,44 @@ pub enum CobraErrorStatus {
#[derive(Clone, Debug)]
pub struct CobraError {
pub status: CobraErrorStatus,
pub message: Option<String>,
pub message: String,
pub message_stack: Vec<String>,
}

impl CobraError {
pub fn new(status: CobraErrorStatus, message: impl Into<String>) -> Self {
Self {
status,
message: Some(message.into()),
message: message.into(),
message_stack: Vec::new()
}
}

pub fn new_with_stack(
status: CobraErrorStatus,
message: impl Into<String>,
message_stack: impl Into<Vec<String>>
) -> Self {
Self {
status,
message: message.into(),
message_stack: message_stack.into(),
}
}
}

impl std::fmt::Display for CobraError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match &self.message {
Some(message) => write!(f, "{}: {:?}", message, self.status),
None => write!(f, "Cobra error: {:?}", self.status),
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
let mut message_string = String::new();
message_string.push_str(&format!("{} with status '{:?}'", self.message, self.status));

if !self.message_stack.is_empty() {
message_string.push(':');
for x in 0..self.message_stack.len() {
message_string.push_str(&format!(" [{}] {}\n", x, self.message_stack[x]))
};
}
write!(f, "{}", message_string)
}
}

Expand Down Expand Up @@ -145,44 +169,83 @@ unsafe fn load_library_fn<T>(
})
}

fn check_fn_call_status(status: PvStatus, function_name: &str) -> Result<(), CobraError> {
fn check_fn_call_status(
vtable: &CobraInnerVTable,
status: PvStatus,
function_name: &str
) -> Result<(), CobraError> {
match status {
PvStatus::SUCCESS => Ok(()),
_ => Err(CobraError::new(
CobraErrorStatus::LibraryError(status),
format!("Function '{}' in the cobra library failed", function_name),
)),
_ => unsafe {
let mut message_stack_ptr: *mut c_char = std::ptr::null_mut();
let mut message_stack_ptr_ptr = addr_of_mut!(message_stack_ptr);

let mut message_stack_depth: i32 = 0;
(vtable.pv_get_error_stack)(
addr_of_mut!(message_stack_ptr_ptr),
addr_of_mut!(message_stack_depth),
);

let mut message_stack = Vec::new();
for i in 0..message_stack_depth as usize {
let message = CStr::from_ptr(*message_stack_ptr_ptr.add(i));
let message = message.to_string_lossy().into_owned();
message_stack.push(message);
}

(vtable.pv_free_error_stack)(message_stack_ptr_ptr);

Err(CobraError::new_with_stack(
CobraErrorStatus::LibraryError(status),
format!("'{function_name}' failed"),
message_stack,
))
},
}
}

struct CobraInnerVTable {
pv_cobra_init: RawSymbol<PvCobraInitFn>,
pv_cobra_process: RawSymbol<PvCobraProcessFn>,
pv_cobra_delete: RawSymbol<PvCobraDeleteFn>,
pv_sample_rate: RawSymbol<PvSampleRateFn>,
pv_cobra_frame_length: RawSymbol<PvCobraFrameLengthFn>,
pv_cobra_version: RawSymbol<PvCobraVersionFn>,
pv_get_error_stack: RawSymbol<PvGetErrorStackFn>,
pv_free_error_stack: RawSymbol<PvFreeErrorStackFn>,
pv_set_sdk: RawSymbol<PvSetSdkFn>,

_lib_guard: Library,
}

struct CobraInner {
ccobra: *mut CCobra,
frame_length: i32,
sample_rate: i32,
version: String,
vtable: CobraInnerVTable,
}

impl CobraInnerVTable {
pub fn new(lib: Library) -> Result<Self, CobraError> {
unsafe {
Ok(Self {
pv_cobra_init: load_library_fn(&lib, b"pv_cobra_init")?,
pv_cobra_process: load_library_fn::<PvCobraProcessFn>(&lib, b"pv_cobra_process")?,
pv_cobra_delete: load_library_fn::<PvCobraDeleteFn>(&lib, b"pv_cobra_delete")?,
pv_sample_rate: load_library_fn(&lib, b"pv_sample_rate")?,
pv_cobra_frame_length: load_library_fn(&lib, b"pv_cobra_frame_length")?,
pv_cobra_version: load_library_fn(&lib, b"pv_cobra_version")?,
pv_get_error_stack: load_library_fn(&lib, b"pv_get_error_stack")?,
pv_free_error_stack: load_library_fn(&lib, b"pv_free_error_stack")?,
pv_set_sdk: load_library_fn(&lib, b"pv_set_sdk")?,

_lib_guard: lib,
})
}
}
}

struct CobraInner {
ccobra: *mut CCobra,
frame_length: i32,
sample_rate: i32,
version: String,
vtable: CobraInnerVTable,
}

impl CobraInner {
pub fn init<S: Into<String>, P: Into<PathBuf>>(access_key: S, library_path: P) -> Result<Self, CobraError> {

Expand Down Expand Up @@ -212,6 +275,17 @@ impl CobraInner {
format!("Failed to load cobra dynamic library: {}", err),
)
})?;
let vtable = CobraInnerVTable::new(lib)?;

let sdk_string = match CString::new("rust") {
Ok(sdk_string) => sdk_string,
Err(err) => {
return Err(CobraError::new(
CobraErrorStatus::ArgumentError,
format!("sdk_string is not a valid C string {err}"),
))
}
};

let pv_access_key = CString::new(access_key).map_err(|err| {
CobraError::new(
Expand All @@ -220,37 +294,27 @@ impl CobraInner {
)
})?;

let (ccobra, sample_rate, frame_length, version) = unsafe {
let pv_cobra_init = load_library_fn::<PvCobraInitFn>(&lib, b"pv_cobra_init")?;
let pv_cobra_version = load_library_fn::<PvCobraVersionFn>(&lib, b"pv_cobra_version")?;
let pv_sample_rate = load_library_fn::<PvSampleRateFn>(&lib, b"pv_sample_rate")?;
let pv_cobra_frame_length =
load_library_fn::<PvCobraFrameLengthFn>(&lib, b"pv_cobra_frame_length")?;
let mut ccobra = std::ptr::null_mut();

let mut ccobra = std::ptr::null_mut();
// SAFETY: most of the unsafe comes from the `load_library_fn` which is
// safe, because we don't use the raw symbols after this function
// anymore.
let (sample_rate, frame_length, version) = unsafe {
(vtable.pv_set_sdk)(sdk_string.as_ptr());

check_fn_call_status(
pv_cobra_init(
pv_access_key.as_ptr(),
addr_of_mut!(ccobra),
),
"pv_cobra_init",
)?;

let version = match CStr::from_ptr(pv_cobra_version()).to_str() {
Ok(string) => string.to_string(),
Err(err) => {
return Err(CobraError::new(
CobraErrorStatus::LibraryLoadError,
format!("Failed to get version info from Cobra Library: {}", err),
))
}
};
let status = (vtable.pv_cobra_init)(
pv_access_key.as_ptr(),
addr_of_mut!(ccobra),
);
check_fn_call_status(&vtable, status, "pv_cobra_init")?;

let version = CStr::from_ptr((vtable.pv_cobra_version)())
.to_string_lossy()
.into_owned();

(
ccobra,
pv_sample_rate(),
pv_cobra_frame_length(),
(vtable.pv_sample_rate)(),
(vtable.pv_cobra_frame_length)(),
version,
)
};
Expand All @@ -260,7 +324,7 @@ impl CobraInner {
sample_rate,
frame_length,
version,
vtable: CobraInnerVTable::new(lib)?,
vtable,
})
}

Expand All @@ -280,7 +344,7 @@ impl CobraInner {
let status = unsafe {
(self.vtable.pv_cobra_process)(self.ccobra, pcm.as_ptr(), addr_of_mut!(result))
};
check_fn_call_status(status, "pv_cobra_process")?;
check_fn_call_status(&self.vtable, status, "pv_cobra_process")?;

Ok(result)
}
Expand Down
20 changes: 20 additions & 0 deletions binding/rust/tests/cobra_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,26 @@ mod tests {
assert!(loss.abs() < 0.1);
}

#[test]
fn test_error_stack() {
let mut error_stack = Vec::new();

let res = Cobra::new("invalid");
if let Err(err) = res {
error_stack = err.message_stack
}

assert!(0 < error_stack.len() && error_stack.len() <= 8);

let res = Cobra::new("invalid");
if let Err(err) = res {
assert_eq!(error_stack.len(), err.message_stack.len());
for i in 0..error_stack.len() {
assert_eq!(error_stack[i], err.message_stack[i])
}
}
}

#[test]
fn test_version() {
let access_key = env::var("PV_ACCESS_KEY")
Expand Down
4 changes: 1 addition & 3 deletions demo/rust/filedemo/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion demo/rust/filedemo/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ chrono = "0.4.23"
clap = "2.33.3"
hound = "3.5.0"
itertools = "0.10.1"
pv_cobra = "=1.2.1"
pv_cobra = { path = "../../../binding/rust" }
4 changes: 1 addition & 3 deletions demo/rust/micdemo/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion demo/rust/micdemo/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,5 @@ clap = "2.33.3"
ctrlc = "3.1.9"
hound = "3.5.0"
itertools = "0.10.1"
pv_cobra = "=1.2.1"
pv_cobra = { path = "../../../binding/rust" }
pv_recorder = "=1.2.1"
Loading