diff --git a/libs/community/langchain_community/embeddings/voyageai.py b/libs/community/langchain_community/embeddings/voyageai.py index 93109d45c65b6..f8b1a4059e6d5 100644 --- a/libs/community/langchain_community/embeddings/voyageai.py +++ b/libs/community/langchain_community/embeddings/voyageai.py @@ -86,6 +86,15 @@ class VoyageEmbeddings(BaseModel, Embeddings): show_progress_bar: bool = False """Whether to show a progress bar when embedding. Must have tqdm installed if set to True.""" + truncation: Optional[bool] = None + """Whether to truncate the input texts to fit within the context length. + + If True, over-length input texts will be truncated to fit within the context + length, before vectorized by the embedding model. If False, an error will be + raised if any given text exceeds the context length. If not specified + (defaults to None), we will truncate the input text before sending it to the + embedding model if it slightly exceeds the context window length. If it + significantly exceeds the context window length, an error will be raised.""" class Config: """Configuration for this pydantic object.""" @@ -104,12 +113,14 @@ def _invocation_params( self, input: List[str], input_type: Optional[str] = None ) -> Dict: api_key = cast(SecretStr, self.voyage_api_key).get_secret_value() - params = { + params: Dict = { "url": self.voyage_api_base, "headers": {"Authorization": f"Bearer {api_key}"}, "json": {"model": self.model, "input": input, "input_type": input_type}, "timeout": self.request_timeout, } + if self.truncation is not None: + params["json"]["truncation"] = self.truncation return params def _get_embeddings(