Skip to content

Commit

Permalink
Merge pull request WasmEdge#2 from yanghaku/main
Browse files Browse the repository at this point in the history
Add text embedding task
  • Loading branch information
juntao authored Aug 14, 2023
2 parents c606375 + e3783cf commit 7eee049
Show file tree
Hide file tree
Showing 6 changed files with 234 additions and 1 deletion.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
* [ ] Pose Landmark Detection
* [x] Audio Classification
* [x] Text Classification
* [ ] Text Embedding
* [x] Text Embedding
* [ ] Language Detection

## Task APIs
Expand Down
18 changes: 18 additions & 0 deletions scripts/download-models.sh
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,23 @@ text_classification_init() {
popd
}

text_embedding_init() {
text_embedding_dir="${model_path}/text_embedding"
mkdir -p "${text_embedding_dir}"
pushd "${text_embedding_dir}"

model_urls=(
"https://storage.googleapis.com/mediapipe-models/text_embedder/bert_embedder/float32/1/bert_embedder.tflite"
"https://storage.googleapis.com/mediapipe-models/text_embedder/universal_sentence_encoder/float32/latest/universal_sentence_encoder.tflite"
)

for url in "${model_urls[@]}"; do
curl -sLO "${url}"
done

popd
}

object_detection_init
image_classification_init
gesture_recognition_init
Expand All @@ -167,3 +184,4 @@ image_embedding_init
face_detection_init
audio_classification_init
text_classification_init
text_embedding_init
2 changes: 2 additions & 0 deletions src/tasks/text/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
mod text_classification;
mod text_embedding;

pub use text_classification::{TextClassifier, TextClassifierBuilder, TextClassifierSession};
pub use text_embedding::{TextEmbedder, TextEmbedderBuilder, TextEmbedderSession};
73 changes: 73 additions & 0 deletions src/tasks/text/text_embedding/builder.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
use super::TextEmbedder;
use crate::tasks::common::{BaseTaskOptions, EmbeddingOptions};

/// Configure the build options of a new **Text Embedding** task instance.
///
/// Methods can be chained on it in order to configure it.
pub struct TextEmbedderBuilder {
pub(super) base_task_options: BaseTaskOptions,
pub(super) embedding_options: EmbeddingOptions,
}

impl Default for TextEmbedderBuilder {
#[inline(always)]
fn default() -> Self {
Self {
base_task_options: Default::default(),
embedding_options: Default::default(),
}
}
}

impl TextEmbedderBuilder {
/// Create a new builder with default options.
#[inline(always)]
pub fn new() -> Self {
Self::default()
}

base_task_options_impl!(TextEmbedder);

embedding_options_impl!();

/// Use the current build options and use the buffer as model data to create a new task instance.
#[inline]
pub fn build_from_buffer(self, buffer: impl AsRef<[u8]>) -> Result<TextEmbedder, crate::Error> {
let buf = buffer.as_ref();
// parse model and get model resources.
let model_resource = crate::model::parse_model(buf)?;

// check model
model_base_check_impl!(model_resource, 1);
model_resource_check_and_get_impl!(model_resource, to_tensor_info, 0).try_to_text()?;
let input_count = model_resource.input_tensor_count();
if input_count != 1 && input_count != 3 {
return Err(crate::Error::ModelInconsistentError(format!(
"Expect model input tensor count `1` or `3`, but got `{}`",
input_count
)));
}
for i in 0..input_count {
let t = model_resource_check_and_get_impl!(model_resource, input_tensor_type, i);
if t != crate::TensorType::I32 {
// todo: string type support
return Err(crate::Error::ModelInconsistentError(
"All input tensors should be int32 type".into(),
));
}
}

let graph = crate::GraphBuilder::new(
model_resource.model_backend(),
self.base_task_options.device,
)
.build_from_bytes([buf])?;

return Ok(TextEmbedder {
build_options: self,
model_resource,
graph,
input_count,
});
}
}
99 changes: 99 additions & 0 deletions src/tasks/text/text_embedding/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
mod builder;
pub use builder::TextEmbedderBuilder;

use crate::model::ModelResourceTrait;
use crate::postprocess::{EmbeddingResult, TensorsToEmbedding};
use crate::preprocess::text::{TextToTensorInfo, TextToTensors};
use crate::{Error, Graph, GraphExecutionContext, TensorType};

/// Performs embedding on texts.
pub struct TextEmbedder {
build_options: TextEmbedderBuilder,
model_resource: Box<dyn ModelResourceTrait>,
graph: Graph,
input_count: usize,
}

