-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
d4ffed9
commit 078aa66
Showing
6 changed files
with
309 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
--- | ||
title: Gemini | ||
--- | ||
|
||
To use Gemini model, you have to set the `GEMINI_API_KEY` environment variable. You can obtain the Gemini API key from the [Google AI Studio](https://aistudio.google.com/app/apikey) | ||
|
||
## Usage | ||
|
||
```python | ||
import os | ||
from mem0 import Memory | ||
|
||
os.environ["OPENAI_API_KEY"] = "your-api-key" # used for embedding model | ||
os.environ["GEMINI_API_KEY"] = "your-api-key" | ||
|
||
config = { | ||
"llm": { | ||
"provider": "gemini", | ||
"config": { | ||
"model": "gemini-1.5-flash-latest", | ||
"temperature": 0.2, | ||
"max_tokens": 1500, | ||
} | ||
} | ||
} | ||
|
||
m = Memory.from_config(config) | ||
m.add("Likes to play cricket on weekends", user_id="alice", metadata={"category": "hobbies"}) | ||
``` | ||
|
||
## Config | ||
|
||
All available parameters for the `Gemini` config are present in [Master List of All Params in Config](../config). |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
import os | ||
from typing import Dict, List, Optional | ||
|
||
try: | ||
import google.generativeai as genai | ||
from google.generativeai import GenerativeModel | ||
from google.generativeai.types import content_types | ||
except ImportError: | ||
raise ImportError("The 'google-generativeai' library is required. Please install it using 'pip install google-generativeai'.") | ||
|
||
from mem0.configs.llms.base import BaseLlmConfig | ||
from mem0.llms.base import LLMBase | ||
|
||
|
||
class GeminiLLM(LLMBase): | ||
def __init__(self, config: Optional[BaseLlmConfig] = None): | ||
super().__init__(config) | ||
|
||
if not self.config.model: | ||
self.config.model = "gemini-1.5-flash-latest" | ||
|
||
api_key = self.config.api_key or os.getenv("GEMINI_API_KEY") | ||
genai.configure(api_key=api_key) | ||
self.client = GenerativeModel(model_name=self.config.model) | ||
|
||
def _parse_response(self, response, tools): | ||
""" | ||
Process the response based on whether tools are used or not. | ||
Args: | ||
response: The raw response from API. | ||
tools: The list of tools provided in the request. | ||
Returns: | ||
str or dict: The processed response. | ||
""" | ||
if tools: | ||
processed_response = { | ||
"content": content if (content := response.candidates[0].content.parts[0].text) else None, | ||
"tool_calls": [], | ||
} | ||
|
||
for part in response.candidates[0].content.parts: | ||
if fn := part.function_call: | ||
processed_response["tool_calls"].append( | ||
{ | ||
"name": fn.name, | ||
"arguments": {key:val for key, val in fn.args.items()}, | ||
} | ||
) | ||
|
||
return processed_response | ||
else: | ||
return response.candidates[0].content.parts[0].text | ||
|
||
def _reformat_messages(self, messages : List[Dict[str, str]]): | ||
""" | ||
Reformat messages for Gemini. | ||
Args: | ||
messages: The list of messages provided in the request. | ||
Returns: | ||
list: The list of messages in the required format. | ||
""" | ||
new_messages = [] | ||
|
||
for message in messages: | ||
if message["role"] == "system": | ||
content = "THIS IS A SYSTEM PROMPT. YOU MUST OBEY THIS: " + message["content"] | ||
|
||
else: | ||
content = message["content"] | ||
|
||
new_messages.append({"parts": content, | ||
"role": "model" if message["role"] == "model" else "user"}) | ||
|
||
return new_messages | ||
|
||
def _reformat_tools(self, tools: Optional[List[Dict]]): | ||
""" | ||
Reformat tools for Gemini. | ||
Args: | ||
tools: The list of tools provided in the request. | ||
Returns: | ||
list: The list of tools in the required format. | ||
""" | ||
|
||
def remove_additional_properties(data): | ||
"""Recursively removes 'additionalProperties' from nested dictionaries.""" | ||
|
||
if isinstance(data, dict): | ||
filtered_dict = { | ||
key: remove_additional_properties(value) | ||
for key, value in data.items() | ||
if not (key == "additionalProperties") | ||
} | ||
return filtered_dict | ||
else: | ||
return data | ||
|
||
new_tools = [] | ||
if tools: | ||
for tool in tools: | ||
func = tool['function'].copy() | ||
new_tools.append({"function_declarations":[remove_additional_properties(func)]}) | ||
|
||
return new_tools | ||
else: | ||
return None | ||
|
||
def generate_response( | ||
self, | ||
messages: List[Dict[str, str]], | ||
response_format=None, | ||
tools: Optional[List[Dict]] = None, | ||
tool_choice: str = "auto", | ||
): | ||
""" | ||
Generate a response based on the given messages using Gemini. | ||
Args: | ||
messages (list): List of message dicts containing 'role' and 'content'. | ||
response_format (str or object, optional): Format for the response. Defaults to "text". | ||
tools (list, optional): List of tools that the model can call. Defaults to None. | ||
tool_choice (str, optional): Tool choice method. Defaults to "auto". | ||
Returns: | ||
str: The generated response. | ||
""" | ||
|
||
params = { | ||
"temperature": self.config.temperature, | ||
"max_output_tokens": self.config.max_tokens, | ||
"top_p": self.config.top_p, | ||
} | ||
|
||
if response_format: | ||
params["response_mime_type"] = "application/json" | ||
params["response_schema"] = list[response_format] | ||
if tool_choice: | ||
tool_config = content_types.to_tool_config( | ||
{"function_calling_config": | ||
{"mode": tool_choice, "allowed_function_names": [tool['function']['name'] for tool in tools] if tool_choice == "any" else None} | ||
}) | ||
|
||
response = self.client.generate_content(contents = self._reformat_messages(messages), | ||
tools = self._reformat_tools(tools), | ||
generation_config = genai.GenerationConfig(**params), | ||
tool_config = tool_config) | ||
|
||
return self._parse_response(response, tools) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
from unittest.mock import Mock, patch | ||
|
||
import pytest | ||
from google.generativeai import GenerationConfig | ||
from google.generativeai.types import content_types | ||
|
||
from mem0.configs.llms.base import BaseLlmConfig | ||
from mem0.llms.gemini import GeminiLLM | ||
|
||
|
||
@pytest.fixture | ||
def mock_gemini_client(): | ||
with patch("mem0.llms.gemini.GenerativeModel") as mock_gemini: | ||
mock_client = Mock() | ||
mock_gemini.return_value = mock_client | ||
yield mock_client | ||
|
||
|
||
def test_generate_response_without_tools(mock_gemini_client: Mock): | ||
config = BaseLlmConfig(model="gemini-1.5-flash-latest", temperature=0.7, max_tokens=100, top_p=1.0) | ||
llm = GeminiLLM(config) | ||
messages = [ | ||
{"role": "system", "content": "You are a helpful assistant."}, | ||
{"role": "user", "content": "Hello, how are you?"}, | ||
] | ||
|
||
mock_part = Mock(text="I'm doing well, thank you for asking!") | ||
mock_content = Mock(parts=[mock_part]) | ||
mock_message = Mock(content=mock_content) | ||
mock_response = Mock(candidates=[mock_message]) | ||
mock_gemini_client.generate_content.return_value = mock_response | ||
|
||
response = llm.generate_response(messages) | ||
|
||
mock_gemini_client.generate_content.assert_called_once_with( | ||
contents = [ | ||
{"parts": "THIS IS A SYSTEM PROMPT. YOU MUST OBEY THIS: You are a helpful assistant.", "role": "user"}, | ||
{"parts": "Hello, how are you?", "role": "user"} | ||
], | ||
generation_config = GenerationConfig(temperature=0.7, max_output_tokens=100, top_p=1.0), | ||
tools = None, | ||
tool_config = content_types.to_tool_config( | ||
{"function_calling_config": | ||
{"mode": 'auto', "allowed_function_names": None} | ||
}) | ||
) | ||
assert response == "I'm doing well, thank you for asking!" | ||
|
||
def test_generate_response_with_tools(mock_gemini_client: Mock): | ||
config = BaseLlmConfig(model="gemini-1.5-flash-latest", temperature=0.7, max_tokens=100, top_p=1.0) | ||
llm = GeminiLLM(config) | ||
messages = [ | ||
{"role": "system", "content": "You are a helpful assistant."}, | ||
{"role": "user", "content": "Add a new memory: Today is a sunny day."}, | ||
] | ||
tools = [ | ||
{ | ||
"type": "function", | ||
"function": { | ||
"name": "add_memory", | ||
"description": "Add a memory", | ||
"parameters": { | ||
"type": "object", | ||
"properties": {"data": {"type": "string", "description": "Data to add to memory"}}, | ||
"required": ["data"], | ||
}, | ||
}, | ||
} | ||
] | ||
|
||
mock_tool_call = Mock() | ||
mock_tool_call.name = "add_memory" | ||
mock_tool_call.args = {"data": "Today is a sunny day."} | ||
|
||
mock_part = Mock() | ||
mock_part.function_call = mock_tool_call | ||
mock_part.text="I've added the memory for you." | ||
|
||
mock_content = Mock() | ||
mock_content.parts=[mock_part] | ||
|
||
mock_message = Mock() | ||
mock_message.content=mock_content | ||
|
||
mock_response = Mock(candidates=[mock_message]) | ||
mock_gemini_client.generate_content.return_value = mock_response | ||
|
||
response = llm.generate_response(messages, tools=tools) | ||
|
||
mock_gemini_client.generate_content.assert_called_once_with( | ||
contents = [ | ||
{"parts": "THIS IS A SYSTEM PROMPT. YOU MUST OBEY THIS: You are a helpful assistant.", "role": "user"}, | ||
{"parts": "Add a new memory: Today is a sunny day.", "role": "user"} | ||
], | ||
generation_config = GenerationConfig(temperature=0.7, max_output_tokens=100, top_p=1.0), | ||
tools = [ | ||
{ | ||
"function_declarations": [{ | ||
"name": "add_memory", | ||
"description": "Add a memory", | ||
"parameters": { | ||
"type": "object", | ||
"properties": {"data": {"type": "string", "description": "Data to add to memory"}}, | ||
"required": ["data"] | ||
} | ||
}] | ||
} | ||
], | ||
tool_config = content_types.to_tool_config( | ||
{"function_calling_config": | ||
{"mode": 'auto', "allowed_function_names": None} | ||
}) | ||
) | ||
|
||
assert response["content"] == "I've added the memory for you." | ||
assert len(response["tool_calls"]) == 1 | ||
assert response["tool_calls"][0]["name"] == "add_memory" | ||
assert response["tool_calls"][0]["arguments"] == {"data": "Today is a sunny day."} |