Skip to content

Commit

Permalink
Refactor object handle iterator and find objects function
Browse files Browse the repository at this point in the history
- documentation enhanced
- documentation example adjusted/simplified
- tests simplifications
- `find_objects()` method is now calling `iter_objects()`
  • Loading branch information
keldonin committed Sep 6, 2024
1 parent d577ace commit 4cd3c36
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 119 deletions.
135 changes: 52 additions & 83 deletions cryptoki/src/session/object_management.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@ const MAX_OBJECT_COUNT: usize = 10;
/// Used to iterate over the object handles returned by underlying calls to `C_FindObjects`.
/// The iterator is created by calling the `iter_objects` and `iter_objects_with_cache_size` methods on a `Session` object.
///
/// # Note
///
/// The iterator `new()` method will call `C_FindObjectsInit`. It means that until the iterator is dropped,
/// creating another iterator will result in an error (typically `RvError::OperationActive` ).
///
/// # Example
///
/// ```no_run
Expand All @@ -30,47 +35,44 @@ const MAX_OBJECT_COUNT: usize = 10;
/// use cryptoki::types::AuthPin;
/// use std::env;
///
/// fn test() -> Result<(), Error> {
/// let pkcs11 = Pkcs11::new(
/// env::var("PKCS11_SOFTHSM2_MODULE")
/// .unwrap_or_else(|_| "/usr/local/lib/libsofthsm2.so".to_string()),
/// )?;
/// # fn main() -> testresult::TestResult {
/// let pkcs11 = Pkcs11::new(
/// env::var("PKCS11_SOFTHSM2_MODULE")
/// .unwrap_or_else(|_| "/usr/local/lib/libsofthsm2.so".to_string()),
/// )?;
///
/// pkcs11.initialize(CInitializeArgs::OsThreads)?;
/// let slot = pkcs11.get_slots_with_token()?.remove(0);
/// pkcs11.initialize(CInitializeArgs::OsThreads)?;
/// let slot = pkcs11.get_slots_with_token()?.remove(0);
///
/// let session = pkcs11.open_ro_session(slot).unwrap();
/// session.login(UserType::User, Some(&AuthPin::new("fedcba".into())))?;
/// let session = pkcs11.open_ro_session(slot).unwrap();
/// session.login(UserType::User, Some(&AuthPin::new("fedcba".into())))?;
///
/// let token_object = vec![Attribute::Token(true)];
/// let wanted_attr = vec![AttributeType::Label];
/// let token_object = vec![Attribute::Token(true)];
/// let wanted_attr = vec![AttributeType::Label];
///
/// for (idx, obj) in session.iter_objects(&token_object)?.enumerate() {
/// let obj = obj?; // handle potential error condition
/// for (idx, obj) in session.iter_objects(&token_object)?.enumerate() {
/// let obj = obj?; // handle potential error condition
///
/// let attributes = session.get_attributes(obj, &wanted_attr)?;
/// let attributes = session.get_attributes(obj, &wanted_attr)?;
///
/// match attributes.get(0) {
/// Some(Attribute::Label(l)) => {
/// println!(
/// "token object #{}: handle {}, label {}",
/// idx,
/// obj,
/// String::from_utf8(l.to_vec())
/// .unwrap_or_else(|_| "*** not valid utf8 ***".to_string())
/// );
/// }
/// _ => {
/// println!("token object #{}: handle {}, label not found", idx, obj);
/// }
/// match attributes.get(0) {
/// Some(Attribute::Label(l)) => {
/// println!(
/// "token object #{}: handle {}, label {}",
/// idx,
/// obj,
/// String::from_utf8(l.to_vec())
/// .unwrap_or_else(|_| "*** not valid utf8 ***".to_string())
/// );
/// }
/// _ => {
/// println!("token object #{}: handle {}, label not found", idx, obj);
/// }
/// }
/// Ok(())
/// }
/// # Ok(())
/// # }
///
/// pub fn main() {
/// test().unwrap();
/// }
/// ```
#[derive(Debug)]
pub struct ObjectHandleIterator<'a> {
Expand Down Expand Up @@ -154,8 +156,7 @@ impl<'a> Iterator for ObjectHandleIterator<'a> {
)
},
None => {
// C_FindObjects() is not implemented on this implementation
// sort of unexpected. TODO: Consider panic!() instead?
// C_FindObjects() is not implemented,, bark and return an error
log::error!("C_FindObjects() is not implemented on this library");
return Some(Err(Error::NullFunctionPointer) as Result<ObjectHandle>);
}
Expand All @@ -173,9 +174,9 @@ impl<'a> Iterator for ObjectHandleIterator<'a> {

impl Drop for ObjectHandleIterator<'_> {
fn drop(&mut self) {
// silently pass if C_FindObjectsFinal() is not implemented on this implementation
// this is unexpected. TODO: Consider panic!() instead?
// bark but pass if C_FindObjectsFinal() is not implemented
if let Some(f) = get_pkcs11_func!(self.session.client(), C_FindObjectsFinal) {
log::error!("C_FindObjectsFinal() is not implemented on this library");
// swallow the return value, as we can't do anything about it
let _ = unsafe { f(self.session.handle()) };
}
Expand Down Expand Up @@ -220,59 +221,27 @@ impl Session {
template: &[Attribute],
cache_size: usize,
) -> Result<ObjectHandleIterator> {
let template: Vec<CK_ATTRIBUTE> = template.iter().map(|attr| attr.into()).collect();
let template: Vec<CK_ATTRIBUTE> = template.iter().map(Into::into).collect();
ObjectHandleIterator::new(self, template, cache_size)
}

/// Search for session objects matching a template
/// # Arguments
/// * `template` - The template to match objects against
///
/// # Returns
///
/// Upon success, a vector of [ObjectHandle] wrapped in a Result.
/// Upon failure, the first error encountered.
///
/// It is a convenience function that will call [`Session::iter_objects`] and collect the results.
///
/// # See also
/// * [`Session::iter_objects`] for a way to specify the cache size
#[inline(always)]
pub fn find_objects(&self, template: &[Attribute]) -> Result<Vec<ObjectHandle>> {
let mut template: Vec<CK_ATTRIBUTE> = template.iter().map(|attr| attr.into()).collect();

unsafe {
Rv::from(get_pkcs11!(self.client(), C_FindObjectsInit)(
self.handle(),
template.as_mut_ptr(),
template.len().try_into()?,
))
.into_result(Function::FindObjectsInit)?;
}

let mut object_handles = [0; MAX_OBJECT_COUNT];
let mut object_count = MAX_OBJECT_COUNT as CK_ULONG; // set to MAX_OBJECT_COUNT to enter loop
let mut objects = Vec::new();

// as long as the number of objects returned equals the maximum number
// of objects that can be returned, we keep calling C_FindObjects
while object_count == MAX_OBJECT_COUNT as CK_ULONG {
unsafe {
Rv::from(get_pkcs11!(self.client(), C_FindObjects)(
self.handle(),
object_handles.as_mut_ptr() as CK_OBJECT_HANDLE_PTR,
MAX_OBJECT_COUNT.try_into()?,
&mut object_count,
))
.into_result(Function::FindObjects)?;
}

// exit loop, no more objects to be returned, no need to extend the objects vector
if object_count == 0 {
break;
}

// extend the objects vector with the new objects
objects.extend_from_slice(&object_handles[..object_count.try_into()?]);
}

unsafe {
Rv::from(get_pkcs11!(self.client(), C_FindObjectsFinal)(
self.handle(),
))
.into_result(Function::FindObjectsFinal)?;
}

let objects = objects.into_iter().map(ObjectHandle::new).collect();

Ok(objects)
self.iter_objects(template)?
.collect::<Result<Vec<ObjectHandle>>>()
}

/// Create a new object
Expand Down
67 changes: 31 additions & 36 deletions cryptoki/tests/basic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -313,15 +313,13 @@ fn get_token_info() -> TestResult {

#[test]
#[serial]
fn session_find_objects() {
fn session_find_objects() -> testresult::TestResult {
let (pkcs11, slot) = init_pins();
// open a session
let session = pkcs11.open_rw_session(slot).unwrap();
let session = pkcs11.open_rw_session(slot)?;

// log in the session
session
.login(UserType::User, Some(&AuthPin::new(USER_PIN.into())))
.unwrap();
session.login(UserType::User, Some(&AuthPin::new(USER_PIN.into())))?;

// we generate 11 keys with the same CKA_ID
// we will check 3 different use cases, this will cover all cases for Session.find_objects
Expand Down Expand Up @@ -351,32 +349,31 @@ fn session_find_objects() {
Attribute::KeyType(KeyType::DES3),
];

let mut found_keys = session.find_objects(&key_search_template).unwrap();
let mut found_keys = session.find_objects(&key_search_template)?;
assert_eq!(found_keys.len(), 11);

// destroy one key
session.destroy_object(found_keys.pop().unwrap()).unwrap();
session.destroy_object(found_keys.pop().unwrap())?;

let mut found_keys = session.find_objects(&key_search_template).unwrap();
let mut found_keys = session.find_objects(&key_search_template)?;
assert_eq!(found_keys.len(), 10);

// destroy another key
session.destroy_object(found_keys.pop().unwrap()).unwrap();
let found_keys = session.find_objects(&key_search_template).unwrap();
session.destroy_object(found_keys.pop().unwrap())?;
let found_keys = session.find_objects(&key_search_template)?;
assert_eq!(found_keys.len(), 9);
Ok(())
}

#[test]
#[serial]
fn session_objecthandle_iterator() {
fn session_objecthandle_iterator() -> testresult::TestResult {
let (pkcs11, slot) = init_pins();
// open a session
let session = pkcs11.open_rw_session(slot).unwrap();
let session = pkcs11.open_rw_session(slot)?;

// log in the session
session
.login(UserType::User, Some(&AuthPin::new(USER_PIN.into())))
.unwrap();
session.login(UserType::User, Some(&AuthPin::new(USER_PIN.into())))?;

// we generate 11 keys with the same CKA_ID

Expand All @@ -389,9 +386,7 @@ fn session_objecthandle_iterator() {
];

// generate a secret key
let _key = session
.generate_key(&Mechanism::Des3KeyGen, &key_template)
.unwrap();
let _key = session.generate_key(&Mechanism::Des3KeyGen, &key_template);
});

// retrieve these keys using this template
Expand All @@ -404,9 +399,7 @@ fn session_objecthandle_iterator() {

// test iter_objects_with_cache_size()
// count keys with cache size of 20
let found_keys = session
.iter_objects_with_cache_size(&key_search_template, 20)
.unwrap();
let found_keys = session.iter_objects_with_cache_size(&key_search_template, 20)?;
let found_keys = found_keys.map_while(|key| key.ok()).count();
assert_eq!(found_keys, 11);

Expand All @@ -415,23 +408,18 @@ fn session_objecthandle_iterator() {
assert!(found_keys.is_err());

// count keys with cache size of 1
let found_keys = session
.iter_objects_with_cache_size(&key_search_template, 1)
.unwrap();
let found_keys = session.iter_objects_with_cache_size(&key_search_template, 1)?;
let found_keys = found_keys.map_while(|key| key.ok()).count();
assert_eq!(found_keys, 11);

// count keys with cache size of 10
let found_keys = session
.iter_objects_with_cache_size(&key_search_template, 10)
.unwrap();
let found_keys = session.iter_objects_with_cache_size(&key_search_template, 10)?;
let found_keys = found_keys.map_while(|key| key.ok()).count();
assert_eq!(found_keys, 11);

// fetch keys into a vector
let found_keys: Vec<ObjectHandle> = session
.iter_objects_with_cache_size(&key_search_template, 10)
.unwrap()
.iter_objects_with_cache_size(&key_search_template, 10)?
.map_while(|key| key.ok())
.collect();
assert_eq!(found_keys.len(), 11);
Expand All @@ -440,24 +428,31 @@ fn session_objecthandle_iterator() {
let key1 = found_keys[1];

session.destroy_object(key0).unwrap();
let found_keys = session
.iter_objects_with_cache_size(&key_search_template, 10)
.unwrap();
let found_keys = session.iter_objects_with_cache_size(&key_search_template, 10)?;
let found_keys = found_keys.map_while(|key| key.ok()).count();
assert_eq!(found_keys, 10);

// destroy another key
session.destroy_object(key1).unwrap();
let found_keys = session
.iter_objects_with_cache_size(&key_search_template, 10)
.unwrap();
let found_keys = session.iter_objects_with_cache_size(&key_search_template, 10)?;
let found_keys = found_keys.map_while(|key| key.ok()).count();
assert_eq!(found_keys, 9);

// test iter_objects()
let found_keys = session.iter_objects(&key_search_template).unwrap();
let found_keys = session.iter_objects(&key_search_template)?;
let found_keys = found_keys.map_while(|key| key.ok()).count();
assert_eq!(found_keys, 9);

// test interleaved iterators - the second iterator should fail
let iter = session.iter_objects(&key_search_template);
let iter2 = session.iter_objects(&key_search_template);

assert!(matches!(iter, Ok(_)));
assert!(matches!(
iter2,
Err(Error::Pkcs11(RvError::OperationActive, _))
));
Ok(())
}

#[test]
Expand Down

0 comments on commit 4cd3c36

Please sign in to comment.