impl TextEmbedder {
base_task_options_get_impl!();

embedding_options_get_impl!();

/// Create a new task session that contains processing buffers and can do inference.
#[inline(always)]
pub fn new_session(&self) -> Result<TextEmbedderSession, Error> {
let input_to_tensor_info =
model_resource_check_and_get_impl!(self.model_resource, to_tensor_info, 0)
.try_to_text()?;
let mut input_tensor_shapes = Vec::with_capacity(self.input_count);
let mut input_tensor_bufs = Vec::with_capacity(self.input_count);
for i in 0..self.input_count {
let input_tensor_shape =
model_resource_check_and_get_impl!(self.model_resource, input_tensor_shape, i);
let bytes = input_tensor_shape.iter().fold(4, |sum, b| sum * *b);
input_tensor_shapes.push(input_tensor_shape);
input_tensor_bufs.push(vec![0; bytes]);
}
let output_tensor_shape =
model_resource_check_and_get_impl!(self.model_resource, output_tensor_shape, 0);
let mut tensor_to_embedding = TensorsToEmbedding::new(
self.build_options.embedding_options.quantize,
self.build_options.embedding_options.l2_normalize,
);
tensor_to_embedding.add_output_cfg(
get_type_and_quantization!(self.model_resource, 0),
output_tensor_shape,
None,
);

let execution_ctx = self.graph.init_execution_context()?;
Ok(TextEmbedderSession {
execution_ctx,
tensor_to_embedding,
input_to_tensor_info,
input_tensor_shapes,
input_tensor_bufs,
})
}

/// Embed one text using a new session.
#[inline(always)]
pub fn embed(&self, input: &impl TextToTensors) -> Result<EmbeddingResult, Error> {
self.new_session()?.embed(input)
}
}

/// Session to run inference.
/// If process multiple texts, reuse it can get better performance.
pub struct TextEmbedderSession<'a> {
execution_ctx: GraphExecutionContext<'a>,
tensor_to_embedding: TensorsToEmbedding,

input_to_tensor_info: &'a TextToTensorInfo,
input_tensor_shapes: Vec<&'a [usize]>,
input_tensor_bufs: Vec<Vec<u8>>,
}

impl<'a> TextEmbedderSession<'a> {
/// Embed one text use this session.
#[inline(always)]
pub fn embed(&mut self, input: &impl TextToTensors) -> Result<EmbeddingResult, Error> {
input.to_tensors(self.input_to_tensor_info, &mut self.input_tensor_bufs)?;

let tensor_type = TensorType::I32;
for index in 0..self.input_tensor_bufs.len() {
self.execution_ctx.set_input(
index,
tensor_type,
self.input_tensor_shapes[index],
self.input_tensor_bufs[index].as_slice(),
)?;
}
self.execution_ctx.compute()?;

let output_buffer = self.tensor_to_embedding.output_buffer(0);
self.execution_ctx.get_output(0, output_buffer)?;

Ok(self.tensor_to_embedding.result(None))
}
}
41 changes: 41 additions & 0 deletions tests/text_embedding.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
use mediapipe_rs::tasks::text::TextEmbedderBuilder;

const MODEL_1: &'static str = "assets/models/text_embedding/bert_embedder.tflite";
const MODEL_2: &'static str = "assets/models/text_embedding/universal_sentence_encoder.tflite";

#[test]
fn test_text_embedding_model_1() {
text_embedding_tasks_run(MODEL_1)
}

#[test]
fn test_text_embedding_model_2() {
// todo: add universal_sentence_encoder input support.
// text_embedding_tasks_run(MODEL_2)
}

fn text_embedding_tasks_run(model_asset: &str) {
let text_embedder = TextEmbedderBuilder::new()
.l2_normalize(false)
.quantize(false)
.build_from_file(model_asset)
.unwrap();
let mut session = text_embedder.new_session().unwrap();

let text_1 = "I'm feeling so good";
let text_2 = "I'm okay I guess";

let embedding_1 = session.embed(&text_1).unwrap();
let embedding_2 = session.embed(&text_2).unwrap();
assert_eq!(embedding_1.embeddings.len(), 1);
assert_eq!(embedding_2.embeddings.len(), 1);
let e_1 = embedding_1.embeddings.get(0).unwrap();
let e_2 = embedding_2.embeddings.get(0).unwrap();
assert_eq!(e_1.quantized_embedding.len(), 0);
assert_eq!(e_2.quantized_embedding.len(), 0);
assert_eq!(e_1.float_embedding.len(), e_2.float_embedding.len());
assert_ne!(e_1.float_embedding.len(), 0);

let similarity = e_1.cosine_similarity(e_2).unwrap();
eprintln!("'{}', '{}' similarity = {}", text_1, text_2, similarity);
}

0 comments on commit 7eee049

Please sign in to comment.