diff --git a/README.md b/README.md index e0f46af..6963cab 100644 --- a/README.md +++ b/README.md @@ -37,7 +37,7 @@ * [ ] Pose Landmark Detection * [x] Audio Classification * [x] Text Classification -* [ ] Text Embedding +* [x] Text Embedding * [ ] Language Detection ## Task APIs diff --git a/scripts/download-models.sh b/scripts/download-models.sh index 3d60458..e545ec8 100755 --- a/scripts/download-models.sh +++ b/scripts/download-models.sh @@ -158,6 +158,23 @@ text_classification_init() { popd } +text_embedding_init() { + text_embedding_dir="${model_path}/text_embedding" + mkdir -p "${text_embedding_dir}" + pushd "${text_embedding_dir}" + + model_urls=( + "https://storage.googleapis.com/mediapipe-models/text_embedder/bert_embedder/float32/1/bert_embedder.tflite" + "https://storage.googleapis.com/mediapipe-models/text_embedder/universal_sentence_encoder/float32/latest/universal_sentence_encoder.tflite" + ) + + for url in "${model_urls[@]}"; do + curl -sLO "${url}" + done + + popd +} + object_detection_init image_classification_init gesture_recognition_init @@ -167,3 +184,4 @@ image_embedding_init face_detection_init audio_classification_init text_classification_init +text_embedding_init diff --git a/src/tasks/text/mod.rs b/src/tasks/text/mod.rs index 684483f..be75320 100644 --- a/src/tasks/text/mod.rs +++ b/src/tasks/text/mod.rs @@ -1,3 +1,5 @@ mod text_classification; +mod text_embedding; pub use text_classification::{TextClassifier, TextClassifierBuilder, TextClassifierSession}; +pub use text_embedding::{TextEmbedder, TextEmbedderBuilder, TextEmbedderSession}; diff --git a/src/tasks/text/text_embedding/builder.rs b/src/tasks/text/text_embedding/builder.rs new file mode 100644 index 0000000..baea296 --- /dev/null +++ b/src/tasks/text/text_embedding/builder.rs @@ -0,0 +1,73 @@ +use super::TextEmbedder; +use crate::tasks::common::{BaseTaskOptions, EmbeddingOptions}; + +/// Configure the build options of a new **Text Embedding** task instance. +/// +/// Methods can be chained on it in order to configure it. +pub struct TextEmbedderBuilder { + pub(super) base_task_options: BaseTaskOptions, + pub(super) embedding_options: EmbeddingOptions, +} + +impl Default for TextEmbedderBuilder { + #[inline(always)] + fn default() -> Self { + Self { + base_task_options: Default::default(), + embedding_options: Default::default(), + } + } +} + +impl TextEmbedderBuilder { + /// Create a new builder with default options. + #[inline(always)] + pub fn new() -> Self { + Self::default() + } + + base_task_options_impl!(TextEmbedder); + + embedding_options_impl!(); + + /// Use the current build options and use the buffer as model data to create a new task instance. + #[inline] + pub fn build_from_buffer(self, buffer: impl AsRef<[u8]>) -> Result { + let buf = buffer.as_ref(); + // parse model and get model resources. + let model_resource = crate::model::parse_model(buf)?; + + // check model + model_base_check_impl!(model_resource, 1); + model_resource_check_and_get_impl!(model_resource, to_tensor_info, 0).try_to_text()?; + let input_count = model_resource.input_tensor_count(); + if input_count != 1 && input_count != 3 { + return Err(crate::Error::ModelInconsistentError(format!( + "Expect model input tensor count `1` or `3`, but got `{}`", + input_count + ))); + } + for i in 0..input_count { + let t = model_resource_check_and_get_impl!(model_resource, input_tensor_type, i); + if t != crate::TensorType::I32 { + // todo: string type support + return Err(crate::Error::ModelInconsistentError( + "All input tensors should be int32 type".into(), + )); + } + } + + let graph = crate::GraphBuilder::new( + model_resource.model_backend(), + self.base_task_options.device, + ) + .build_from_bytes([buf])?; + + return Ok(TextEmbedder { + build_options: self, + model_resource, + graph, + input_count, + }); + } +} diff --git a/src/tasks/text/text_embedding/mod.rs b/src/tasks/text/text_embedding/mod.rs new file mode 100644 index 0000000..ed1e73a --- /dev/null +++ b/src/tasks/text/text_embedding/mod.rs @@ -0,0 +1,99 @@ +mod builder; +pub use builder::TextEmbedderBuilder; + +use crate::model::ModelResourceTrait; +use crate::postprocess::{EmbeddingResult, TensorsToEmbedding}; +use crate::preprocess::text::{TextToTensorInfo, TextToTensors}; +use crate::{Error, Graph, GraphExecutionContext, TensorType}; + +/// Performs embedding on texts. +pub struct TextEmbedder { + build_options: TextEmbedderBuilder, + model_resource: Box, + graph: Graph, + input_count: usize, +} + +impl TextEmbedder { + base_task_options_get_impl!(); + + embedding_options_get_impl!(); + + /// Create a new task session that contains processing buffers and can do inference. + #[inline(always)] + pub fn new_session(&self) -> Result { + let input_to_tensor_info = + model_resource_check_and_get_impl!(self.model_resource, to_tensor_info, 0) + .try_to_text()?; + let mut input_tensor_shapes = Vec::with_capacity(self.input_count); + let mut input_tensor_bufs = Vec::with_capacity(self.input_count); + for i in 0..self.input_count { + let input_tensor_shape = + model_resource_check_and_get_impl!(self.model_resource, input_tensor_shape, i); + let bytes = input_tensor_shape.iter().fold(4, |sum, b| sum * *b); + input_tensor_shapes.push(input_tensor_shape); + input_tensor_bufs.push(vec![0; bytes]); + } + let output_tensor_shape = + model_resource_check_and_get_impl!(self.model_resource, output_tensor_shape, 0); + let mut tensor_to_embedding = TensorsToEmbedding::new( + self.build_options.embedding_options.quantize, + self.build_options.embedding_options.l2_normalize, + ); + tensor_to_embedding.add_output_cfg( + get_type_and_quantization!(self.model_resource, 0), + output_tensor_shape, + None, + ); + + let execution_ctx = self.graph.init_execution_context()?; + Ok(TextEmbedderSession { + execution_ctx, + tensor_to_embedding, + input_to_tensor_info, + input_tensor_shapes, + input_tensor_bufs, + }) + } + + /// Embed one text using a new session. + #[inline(always)] + pub fn embed(&self, input: &impl TextToTensors) -> Result { + self.new_session()?.embed(input) + } +} + +/// Session to run inference. +/// If process multiple texts, reuse it can get better performance. +pub struct TextEmbedderSession<'a> { + execution_ctx: GraphExecutionContext<'a>, + tensor_to_embedding: TensorsToEmbedding, + + input_to_tensor_info: &'a TextToTensorInfo, + input_tensor_shapes: Vec<&'a [usize]>, + input_tensor_bufs: Vec>, +} + +impl<'a> TextEmbedderSession<'a> { + /// Embed one text use this session. + #[inline(always)] + pub fn embed(&mut self, input: &impl TextToTensors) -> Result { + input.to_tensors(self.input_to_tensor_info, &mut self.input_tensor_bufs)?; + + let tensor_type = TensorType::I32; + for index in 0..self.input_tensor_bufs.len() { + self.execution_ctx.set_input( + index, + tensor_type, + self.input_tensor_shapes[index], + self.input_tensor_bufs[index].as_slice(), + )?; + } + self.execution_ctx.compute()?; + + let output_buffer = self.tensor_to_embedding.output_buffer(0); + self.execution_ctx.get_output(0, output_buffer)?; + + Ok(self.tensor_to_embedding.result(None)) + } +} diff --git a/tests/text_embedding.rs b/tests/text_embedding.rs new file mode 100644 index 0000000..7b1348c --- /dev/null +++ b/tests/text_embedding.rs @@ -0,0 +1,41 @@ +use mediapipe_rs::tasks::text::TextEmbedderBuilder; + +const MODEL_1: &'static str = "assets/models/text_embedding/bert_embedder.tflite"; +const MODEL_2: &'static str = "assets/models/text_embedding/universal_sentence_encoder.tflite"; + +#[test] +fn test_text_embedding_model_1() { + text_embedding_tasks_run(MODEL_1) +} + +#[test] +fn test_text_embedding_model_2() { + // todo: add universal_sentence_encoder input support. + // text_embedding_tasks_run(MODEL_2) +} + +fn text_embedding_tasks_run(model_asset: &str) { + let text_embedder = TextEmbedderBuilder::new() + .l2_normalize(false) + .quantize(false) + .build_from_file(model_asset) + .unwrap(); + let mut session = text_embedder.new_session().unwrap(); + + let text_1 = "I'm feeling so good"; + let text_2 = "I'm okay I guess"; + + let embedding_1 = session.embed(&text_1).unwrap(); + let embedding_2 = session.embed(&text_2).unwrap(); + assert_eq!(embedding_1.embeddings.len(), 1); + assert_eq!(embedding_2.embeddings.len(), 1); + let e_1 = embedding_1.embeddings.get(0).unwrap(); + let e_2 = embedding_2.embeddings.get(0).unwrap(); + assert_eq!(e_1.quantized_embedding.len(), 0); + assert_eq!(e_2.quantized_embedding.len(), 0); + assert_eq!(e_1.float_embedding.len(), e_2.float_embedding.len()); + assert_ne!(e_1.float_embedding.len(), 0); + + let similarity = e_1.cosine_similarity(e_2).unwrap(); + eprintln!("'{}', '{}' similarity = {}", text_1, text_2, similarity); +}