Skip to content

Commit

Permalink
refactor(bindings): add general bindings error context (aws#4811)
Browse files Browse the repository at this point in the history
  • Loading branch information
lrstewart authored Oct 2, 2024
1 parent e5ef845 commit 50ad945
Showing 1 changed file with 40 additions and 15 deletions.
55 changes: 40 additions & 15 deletions bindings/rust/s2n-tls/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use s2n_tls_sys::*;
use std::{convert::TryFrom, ffi::CStr};

#[non_exhaustive]
#[derive(Debug, PartialEq)]
#[derive(Copy, Clone, Debug, PartialEq)]
pub enum ErrorType {
UnknownErrorType,
NoError,
Expand Down Expand Up @@ -47,8 +47,7 @@ impl From<libc::c_int> for ErrorType {
}

enum Context {
InvalidInput,
MissingWaker,
Bindings(ErrorType, &'static str, &'static str),
Code(s2n_status_code::Type, Errno),
Application(Box<dyn std::error::Error + Send + Sync + 'static>),
}
Expand Down Expand Up @@ -149,8 +148,16 @@ impl<T: Fallible> Pollable for T {
}

impl Error {
pub(crate) const INVALID_INPUT: Error = Self(Context::InvalidInput);
pub(crate) const MISSING_WAKER: Error = Self(Context::MissingWaker);
pub(crate) const INVALID_INPUT: Error = Self::bindings(
ErrorType::UsageError,
"InvalidInput",
"An input parameter was incorrect",
);
pub(crate) const MISSING_WAKER: Error = Self::bindings(
ErrorType::UsageError,
"MissingWaker",
"Tried to perform an asynchronous operation without a configured waker",
);

/// Converts an io::Error into an s2n-tls Error
pub fn io_error(err: std::io::Error) -> Error {
Expand All @@ -167,6 +174,15 @@ impl Error {
Self(Context::Application(error))
}

/// An error occured while running bindings code.
pub(crate) const fn bindings(
kind: ErrorType,
name: &'static str,
message: &'static str,
) -> Self {
Self(Context::Bindings(kind, name, message))
}

fn capture() -> Self {
unsafe {
let s2n_errno = s2n_errno_location();
Expand All @@ -184,8 +200,7 @@ impl Error {

pub fn name(&self) -> &'static str {
match self.0 {
Context::InvalidInput => "InvalidInput",
Context::MissingWaker => "MissingWaker",
Context::Bindings(_, name, _) => name,
Context::Application(_) => "ApplicationError",
Context::Code(code, _) => unsafe {
// Safety: we assume the string has a valid encoding coming from s2n
Expand All @@ -196,10 +211,7 @@ impl Error {

pub fn message(&self) -> &'static str {
match self.0 {
Context::InvalidInput => "A parameter was incorrect",
Context::MissingWaker => {
"Tried to perform an asynchronous operation without a configured waker"
}
Context::Bindings(_, _, msg) => msg,
Context::Application(_) => "An error occurred while executing application code",
Context::Code(code, _) => unsafe {
// Safety: we assume the string has a valid encoding coming from s2n
Expand All @@ -210,7 +222,7 @@ impl Error {

pub fn debug(&self) -> Option<&'static str> {
match self.0 {
Context::InvalidInput | Context::MissingWaker | Context::Application(_) => None,
Context::Bindings(_, _, _) | Context::Application(_) => None,
Context::Code(code, _) => unsafe {
let debug_info = s2n_strerror_debug(code, core::ptr::null());

Expand All @@ -230,15 +242,15 @@ impl Error {

pub fn kind(&self) -> ErrorType {
match self.0 {
Context::InvalidInput | Context::MissingWaker => ErrorType::UsageError,
Context::Bindings(error_type, _, _) => error_type,
Context::Application(_) => ErrorType::Application,
Context::Code(code, _) => unsafe { ErrorType::from(s2n_error_get_type(code)) },
}
}

pub fn source(&self) -> ErrorSource {
match self.0 {
Context::InvalidInput | Context::MissingWaker => ErrorSource::Bindings,
Context::Bindings(_, _, _) => ErrorSource::Bindings,
Context::Application(_) => ErrorSource::Application,
Context::Code(_, _) => ErrorSource::Library,
}
Expand Down Expand Up @@ -270,7 +282,7 @@ impl Error {
/// This API is currently incomplete and should not be relied upon.
pub fn alert(&self) -> Option<u8> {
match self.0 {
Context::InvalidInput | Context::MissingWaker | Context::Application(_) => None,
Context::Bindings(_, _, _) | Context::Application(_) => None,
Context::Code(code, _) => {
let mut alert = 0;
let r = unsafe { s2n_error_get_alert(code, &mut alert) };
Expand Down Expand Up @@ -465,4 +477,17 @@ mod tests {
.unwrap();
}
}

#[test]
fn bindings_error() {
let name = "TestError";
let message = "Custom error for test";
let kind = ErrorType::InternalError;
let error = Error::bindings(kind, name, message);
assert_eq!(error.kind(), kind);
assert_eq!(error.name(), name);
assert_eq!(error.message(), message);
assert_eq!(error.debug(), None);
assert_eq!(error.source(), ErrorSource::Bindings);
}
}

0 comments on commit 50ad945

Please sign in to comment.