forked from WasmEdge/mediapipe-rs
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request WasmEdge#2 from yanghaku/main
Add text embedding task
- Loading branch information
Showing
6 changed files
with
234 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,5 @@ | ||
mod text_classification; | ||
mod text_embedding; | ||
|
||
pub use text_classification::{TextClassifier, TextClassifierBuilder, TextClassifierSession}; | ||
pub use text_embedding::{TextEmbedder, TextEmbedderBuilder, TextEmbedderSession}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
use super::TextEmbedder; | ||
use crate::tasks::common::{BaseTaskOptions, EmbeddingOptions}; | ||
|
||
/// Configure the build options of a new **Text Embedding** task instance. | ||
/// | ||
/// Methods can be chained on it in order to configure it. | ||
pub struct TextEmbedderBuilder { | ||
pub(super) base_task_options: BaseTaskOptions, | ||
pub(super) embedding_options: EmbeddingOptions, | ||
} | ||
|
||
impl Default for TextEmbedderBuilder { | ||
#[inline(always)] | ||
fn default() -> Self { | ||
Self { | ||
base_task_options: Default::default(), | ||
embedding_options: Default::default(), | ||
} | ||
} | ||
} | ||
|
||
impl TextEmbedderBuilder { | ||
/// Create a new builder with default options. | ||
#[inline(always)] | ||
pub fn new() -> Self { | ||
Self::default() | ||
} | ||
|
||
base_task_options_impl!(TextEmbedder); | ||
|
||
embedding_options_impl!(); | ||
|
||
/// Use the current build options and use the buffer as model data to create a new task instance. | ||
#[inline] | ||
pub fn build_from_buffer(self, buffer: impl AsRef<[u8]>) -> Result<TextEmbedder, crate::Error> { | ||
let buf = buffer.as_ref(); | ||
// parse model and get model resources. | ||
let model_resource = crate::model::parse_model(buf)?; | ||
|
||
// check model | ||
model_base_check_impl!(model_resource, 1); | ||
model_resource_check_and_get_impl!(model_resource, to_tensor_info, 0).try_to_text()?; | ||
let input_count = model_resource.input_tensor_count(); | ||
if input_count != 1 && input_count != 3 { | ||
return Err(crate::Error::ModelInconsistentError(format!( | ||
"Expect model input tensor count `1` or `3`, but got `{}`", | ||
input_count | ||
))); | ||
} | ||
for i in 0..input_count { | ||
let t = model_resource_check_and_get_impl!(model_resource, input_tensor_type, i); | ||
if t != crate::TensorType::I32 { | ||
// todo: string type support | ||
return Err(crate::Error::ModelInconsistentError( | ||
"All input tensors should be int32 type".into(), | ||
)); | ||
} | ||
} | ||
|
||
let graph = crate::GraphBuilder::new( | ||
model_resource.model_backend(), | ||
self.base_task_options.device, | ||
) | ||
.build_from_bytes([buf])?; | ||
|
||
return Ok(TextEmbedder { | ||
build_options: self, | ||
model_resource, | ||
graph, | ||
input_count, | ||
}); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
mod builder; | ||
pub use builder::TextEmbedderBuilder; | ||
|
||
use crate::model::ModelResourceTrait; | ||
use crate::postprocess::{EmbeddingResult, TensorsToEmbedding}; | ||
use crate::preprocess::text::{TextToTensorInfo, TextToTensors}; | ||
use crate::{Error, Graph, GraphExecutionContext, TensorType}; | ||
|
||
/// Performs embedding on texts. | ||
pub struct TextEmbedder { | ||
build_options: TextEmbedderBuilder, | ||
model_resource: Box<dyn ModelResourceTrait>, | ||
graph: Graph, | ||
input_count: usize, | ||
} | ||
|
||
impl TextEmbedder { | ||
base_task_options_get_impl!(); | ||
|
||
embedding_options_get_impl!(); | ||
|
||
/// Create a new task session that contains processing buffers and can do inference. | ||
#[inline(always)] | ||
pub fn new_session(&self) -> Result<TextEmbedderSession, Error> { | ||
let input_to_tensor_info = | ||
model_resource_check_and_get_impl!(self.model_resource, to_tensor_info, 0) | ||
.try_to_text()?; | ||
let mut input_tensor_shapes = Vec::with_capacity(self.input_count); | ||
let mut input_tensor_bufs = Vec::with_capacity(self.input_count); | ||
for i in 0..self.input_count { | ||
let input_tensor_shape = | ||
model_resource_check_and_get_impl!(self.model_resource, input_tensor_shape, i); | ||
let bytes = input_tensor_shape.iter().fold(4, |sum, b| sum * *b); | ||
input_tensor_shapes.push(input_tensor_shape); | ||
input_tensor_bufs.push(vec![0; bytes]); | ||
} | ||
let output_tensor_shape = | ||
model_resource_check_and_get_impl!(self.model_resource, output_tensor_shape, 0); | ||
let mut tensor_to_embedding = TensorsToEmbedding::new( | ||
self.build_options.embedding_options.quantize, | ||
self.build_options.embedding_options.l2_normalize, | ||
); | ||
tensor_to_embedding.add_output_cfg( | ||
get_type_and_quantization!(self.model_resource, 0), | ||
output_tensor_shape, | ||
None, | ||
); | ||
|
||
let execution_ctx = self.graph.init_execution_context()?; | ||
Ok(TextEmbedderSession { | ||
execution_ctx, | ||
tensor_to_embedding, | ||
input_to_tensor_info, | ||
input_tensor_shapes, | ||
input_tensor_bufs, | ||
}) | ||
} | ||
|
||
/// Embed one text using a new session. | ||
#[inline(always)] | ||
pub fn embed(&self, input: &impl TextToTensors) -> Result<EmbeddingResult, Error> { | ||
self.new_session()?.embed(input) | ||
} | ||
} | ||
|
||
/// Session to run inference. | ||
/// If process multiple texts, reuse it can get better performance. | ||
pub struct TextEmbedderSession<'a> { | ||
execution_ctx: GraphExecutionContext<'a>, | ||
tensor_to_embedding: TensorsToEmbedding, | ||
|
||
input_to_tensor_info: &'a TextToTensorInfo, | ||
input_tensor_shapes: Vec<&'a [usize]>, | ||
input_tensor_bufs: Vec<Vec<u8>>, | ||
} | ||
|
||
impl<'a> TextEmbedderSession<'a> { | ||
/// Embed one text use this session. | ||
#[inline(always)] | ||
pub fn embed(&mut self, input: &impl TextToTensors) -> Result<EmbeddingResult, Error> { | ||
input.to_tensors(self.input_to_tensor_info, &mut self.input_tensor_bufs)?; | ||
|
||
let tensor_type = TensorType::I32; | ||
for index in 0..self.input_tensor_bufs.len() { | ||
self.execution_ctx.set_input( | ||
index, | ||
tensor_type, | ||
self.input_tensor_shapes[index], | ||
self.input_tensor_bufs[index].as_slice(), | ||
)?; | ||
} | ||
self.execution_ctx.compute()?; | ||
|
||
let output_buffer = self.tensor_to_embedding.output_buffer(0); | ||
self.execution_ctx.get_output(0, output_buffer)?; | ||
|
||
Ok(self.tensor_to_embedding.result(None)) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
use mediapipe_rs::tasks::text::TextEmbedderBuilder; | ||
|
||
const MODEL_1: &'static str = "assets/models/text_embedding/bert_embedder.tflite"; | ||
const MODEL_2: &'static str = "assets/models/text_embedding/universal_sentence_encoder.tflite"; | ||
|
||
#[test] | ||
fn test_text_embedding_model_1() { | ||
text_embedding_tasks_run(MODEL_1) | ||
} | ||
|
||
#[test] | ||
fn test_text_embedding_model_2() { | ||
// todo: add universal_sentence_encoder input support. | ||
// text_embedding_tasks_run(MODEL_2) | ||
} | ||
|
||
fn text_embedding_tasks_run(model_asset: &str) { | ||
let text_embedder = TextEmbedderBuilder::new() | ||
.l2_normalize(false) | ||
.quantize(false) | ||
.build_from_file(model_asset) | ||
.unwrap(); | ||
let mut session = text_embedder.new_session().unwrap(); | ||
|
||
let text_1 = "I'm feeling so good"; | ||
let text_2 = "I'm okay I guess"; | ||
|
||
let embedding_1 = session.embed(&text_1).unwrap(); | ||
let embedding_2 = session.embed(&text_2).unwrap(); | ||
assert_eq!(embedding_1.embeddings.len(), 1); | ||
assert_eq!(embedding_2.embeddings.len(), 1); | ||
let e_1 = embedding_1.embeddings.get(0).unwrap(); | ||
let e_2 = embedding_2.embeddings.get(0).unwrap(); | ||
assert_eq!(e_1.quantized_embedding.len(), 0); | ||
assert_eq!(e_2.quantized_embedding.len(), 0); | ||
assert_eq!(e_1.float_embedding.len(), e_2.float_embedding.len()); | ||
assert_ne!(e_1.float_embedding.len(), 0); | ||
|
||
let similarity = e_1.cosine_similarity(e_2).unwrap(); | ||
eprintln!("'{}', '{}' similarity = {}", text_1, text_2, similarity); | ||
} |