>
);
diff --git a/gui/pages/_app.css b/gui/pages/_app.css
index d9fd436c2..33b676b8e 100644
--- a/gui/pages/_app.css
+++ b/gui/pages/_app.css
@@ -39,6 +39,7 @@
justify-content: center;
height: 120px;
cursor: pointer;
+ padding: 20px;
background: transparent;
}
@@ -209,7 +210,7 @@ input[type="range"]::-moz-range-track {
.dropdown_container {
width: 150px;
- height: fit-content;
+ height: auto;
background: #2E293F;
flex-direction: column;
justify-content: center;
@@ -617,6 +618,14 @@ p {
border-radius: 16px;
}
+.cancel_action {
+ margin-top: 10px;
+ width: fit-content;
+ color: #888888;
+ font-size: 12px;
+ cursor: pointer;
+}
+
.mt_6{margin-top: 6px;}
.mt_8{margin-top: 8px;}
.mt_10{margin-top: 10px;}
diff --git a/gui/pages/_app.js b/gui/pages/_app.js
index 0bd893e04..56a8af0c0 100644
--- a/gui/pages/_app.js
+++ b/gui/pages/_app.js
@@ -6,12 +6,20 @@ import 'bootstrap/dist/css/bootstrap.css';
import './_app.css'
import Head from 'next/head';
import Image from "next/image";
-import { getOrganisation, getProject, validateAccessToken, checkEnvironment, addUser } from "@/pages/api/DashboardService";
+import {
+ getOrganisation,
+ getProject,
+ validateAccessToken,
+ checkEnvironment,
+ addUser,
+ installToolkitTemplate, installAgentTemplate
+} from "@/pages/api/DashboardService";
import { githubClientId } from "@/pages/api/apiConfig";
import { useRouter } from 'next/router';
import querystring from 'querystring';
import {refreshUrl, loadingTextEffect} from "@/utils/utils";
import MarketplacePublic from "./Content/Marketplace/MarketplacePublic"
+import {toast} from "react-toastify";
export default function App() {
const [selectedView, setSelectedView] = useState('');
@@ -23,6 +31,7 @@ export default function App() {
const [loadingText, setLoadingText] = useState("Initializing SuperAGI");
const router = useRouter();
const [showMarketplace, setShowMarketplace] = useState(false);
+ const excludedKeys = ['repo_starred', 'popup_closed_time', 'twitter_toolkit_id', 'accessToken', 'agent_to_install', 'toolkit_to_install'];
function fetchOrganisation(userId) {
getOrganisation(userId)
@@ -34,9 +43,38 @@ export default function App() {
});
}
+ const installFromMarketplace = () => {
+ const toolkitName = localStorage.getItem('toolkit_to_install') || null;
+ const agentTemplateId = localStorage.getItem('agent_to_install') || null;
+
+ if(toolkitName !== null) {
+ installToolkitTemplate(toolkitName)
+ .then((response) => {
+ toast.success("Template installed", {autoClose: 1800});
+ })
+ .catch((error) => {
+ console.error('Error installing template:', error);
+ });
+ localStorage.removeItem('toolkit_to_install');
+ }
+
+ if(agentTemplateId !== null) {
+ installAgentTemplate(agentTemplateId)
+ .then((response) => {
+ toast.success("Template installed", {autoClose: 1800});
+ })
+ .catch((error) => {
+ console.error('Error installing template:', error);
+ });
+ localStorage.removeItem('agent_to_install');
+ }
+ }
+
useEffect(() => {
if(window.location.href.toLowerCase().includes('marketplace')) {
setShowMarketplace(true);
+ } else {
+ installFromMarketplace();
}
loadingTextEffect('Initializing SuperAGI', setLoadingText, 500);
@@ -118,6 +156,25 @@ export default function App() {
window.open(`https://github.com/login/oauth/authorize?scope=user:email&client_id=${github_client_id}`, '_self')
}
+ useEffect(() => {
+ const clearLocalStorage = () => {
+ Object.keys(localStorage).forEach((key) => {
+ if (!excludedKeys.includes(key)) {
+ localStorage.removeItem(key);
+ }
+ });
+ };
+
+ window.addEventListener('beforeunload', clearLocalStorage);
+ window.addEventListener('unload', clearLocalStorage);
+
+ return () => {
+ window.removeEventListener('beforeunload', clearLocalStorage);
+ window.removeEventListener('unload', clearLocalStorage);
+ };
+ }, []);
+
+
return (
diff --git a/gui/pages/api/DashboardService.js b/gui/pages/api/DashboardService.js
index cc473ce8b..04862185f 100644
--- a/gui/pages/api/DashboardService.js
+++ b/gui/pages/api/DashboardService.js
@@ -44,6 +44,10 @@ export const createAgent = (agentData) => {
return api.post(`/agents/create`, agentData);
};
+export const addTool = (toolData) => {
+ return api.post(`/toolkits/get/local/install`, toolData);
+};
+
export const updateAgents = (agentData) => {
return api.put(`/agentconfigs/update/`, agentData);
};
@@ -128,6 +132,14 @@ export const authenticateGoogleCred = (toolKitId) => {
return api.get(`/google/get_google_creds/toolkit_id/${toolKitId}`);
}
+export const authenticateTwitterCred = (toolKitId) => {
+ return api.get(`/twitter/get_twitter_creds/toolkit_id/${toolKitId}`);
+}
+
+export const sendTwitterCreds = (twitter_creds) => {
+ return api.post(`/twitter/send_twitter_creds/${twitter_creds}`);
+}
+
export const fetchToolTemplateList = () => {
return api.get(`/toolkits/get/list?page=0`);
}
@@ -135,6 +147,11 @@ export const fetchToolTemplateList = () => {
export const fetchToolTemplateOverview = (toolTemplateName) => {
return api.get(`/toolkits/marketplace/readme/${toolTemplateName}`);
}
+
export const installToolkitTemplate = (templateName) => {
return api.get(`/toolkits/get/install/${templateName}`);
}
+
+export const getExecutionDetails = (executionId) => {
+ return api.get(`/agent_executions_configs/details/${executionId}`);
+}
\ No newline at end of file
diff --git a/gui/public/images/arrow_forward_secondary.svg b/gui/public/images/arrow_forward_secondary.svg
new file mode 100644
index 000000000..855bc6695
--- /dev/null
+++ b/gui/public/images/arrow_forward_secondary.svg
@@ -0,0 +1,8 @@
+
diff --git a/gui/public/images/download_icon.svg b/gui/public/images/download_icon.svg
new file mode 100644
index 000000000..af6e8d8d7
--- /dev/null
+++ b/gui/public/images/download_icon.svg
@@ -0,0 +1,8 @@
+
diff --git a/gui/public/images/filemanager_icon.svg b/gui/public/images/filemanager_icon.svg
new file mode 100644
index 000000000..55ec3478a
--- /dev/null
+++ b/gui/public/images/filemanager_icon.svg
@@ -0,0 +1,14 @@
+
diff --git a/gui/public/images/github_icon.svg b/gui/public/images/github_icon.svg
new file mode 100644
index 000000000..6d377ae16
--- /dev/null
+++ b/gui/public/images/github_icon.svg
@@ -0,0 +1,9 @@
+
diff --git a/gui/public/images/gmail_icon.svg b/gui/public/images/gmail_icon.svg
new file mode 100644
index 000000000..457ddaf13
--- /dev/null
+++ b/gui/public/images/gmail_icon.svg
@@ -0,0 +1,14 @@
+
diff --git a/gui/public/images/google_calender_icon.svg b/gui/public/images/google_calender_icon.svg
new file mode 100644
index 000000000..19793fa13
--- /dev/null
+++ b/gui/public/images/google_calender_icon.svg
@@ -0,0 +1,15 @@
+
diff --git a/gui/public/images/google_search_icon.svg b/gui/public/images/google_search_icon.svg
new file mode 100644
index 000000000..00875fbf2
--- /dev/null
+++ b/gui/public/images/google_search_icon.svg
@@ -0,0 +1,9 @@
+
diff --git a/gui/public/images/google_serp_icon.svg b/gui/public/images/google_serp_icon.svg
new file mode 100644
index 000000000..dd6e19f0f
--- /dev/null
+++ b/gui/public/images/google_serp_icon.svg
@@ -0,0 +1,9 @@
+
diff --git a/gui/public/images/jira_icon.svg b/gui/public/images/jira_icon.svg
new file mode 100644
index 000000000..5b68b443f
--- /dev/null
+++ b/gui/public/images/jira_icon.svg
@@ -0,0 +1,9 @@
+
diff --git a/gui/public/images/searx_icon.svg b/gui/public/images/searx_icon.svg
new file mode 100644
index 000000000..cbe8eb401
--- /dev/null
+++ b/gui/public/images/searx_icon.svg
@@ -0,0 +1,9 @@
+
diff --git a/gui/public/images/slack_icon.svg b/gui/public/images/slack_icon.svg
new file mode 100644
index 000000000..47267a585
--- /dev/null
+++ b/gui/public/images/slack_icon.svg
@@ -0,0 +1,9 @@
+
diff --git a/gui/public/images/twitter_icon.svg b/gui/public/images/twitter_icon.svg
new file mode 100644
index 000000000..cbd05ea9d
--- /dev/null
+++ b/gui/public/images/twitter_icon.svg
@@ -0,0 +1,9 @@
+
diff --git a/gui/public/images/webscraper_icon.svg b/gui/public/images/webscraper_icon.svg
new file mode 100644
index 000000000..e50b21be7
--- /dev/null
+++ b/gui/public/images/webscraper_icon.svg
@@ -0,0 +1,9 @@
+
diff --git a/gui/utils/utils.js b/gui/utils/utils.js
index ffaca79f3..3bc4def49 100644
--- a/gui/utils/utils.js
+++ b/gui/utils/utils.js
@@ -1,5 +1,10 @@
import {baseUrl} from "@/pages/api/apiConfig";
import {EventBus} from "@/utils/eventBus";
+import JSZip from "jszip";
+
+export const getUserTimezone = () => {
+ return Intl.DateTimeFormat().resolvedOptions().timeZone;
+}
export const formatTimeDifference = (timeDifference) => {
const units = ['years', 'months', 'days', 'hours', 'minutes'];
@@ -47,30 +52,80 @@ export const formatBytes = (bytes, decimals = 2) => {
return `${formattedValue} ${sizes[i]}`;
}
-export const downloadFile = (fileId) => {
+export const downloadFile = (fileId, fileName = null) => {
const authToken = localStorage.getItem('accessToken');
const url = `${baseUrl()}/resources/get/${fileId}`;
const env = localStorage.getItem('applicationEnvironment');
- if(env === 'PROD') {
+ if (env === 'PROD') {
const headers = {
Authorization: `Bearer ${authToken}`,
};
- fetch(url, { headers })
+ return fetch(url, { headers })
.then((response) => response.blob())
.then((blob) => {
- const fileUrl = window.URL.createObjectURL(blob);
- window.open(fileUrl, "_blank");
+ if (fileName) {
+ const fileUrl = window.URL.createObjectURL(blob);
+ const anchorElement = document.createElement('a');
+ anchorElement.href = fileUrl;
+ anchorElement.download = fileName;
+ anchorElement.click();
+ window.URL.revokeObjectURL(fileUrl);
+ } else {
+ return blob;
+ }
})
.catch((error) => {
- console.error("Error downloading file:", error);
+ console.error('Error downloading file:', error);
});
} else {
- window.open(url, "_blank");
+ if (fileName) {
+ window.open(url, '_blank');
+ } else {
+ return fetch(url)
+ .then((response) => response.blob())
+ .catch((error) => {
+ console.error('Error downloading file:', error);
+ });
+ }
}
};
+export const downloadAllFiles = (files) => {
+ const zip = new JSZip();
+ const promises = [];
+
+ files.forEach(file => {
+ const promise = downloadFile(file.id)
+ .then(blob => {
+ const fileBlob = new Blob([blob], { type: file.type });
+ zip.file(file.name, fileBlob);
+ })
+ .catch(error => {
+ console.error('Error downloading file:', error);
+ });
+
+ promises.push(promise);
+ });
+
+ Promise.all(promises)
+ .then(() => {
+ zip.generateAsync({ type: 'blob' })
+ .then(content => {
+ const timestamp = new Date().getTime();
+ const zipFilename = `files_${timestamp}.zip`;
+ const downloadLink = document.createElement('a');
+ downloadLink.href = URL.createObjectURL(content);
+ downloadLink.download = zipFilename;
+ downloadLink.click();
+ })
+ .catch(error => {
+ console.error('Error generating zip:', error);
+ });
+ });
+};
+
export const refreshUrl = () => {
if (typeof window === 'undefined') {
return;
@@ -94,7 +149,7 @@ export const loadingTextEffect = (loadingText, setLoadingText, timer) => {
export const openNewTab = (id, name, contentType) => {
EventBus.emit('openNewTab', {
- element: {id: id, name: name, contentType: contentType}
+ element: {id: id, name: name, contentType: contentType, internalId: createInternalId()}
});
}
@@ -102,4 +157,72 @@ export const removeTab = (id, name, contentType) => {
EventBus.emit('removeTab', {
element: {id: id, name: name, contentType: contentType}
});
+}
+
+export const setLocalStorageValue = (key, value, stateFunction) => {
+ stateFunction(value);
+ localStorage.setItem(key, value);
+}
+
+export const setLocalStorageArray = (key, value, stateFunction) => {
+ stateFunction(value);
+ const arrayString = JSON.stringify(value);
+ localStorage.setItem(key, arrayString);
+}
+
+export const removeInternalId = (internalId) => {
+ const internal_ids = localStorage.getItem("agi_internal_ids");
+ let idsArray = internal_ids ? internal_ids.split(",").map(Number) : [];
+
+ if(idsArray.length <= 0) {
+ return;
+ }
+
+ const internalIdIndex = idsArray.indexOf(internalId);
+ if (internalIdIndex !== -1) {
+ idsArray.splice(internalIdIndex, 1);
+ localStorage.setItem('agi_internal_ids', idsArray.join(','));
+ localStorage.removeItem("agent_create_click_" + String(internalId));
+ localStorage.removeItem("agent_name_" + String(internalId));
+ localStorage.removeItem("agent_description_" + String(internalId));
+ localStorage.removeItem("agent_goals_" + String(internalId));
+ localStorage.removeItem("agent_instructions_" + String(internalId));
+ localStorage.removeItem("agent_constraints_" + String(internalId));
+ localStorage.removeItem("agent_model_" + String(internalId));
+ localStorage.removeItem("agent_type_" + String(internalId));
+ localStorage.removeItem("tool_names_" + String(internalId));
+ localStorage.removeItem("tool_ids_" + String(internalId));
+ localStorage.removeItem("agent_rolling_window_" + String(internalId));
+ localStorage.removeItem("agent_database_" + String(internalId));
+ localStorage.removeItem("agent_permission_" + String(internalId));
+ localStorage.removeItem("agent_exit_criterion_" + String(internalId));
+ localStorage.removeItem("agent_iterations_" + String(internalId));
+ localStorage.removeItem("agent_step_time_" + String(internalId));
+ localStorage.removeItem("advanced_options_" + String(internalId));
+ localStorage.removeItem("has_LTM_" + String(internalId));
+ localStorage.removeItem("has_resource_" + String(internalId));
+ localStorage.removeItem("agent_files_" + String(internalId));
+ }
+}
+
+export const createInternalId = () => {
+ let newId = 1;
+
+ if (typeof window !== 'undefined') {
+ const internal_ids = localStorage.getItem("agi_internal_ids");
+ let idsArray = internal_ids ? internal_ids.split(",").map(Number) : [];
+ let found = false;
+
+ for (let i = 1; !found; i++) {
+ if (!idsArray.includes(i)) {
+ newId = i;
+ found = true;
+ }
+ }
+
+ idsArray.push(newId);
+ localStorage.setItem('agi_internal_ids', idsArray.join(','));
+ }
+
+ return newId;
}
\ No newline at end of file
diff --git a/main.py b/main.py
index 222e9dd4c..6ecefaa95 100644
--- a/main.py
+++ b/main.py
@@ -15,6 +15,11 @@
from sqlalchemy.orm import sessionmaker
import superagi
+import urllib.parse
+import json
+import http.client as http_client
+from superagi.helper.twitter_tokens import TwitterTokens
+from datetime import datetime, timedelta
from superagi.agent.agent_prompt_builder import AgentPromptBuilder
from superagi.config.config import get_config
from superagi.controllers.agent import router as agent_router
@@ -28,19 +33,23 @@
from superagi.controllers.config import router as config_router
from superagi.controllers.organisation import router as organisation_router
from superagi.controllers.project import router as project_router
+from superagi.controllers.twitter_oauth import router as twitter_oauth_router
from superagi.controllers.resources import router as resources_router
from superagi.controllers.tool import router as tool_router
from superagi.controllers.tool_config import router as tool_config_router
from superagi.controllers.toolkit import router as toolkit_router
from superagi.controllers.user import router as user_router
+from superagi.controllers.agent_execution_config import router as agent_execution_config
from superagi.helper.tool_helper import register_toolkits
from superagi.lib.logger import logger
from superagi.llms.openai import OpenAi
+from superagi.helper.auth import get_current_user
from superagi.models.agent_workflow import AgentWorkflow
from superagi.models.agent_workflow_step import AgentWorkflowStep
from superagi.models.organisation import Organisation
from superagi.models.tool_config import ToolConfig
from superagi.models.toolkit import Toolkit
+from superagi.models.oauth_tokens import OauthTokens
from superagi.models.types.login_request import LoginRequest
from superagi.models.user import User
@@ -97,6 +106,8 @@
app.include_router(config_router, prefix="/configs")
app.include_router(agent_template_router, prefix="/agent_templates")
app.include_router(agent_workflow_router, prefix="/agent_workflows")
+app.include_router(twitter_oauth_router, prefix="/twitter")
+app.include_router(agent_execution_config, prefix="/agent_executions_configs")
# in production you can use Settings management
@@ -320,7 +331,6 @@ async def google_auth_calendar(code: str = Query(...), Authorize: AuthJWT = Depe
frontend_url = superagi.config.config.get_config("FRONTEND_URL", "http://localhost:3000")
return RedirectResponse(frontend_url)
-
@app.get('/github-login')
def github_login():
"""GitHub login"""
@@ -411,7 +421,6 @@ def get_google_calendar_tool_configs(toolkit_id: int):
"client_id": google_calendar_config.value
}
-
@app.get("/validate-open-ai-key/{open_ai_key}")
async def root(open_ai_key: str, Authorize: AuthJWT = Depends()):
"""API to validate Open AI Key"""
diff --git a/migrations/versions/83424de1347e_added_agent_execution_config.py b/migrations/versions/83424de1347e_added_agent_execution_config.py
new file mode 100644
index 000000000..5c57a3963
--- /dev/null
+++ b/migrations/versions/83424de1347e_added_agent_execution_config.py
@@ -0,0 +1,36 @@
+"""added_agent_execution_config
+
+Revision ID: 83424de1347e
+Revises: c02f3d759bf3
+Create Date: 2023-07-03 22:42:50.091762
+
+"""
+from alembic import op
+import sqlalchemy as sa
+
+
+# revision identifiers, used by Alembic.
+revision = '83424de1347e'
+down_revision = 'c02f3d759bf3'
+branch_labels = None
+depends_on = None
+
+
+def upgrade() -> None:
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.create_table('agent_execution_configs',
+ sa.Column('id', sa.Integer(), nullable=False),
+ sa.Column('agent_execution_id', sa.Integer(), nullable=True),
+ sa.Column('key', sa.String(), nullable=True),
+ sa.Column('value', sa.Text(), nullable=True),
+ sa.Column('created_at', sa.DateTime(), nullable=True),
+ sa.Column('updated_at', sa.DateTime(), nullable=True),
+ sa.PrimaryKeyConstraint('id')
+ )
+ # ### end Alembic commands ###
+
+
+def downgrade() -> None:
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.drop_table('agent_execution_configs')
+ # ### end Alembic commands ###
diff --git a/migrations/versions/c02f3d759bf3_add_summary_to_resource.py b/migrations/versions/c02f3d759bf3_add_summary_to_resource.py
new file mode 100644
index 000000000..ef17aae07
--- /dev/null
+++ b/migrations/versions/c02f3d759bf3_add_summary_to_resource.py
@@ -0,0 +1,28 @@
+"""add summary to resource
+
+Revision ID: c02f3d759bf3
+Revises: 1d54db311055
+Create Date: 2023-06-27 05:07:29.016704
+
+"""
+from alembic import op
+import sqlalchemy as sa
+
+
+# revision identifiers, used by Alembic.
+revision = 'c02f3d759bf3'
+down_revision = 'c5c19944c90c'
+branch_labels = None
+depends_on = None
+
+
+def upgrade() -> None:
+ # ### commands auto generated by Alembic - please adjust! ##
+ op.add_column('resources', sa.Column('summary', sa.Text(), nullable=True))
+ # ### end Alembic commands ###
+
+
+def downgrade() -> None:
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.drop_column('resources', 'summary')
+ # ### end Alembic commands ###
diff --git a/migrations/versions/c5c19944c90c_create_oauth_tokens.py b/migrations/versions/c5c19944c90c_create_oauth_tokens.py
new file mode 100644
index 000000000..8986afcdc
--- /dev/null
+++ b/migrations/versions/c5c19944c90c_create_oauth_tokens.py
@@ -0,0 +1,48 @@
+"""Create Oauth Tokens
+
+Revision ID: c5c19944c90c
+Revises: 7a3e336c0fba
+Create Date: 2023-06-30 07:26:29.180784
+
+"""
+from alembic import op
+import sqlalchemy as sa
+
+
+# revision identifiers, used by Alembic.
+revision = 'c5c19944c90c'
+down_revision = '7a3e336c0fba'
+branch_labels = None
+depends_on = None
+
+
+def upgrade() -> None:
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.create_table('oauth_tokens',
+ sa.Column('created_at', sa.DateTime(), nullable=True),
+ sa.Column('updated_at', sa.DateTime(), nullable=True),
+ sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
+ sa.Column('user_id', sa.Integer(), nullable=True),
+ sa.Column('organisation_id', sa.Integer(), nullable=True),
+ sa.Column('toolkit_id', sa.Integer(), nullable=True),
+ sa.Column('key', sa.String(), nullable=True),
+ sa.Column('value', sa.Text(), nullable=True),
+ sa.PrimaryKeyConstraint('id')
+ )
+ op.drop_index('ix_agent_execution_permissions_agent_execution_id', table_name='agent_execution_permissions')
+ op.drop_index('ix_atc_agnt_template_id_key', table_name='agent_template_configs')
+ op.drop_index('ix_agt_agnt_name', table_name='agent_templates')
+ op.drop_index('ix_agt_agnt_organisation_id', table_name='agent_templates')
+ op.drop_index('ix_agt_agnt_workflow_id', table_name='agent_templates')
+ # ### end Alembic commands ###
+
+
+def downgrade() -> None:
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.create_index('ix_agt_agnt_workflow_id', 'agent_templates', ['agent_workflow_id'], unique=False)
+ op.create_index('ix_agt_agnt_organisation_id', 'agent_templates', ['organisation_id'], unique=False)
+ op.create_index('ix_agt_agnt_name', 'agent_templates', ['name'], unique=False)
+ op.create_index('ix_atc_agnt_template_id_key', 'agent_template_configs', ['agent_template_id', 'key'], unique=False)
+ op.create_index('ix_agent_execution_permissions_agent_execution_id', 'agent_execution_permissions', ['agent_execution_id'], unique=False)
+ op.drop_table('oauth_tokens')
+ # ### end Alembic commands ###
diff --git a/nginx/default.conf b/nginx/default.conf
index 22ede4873..456ccc56e 100644
--- a/nginx/default.conf
+++ b/nginx/default.conf
@@ -10,6 +10,7 @@ server {
location /api {
proxy_pass http://backend:8001;
+ client_max_body_size 50M;
proxy_set_header Host $host;
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
diff --git a/requirements.txt b/requirements.txt
index 624a2d107..ff6a3725d 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -24,8 +24,10 @@ colorama==0.4.6
confluent-kafka==2.1.1
cryptography==41.0.1
cssselect==1.2.0
+chromadb==0.3.26
dataclasses-json==0.5.7
defusedxml==0.7.1
+docx2txt==0.8
dnspython==2.3.0
email-validator==2.0.0.post2
exceptiongroup==1.1.1
@@ -63,6 +65,7 @@ json5==0.9.14
jsonmerge==1.9.0
jsonschema==4.17.3
kombu==5.2.4
+llama-index==0.6.35
log-symbols==0.0.14
loguru==0.7.0
lxml==4.9.2
@@ -89,7 +92,6 @@ prompt-toolkit==3.0.38
psycopg2==2.9.6
pycparser==2.21
pydantic==1.10.8
-pydantic-sqlalchemy==0.0.9
PyJWT==1.7.1
PyPDF2==3.0.1
pyquery==2.0.0
@@ -100,6 +102,7 @@ python-dotenv==1.0.0
python-multipart==0.0.6
pytz==2023.3
PyYAML==6.0
+qdrant-client==1.3.1
redis==4.5.5
regex==2023.5.5
requests==2.31.0
@@ -113,8 +116,8 @@ six==1.16.0
sniffio==1.3.0
soupsieve==2.4.1
spinners==0.0.24
-SQLAlchemy==1.4.48
starlette==0.27.0
+SQLAlchemy==2.0.16
tenacity==8.2.2
termcolor==2.3.0
tiktoken==0.4.0
@@ -123,7 +126,6 @@ tldextract==3.4.4
tqdm==4.65.0
tweepy==4.14.0
typing-inspect==0.8.0
-typing_extensions==4.6.2
ujson==5.7.0
urllib3==1.26.16
uvicorn==0.22.0
@@ -143,3 +145,9 @@ pylint==2.17.4
pre-commit==3.3.3
pytest-cov==4.1.0
pytest-mock==3.11.1
+transformers==4.30.2
+pypdf==3.11.0
+python-pptx==0.6.21
+Pillow==9.5.0
+EbookLib==0.18
+html2text==2020.1.16
\ No newline at end of file
diff --git a/superagi/agent/super_agi.py b/superagi/agent/super_agi.py
index 700ab6024..478373134 100644
--- a/superagi/agent/super_agi.py
+++ b/superagi/agent/super_agi.py
@@ -57,6 +57,7 @@ def __init__(self,
memory: VectorStore,
tools: List[BaseTool],
agent_config: Any,
+ agent_execution_config: Any,
output_parser: BaseOutputParser = AgentOutputParser(),
):
self.ai_name = ai_name
@@ -67,6 +68,7 @@ def __init__(self,
self.output_parser = output_parser
self.tools = tools
self.agent_config = agent_config
+ self.agent_execution_config = agent_execution_config
# Init Log
# print("\033[92m\033[1m" + "\nWelcome to SuperAGI - The future of AGI" + "\033[0m\033[0m")
@@ -165,7 +167,7 @@ def execute(self, workflow_step: AgentWorkflowStep):
total_tokens = current_tokens + TokenCounter.count_message_tokens(response, self.llm.get_model())
self.update_agent_execution_tokens(current_calls, total_tokens)
- if response['content'] is None:
+ if 'content' not in response or response['content'] is None:
raise RuntimeError(f"Failed to get response from llm")
assistant_reply = response['content']
@@ -286,7 +288,7 @@ def build_agent_prompt(self, prompt: str, task_queue: TaskQueue, max_token_limit
if len(pending_tasks) > 0 or len(completed_tasks) > 0:
add_finish_tool = False
- prompt = AgentPromptBuilder.replace_main_variables(prompt, self.agent_config["goal"], self.agent_config["instruction"],
+ prompt = AgentPromptBuilder.replace_main_variables(prompt, self.agent_execution_config["goal"], self.agent_execution_config["instruction"],
self.agent_config["constraints"], self.tools, add_finish_tool)
response = task_queue.get_last_task_details()
diff --git a/superagi/controllers/__init__.py b/superagi/controllers/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/superagi/controllers/agent.py b/superagi/controllers/agent.py
index 5d3fb2f84..a790e27cb 100644
--- a/superagi/controllers/agent.py
+++ b/superagi/controllers/agent.py
@@ -1,32 +1,54 @@
from fastapi_sqlalchemy import db
-from fastapi import HTTPException, Depends, Request
+from fastapi import HTTPException, Depends
from fastapi_jwt_auth import AuthJWT
+from pydantic import BaseModel
+
from superagi.models.agent import Agent
+from superagi.models.agent_execution_config import AgentExecutionConfiguration
from superagi.models.agent_template import AgentTemplate
-from superagi.models.agent_template_config import AgentTemplateConfig
from superagi.models.project import Project
from fastapi import APIRouter
-from pydantic_sqlalchemy import sqlalchemy_to_pydantic
-
from superagi.models.agent_workflow import AgentWorkflow
from superagi.models.types.agent_with_config import AgentWithConfig
from superagi.models.agent_config import AgentConfiguration
from superagi.models.agent_execution import AgentExecution
-from superagi.models.agent_execution_feed import AgentExecutionFeed
from superagi.models.tool import Tool
from jsonmerge import merge
from superagi.worker import execute_agent
from datetime import datetime
import json
from sqlalchemy import func
-from superagi.helper.auth import check_auth, get_user_organisation
+from superagi.helper.auth import check_auth
+
+# from superagi.types.db import AgentOut, AgentIn
router = APIRouter()
+class AgentOut(BaseModel):
+ id: int
+ name: str
+ project_id: int
+ description: str
+ created_at: datetime
+ updated_at: datetime
+
+ class Config:
+ orm_mode = True
+
+
+class AgentIn(BaseModel):
+ name: str
+ project_id: int
+ description: str
+
+ class Config:
+ orm_mode = True
+
+
# CRUD Operations
-@router.post("/add", response_model=sqlalchemy_to_pydantic(Agent), status_code=201)
-def create_agent(agent: sqlalchemy_to_pydantic(Agent, exclude=["id"]),
+@router.post("/add", response_model=AgentOut, status_code=201)
+def create_agent(agent: AgentIn,
Authorize: AuthJWT = Depends(check_auth)):
"""
Creates a new Agent
@@ -57,7 +79,7 @@ def create_agent(agent: sqlalchemy_to_pydantic(Agent, exclude=["id"]),
return db_agent
-@router.get("/get/{agent_id}", response_model=sqlalchemy_to_pydantic(Agent))
+@router.get("/get/{agent_id}", response_model=AgentOut)
def get_agent(agent_id: int,
Authorize: AuthJWT = Depends(check_auth)):
"""
@@ -79,8 +101,8 @@ def get_agent(agent_id: int,
return db_agent
-@router.put("/update/{agent_id}", response_model=sqlalchemy_to_pydantic(Agent))
-def update_agent(agent_id: int, agent: sqlalchemy_to_pydantic(Agent, exclude=["id"]),
+@router.put("/update/{agent_id}", response_model=AgentOut)
+def update_agent(agent_id: int, agent: AgentIn,
Authorize: AuthJWT = Depends(check_auth)):
"""
Update an existing Agent
@@ -164,12 +186,19 @@ def create_agent_with_config(agent_with_config: AgentWithConfig,
db_agent = Agent.create_agent_with_config(db, agent_with_config)
start_step_id = AgentWorkflow.fetch_trigger_step_id(db.session, db_agent.agent_workflow_id)
# Creating an execution with RUNNING status
- execution = AgentExecution(status='RUNNING', last_execution_time=datetime.now(), agent_id=db_agent.id,
+ execution = AgentExecution(status='CREATED', last_execution_time=datetime.now(), agent_id=db_agent.id,
name="New Run", current_step_id=start_step_id)
+ agent_execution_configs = {
+ "goal": agent_with_config.goal,
+ "instruction": agent_with_config.instruction
+ }
db.session.add(execution)
db.session.commit()
- execute_agent.delay(execution.id, datetime.now())
+ db.session.flush()
+ AgentExecutionConfiguration.add_or_update_agent_execution_config(session=db.session, execution=execution,
+ agent_execution_configs=agent_execution_configs)
+ # execute_agent.delay(execution.id, datetime.now())
return {
"id": db_agent.id,
@@ -210,17 +239,18 @@ def get_agents_by_project_id(project_id: int,
# Query the AgentExecution table using the agent ID
executions = db.session.query(AgentExecution).filter_by(agent_id=agent_id).all()
- isRunning = False
+ is_running = False
for execution in executions:
if execution.status == "RUNNING":
- isRunning = True
+ is_running = True
break
new_agent = {
**agent_dict,
- 'status': isRunning
+ 'is_running': is_running
}
new_agents.append(new_agent)
- return new_agents
+ new_agents_sorted = sorted(new_agents, key=lambda agent: agent['is_running'] == True, reverse=True)
+ return new_agents_sorted
@router.get("/get/details/{agent_id}")
@@ -255,11 +285,10 @@ def get_agent_configuration(agent_id: int,
total_tokens = db.session.query(func.sum(AgentExecution.num_of_tokens)).filter(
AgentExecution.agent_id == agent_id).scalar()
-
# Construct the JSON response
response = {result.key: result.value for result in results}
response = merge(response, {"name": agent.name, "description": agent.description,
- # Query the AgentConfiguration table for the speci
+ # Query the AgentConfiguration table for the speci
"goal": eval(response["goal"]),
"instruction": eval(response.get("instruction", '[]')),
"calls": total_calls,
diff --git a/superagi/controllers/agent_config.py b/superagi/controllers/agent_config.py
index a923acea8..d57d4943e 100644
--- a/superagi/controllers/agent_config.py
+++ b/superagi/controllers/agent_config.py
@@ -1,19 +1,46 @@
+from typing import Union, List
+
from fastapi import APIRouter
from fastapi import HTTPException, Depends
from fastapi_jwt_auth import AuthJWT
from fastapi_sqlalchemy import db
-from pydantic_sqlalchemy import sqlalchemy_to_pydantic
+from pydantic import BaseModel
+
from superagi.helper.auth import check_auth
from superagi.models.agent import Agent
from superagi.models.agent_config import AgentConfiguration
from superagi.models.types.agent_config import AgentConfig
+# from superagi.types.db import AgentConfigurationIn, AgentConfigurationOut
+from datetime import datetime
+
router = APIRouter()
+class AgentConfigurationOut(BaseModel):
+ id: int
+ agent_id: int
+ key: str
+ value: str
+ created_at: datetime
+ updated_at: datetime
+
+ class Config:
+ orm_mode = True
+
+
+class AgentConfigurationIn(BaseModel):
+ agent_id: int
+ key: str
+ value: Union[str, List[str]]
+
+ class Config:
+ orm_mode = True
+
+
# CRUD Operations
-@router.post("/add", response_model=sqlalchemy_to_pydantic(AgentConfiguration), status_code=201)
-def create_agent_config(agent_config: sqlalchemy_to_pydantic(AgentConfiguration, exclude=["id"], ),
+@router.post("/add", response_model=AgentConfigurationOut, status_code=201)
+def create_agent_config(agent_config: AgentConfigurationIn,
Authorize: AuthJWT = Depends(check_auth)):
"""
Create a new agent configuration by setting a new key and value related to the agent.
@@ -39,7 +66,7 @@ def create_agent_config(agent_config: sqlalchemy_to_pydantic(AgentConfiguration,
return db_agent_config
-@router.get("/get/{agent_config_id}", response_model=sqlalchemy_to_pydantic(AgentConfiguration))
+@router.get("/get/{agent_config_id}", response_model=AgentConfigurationOut)
def get_agent(agent_config_id: int,
Authorize: AuthJWT = Depends(check_auth)):
"""
@@ -61,8 +88,8 @@ def get_agent(agent_config_id: int,
return db_agent_config
-@router.put("/update", response_model=sqlalchemy_to_pydantic(AgentConfiguration))
-def update_agent(agent_config: AgentConfig,
+@router.put("/update", response_model=AgentConfigurationOut)
+def update_agent(agent_config: AgentConfigurationIn,
Authorize: AuthJWT = Depends(check_auth)):
"""
Update a particular agent configuration value for the given agent_id and agent_config key.
diff --git a/superagi/controllers/agent_execution.py b/superagi/controllers/agent_execution.py
index 16373660b..61656379e 100644
--- a/superagi/controllers/agent_execution.py
+++ b/superagi/controllers/agent_execution.py
@@ -1,24 +1,61 @@
from datetime import datetime
+from typing import Optional
+
from fastapi_sqlalchemy import db
from fastapi import HTTPException, Depends
from fastapi_jwt_auth import AuthJWT
+from pydantic import BaseModel
+from pydantic.fields import List
from superagi.helper.time_helper import get_time_difference
+from superagi.models.agent_execution_config import AgentExecutionConfiguration
from superagi.models.agent_workflow import AgentWorkflow
from superagi.worker import execute_agent
from superagi.models.agent_execution import AgentExecution
from superagi.models.agent import Agent
from fastapi import APIRouter
-from pydantic_sqlalchemy import sqlalchemy_to_pydantic
from sqlalchemy import desc
from superagi.helper.auth import check_auth
+# from superagi.types.db import AgentExecutionOut, AgentExecutionIn
router = APIRouter()
+class AgentExecutionOut(BaseModel):
+ id: int
+ status: str
+ name: str
+ agent_id: int
+ last_execution_time: datetime
+ num_of_calls: int
+ num_of_tokens: int
+ current_step_id: int
+ permission_id: Optional[int]
+ created_at: datetime
+ updated_at: datetime
+
+ class Config:
+ orm_mode = True
+
+
+class AgentExecutionIn(BaseModel):
+ status: Optional[str]
+ name: Optional[str]
+ agent_id: Optional[int]
+ last_execution_time: Optional[datetime]
+ num_of_calls: Optional[int]
+ num_of_tokens: Optional[int]
+ current_step_id: Optional[int]
+ permission_id: Optional[int]
+ goal: Optional[List[str]]
+ instruction: Optional[List[str]]
+
+ class Config:
+ orm_mode = True
+
# CRUD Operations
-@router.post("/add", response_model=sqlalchemy_to_pydantic(AgentExecution), status_code=201)
-def create_agent_execution(agent_execution: sqlalchemy_to_pydantic(AgentExecution, exclude=["id"]),
+@router.post("/add", response_model=AgentExecutionOut, status_code=201)
+def create_agent_execution(agent_execution: AgentExecutionIn,
Authorize: AuthJWT = Depends(check_auth)):
"""
Create a new agent execution/run.
@@ -42,15 +79,23 @@ def create_agent_execution(agent_execution: sqlalchemy_to_pydantic(AgentExecutio
agent_id=agent_execution.agent_id, name=agent_execution.name, num_of_calls=0,
num_of_tokens=0,
current_step_id=start_step_id)
+ agent_execution_configs = {
+ "goal": agent_execution.goal,
+ "instruction": agent_execution.instruction
+ }
db.session.add(db_agent_execution)
db.session.commit()
+ db.session.flush()
+ AgentExecutionConfiguration.add_or_update_agent_execution_config(session=db.session, execution=db_agent_execution,
+ agent_execution_configs=agent_execution_configs)
+
if db_agent_execution.status == "RUNNING":
execute_agent.delay(db_agent_execution.id, datetime.now())
return db_agent_execution
-@router.get("/get/{agent_execution_id}", response_model=sqlalchemy_to_pydantic(AgentExecution))
+@router.get("/get/{agent_execution_id}", response_model=AgentExecutionOut)
def get_agent_execution(agent_execution_id: int,
Authorize: AuthJWT = Depends(check_auth)):
"""
@@ -72,14 +117,14 @@ def get_agent_execution(agent_execution_id: int,
return db_agent_execution
-@router.put("/update/{agent_execution_id}", response_model=sqlalchemy_to_pydantic(AgentExecution))
+@router.put("/update/{agent_execution_id}", response_model=AgentExecutionOut)
def update_agent_execution(agent_execution_id: int,
- agent_execution: sqlalchemy_to_pydantic(AgentExecution, exclude=["id"]),
+ agent_execution: AgentExecutionIn,
Authorize: AuthJWT = Depends(check_auth)):
"""Update details of particular agent_execution by agent_execution_id"""
db_agent_execution = db.session.query(AgentExecution).filter(AgentExecution.id == agent_execution_id).first()
- if agent_execution == "COMPLETED":
+ if agent_execution.status == "COMPLETED":
raise HTTPException(status_code=400, detail="Invalid Request")
if not db_agent_execution:
diff --git a/superagi/controllers/agent_execution_config.py b/superagi/controllers/agent_execution_config.py
new file mode 100644
index 000000000..89decfd87
--- /dev/null
+++ b/superagi/controllers/agent_execution_config.py
@@ -0,0 +1,36 @@
+from fastapi import APIRouter
+from fastapi import HTTPException, Depends
+from fastapi_jwt_auth import AuthJWT
+from fastapi_sqlalchemy import db
+
+from superagi.helper.auth import check_auth
+from superagi.models.agent_execution_config import AgentExecutionConfiguration
+
+router = APIRouter()
+
+
+@router.get("/details/{agent_execution_id}")
+def get_agent_execution_configuration(agent_execution_id: int,
+ Authorize: AuthJWT = Depends(check_auth)):
+ """
+ Get the agent execution configuration using the agent execution ID.
+
+ Args:
+ agent_execution_id (int): Identifier of the agent.
+ Authorize (AuthJWT, optional): Authorization dependency. Defaults to Depends(check_auth).
+
+ Returns:
+ dict: Agent Execution configuration including its details.
+
+ Raises:
+ HTTPException (status_code=404): If the agent is not found.
+ """
+
+ agent_execution_config = db.session.query(AgentExecutionConfiguration).filter(
+ AgentExecutionConfiguration.agent_execution_id == agent_execution_id
+ ).all()
+ if not agent_execution_config:
+ raise HTTPException(status_code=404, detail="Agent Execution Configuration not found")
+ response = {result.key: eval(result.value) for result in agent_execution_config}
+
+ return response
diff --git a/superagi/controllers/agent_execution_feed.py b/superagi/controllers/agent_execution_feed.py
index 81e1f3e35..262c0610e 100644
--- a/superagi/controllers/agent_execution_feed.py
+++ b/superagi/controllers/agent_execution_feed.py
@@ -1,10 +1,12 @@
from datetime import datetime
+from typing import Optional
from fastapi import APIRouter
from fastapi import HTTPException, Depends
from fastapi_jwt_auth import AuthJWT
from fastapi_sqlalchemy import db
-from pydantic_sqlalchemy import sqlalchemy_to_pydantic
+from pydantic import BaseModel
+
from sqlalchemy.sql import asc
from superagi.agent.task_queue import TaskQueue
@@ -14,13 +16,39 @@
from superagi.helper.feed_parser import parse_feed
from superagi.models.agent_execution import AgentExecution
from superagi.models.agent_execution_feed import AgentExecutionFeed
+# from superagi.types.db import AgentExecutionFeedOut, AgentExecutionFeedIn
router = APIRouter()
+class AgentExecutionFeedOut(BaseModel):
+ id: int
+ agent_execution_id: int
+ agent_id: int
+ feed: str
+ role: str
+ extra_info: Optional[str]
+ created_at: datetime
+ updated_at: datetime
+
+ class Config:
+ orm_mode = True
+
+
+class AgentExecutionFeedIn(BaseModel):
+ id: int
+ agent_execution_id: int
+ agent_id: int
+ feed: str
+ role: str
+ extra_info: str
+
+ class Config:
+ orm_mode = True
+
# CRUD Operations
-@router.post("/add", response_model=sqlalchemy_to_pydantic(AgentExecutionFeed), status_code=201)
-def create_agent_execution_feed(agent_execution_feed: sqlalchemy_to_pydantic(AgentExecutionFeed, exclude=["id"]),
+@router.post("/add", response_model=AgentExecutionFeedOut, status_code=201)
+def create_agent_execution_feed(agent_execution_feed: AgentExecutionFeedIn,
Authorize: AuthJWT = Depends(check_auth)):
"""
Add a new agent execution feed.
@@ -48,7 +76,7 @@ def create_agent_execution_feed(agent_execution_feed: sqlalchemy_to_pydantic(Age
return db_agent_execution_feed
-@router.get("/get/{agent_execution_feed_id}", response_model=sqlalchemy_to_pydantic(AgentExecutionFeed))
+@router.get("/get/{agent_execution_feed_id}", response_model=AgentExecutionFeedOut)
def get_agent_execution_feed(agent_execution_feed_id: int,
Authorize: AuthJWT = Depends(check_auth)):
"""
@@ -71,9 +99,9 @@ def get_agent_execution_feed(agent_execution_feed_id: int,
return db_agent_execution_feed
-@router.put("/update/{agent_execution_feed_id}", response_model=sqlalchemy_to_pydantic(AgentExecutionFeed))
+@router.put("/update/{agent_execution_feed_id}", response_model=AgentExecutionFeedOut)
def update_agent_execution_feed(agent_execution_feed_id: int,
- agent_execution_feed: sqlalchemy_to_pydantic(AgentExecutionFeed, exclude=["id"]),
+ agent_execution_feed: AgentExecutionFeedIn,
Authorize: AuthJWT = Depends(check_auth)):
"""
Update a particular agent execution feed.
diff --git a/superagi/controllers/agent_execution_permission.py b/superagi/controllers/agent_execution_permission.py
index e7c4f9e3b..714ff4a47 100644
--- a/superagi/controllers/agent_execution_permission.py
+++ b/superagi/controllers/agent_execution_permission.py
@@ -4,16 +4,45 @@
from fastapi_sqlalchemy import db
from fastapi import HTTPException, Depends, Body
from fastapi_jwt_auth import AuthJWT
+from pydantic import BaseModel
from superagi.models.agent_execution_permission import AgentExecutionPermission
from superagi.worker import execute_agent
from fastapi import APIRouter
-from pydantic_sqlalchemy import sqlalchemy_to_pydantic
+
from superagi.helper.auth import check_auth
+# from superagi.types.db import AgentExecutionPermissionOut, AgentExecutionPermissionIn
router = APIRouter()
+class AgentExecutionPermissionOut(BaseModel):
+ id: int
+ agent_execution_id: int
+ agent_id: int
+ status: str
+ tool_name: str
+ user_feedback: str
+ assistant_reply: str
+ created_at: datetime
+ updated_at: datetime
+
+ class Config:
+ orm_mode = True
+
+
+class AgentExecutionPermissionIn(BaseModel):
+ agent_execution_id: int
+ agent_id: int
+ status: str
+ tool_name: str
+ user_feedback: str
+ assistant_reply: str
+
+ class Config:
+ orm_mode = True
+
+
@router.get("/get/{agent_execution_permission_id}")
def get_agent_execution_permission(agent_execution_permission_id: int,
Authorize: AuthJWT = Depends(check_auth)):
@@ -37,9 +66,9 @@ def get_agent_execution_permission(agent_execution_permission_id: int,
return db_agent_execution_permission
-@router.post("/add", response_model=sqlalchemy_to_pydantic(AgentExecutionPermission))
+@router.post("/add", response_model=AgentExecutionPermissionOut)
def create_agent_execution_permission(
- agent_execution_permission: sqlalchemy_to_pydantic(AgentExecutionPermission, exclude=["id"])
+ agent_execution_permission: AgentExecutionPermissionIn
, Authorize: AuthJWT = Depends(check_auth)):
"""
Create a new agent execution permission.
@@ -58,10 +87,9 @@ def create_agent_execution_permission(
@router.patch("/update/{agent_execution_permission_id}",
- response_model=sqlalchemy_to_pydantic(AgentExecutionPermission, exclude=["id"]))
+ response_model=AgentExecutionPermissionIn)
def update_agent_execution_permission(agent_execution_permission_id: int,
- agent_execution_permission: sqlalchemy_to_pydantic(AgentExecutionPermission,
- exclude=["id"]),
+ agent_execution_permission: AgentExecutionPermissionIn,
Authorize: AuthJWT = Depends(check_auth)):
"""
Update an AgentExecutionPermission in the database.
diff --git a/superagi/controllers/agent_template.py b/superagi/controllers/agent_template.py
index 8839cdb3f..140622aee 100644
--- a/superagi/controllers/agent_template.py
+++ b/superagi/controllers/agent_template.py
@@ -1,7 +1,9 @@
+from datetime import datetime
+
from fastapi import APIRouter
from fastapi import HTTPException, Depends
from fastapi_sqlalchemy import db
-from pydantic_sqlalchemy import sqlalchemy_to_pydantic
+from pydantic import BaseModel
from main import get_config
from superagi.helper.auth import get_user_organisation
@@ -11,12 +13,37 @@
from superagi.models.agent_template_config import AgentTemplateConfig
from superagi.models.agent_workflow import AgentWorkflow
from superagi.models.tool import Tool
+# from superagi.types.db import AgentTemplateIn, AgentTemplateOut
router = APIRouter()
-@router.post("/create", status_code=201, response_model=sqlalchemy_to_pydantic(AgentTemplate))
-def create_agent_template(agent_template: sqlalchemy_to_pydantic(AgentTemplate, exclude=["id"]),
+class AgentTemplateOut(BaseModel):
+ id: int
+ organisation_id: int
+ agent_workflow_id: int
+ name: str
+ description: str
+ marketplace_template_id: int
+ created_at: datetime
+ updated_at: datetime
+
+ class Config:
+ orm_mode = True
+
+
+class AgentTemplateIn(BaseModel):
+ organisation_id: int
+ agent_workflow_id: int
+ name: str
+ description: str
+ marketplace_template_id: int
+
+ class Config:
+ orm_mode = True
+
+@router.post("/create", status_code=201, response_model=AgentTemplateOut)
+def create_agent_template(agent_template: AgentTemplateIn,
organisation=Depends(get_user_organisation)):
"""
Create an agent template.
@@ -81,7 +108,7 @@ def get_agent_template(template_source, agent_template_id: int, organisation=Dep
return template
-@router.post("/update_details/{agent_template_id}", response_model=sqlalchemy_to_pydantic(AgentTemplate))
+@router.post("/update_details/{agent_template_id}", response_model=AgentTemplateOut)
def update_agent_template(agent_template_id: int,
agent_configs: dict,
organisation=Depends(get_user_organisation)):
diff --git a/superagi/controllers/budget.py b/superagi/controllers/budget.py
index 41452e25a..9503acaaf 100644
--- a/superagi/controllers/budget.py
+++ b/superagi/controllers/budget.py
@@ -2,16 +2,33 @@
from fastapi import HTTPException, Depends
from fastapi_jwt_auth import AuthJWT
from fastapi_sqlalchemy import db
-from pydantic_sqlalchemy import sqlalchemy_to_pydantic
+from pydantic import BaseModel
from superagi.helper.auth import check_auth
from superagi.models.budget import Budget
+# from superagi.types.db import BudgetIn, BudgetOut
router = APIRouter()
-@router.post("/add", response_model=sqlalchemy_to_pydantic(Budget), status_code=201)
-def create_budget(budget: sqlalchemy_to_pydantic(Budget, exclude=["id"]),
+class BudgetOut(BaseModel):
+ id: int
+ budget: float
+ cycle: str
+
+ class Config:
+ orm_mode = True
+
+
+class BudgetIn(BaseModel):
+ budget: float
+ cycle: str
+
+ class Config:
+ orm_mode = True
+
+@router.post("/add", response_model=BudgetOut, status_code=201)
+def create_budget(budget: BudgetIn,
Authorize: AuthJWT = Depends(check_auth)):
"""
Create a new budget.
@@ -34,7 +51,7 @@ def create_budget(budget: sqlalchemy_to_pydantic(Budget, exclude=["id"]),
return new_budget
-@router.get("/get/{budget_id}", response_model=sqlalchemy_to_pydantic(Budget))
+@router.get("/get/{budget_id}", response_model=BudgetOut)
def get_budget(budget_id: int,
Authorize: AuthJWT = Depends(check_auth)):
"""
@@ -54,8 +71,8 @@ def get_budget(budget_id: int,
return db_budget
-@router.put("/update/{budget_id}", response_model=sqlalchemy_to_pydantic(Budget))
-def update_budget(budget_id: int, budget: sqlalchemy_to_pydantic(Budget, exclude=["id"]),
+@router.put("/update/{budget_id}", response_model=BudgetOut)
+def update_budget(budget_id: int, budget: BudgetIn,
Authorize: AuthJWT = Depends(check_auth)):
"""
Update budget details by budget_id.
diff --git a/superagi/controllers/config.py b/superagi/controllers/config.py
index 0c16841f8..84ff35be3 100644
--- a/superagi/controllers/config.py
+++ b/superagi/controllers/config.py
@@ -1,5 +1,9 @@
+from datetime import datetime
+from typing import Optional
+
from fastapi import APIRouter
-from pydantic_sqlalchemy import sqlalchemy_to_pydantic
+from pydantic import BaseModel
+
from superagi.models.configuration import Configuration
from superagi.models.organisation import Organisation
from fastapi_sqlalchemy import db
@@ -9,14 +13,35 @@
from fastapi_jwt_auth import AuthJWT
from superagi.helper.encyption_helper import encrypt_data,decrypt_data
from superagi.lib.logger import logger
+# from superagi.types.db import ConfigurationIn, ConfigurationOut
router = APIRouter()
+class ConfigurationOut(BaseModel):
+ id: int
+ organisation_id: int
+ key: str
+ value: str
+ created_at: datetime
+ updated_at: datetime
+
+ class Config:
+ orm_mode = True
+
+
+class ConfigurationIn(BaseModel):
+ organisation_id: Optional[int]
+ key: str
+ value: str
+
+ class Config:
+ orm_mode = True
+
# CRUD Operations
@router.post("/add/organisation/{organisation_id}", status_code=201,
- response_model=sqlalchemy_to_pydantic(Configuration))
-def create_config(config: sqlalchemy_to_pydantic(Configuration, exclude=["id"]), organisation_id: int,
+ response_model=ConfigurationOut)
+def create_config(config: ConfigurationIn, organisation_id: int,
Authorize: AuthJWT = Depends(check_auth)):
"""
Creates a new Organisation level config.
diff --git a/superagi/controllers/organisation.py b/superagi/controllers/organisation.py
index f4e984293..dae46a532 100644
--- a/superagi/controllers/organisation.py
+++ b/superagi/controllers/organisation.py
@@ -1,8 +1,10 @@
+from datetime import datetime
+
from fastapi import APIRouter
from fastapi import HTTPException, Depends
from fastapi_jwt_auth import AuthJWT
from fastapi_sqlalchemy import db
-from pydantic_sqlalchemy import sqlalchemy_to_pydantic
+from pydantic import BaseModel
from superagi.helper.auth import check_auth
from superagi.helper.tool_helper import register_toolkits
@@ -10,13 +12,32 @@
from superagi.models.project import Project
from superagi.models.user import User
from superagi.lib.logger import logger
+# from superagi.types.db import OrganisationIn, OrganisationOut
router = APIRouter()
+class OrganisationOut(BaseModel):
+ id: int
+ name: str
+ description: str
+ created_at: datetime
+ updated_at: datetime
+
+ class Config:
+ orm_mode = True
+
+
+class OrganisationIn(BaseModel):
+ name: str
+ description: str
+
+ class Config:
+ orm_mode = True
+
# CRUD Operations
-@router.post("/add", response_model=sqlalchemy_to_pydantic(Organisation), status_code=201)
-def create_organisation(organisation: sqlalchemy_to_pydantic(Organisation, exclude=["id"]),
+@router.post("/add", response_model=OrganisationOut, status_code=201)
+def create_organisation(organisation: OrganisationIn,
Authorize: AuthJWT = Depends(check_auth)):
"""
Create a new organisation.
@@ -45,7 +66,7 @@ def create_organisation(organisation: sqlalchemy_to_pydantic(Organisation, exclu
return new_organisation
-@router.get("/get/{organisation_id}", response_model=sqlalchemy_to_pydantic(Organisation))
+@router.get("/get/{organisation_id}", response_model=OrganisationOut)
def get_organisation(organisation_id: int, Authorize: AuthJWT = Depends(check_auth)):
"""
Get organisation details by organisation_id.
@@ -67,8 +88,8 @@ def get_organisation(organisation_id: int, Authorize: AuthJWT = Depends(check_au
return db_organisation
-@router.put("/update/{organisation_id}", response_model=sqlalchemy_to_pydantic(Organisation))
-def update_organisation(organisation_id: int, organisation: sqlalchemy_to_pydantic(Organisation, exclude=["id"]),
+@router.put("/update/{organisation_id}", response_model=OrganisationOut)
+def update_organisation(organisation_id: int, organisation: OrganisationIn,
Authorize: AuthJWT = Depends(check_auth)):
"""
Update organisation details by organisation_id.
@@ -96,7 +117,7 @@ def update_organisation(organisation_id: int, organisation: sqlalchemy_to_pydant
return db_organisation
-@router.get("/get/user/{user_id}", response_model=sqlalchemy_to_pydantic(Organisation), status_code=201)
+@router.get("/get/user/{user_id}", response_model=OrganisationOut, status_code=201)
def get_organisations_by_user(user_id: int):
"""
Get organisations associated with a user.If Organisation does not exists a new organisation is created
diff --git a/superagi/controllers/project.py b/superagi/controllers/project.py
index 8894137cb..aed73b5f2 100644
--- a/superagi/controllers/project.py
+++ b/superagi/controllers/project.py
@@ -1,19 +1,39 @@
-from fastapi_sqlalchemy import DBSessionMiddleware, db
-from fastapi import HTTPException, Depends, Request
+from fastapi_sqlalchemy import db
+from fastapi import HTTPException, Depends
from fastapi_jwt_auth import AuthJWT
+from pydantic import BaseModel
+
from superagi.models.project import Project
from superagi.models.organisation import Organisation
from fastapi import APIRouter
-from pydantic_sqlalchemy import sqlalchemy_to_pydantic
from superagi.helper.auth import check_auth
from superagi.lib.logger import logger
+# from superagi.types.db import ProjectIn, ProjectOut
router = APIRouter()
+class ProjectOut(BaseModel):
+ id: int
+ name: str
+ organisation_id: int
+ description: str
+
+ class Config:
+ orm_mode = True
+
+
+class ProjectIn(BaseModel):
+ name: str
+ organisation_id: int
+ description: str
+
+ class Config:
+ orm_mode = True
+
# CRUD Operations
-@router.post("/add", response_model=sqlalchemy_to_pydantic(Project), status_code=201)
-def create_project(project: sqlalchemy_to_pydantic(Project, exclude=["id"]),
+@router.post("/add", response_model=ProjectOut, status_code=201)
+def create_project(project: ProjectIn,
Authorize: AuthJWT = Depends(check_auth)):
"""
Create a new project.
@@ -47,7 +67,7 @@ def create_project(project: sqlalchemy_to_pydantic(Project, exclude=["id"]),
return project
-@router.get("/get/{project_id}", response_model=sqlalchemy_to_pydantic(Project))
+@router.get("/get/{project_id}", response_model=ProjectOut)
def get_project(project_id: int, Authorize: AuthJWT = Depends(check_auth)):
"""
Get project details by project_id.
@@ -69,8 +89,8 @@ def get_project(project_id: int, Authorize: AuthJWT = Depends(check_auth)):
return db_project
-@router.put("/update/{project_id}", response_model=sqlalchemy_to_pydantic(Project))
-def update_project(project_id: int, project: sqlalchemy_to_pydantic(Project, exclude=["id"]),
+@router.put("/update/{project_id}", response_model=ProjectOut)
+def update_project(project_id: int, project: ProjectIn,
Authorize: AuthJWT = Depends(check_auth)):
"""
Update a project detail by project_id.
diff --git a/superagi/controllers/resources.py b/superagi/controllers/resources.py
index ed96d4b60..eee0ddd62 100644
--- a/superagi/controllers/resources.py
+++ b/superagi/controllers/resources.py
@@ -17,10 +17,11 @@
from superagi.lib.logger import logger
from superagi.models.agent import Agent
from superagi.models.resource import Resource
+from superagi.worker import summarize_resource
+from superagi.types.storage_types import StorageType
router = APIRouter()
-
s3 = boto3.client(
's3',
aws_access_key_id=get_config("AWS_ACCESS_KEY_ID"),
@@ -55,24 +56,25 @@ async def upload(agent_id: int, file: UploadFile = File(...), name=Form(...), si
if agent is None:
raise HTTPException(status_code=400, detail="Agent does not exists")
- if not name.endswith(".txt") and not name.endswith(".pdf"):
+ # accepted_file_types is a tuple because endswith() expects a tuple
+ accepted_file_types = (".pdf", ".docx", ".pptx", ".csv", ".txt", ".epub")
+ if not name.endswith(accepted_file_types):
raise HTTPException(status_code=400, detail="File type not supported!")
- storage_type = get_config("STORAGE_TYPE")
- Resource.validate_resource_type(storage_type)
+ storage_type = StorageType.get_storage_type(get_config("STORAGE_TYPE"))
save_directory = ResourceHelper.get_root_input_dir() + "/"
if "{agent_id}" in save_directory:
save_directory = save_directory.replace("{agent_id}", str(agent_id))
path = ""
os.makedirs(save_directory, exist_ok=True)
file_path = os.path.join(save_directory, file.filename)
- if storage_type == "FILE":
+ if storage_type == StorageType.FILE:
path = file_path
with open(file_path, "wb") as f:
contents = await file.read()
f.write(contents)
file.file.close()
- elif storage_type == "S3":
+ elif storage_type == StorageType.S3:
bucket_name = get_config("BUCKET_NAME")
file_name = file.filename.split('.')
path = 'input/' + file_name[0] + '_' + str(datetime.datetime.now()).replace(' ', '').replace('.', '').replace(
@@ -83,12 +85,16 @@ async def upload(agent_id: int, file: UploadFile = File(...), name=Form(...), si
except NoCredentialsError:
raise HTTPException(status_code=500, detail="AWS credentials not found. Check your configuration.")
- resource = Resource(name=name, path=path, storage_type=storage_type, size=size, type=type, channel="INPUT",
+ resource = Resource(name=name, path=path, storage_type=storage_type.value, size=size, type=type, channel="INPUT",
agent_id=agent.id)
+
db.session.add(resource)
db.session.commit()
db.session.flush()
+
+ summarize_resource.delay(agent_id, resource.id)
logger.info(resource)
+
return resource
@@ -136,7 +142,7 @@ def download_file_by_id(resource_id: int,
if not resource:
raise HTTPException(status_code=400, detail="Resource Not found!")
- if resource.storage_type == "S3":
+ if resource.storage_type == StorageType.S3.value:
bucket_name = get_config("BUCKET_NAME")
file_key = resource.path
response = s3.get_object(Bucket=bucket_name, Key=file_key)
diff --git a/superagi/controllers/tool.py b/superagi/controllers/tool.py
index f4c860167..ef55aa2c6 100644
--- a/superagi/controllers/tool.py
+++ b/superagi/controllers/tool.py
@@ -1,8 +1,10 @@
+from datetime import datetime
+
from fastapi import APIRouter
from fastapi import HTTPException, Depends
from fastapi_jwt_auth import AuthJWT
from fastapi_sqlalchemy import db
-from pydantic_sqlalchemy import sqlalchemy_to_pydantic
+from pydantic import BaseModel
from superagi.helper.auth import check_auth, get_user_organisation
from superagi.models.organisation import Organisation
@@ -12,17 +14,39 @@
router = APIRouter()
+class ToolOut(BaseModel):
+ id: int
+ name: str
+ folder_name: str
+ class_name: str
+ file_name: str
+ created_at: datetime
+ updated_at: datetime
+
+ class Config:
+ orm_mode = True
+
+
+class ToolIn(BaseModel):
+ name: str
+ folder_name: str
+ class_name: str
+ file_name: str
+
+ class Config:
+ orm_mode = True
+
# CRUD Operations
-@router.post("/add", response_model=sqlalchemy_to_pydantic(Tool), status_code=201)
+@router.post("/add", response_model=ToolOut, status_code=201)
def create_tool(
- tool: sqlalchemy_to_pydantic(Tool, exclude=["id"]),
+ tool: ToolIn,
Authorize: AuthJWT = Depends(check_auth),
):
"""
Create a new tool.
Args:
- tool (sqlalchemy_to_pydantic(Tool, exclude=["id"])): Tool data.
+ tool (ToolIn): Tool data.
Returns:
Tool: The created tool.
@@ -43,7 +67,7 @@ def create_tool(
return db_tool
-@router.get("/get/{tool_id}", response_model=sqlalchemy_to_pydantic(Tool))
+@router.get("/get/{tool_id}", response_model=ToolOut)
def get_tool(
tool_id: int,
Authorize: AuthJWT = Depends(check_auth),
@@ -80,10 +104,10 @@ def get_tools(
return tools
-@router.put("/update/{tool_id}", response_model=sqlalchemy_to_pydantic(Tool))
+@router.put("/update/{tool_id}", response_model=ToolOut)
def update_tool(
tool_id: int,
- tool: sqlalchemy_to_pydantic(Tool, exclude=["id"]),
+ tool: ToolIn,
Authorize: AuthJWT = Depends(check_auth),
):
"""
@@ -91,7 +115,7 @@ def update_tool(
Args:
tool_id (int): ID of the tool.
- tool (sqlalchemy_to_pydantic(Tool, exclude=["id"])): Updated tool data.
+ tool (ToolIn): Updated tool data.
Returns:
Tool: The updated tool details.
diff --git a/superagi/controllers/tool_config.py b/superagi/controllers/tool_config.py
index 709b15bf4..c4a169760 100644
--- a/superagi/controllers/tool_config.py
+++ b/superagi/controllers/tool_config.py
@@ -1,7 +1,7 @@
from fastapi import APIRouter, HTTPException, Depends
from fastapi_jwt_auth import AuthJWT
from fastapi_sqlalchemy import db
-from pydantic_sqlalchemy import sqlalchemy_to_pydantic
+from pydantic import BaseModel
from superagi.helper.auth import check_auth
from superagi.helper.auth import get_user_organisation
@@ -11,6 +11,14 @@
router = APIRouter()
+class ToolConfigOut(BaseModel):
+ id = int
+ key = str
+ value = str
+ toolkit_id = int
+
+ class Config:
+ orm_mode = True
@router.post("/add/{toolkit_name}", status_code=201)
def update_tool_config(toolkit_name: str, configs: list, organisation: Organisation = Depends(get_user_organisation)):
@@ -56,7 +64,7 @@ def update_tool_config(toolkit_name: str, configs: list, organisation: Organisat
raise HTTPException(status_code=500, detail=str(e))
-@router.post("/create-or-update/{toolkit_name}", status_code=201, response_model=sqlalchemy_to_pydantic(ToolConfig))
+@router.post("/create-or-update/{toolkit_name}", status_code=201, response_model=ToolConfigOut)
def create_or_update_tool_config(toolkit_name: str, tool_configs,
Authorize: AuthJWT = Depends(check_auth)):
"""
diff --git a/superagi/controllers/twitter_oauth.py b/superagi/controllers/twitter_oauth.py
new file mode 100644
index 000000000..b79b7be66
--- /dev/null
+++ b/superagi/controllers/twitter_oauth.py
@@ -0,0 +1,70 @@
+from fastapi import Depends, Query
+from fastapi import APIRouter
+from fastapi.responses import RedirectResponse
+from fastapi_jwt_auth import AuthJWT
+from fastapi_sqlalchemy import db
+from sqlalchemy.orm import sessionmaker
+
+import superagi
+import json
+from superagi.models.db import connect_db
+import http.client as http_client
+from superagi.helper.twitter_tokens import TwitterTokens
+from superagi.helper.auth import get_current_user
+from superagi.models.tool_config import ToolConfig
+from superagi.models.toolkit import Toolkit
+from superagi.models.oauth_tokens import OauthTokens
+
+router = APIRouter()
+
+@router.get('/oauth-tokens')
+async def twitter_oauth(oauth_token: str = Query(...),oauth_verifier: str = Query(...), Authorize: AuthJWT = Depends()):
+ print("///////////////////////////")
+ print(oauth_token)
+ token_uri = f'https://api.twitter.com/oauth/access_token?oauth_verifier={oauth_verifier}&oauth_token={oauth_token}'
+ conn = http_client.HTTPSConnection("api.twitter.com")
+ conn.request("POST", token_uri, "")
+ res = conn.getresponse()
+ response_data = res.read().decode('utf-8')
+ frontend_url = superagi.config.config.get_config("FRONTEND_URL", "http://localhost:3000")
+ redirect_url_success = f"{frontend_url}/twitter_creds/?{response_data}"
+ return RedirectResponse(url=redirect_url_success)
+
+@router.post("/send_twitter_creds/{twitter_creds}")
+def send_twitter_tool_configs(twitter_creds: str, Authorize: AuthJWT = Depends()):
+ engine = connect_db()
+ Session = sessionmaker(bind=engine)
+ session = Session()
+ current_user = get_current_user()
+ user_id = current_user.id
+ credentials = json.loads(twitter_creds)
+ credentials["user_id"] = user_id
+ toolkit = db.session.query(Toolkit).filter(Toolkit.id == credentials["toolkit_id"]).first()
+ api_key = db.session.query(ToolConfig).filter(ToolConfig.key == "TWITTER_API_KEY", ToolConfig.toolkit_id == credentials["toolkit_id"]).first()
+ api_key_secret = db.session.query(ToolConfig).filter(ToolConfig.key == "TWITTER_API_SECRET", ToolConfig.toolkit_id == credentials["toolkit_id"]).first()
+ final_creds = {
+ "api_key": api_key.value,
+ "api_key_secret": api_key_secret.value,
+ "oauth_token": credentials["oauth_token"],
+ "oauth_token_secret": credentials["oauth_token_secret"]
+ }
+ tokens = OauthTokens().add_or_update(session, credentials["toolkit_id"], user_id, toolkit.organisation_id, "TWITTER_OAUTH_TOKENS", str(final_creds))
+ if tokens:
+ success = True
+ else:
+ success = False
+ return success
+
+@router.get("/get_twitter_creds/toolkit_id/{toolkit_id}")
+def get_twitter_tool_configs(toolkit_id: int):
+ engine = connect_db()
+ Session = sessionmaker(bind=engine)
+ session = Session()
+ twitter_config_key = session.query(ToolConfig).filter(ToolConfig.toolkit_id == toolkit_id,ToolConfig.key == "TWITTER_API_KEY").first()
+ twitter_config_secret = session.query(ToolConfig).filter(ToolConfig.toolkit_id == toolkit_id,ToolConfig.key == "TWITTER_API_SECRET").first()
+ api_data = {
+ "api_key": twitter_config_key.value,
+ "api_secret": twitter_config_secret.value
+ }
+ response = TwitterTokens(session).get_request_token(api_data)
+ return response
\ No newline at end of file
diff --git a/superagi/controllers/user.py b/superagi/controllers/user.py
index c1c278b96..19bf2c935 100644
--- a/superagi/controllers/user.py
+++ b/superagi/controllers/user.py
@@ -1,27 +1,56 @@
+from datetime import datetime
+from typing import Optional
+
from fastapi_sqlalchemy import db
from fastapi import HTTPException, Depends, Request
from fastapi_jwt_auth import AuthJWT
+from pydantic import BaseModel
from superagi.models.organisation import Organisation
from superagi.models.project import Project
from superagi.models.user import User
from fastapi import APIRouter
-from pydantic_sqlalchemy import sqlalchemy_to_pydantic
+
from superagi.helper.auth import check_auth
from superagi.lib.logger import logger
+# from superagi.types.db import UserBase, UserIn, UserOut
router = APIRouter()
+class UserBase(BaseModel):
+ name: str
+ email: str
+ password: str
+
+ class Config:
+ orm_mode = True
+
+
+class UserOut(UserBase):
+ id: int
+ organisation_id: int
+ created_at: datetime
+ updated_at: datetime
+
+ class Config:
+ orm_mode = True
+
+
+class UserIn(UserBase):
+ organisation_id: Optional[int]
+
+ class Config:
+ orm_mode = True
# CRUD Operations
-@router.post("/add", response_model=sqlalchemy_to_pydantic(User), status_code=201)
-def create_user(user: sqlalchemy_to_pydantic(User, exclude=["id"]),
+@router.post("/add", response_model=UserOut, status_code=201)
+def create_user(user: UserIn,
Authorize: AuthJWT = Depends(check_auth)):
"""
Create a new user.
Args:
- user (sqlalchemy_to_pydantic(User, exclude=["id"])): User data.
+ user (UserIn): User data.
Returns:
User: The created user.
@@ -44,7 +73,7 @@ def create_user(user: sqlalchemy_to_pydantic(User, exclude=["id"]),
return db_user
-@router.get("/get/{user_id}", response_model=sqlalchemy_to_pydantic(User))
+@router.get("/get/{user_id}", response_model=UserOut)
def get_user(user_id: int,
Authorize: AuthJWT = Depends(check_auth)):
"""
@@ -68,16 +97,16 @@ def get_user(user_id: int,
return db_user
-@router.put("/update/{user_id}", response_model=sqlalchemy_to_pydantic(User))
+@router.put("/update/{user_id}", response_model=UserOut)
def update_user(user_id: int,
- user: sqlalchemy_to_pydantic(User, exclude=["id"]),
+ user: UserBase,
Authorize: AuthJWT = Depends(check_auth)):
"""
Update a particular user.
Args:
user_id (int): ID of the user.
- user (sqlalchemy_to_pydantic(User, exclude=["id"])): Updated user data.
+ user (UserIn): Updated user data.
Returns:
User: The updated user details.
diff --git a/superagi/helper/auth.py b/superagi/helper/auth.py
index 5a80677e6..a185ece83 100644
--- a/superagi/helper/auth.py
+++ b/superagi/helper/auth.py
@@ -33,6 +33,13 @@ def get_user_organisation(Authorize: AuthJWT = Depends(check_auth)):
Returns:
Organisation: Instance of Organisation class to which the authenticated user belongs.
"""
+ user = get_current_user()
+ if user is None:
+ raise HTTPException(status_code=401, detail="Unauthenticated")
+ organisation = db.session.query(Organisation).filter(Organisation.id == user.organisation_id).first()
+ return organisation
+
+def get_current_user(Authorize: AuthJWT = Depends(check_auth)):
env = get_config("ENV", "DEV")
if env == "DEV":
@@ -43,7 +50,4 @@ def get_user_organisation(Authorize: AuthJWT = Depends(check_auth)):
# Query the User table to find the user by their email
user = db.session.query(User).filter(User.email == email).first()
- if user is None:
- raise HTTPException(status_code=401, detail="Unauthenticated")
- organisation = db.session.query(Organisation).filter(Organisation.id == user.organisation_id).first()
- return organisation
\ No newline at end of file
+ return user
\ No newline at end of file
diff --git a/superagi/helper/google_calendar_creds.py b/superagi/helper/google_calendar_creds.py
index 394fa901c..9da707fdb 100644
--- a/superagi/helper/google_calendar_creds.py
+++ b/superagi/helper/google_calendar_creds.py
@@ -10,7 +10,7 @@
from sqlalchemy.orm import sessionmaker
from superagi.models.db import connect_db
from superagi.models.tool_config import ToolConfig
-from superagi.resource_manager.manager import ResourceManager
+from superagi.resource_manager.file_manager import FileManager
class GoogleCalendarCreds:
@@ -28,7 +28,7 @@ def get_credentials(self, toolkit_id):
engine = connect_db()
Session = sessionmaker(bind=engine)
session = Session()
- resource_manager: ResourceManager = None
+ resource_manager: FileManager = None
with open(file_path,'rb') as file:
creds = pickle.load(file)
if isinstance(creds, str):
diff --git a/superagi/helper/resource_helper.py b/superagi/helper/resource_helper.py
index 76bf2adf0..7001840b8 100644
--- a/superagi/helper/resource_helper.py
+++ b/superagi/helper/resource_helper.py
@@ -3,6 +3,7 @@
import os
import datetime
from superagi.lib.logger import logger
+from superagi.types.storage_types import StorageType
class ResourceHelper:
@@ -21,7 +22,7 @@ def make_written_file_resource(file_name: str, agent_id: int, channel: str):
Resource: The Resource object.
"""
path = ResourceHelper.get_root_output_dir()
- storage_type = get_config("STORAGE_TYPE")
+ storage_type = StorageType.get_storage_type(get_config("STORAGE_TYPE"))
file_extension = os.path.splitext(file_name)[1][1:]
if file_extension in ["png", "jpg", "jpeg"]:
@@ -38,14 +39,14 @@ def make_written_file_resource(file_name: str, agent_id: int, channel: str):
final_path = ResourceHelper.get_resource_path(file_name)
file_size = os.path.getsize(final_path)
- if storage_type == "S3":
+ if storage_type == StorageType.S3:
file_name_parts = file_name.split('.')
file_name = file_name_parts[0] + '_' + str(datetime.datetime.now()).replace(' ', '') \
.replace('.', '').replace(':', '') + '.' + file_name_parts[1]
path = 'input/' if (channel == "INPUT") else 'output/'
logger.info(final_path)
- resource = Resource(name=file_name, path=path + file_name, storage_type=storage_type, size=file_size,
+ resource = Resource(name=file_name, path=path + file_name, storage_type=storage_type.value, size=file_size,
type=file_type,
channel="OUTPUT",
agent_id=agent_id)
diff --git a/superagi/helper/twitter_helper.py b/superagi/helper/twitter_helper.py
new file mode 100644
index 000000000..e47c3d60d
--- /dev/null
+++ b/superagi/helper/twitter_helper.py
@@ -0,0 +1,42 @@
+import os
+import json
+import base64
+import requests
+from requests_oauthlib import OAuth1
+from requests_oauthlib import OAuth1Session
+from superagi.helper.resource_helper import ResourceHelper
+
+class TwitterHelper:
+
+ def get_media_ids(self, media_files, creds, agent_id):
+ media_ids = []
+ oauth = OAuth1(creds.api_key,
+ client_secret=creds.api_key_secret,
+ resource_owner_key=creds.oauth_token,
+ resource_owner_secret=creds.oauth_token_secret)
+ for file in media_files:
+ file_path = self.get_file_path(file, agent_id)
+ image_data = open(file_path, 'rb').read()
+ b64_image = base64.b64encode(image_data)
+ upload_endpoint = 'https://upload.twitter.com/1.1/media/upload.json'
+ headers = {'Authorization': 'application/octet-stream'}
+ response = requests.post(upload_endpoint, headers=headers,
+ data={'media_data': b64_image},
+ auth=oauth)
+ ids = json.loads(response.text)['media_id']
+ media_ids.append(str(ids))
+ return media_ids
+
+ def get_file_path(self, file_name, agent_id):
+ final_path = ResourceHelper().get_agent_resource_path(file_name, agent_id)
+ return final_path
+
+ def send_tweets(self, params, creds):
+ tweet_endpoint = "https://api.twitter.com/2/tweets"
+ oauth = OAuth1Session(creds.api_key,
+ client_secret=creds.api_key_secret,
+ resource_owner_key=creds.oauth_token,
+ resource_owner_secret=creds.oauth_token_secret)
+
+ response = oauth.post(tweet_endpoint,json=params)
+ return response
diff --git a/superagi/helper/twitter_tokens.py b/superagi/helper/twitter_tokens.py
new file mode 100644
index 000000000..10b36d59a
--- /dev/null
+++ b/superagi/helper/twitter_tokens.py
@@ -0,0 +1,77 @@
+import hmac
+import time
+import random
+import base64
+import hashlib
+import urllib.parse
+import ast
+import http.client as http_client
+from sqlalchemy.orm import Session
+from superagi.models.toolkit import Toolkit
+from superagi.models.oauth_tokens import OauthTokens
+
+class Creds:
+
+ def __init__(self,api_key, api_key_secret, oauth_token, oauth_token_secret):
+ self.api_key = api_key
+ self.api_key_secret = api_key_secret
+ self.oauth_token = oauth_token
+ self.oauth_token_secret = oauth_token_secret
+
+class TwitterTokens:
+
+ def __init__(self, session: Session):
+ self.session = session
+
+ def get_request_token(self,api_data):
+ api_key = api_data["api_key"]
+ api_secret_key = api_data["api_secret"]
+ http_method = 'POST'
+ base_url = 'https://api.twitter.com/oauth/request_token'
+
+ params = {
+ 'oauth_callback': 'http://localhost:3000/api/twitter/oauth-tokens',
+ 'oauth_consumer_key': api_key,
+ 'oauth_nonce': self.gen_nonce(),
+ 'oauth_signature_method': 'HMAC-SHA1',
+ 'oauth_timestamp': int(time.time()),
+ 'oauth_version': '1.0'
+ }
+
+ params_sorted = sorted(params.items())
+ params_qs = '&'.join([f'{k}={self.percent_encode(str(v))}' for k, v in params_sorted])
+
+ base_string = f'{http_method}&{self.percent_encode(base_url)}&{self.percent_encode(params_qs)}'
+
+ signing_key = f'{self.percent_encode(api_secret_key)}&'
+ signature = hmac.new(signing_key.encode(), base_string.encode(), hashlib.sha1)
+ params['oauth_signature'] = base64.b64encode(signature.digest()).decode()
+
+ auth_header = 'OAuth ' + ', '.join([f'{k}="{self.percent_encode(str(v))}"' for k, v in params.items()])
+
+ headers = {
+ 'Content-Type': 'application/x-www-form-urlencoded',
+ 'Authorization': auth_header
+ }
+ conn = http_client.HTTPSConnection("api.twitter.com")
+ conn.request("POST", "/oauth/request_token", "", headers)
+ res = conn.getresponse()
+ response_data = res.read().decode('utf-8')
+ conn.close()
+ request_token_resp = dict(urllib.parse.parse_qsl(response_data))
+ return request_token_resp
+
+ def percent_encode(self, val):
+ return urllib.parse.quote(val, safe='')
+
+ def gen_nonce(self):
+ nonce = ''.join([str(random.randint(0, 9)) for i in range(32)])
+ return nonce
+
+ def get_twitter_creds(self, toolkit_id):
+ toolkit = self.session.query(Toolkit).filter(Toolkit.id == toolkit_id).first()
+ organisation_id = toolkit.organisation_id
+ twitter_creds = self.session.query(OauthTokens).filter(OauthTokens.toolkit_id == toolkit_id, OauthTokens.organisation_id == organisation_id).first()
+ twitter_creds = ast.literal_eval(twitter_creds.value)
+ final_creds = Creds(twitter_creds['api_key'], twitter_creds['api_key_secret'], twitter_creds['oauth_token'], twitter_creds['oauth_token_secret'])
+ return final_creds
\ No newline at end of file
diff --git a/superagi/jobs/agent_executor.py b/superagi/jobs/agent_executor.py
index 11c52a709..fdd49b8ce 100644
--- a/superagi/jobs/agent_executor.py
+++ b/superagi/jobs/agent_executor.py
@@ -1,17 +1,19 @@
import importlib
from datetime import datetime, timedelta
-
from fastapi import HTTPException
+
from sqlalchemy.orm import sessionmaker
import superagi.worker
from superagi.agent.super_agi import SuperAgi
from superagi.config.config import get_config
from superagi.helper.encyption_helper import decrypt_data
+from superagi.resource_manager.resource_summary import ResourceSummarizer
from superagi.lib.logger import logger
from superagi.llms.openai import OpenAi
from superagi.models.agent import Agent
from superagi.models.agent_execution import AgentExecution
+from superagi.models.agent_execution_config import AgentExecutionConfiguration
from superagi.models.agent_execution_feed import AgentExecutionFeed
from superagi.models.agent_execution_permission import AgentExecutionPermission
from superagi.models.agent_workflow_step import AgentWorkflowStep
@@ -21,13 +23,18 @@
from superagi.models.project import Project
from superagi.models.tool import Tool
from superagi.models.tool_config import ToolConfig
+from superagi.models.resource import Resource
from superagi.tools.base_tool import BaseToolkitConfiguration
-from superagi.resource_manager.manager import ResourceManager
+from superagi.resource_manager.file_manager import FileManager
from superagi.tools.thinking.tools import ThinkingTool
+from superagi.tools.resource.query_resource import QueryResourceTool
from superagi.tools.tool_response_query_manager import ToolResponseQueryManager
from superagi.vector_store.embedding.openai import OpenAiEmbedding
from superagi.vector_store.vector_factory import VectorFactory
+from superagi.types.vector_store_types import VectorStoreType
+from superagi.models.agent_config import AgentConfiguration
import yaml
+
# from superagi.helper.tool_helper import get_tool_config_by_key
engine = connect_db()
@@ -48,6 +55,7 @@ def get_tool_config(self, key: str):
return tool_config.value
return super().get_tool_config(key=key)
+
class AgentExecutor:
@staticmethod
def validate_filename(filename):
@@ -65,7 +73,7 @@ def validate_filename(filename):
return filename
@staticmethod
- def create_object(tool,session):
+ def create_object(tool, session):
"""
Create an object of a agent usable tool dynamically.
@@ -157,6 +165,7 @@ def execute_next_action(self, agent_execution_id):
]
parsed_config = Agent.fetch_configuration(session, agent.id)
+ parsed_execution_config = AgentExecutionConfiguration.fetch_configuration(session, agent_execution)
max_iterations = (parsed_config["max_iterations"])
total_calls = agent_execution.num_of_calls
@@ -173,7 +182,7 @@ def execute_next_action(self, agent_execution_id):
try:
if parsed_config["LTM_DB"] == "Pinecone":
- memory = VectorFactory.get_vector_storage("PineCone", "super-agent-index1",
+ memory = VectorFactory.get_vector_storage(VectorStoreType.PINECONE, "super-agent-index1",
OpenAiEmbedding(model_api_key))
else:
memory = VectorFactory.get_vector_storage("PineCone", "super-agent-index1",
@@ -184,17 +193,23 @@ def execute_next_action(self, agent_execution_id):
user_tools = session.query(Tool).filter(Tool.id.in_(parsed_config["tools"])).all()
for tool in user_tools:
- tool = AgentExecutor.create_object(tool,session)
+ tool = AgentExecutor.create_object(tool, session)
tools.append(tool)
- tools = self.set_default_params_tools(tools, parsed_config, agent_execution.agent_id,
- model_api_key=model_api_key, session=session)
-
+ resource_summary = self.get_agent_resource_summary(agent_id=agent.id, session=session,
+ default_summary=parsed_config.get("resource_summary"))
+ if resource_summary is not None:
+ tools.append(QueryResourceTool())
+ tools = self.set_default_params_tools(tools, parsed_config,parsed_execution_config, agent_execution.agent_id,
+ model_api_key=model_api_key,
+ resource_description=resource_summary,
+ session=session)
spawned_agent = SuperAgi(ai_name=parsed_config["name"], ai_role=parsed_config["description"],
llm=OpenAi(model=parsed_config["model"], api_key=model_api_key), tools=tools,
memory=memory,
- agent_config=parsed_config)
+ agent_config=parsed_config,
+ agent_execution_config=parsed_execution_config)
try:
self.handle_wait_for_permission(agent_execution, spawned_agent, session)
@@ -203,7 +218,16 @@ def execute_next_action(self, agent_execution_id):
agent_workflow_step = session.query(AgentWorkflowStep).filter(
AgentWorkflowStep.id == agent_execution.current_step_id).first()
- response = spawned_agent.execute(agent_workflow_step)
+
+ try:
+ response = spawned_agent.execute(agent_workflow_step)
+ except RuntimeError as e:
+ superagi.worker.execute_agent.delay(agent_execution_id, datetime.now())
+ session.close()
+ # If our execution encounters an error we return and attempt to retry
+ return
+
+
if "retry" in response and response["retry"]:
response = spawned_agent.execute(agent_workflow_step)
agent_execution.current_step_id = agent_workflow_step.next_step_id
@@ -224,15 +248,18 @@ def execute_next_action(self, agent_execution_id):
session.close()
engine.dispose()
- def set_default_params_tools(self, tools, parsed_config, agent_id, model_api_key, session):
+ def set_default_params_tools(self, tools, parsed_config, parsed_execution_config, agent_id, model_api_key, session,
+ resource_description=None):
"""
Set the default parameters for the tools.
Args:
tools (list): The list of tools.
- parsed_config (dict): The parsed configuration.
+ parsed_config (dict): Parsed agent configuration.
+ parsed_execution_config (dict): Parsed execution configuration
agent_id (int): The ID of the agent.
model_api_key (str): The API key of the model.
+ resource_description (str): The description of the resource.
Returns:
list: The list of tools with default parameters.
@@ -240,23 +267,26 @@ def set_default_params_tools(self, tools, parsed_config, agent_id, model_api_key
new_tools = []
for tool in tools:
if hasattr(tool, 'goals'):
- tool.goals = parsed_config["goal"]
+ tool.goals = parsed_execution_config["goal"]
if hasattr(tool, 'instructions'):
- tool.instructions = parsed_config["instruction"]
- if hasattr(tool, 'llm') and (parsed_config["model"] == "gpt4" or parsed_config["model"] == "gpt-3.5-turbo"):
- tool.llm = OpenAi(model="gpt-3.5-turbo", api_key=model_api_key, temperature=0.3)
+ tool.instructions = parsed_execution_config["instruction"]
+ if hasattr(tool, 'llm') and (parsed_config["model"] == "gpt4" or parsed_config[
+ "model"] == "gpt-3.5-turbo") and tool.name != "Query Resource":
+ tool.llm = OpenAi(model="gpt-3.5-turbo", api_key=model_api_key, temperature=0.4)
elif hasattr(tool, 'llm'):
- tool.llm = OpenAi(model=parsed_config["model"], api_key=model_api_key, temperature=0.3)
+ tool.llm = OpenAi(model=parsed_config["model"], api_key=model_api_key, temperature=0.4)
if hasattr(tool, 'image_llm'):
tool.image_llm = OpenAi(model=parsed_config["model"], api_key=model_api_key)
if hasattr(tool, 'agent_id'):
tool.agent_id = agent_id
if hasattr(tool, 'resource_manager'):
- tool.resource_manager = ResourceManager(session=session, agent_id=agent_id)
+ tool.resource_manager = FileManager(session=session, agent_id=agent_id)
if hasattr(tool, 'tool_response_manager'):
tool.tool_response_manager = ToolResponseQueryManager(session=session, agent_execution_id=parsed_config[
"agent_execution_id"])
+ if tool.name == "Query Resource" and resource_description:
+ tool.description = tool.description.replace("{summary}", resource_description)
new_tools.append(tool)
return tools
@@ -292,3 +322,17 @@ def handle_wait_for_permission(self, agent_execution, spawned_agent, session):
session.add(agent_execution_feed)
agent_execution.status = "RUNNING"
session.commit()
+
+ def get_agent_resource_summary(self, agent_id: int, session: Session, default_summary: str):
+ ResourceSummarizer(session=session).generate_agent_summary(agent_id=agent_id,generate_all=True)
+ agent_config_resource_summary = session.query(AgentConfiguration). \
+ filter(AgentConfiguration.agent_id == agent_id,
+ AgentConfiguration.key == "resource_summary").first()
+ resource_summary = agent_config_resource_summary.value if agent_config_resource_summary is not None else default_summary
+ return resource_summary
+
+ def check_for_resource(self,agent_id: int, session: Session):
+ resource = session.query(Resource).filter(Resource.agent_id == agent_id,Resource.channel == 'INPUT').first()
+ if resource is None:
+ return False
+ return True
diff --git a/superagi/models/agent.py b/superagi/models/agent.py
index 2786300b7..e0206f4ca 100644
--- a/superagi/models/agent.py
+++ b/superagi/models/agent.py
@@ -12,7 +12,8 @@
# from superagi.models import AgentConfiguration
from superagi.models.base_model import DBBaseModel
from superagi.lib.logger import logger
-
+from superagi.models.organisation import Organisation
+from superagi.models.project import Project
class Agent(DBBaseModel):
"""
@@ -100,7 +101,7 @@ def eval_agent_config(cls, key, value):
"""
- if key in ["name", "description", "agent_type", "exit", "model", "permission_type", "LTM_DB"]:
+ if key in ["name", "description", "agent_type", "exit", "model", "permission_type", "LTM_DB", "resource_summary"]:
return value
elif key in ["project_id", "memory_window", "max_iterations", "iteration_interval"]:
return int(value)
@@ -229,3 +230,18 @@ def create_agent_with_marketplace_template_id(cls, db, project_id, agent_templat
db.session.commit()
db.session.flush()
return db_agent
+
+ def get_agent_organisation(self, session):
+ """
+ Get the organization of the agent.
+
+ Args:
+ session: The database session.
+
+ Returns:
+ Organization: The organization of the agent.
+
+ """
+ project = session.query(Project).filter(Project.id == self.project_id).first()
+ organisation = session.query(Organisation).filter(Organisation.id == project.organisation_id).first()
+ return organisation
\ No newline at end of file
diff --git a/superagi/models/agent_execution_config.py b/superagi/models/agent_execution_config.py
new file mode 100644
index 000000000..1a451d8c7
--- /dev/null
+++ b/superagi/models/agent_execution_config.py
@@ -0,0 +1,102 @@
+from sqlalchemy import Column, Integer, String, Text
+
+from superagi.models.base_model import DBBaseModel
+
+
+class AgentExecutionConfiguration(DBBaseModel):
+ """
+ Agent Execution related configurations like goals, instructions are stored here
+
+ Attributes:
+ id (int): The unique identifier of the agent execution config.
+ agent_execution_id (int): The identifier of the associated agent execution.
+ key (str): The key of the configuration setting.
+ value (str): The value of the configuration setting.
+ """
+
+ __tablename__ = 'agent_execution_configs'
+
+ id = Column(Integer, primary_key=True)
+ agent_execution_id = Column(Integer)
+ key = Column(String)
+ value = Column(Text)
+
+ def __repr__(self):
+ """
+ Returns a string representation of the AgentExecutionConfig object.
+
+ Returns:
+ str: String representation of the AgentTemplateConfig.
+ """
+
+ return f"AgentExecutionConfig(id={self.id}, agent_execution_id='{self.agent_execution_id}', " \
+ f"key='{self.key}', value='{self.value}')"
+
+ @classmethod
+ def add_or_update_agent_execution_config(cls, session, execution, agent_execution_configs):
+ agent_execution_configurations = [
+ AgentExecutionConfiguration(agent_execution_id=execution.id, key=key, value=str(value))
+ for key, value in agent_execution_configs.items()
+ ]
+ for agent_execution in agent_execution_configurations:
+ agent_execution_config = (
+ session.query(AgentExecutionConfiguration)
+ .filter(
+ AgentExecutionConfiguration.agent_execution_id == execution.id,
+ AgentExecutionConfiguration.key == agent_execution.key
+ )
+ .first()
+ )
+
+ if agent_execution_config:
+ agent_execution_config.value = str(agent_execution.value)
+ else:
+ agent_execution_config = AgentExecutionConfiguration(
+ agent_execution_id=execution.id,
+ key=agent_execution.key,
+ value=str(agent_execution.value)
+ )
+ session.add(agent_execution_config)
+ session.commit()
+
+ @classmethod
+ def fetch_configuration(cls, session, execution):
+ """
+ Fetches the execution configuration of an agent.
+
+ Args:
+ session: The database session object.
+ execution (AgentExecution): The AgentExecution of the agent.
+
+ Returns:
+ dict: Parsed agent configuration.
+
+ """
+ agent_configurations = session.query(AgentExecutionConfiguration).filter_by(
+ agent_execution_id=execution.id).all()
+ parsed_config = {
+ "goal": [],
+ "instruction": [],
+ }
+ if not agent_configurations:
+ return parsed_config
+ for item in agent_configurations:
+ parsed_config[item.key] = cls.eval_agent_config(item.key, item.value)
+ return parsed_config
+
+ @classmethod
+ def eval_agent_config(cls, key, value):
+ """
+ Evaluates the value of an agent execution configuration setting based on its key.
+
+ Args:
+ key (str): The key of the execution configuration setting.
+ value (str): The value of execution configuration setting.
+
+ Returns:
+ object: The evaluated value of the execution configuration setting.
+
+ """
+
+ if key == "goal" or key == "instruction":
+ return eval(value)
diff --git a/superagi/models/base_model.py b/superagi/models/base_model.py
index 8b872b11b..180110237 100644
--- a/superagi/models/base_model.py
+++ b/superagi/models/base_model.py
@@ -1,7 +1,7 @@
import json
from sqlalchemy import Column, DateTime, INTEGER
-from sqlalchemy.ext.declarative import declarative_base
+from sqlalchemy.orm import declarative_base
from datetime import datetime
Base = declarative_base()
diff --git a/superagi/models/configuration.py b/superagi/models/configuration.py
index 2193e941d..b2497c644 100644
--- a/superagi/models/configuration.py
+++ b/superagi/models/configuration.py
@@ -1,4 +1,6 @@
from sqlalchemy import Column, Integer, String,Text
+
+from superagi.helper.encyption_helper import decrypt_data
from superagi.models.base_model import DBBaseModel
@@ -29,3 +31,23 @@ def __repr__(self):
"""
return f"Config(id={self.id}, organisation_id={self.organisation_id}, key={self.key}, value={self.value})"
+
+
+ @classmethod
+ def fetch_configuration(cls, session, organisation_id: int, key: str, default_value=None) -> str:
+ """
+ Fetches the configuration of an agent.
+
+ Args:
+ session: The database session object.
+ organisation_id (int): The ID of the organisation.
+ key (str): The key of the configuration.
+ default_value (str): The default value of the configuration.
+
+ Returns:
+ dict: Parsed configuration.
+
+ """
+
+ configuration = session.query(Configuration).filter_by(organisation_id=organisation_id, key=key).first()
+ return decrypt_data(configuration.value) if configuration else default_value
\ No newline at end of file
diff --git a/superagi/models/oauth_tokens.py b/superagi/models/oauth_tokens.py
new file mode 100644
index 000000000..996c0dae8
--- /dev/null
+++ b/superagi/models/oauth_tokens.py
@@ -0,0 +1,53 @@
+from sqlalchemy import Column, Integer, String, Text
+from sqlalchemy.orm import Session
+
+from superagi.models.base_model import DBBaseModel
+import json
+import yaml
+
+
+
+class OauthTokens(DBBaseModel):
+ """
+ Model representing a OauthTokens.
+
+ Attributes:
+ id (Integer): The primary key of the oauth token.
+ user_id (Integer): The ID of the user associated with the Tokens.
+ toolkit_id (Integer): The ID of the toolkit associated with the Tokens.
+ key (String): The Token Key.
+ value (Text): The Token value.
+ """
+
+ __tablename__ = 'oauth_tokens'
+
+ id = Column(Integer, primary_key=True, autoincrement=True)
+ user_id = Column(Integer)
+ organisation_id = Column(Integer)
+ toolkit_id = Column(Integer)
+ key = Column(String)
+ value = Column(Text)
+
+ def __repr__(self):
+ """
+ Returns a string representation of the OauthTokens object.
+
+ Returns:
+ str: String representation of the OauthTokens object.
+ """
+
+ return f"Tokens(id={self.id}, user_id={self.user_id}, organisation_id={self.organisation_id} toolkit_id={self.toolkit_id}, key={self.key}, value={self.value})"
+
+ @classmethod
+ def add_or_update(self, session: Session, toolkit_id: int, user_id: int, organisation_id: int, key: str, value: Text = None):
+ oauth_tokens = session.query(OauthTokens).filter_by(toolkit_id=toolkit_id, user_id=user_id).first()
+ if oauth_tokens:
+ # Update existing oauth tokens
+ if value is not None:
+ oauth_tokens.value = value
+ else:
+ # Create new oauth tokens
+ oauth_tokens = OauthTokens(toolkit_id=toolkit_id, user_id=user_id, organisation_id=organisation_id, key=key, value=value)
+ session.add(oauth_tokens)
+
+ session.commit()
\ No newline at end of file
diff --git a/superagi/models/resource.py b/superagi/models/resource.py
index 53091fc1a..15edb3e59 100644
--- a/superagi/models/resource.py
+++ b/superagi/models/resource.py
@@ -1,4 +1,4 @@
-from sqlalchemy import Column, Integer, String, Float
+from sqlalchemy import Column, Integer, String, Float, Text
from superagi.models.base_model import DBBaseModel
from sqlalchemy.orm import sessionmaker
@@ -28,6 +28,7 @@ class Resource(DBBaseModel):
type = Column(String) # application/pdf etc
channel = Column(String) # INPUT,OUTPUT
agent_id = Column(Integer)
+ summary = Column(Text)
def __repr__(self):
"""
diff --git a/superagi/resource_manager/manager.py b/superagi/resource_manager/file_manager.py
similarity index 89%
rename from superagi/resource_manager/manager.py
rename to superagi/resource_manager/file_manager.py
index 0b376a53e..882987c83 100644
--- a/superagi/resource_manager/manager.py
+++ b/superagi/resource_manager/file_manager.py
@@ -1,12 +1,20 @@
+import csv
+import os
+
+from llama_index import SimpleDirectoryReader
+from llama_index.indices.response import ResponseMode
+from llama_index.schema import Document
from sqlalchemy.orm import Session
+from superagi.config.config import get_config
from superagi.helper.resource_helper import ResourceHelper
from superagi.helper.s3_helper import S3Helper
from superagi.lib.logger import logger
-import os
-import csv
+from superagi.types.storage_types import StorageType
+from superagi.types.vector_store_types import VectorStoreType
+
-class ResourceManager:
+class FileManager:
def __init__(self, session: Session, agent_id: int = None):
self.session = session
self.agent_id = agent_id
@@ -38,7 +46,7 @@ def write_to_s3(self, file_name, final_path):
self.session.add(resource)
self.session.commit()
self.session.flush()
- if resource.storage_type == "S3":
+ if resource.storage_type == StorageType.S3.value:
s3_helper = S3Helper()
s3_helper.upload_file(img, path=resource.path)
@@ -77,3 +85,4 @@ def write_csv_file(self, file_name: str, csv_data):
def get_agent_resource_path(self, file_name: str):
return ResourceHelper.get_agent_resource_path(file_name, self.agent_id)
+
diff --git a/superagi/resource_manager/llama_document_summary.py b/superagi/resource_manager/llama_document_summary.py
new file mode 100644
index 000000000..5d7c8e246
--- /dev/null
+++ b/superagi/resource_manager/llama_document_summary.py
@@ -0,0 +1,62 @@
+import os
+
+from llama_index.indices.response import ResponseMode
+from llama_index.schema import Document
+
+from superagi.config.config import get_config
+
+
+class LlamaDocumentSummary:
+ def __init__(self, model_name=get_config("RESOURCES_SUMMARY_MODEL_NAME", "gpt-3.5-turbo"), model_api_key: str = None):
+ self.model_name = model_name
+ self.model_api_key = model_api_key
+
+ def generate_summary_of_document(self, documents: list[Document]):
+ """
+ Generates summary of the documents
+
+ :param documents: list of Document objects
+ :return: summary of the documents
+ """
+ from llama_index import LLMPredictor, ServiceContext, ResponseSynthesizer, DocumentSummaryIndex
+
+ os.environ["OPENAI_API_KEY"] = get_config("OPENAI_API_KEY", "") or self.model_api_key
+ llm_predictor_chatgpt = LLMPredictor(llm=self._build_llm())
+ service_context = ServiceContext.from_defaults(llm_predictor=llm_predictor_chatgpt, chunk_size=1024)
+ response_synthesizer = ResponseSynthesizer.from_args(response_mode=ResponseMode.TREE_SUMMARIZE, use_async=True)
+ doc_summary_index = DocumentSummaryIndex.from_documents(
+ documents=documents,
+ service_context=service_context,
+ response_synthesizer=response_synthesizer
+ )
+
+ return doc_summary_index.get_document_summary(documents[0].doc_id)
+
+ def generate_summary_of_texts(self, texts: list[str]):
+ """
+ Generates summary of the texts
+
+ :param texts: list of texts
+ :return: summary of the texts
+ """
+ from llama_index import Document
+ if texts is not None and len(texts) > 0:
+ documents = [Document(doc_id=f"doc_id_{i}", text=text) for i, text in enumerate(texts)]
+ return self.generate_summary_of_document(documents)
+ raise ValueError("texts must be provided")
+
+ def _build_llm(self):
+ """
+ Builds the LLM model
+
+ :return: LLM model object
+ """
+ open_ai_models = ['gpt-4', 'gpt-3.5-turbo', 'gpt-3.5-turbo-16k', 'gpt-4-32k']
+ if self.model_name in open_ai_models:
+ from langchain.chat_models import ChatOpenAI
+
+ openai_api_key = get_config("OPENAI_API_KEY") or self.model_api_key
+ return ChatOpenAI(temperature=0, model_name=self.model_name,
+ openai_api_key=openai_api_key)
+
+ raise Exception(f"Model name {self.model_name} not supported for document summary")
diff --git a/superagi/resource_manager/llama_vector_store_factory.py b/superagi/resource_manager/llama_vector_store_factory.py
new file mode 100644
index 000000000..7b96525fc
--- /dev/null
+++ b/superagi/resource_manager/llama_vector_store_factory.py
@@ -0,0 +1,59 @@
+from llama_index.vector_stores.types import VectorStore
+
+from superagi.config.config import get_config
+from superagi.types.vector_store_types import VectorStoreType
+
+
+class LlamaVectorStoreFactory:
+ """
+ Factory class to create vector stores based on the vector_store_name
+
+ :param vector_store_name: VectorStoreType
+ :param index_name: str
+
+ :return: VectorStore object
+ """
+ def __init__(self, vector_store_name: VectorStoreType, index_name: str):
+ self.vector_store_name = vector_store_name
+ self.index_name = index_name
+
+ def get_vector_store(self) -> VectorStore:
+ """
+ Returns the vector store based on the vector_store_name
+
+ :return: VectorStore object
+ """
+ if self.vector_store_name == VectorStoreType.PINECONE:
+ from llama_index.vector_stores import PineconeVectorStore
+ return PineconeVectorStore(self.index_name)
+
+ if self.vector_store_name == VectorStoreType.REDIS:
+ redis_url = get_config("REDIS_VECTOR_STORE_URL") or "redis://super__redis:6379"
+ from llama_index.vector_stores import RedisVectorStore
+ return RedisVectorStore(
+ index_name=self.index_name,
+ redis_url=redis_url,
+ metadata_fields=["agent_id", "resource_id"]
+ )
+
+ if self.vector_store_name == VectorStoreType.CHROMA:
+ from llama_index.vector_stores import ChromaVectorStore
+ import chromadb
+ from chromadb.config import Settings
+ chroma_host_name = get_config("CHROMA_HOST_NAME") or "localhost"
+ chroma_port = get_config("CHROMA_PORT") or 8000
+ chroma_client = chromadb.Client(
+ Settings(chroma_api_impl="rest", chroma_server_host=chroma_host_name,
+ chroma_server_http_port=chroma_port))
+ chroma_collection = chroma_client.get_or_create_collection(self.index_name)
+ return ChromaVectorStore(chroma_collection)
+
+ if self.vector_store_name == VectorStoreType.QDRANT:
+ from llama_index.vector_stores import QdrantVectorStore
+ qdrant_host_name = get_config("QDRANT_HOST_NAME") or "localhost"
+ qdrant_port = get_config("QDRANT_PORT") or 6333
+ from qdrant_client import QdrantClient
+ qdrant_client = QdrantClient(host=qdrant_host_name, port=qdrant_port)
+ return QdrantVectorStore(client=qdrant_client, collection_name=self.index_name)
+
+ raise ValueError(str(self.vector_store_name) + " vector store is not supported yet.")
diff --git a/superagi/resource_manager/resource_manager.py b/superagi/resource_manager/resource_manager.py
new file mode 100644
index 000000000..9be212afe
--- /dev/null
+++ b/superagi/resource_manager/resource_manager.py
@@ -0,0 +1,97 @@
+import os
+
+from llama_index import SimpleDirectoryReader
+
+from superagi.config.config import get_config
+from superagi.helper.resource_helper import ResourceHelper
+from superagi.lib.logger import logger
+from superagi.resource_manager.llama_vector_store_factory import LlamaVectorStoreFactory
+from superagi.types.vector_store_types import VectorStoreType
+
+
+class ResourceManager:
+ """
+ Resource Manager handles creation of resources and saving them to the vector store.
+
+ :param agent_id: The agent id to use when saving resources to the vector store.
+ """
+ def __init__(self, agent_id: str = None):
+ self.agent_id = agent_id
+
+ def create_llama_document(self, file_path: str):
+ """
+ Creates a document index from a given file path.
+
+ :param file_path: The file path to create the document index from.
+ :return: A list of documents.
+ """
+ if file_path is None:
+ raise Exception("file_path must be provided")
+ documents = SimpleDirectoryReader(input_files=[file_path]).load_data()
+
+ return documents
+
+ def create_llama_document_s3(self, file_path: str):
+ """
+ Creates a document index from a given file path.
+
+ :param file_path: The file path to create the document index from.
+ :return: A list of documents.
+ """
+
+ if file_path is None:
+ raise Exception("file_path must be provided")
+
+ import boto3
+ s3 = boto3.client(
+ 's3',
+ aws_access_key_id=get_config("AWS_ACCESS_KEY_ID"),
+ aws_secret_access_key=get_config("AWS_SECRET_ACCESS_KEY"),
+ )
+ bucket_name = get_config("BUCKET_NAME")
+ file = s3.get_object(Bucket=bucket_name, Key=file_path)
+ file_name = file_path.split("/")[-1]
+ save_directory = ResourceHelper.get_root_input_dir() + "/"
+ file_path = save_directory + file_name
+ with open(file_path, "wb") as f:
+ contents = file['Body'].read()
+ f.write(contents)
+
+ documents = SimpleDirectoryReader(input_files=[file_path]).load_data()
+ os.remove(file_path)
+ return documents
+
+ def save_document_to_vector_store(self, documents: list, resource_id: str, mode_api_key: str = None):
+ """
+ Saves a document to the vector store.
+
+ :param documents: The documents to save to the vector store.
+ :param resource_id: The resource id to use when saving the documents to the vector store.
+ :param mode_api_key: The mode api key to use when creating embedding to the vector store.
+ """
+ from llama_index import VectorStoreIndex, StorageContext
+ import openai
+ openai.api_key = get_config("OPENAI_API_KEY") or mode_api_key
+ os.environ["OPENAI_API_KEY"] = get_config("OPENAI_API_KEY", "") or mode_api_key
+ for docs in documents:
+ if docs.metadata is None:
+ docs.metadata = {}
+ docs.metadata["agent_id"] = str(self.agent_id)
+ docs.metadata["resource_id"] = resource_id
+ vector_store = None
+ storage_context = None
+ vector_store_name = VectorStoreType.get_vector_store_type(get_config("RESOURCE_VECTOR_STORE") or "Redis")
+ vector_store_index_name = get_config("RESOURCE_VECTOR_STORE_INDEX_NAME") or "super-agent-index"
+ try:
+ vector_store = LlamaVectorStoreFactory(vector_store_name, vector_store_index_name).get_vector_store()
+ storage_context = StorageContext.from_defaults(vector_store=vector_store)
+ except ValueError as e:
+ logger.error(f"Vector store not found{e}")
+ try:
+ index = VectorStoreIndex.from_documents(documents, storage_context=storage_context)
+ index.set_index_id(f'Agent {self.agent_id}')
+ except Exception as e:
+ logger.error(e)
+ # persisting the data in case of redis
+ if vector_store_name == VectorStoreType.REDIS:
+ vector_store.persist(persist_path="")
diff --git a/superagi/resource_manager/resource_summary.py b/superagi/resource_manager/resource_summary.py
new file mode 100644
index 000000000..9115fbb1d
--- /dev/null
+++ b/superagi/resource_manager/resource_summary.py
@@ -0,0 +1,94 @@
+from datetime import datetime
+
+from superagi.lib.logger import logger
+from superagi.models.agent import Agent
+from superagi.models.agent_config import AgentConfiguration
+from superagi.models.configuration import Configuration
+from superagi.models.resource import Resource
+from superagi.resource_manager.llama_document_summary import LlamaDocumentSummary
+from superagi.resource_manager.resource_manager import ResourceManager
+
+
+class ResourceSummarizer:
+ """Class to summarize a resource."""
+
+ def __init__(self, session):
+ self.session = session
+
+ def add_to_vector_store_and_create_summary(self, agent_id: int, resource_id: int, documents: list):
+ """
+ Add a file to the vector store and generate a summary for it.
+
+ Args:
+ agent_id (str): ID of the agent.
+ resource_id (int): ID of the resource.
+ openai_api_key (str): OpenAI API key.
+ documents (list): List of documents.
+ """
+ agent = self.session.query(Agent).filter(Agent.id == agent_id).first()
+ organization = agent.get_agent_organisation(self.session)
+ model_api_key = Configuration.fetch_configuration(self.session, organization.id, "model_api_key")
+ try:
+ ResourceManager(str(agent_id)).save_document_to_vector_store(documents, str(resource_id), model_api_key)
+ except Exception as e:
+ logger.error(e)
+ summary = None
+ try:
+ summary = LlamaDocumentSummary(model_api_key=model_api_key).generate_summary_of_document(documents)
+ except Exception as e:
+ logger.error(e)
+ resource = self.session.query(Resource).filter(Resource.id == resource_id).first()
+ resource.summary = summary
+ self.session.commit()
+
+ def generate_agent_summary(self, agent_id: int, generate_all: bool = False) -> str:
+ """Generate a summary of all resources for an agent."""
+ agent_config_resource_summary = self.session.query(AgentConfiguration). \
+ filter(AgentConfiguration.agent_id == agent_id,
+ AgentConfiguration.key == "resource_summary").first()
+ resources = self.session.query(Resource).filter(Resource.agent_id == agent_id,Resource.channel == 'INPUT').all()
+ if not resources:
+ return
+
+ agent = self.session.query(Agent).filter(Agent.id == agent_id).first()
+ organization = agent.get_agent_organisation(self.session)
+ model_api_key = Configuration.fetch_configuration(self.session, organization.id, "model_api_key")
+
+ summary_texts = [resource.summary for resource in resources if resource.summary is not None]
+
+ # generate_all is added because we want to generate summary for all resources when agent is created
+ # this is set to false when adding individual resources
+ if len(summary_texts) < len(resources) and generate_all:
+ file_paths = [resource.path for resource in resources if resource.summary is None]
+ for file_path in file_paths:
+ if resources[0].storage_type == 'S3':
+ documents = ResourceManager(str(agent_id)).create_llama_document_s3(file_path)
+ else:
+ documents = ResourceManager(str(agent_id)).create_llama_document(file_path)
+ summary_texts.append(LlamaDocumentSummary(model_api_key=model_api_key).generate_summary_of_document(documents))
+
+ agent_last_resource = self.session.query(AgentConfiguration). \
+ filter(AgentConfiguration.agent_id == agent_id,
+ AgentConfiguration.key == "last_resource_time").first()
+ if agent_last_resource is not None and \
+ datetime.strptime(agent_last_resource.value, '%Y-%m-%d %H:%M:%S.%f') == resources[-1].updated_at \
+ and not generate_all:
+ return
+
+ resource_summary = summary_texts[0] if summary_texts else None
+ if len(summary_texts) > 1:
+ resource_summary = LlamaDocumentSummary(model_api_key=model_api_key).generate_summary_of_texts(summary_texts)
+
+ if agent_config_resource_summary is not None:
+ agent_config_resource_summary.value = resource_summary
+ else:
+ agent_config_resource_summary = AgentConfiguration(agent_id=agent_id, key="resource_summary",
+ value=resource_summary)
+ self.session.add(agent_config_resource_summary)
+ if agent_last_resource is not None:
+ agent_last_resource.value = str(resources[-1].updated_at)
+ else:
+ agent_last_resource = AgentConfiguration(agent_id=agent_id, key="last_resource_time",
+ value=str(resources[-1].updated_at))
+ self.session.add(agent_last_resource)
+ self.session.commit()
diff --git a/superagi/tool_manager.py b/superagi/tool_manager.py
index 54b983dd9..674dd0e03 100644
--- a/superagi/tool_manager.py
+++ b/superagi/tool_manager.py
@@ -64,7 +64,7 @@ def download_and_extract_tools():
tools_config = load_tools_config()
for tool_name, tool_url in tools_config.items():
- tool_folder = os.path.join("", "tools", tool_name)
+ tool_folder = os.path.join("superagi", "tools", tool_name)
if not os.path.exists(tool_folder):
os.makedirs(tool_folder)
download_tool(tool_url, tool_folder)
diff --git a/superagi/tools/code/write_code.py b/superagi/tools/code/write_code.py
index 2b29d6f0d..107589d56 100644
--- a/superagi/tools/code/write_code.py
+++ b/superagi/tools/code/write_code.py
@@ -8,7 +8,7 @@
from superagi.helper.token_counter import TokenCounter
from superagi.lib.logger import logger
from superagi.llms.base_llm import BaseLlm
-from superagi.resource_manager.manager import ResourceManager
+from superagi.resource_manager.file_manager import FileManager
from superagi.tools.base_tool import BaseTool
from superagi.tools.tool_response_query_manager import ToolResponseQueryManager
@@ -44,7 +44,7 @@ class CodingTool(BaseTool):
)
args_schema: Type[CodingSchema] = CodingSchema
goals: List[str] = []
- resource_manager: Optional[ResourceManager] = None
+ resource_manager: Optional[FileManager] = None
tool_response_manager: Optional[ToolResponseQueryManager] = None
class Config:
diff --git a/superagi/tools/code/write_spec.py b/superagi/tools/code/write_spec.py
index 45773f024..00e626426 100644
--- a/superagi/tools/code/write_spec.py
+++ b/superagi/tools/code/write_spec.py
@@ -7,7 +7,7 @@
from superagi.helper.token_counter import TokenCounter
from superagi.lib.logger import logger
from superagi.llms.base_llm import BaseLlm
-from superagi.resource_manager.manager import ResourceManager
+from superagi.resource_manager.file_manager import FileManager
from superagi.tools.base_tool import BaseTool
@@ -43,7 +43,7 @@ class WriteSpecTool(BaseTool):
)
args_schema: Type[WriteSpecSchema] = WriteSpecSchema
goals: List[str] = []
- resource_manager: Optional[ResourceManager] = None
+ resource_manager: Optional[FileManager] = None
class Config:
arbitrary_types_allowed = True
diff --git a/superagi/tools/code/write_test.py b/superagi/tools/code/write_test.py
index 4d889575a..9fb105873 100644
--- a/superagi/tools/code/write_test.py
+++ b/superagi/tools/code/write_test.py
@@ -8,7 +8,7 @@
from superagi.helper.token_counter import TokenCounter
from superagi.lib.logger import logger
from superagi.llms.base_llm import BaseLlm
-from superagi.resource_manager.manager import ResourceManager
+from superagi.resource_manager.file_manager import FileManager
from superagi.tools.base_tool import BaseTool
from superagi.tools.tool_response_query_manager import ToolResponseQueryManager
@@ -47,7 +47,7 @@ class WriteTestTool(BaseTool):
)
args_schema: Type[WriteTestSchema] = WriteTestSchema
goals: List[str] = []
- resource_manager: Optional[ResourceManager] = None
+ resource_manager: Optional[FileManager] = None
tool_response_manager: Optional[ToolResponseQueryManager] = None
class Config:
diff --git a/superagi/tools/email/send_email.py b/superagi/tools/email/send_email.py
index 2aa55cde7..09708ab94 100644
--- a/superagi/tools/email/send_email.py
+++ b/superagi/tools/email/send_email.py
@@ -27,7 +27,7 @@ class SendEmailTool(BaseTool):
name: str = "Send Email"
args_schema: Type[BaseModel] = SendEmailInput
description: str = "Send an Email"
-
+
def _execute(self, to: str, subject: str, body: str) -> str:
"""
Execute the send email tool.
diff --git a/superagi/tools/email/send_email_attachment.py b/superagi/tools/email/send_email_attachment.py
index 2a659d1b4..2f48908cb 100644
--- a/superagi/tools/email/send_email_attachment.py
+++ b/superagi/tools/email/send_email_attachment.py
@@ -9,6 +9,7 @@
from pydantic import BaseModel, Field
from superagi.helper.imap_email import ImapEmail
from superagi.tools.base_tool import BaseTool
+from superagi.helper.resource_helper import ResourceHelper
class SendEmailAttachmentInput(BaseModel):
@@ -30,7 +31,8 @@ class SendEmailAttachmentTool(BaseTool):
name: str = "Send Email with Attachment"
args_schema: Type[BaseModel] = SendEmailAttachmentInput
description: str = "Send an Email with a file attached to it"
-
+ agent_id: int = None
+
def _execute(self, to: str, subject: str, body: str, filename: str) -> str:
"""
Execute the send email tool with attachment.
@@ -44,21 +46,13 @@ def _execute(self, to: str, subject: str, body: str, filename: str) -> str:
Returns:
success or failure message
"""
- input_root_dir = self.get_tool_config('RESOURCES_INPUT_ROOT_DIR')
- output_root_dir = self.get_tool_config('RESOURCES_OUTPUT_ROOT_DIR')
- final_path = None
+ final_path = ResourceHelper.get_agent_resource_path(filename, self.agent_id)
- if input_root_dir is not None:
- input_root_dir = input_root_dir if input_root_dir.startswith("/") else os.getcwd() + "/" + input_root_dir
- input_root_dir = input_root_dir if input_root_dir.endswith("/") else input_root_dir + "/"
- final_path = input_root_dir + filename
+ if final_path is None or not os.path.exists(final_path):
+ final_path = ResourceHelper.get_root_input_dir() + filename
if final_path is None or not os.path.exists(final_path):
- if output_root_dir is not None:
- output_root_dir = output_root_dir if output_root_dir.startswith(
- "/") else os.getcwd() + "/" + output_root_dir
- output_root_dir = output_root_dir if output_root_dir.endswith("/") else output_root_dir + "/"
- final_path = output_root_dir + filename
+ raise FileNotFoundError(f"File '{filename}' not found.")
attachment = os.path.basename(final_path)
return self.send_email_with_attachment(to, subject, body, final_path, attachment)
diff --git a/superagi/tools/file/read_file.py b/superagi/tools/file/read_file.py
index 47cb23096..099567bdd 100644
--- a/superagi/tools/file/read_file.py
+++ b/superagi/tools/file/read_file.py
@@ -4,7 +4,7 @@
from pydantic import BaseModel, Field
from superagi.helper.resource_helper import ResourceHelper
-from superagi.resource_manager.manager import ResourceManager
+from superagi.resource_manager.file_manager import FileManager
from superagi.tools.base_tool import BaseTool
@@ -26,7 +26,7 @@ class ReadFileTool(BaseTool):
agent_id: int = None
args_schema: Type[BaseModel] = ReadFileSchema
description: str = "Reads the file content in a specified location"
- resource_manager: Optional[ResourceManager] = None
+ resource_manager: Optional[FileManager] = None
def _execute(self, file_name: str):
"""
diff --git a/superagi/tools/file/write_file.py b/superagi/tools/file/write_file.py
index 93e979c00..f8b61aeaf 100644
--- a/superagi/tools/file/write_file.py
+++ b/superagi/tools/file/write_file.py
@@ -3,7 +3,7 @@
from pydantic import BaseModel, Field
# from superagi.helper.s3_helper import upload_to_s3
-from superagi.resource_manager.manager import ResourceManager
+from superagi.resource_manager.file_manager import FileManager
from superagi.tools.base_tool import BaseTool
@@ -31,7 +31,7 @@ class WriteFileTool(BaseTool):
args_schema: Type[BaseModel] = WriteFileInput
description: str = "Writes text to a file"
agent_id: int = None
- resource_manager: Optional[ResourceManager] = None
+ resource_manager: Optional[FileManager] = None
class Config:
arbitrary_types_allowed = True
diff --git a/superagi/tools/google_calendar/list_calendar_events.py b/superagi/tools/google_calendar/list_calendar_events.py
index 9d7a411be..c96a30d24 100644
--- a/superagi/tools/google_calendar/list_calendar_events.py
+++ b/superagi/tools/google_calendar/list_calendar_events.py
@@ -7,7 +7,7 @@
from superagi.tools.base_tool import BaseTool
from superagi.helper.google_calendar_creds import GoogleCalendarCreds
from superagi.helper.calendar_date import CalendarDate
-from superagi.resource_manager.manager import ResourceManager
+from superagi.resource_manager.file_manager import FileManager
from superagi.helper.s3_helper import S3Helper
from urllib.parse import urlparse, parse_qs
from sqlalchemy.orm import sessionmaker
@@ -27,7 +27,7 @@ class ListCalendarEventsTool(BaseTool):
args_schema: Type[BaseModel] = ListCalendarEventsInput
description: str = "Get the list of all the events from Google Calendar"
agent_id: int = None
- resource_manager: ResourceManager = None
+ resource_manager: FileManager = None
def _execute(self, start_time: str = 'None', start_date: str = 'None', end_date: str = 'None', end_time: str = 'None'):
service = self.get_google_calendar_service()
diff --git a/superagi/tools/image_generation/dalle_image_gen.py b/superagi/tools/image_generation/dalle_image_gen.py
index c1ed8c578..18d218cc2 100644
--- a/superagi/tools/image_generation/dalle_image_gen.py
+++ b/superagi/tools/image_generation/dalle_image_gen.py
@@ -4,7 +4,7 @@
from pydantic import BaseModel, Field
from superagi.llms.base_llm import BaseLlm
-from superagi.resource_manager.manager import ResourceManager
+from superagi.resource_manager.file_manager import FileManager
from superagi.tools.base_tool import BaseTool
class DalleImageGenInput(BaseModel):
@@ -31,7 +31,7 @@ class DalleImageGenTool(BaseTool):
description: str = "Generate Images using Dalle"
llm: Optional[BaseLlm] = None
agent_id: int = None
- resource_manager: Optional[ResourceManager] = None
+ resource_manager: Optional[FileManager] = None
class Config:
arbitrary_types_allowed = True
diff --git a/superagi/tools/image_generation/stable_diffusion_image_gen.py b/superagi/tools/image_generation/stable_diffusion_image_gen.py
index ffc27f49a..983d7ec5e 100644
--- a/superagi/tools/image_generation/stable_diffusion_image_gen.py
+++ b/superagi/tools/image_generation/stable_diffusion_image_gen.py
@@ -6,7 +6,7 @@
from PIL import Image
from pydantic import BaseModel, Field
-from superagi.resource_manager.manager import ResourceManager
+from superagi.resource_manager.file_manager import FileManager
from superagi.tools.base_tool import BaseTool
@@ -35,7 +35,7 @@ class StableDiffusionImageGenTool(BaseTool):
args_schema: Type[BaseModel] = StableDiffusionImageGenInput
description: str = "Generate Images using Stable Diffusion"
agent_id: int = None
- resource_manager: Optional[ResourceManager] = None
+ resource_manager: Optional[FileManager] = None
class Config:
arbitrary_types_allowed = True
diff --git a/superagi/tools/resource/__init__.py b/superagi/tools/resource/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/superagi/tools/resource/query_resource.py b/superagi/tools/resource/query_resource.py
new file mode 100644
index 000000000..d853c4811
--- /dev/null
+++ b/superagi/tools/resource/query_resource.py
@@ -0,0 +1,79 @@
+import logging
+import os
+from typing import Type
+
+import openai
+from langchain.chat_models import ChatOpenAI
+from llama_index import VectorStoreIndex, LLMPredictor, ServiceContext
+from llama_index.vector_stores.types import ExactMatchFilter, MetadataFilters
+from pydantic import BaseModel, Field
+
+from superagi.config.config import get_config
+from superagi.resource_manager.llama_vector_store_factory import LlamaVectorStoreFactory
+from superagi.resource_manager.resource_manager import ResourceManager
+from superagi.tools.base_tool import BaseTool
+from superagi.types.vector_store_types import VectorStoreType
+from superagi.vector_store.chromadb import ChromaDB
+from superagi.vector_store.embedding.openai import OpenAiEmbedding
+from typing import Optional
+from superagi.llms.base_llm import BaseLlm
+
+
+class QueryResource(BaseModel):
+ """Input for QueryResource tool."""
+ query: str = Field(..., description="the search query to search resources")
+
+
+class QueryResourceTool(BaseTool):
+ """
+ Read File tool
+
+ Attributes:
+ name : The name.
+ description : The description.
+ args_schema : The args schema.
+ """
+ name: str = "QueryResource"
+ args_schema: Type[BaseModel] = QueryResource
+ description: str = "Tool searches resources content and extracts relevant information to perform the given task." \
+ "Tool is given preference over other search/read file tools for relevant data." \
+ "Resources content includes: {summary}"
+ agent_id: int = None
+ llm: Optional[BaseLlm] = None
+
+ def _execute(self, query: str):
+ openai.api_key = getattr(self.llm, 'api_key')
+ os.environ["OPENAI_API_KEY"] = getattr(self.llm, 'api_key')
+ llm_predictor_chatgpt = LLMPredictor(llm=ChatOpenAI(temperature=0, model_name=self.llm.get_model(),
+ openai_api_key=get_config("OPENAI_API_KEY")))
+ service_context = ServiceContext.from_defaults(llm_predictor=llm_predictor_chatgpt)
+ vector_store_name = VectorStoreType.get_vector_store_type(
+ self.get_tool_config(key="RESOURCE_VECTOR_STORE") or "Redis")
+ vector_store_index_name = self.get_tool_config(key="RESOURCE_VECTOR_STORE_INDEX_NAME") or "super-agent-index"
+ logging.info(f"vector_store_name {vector_store_name}")
+ logging.info(f"vector_store_index_name {vector_store_index_name}")
+ vector_store = LlamaVectorStoreFactory(vector_store_name, vector_store_index_name).get_vector_store()
+ logging.info(f"vector_store {vector_store}")
+ as_query_engine_args = dict(
+ filters=MetadataFilters(
+ filters=[
+ ExactMatchFilter(
+ key="agent_id",
+ value=str(self.agent_id)
+ )
+ ]
+ )
+ )
+ if vector_store_name == VectorStoreType.CHROMA:
+ as_query_engine_args["chroma_collection"] = ChromaDB.create_collection(
+ collection_name=vector_store_index_name)
+ index = VectorStoreIndex.from_vector_store(vector_store=vector_store, service_context=service_context)
+ query_engine = index.as_query_engine(
+ **as_query_engine_args
+ )
+ try:
+ response = query_engine.query(query)
+ except ValueError as e:
+ logging.error(f"ValueError {e}")
+ response = "Document not found"
+ return response
diff --git a/superagi/tools/resource/resource_toolkit.py b/superagi/tools/resource/resource_toolkit.py
new file mode 100644
index 000000000..069a5138e
--- /dev/null
+++ b/superagi/tools/resource/resource_toolkit.py
@@ -0,0 +1,20 @@
+from abc import ABC
+from typing import List
+from superagi.tools.base_tool import BaseTool, BaseToolkit
+from superagi.tools.resource.query_resource import QueryResourceTool
+
+
+class JiraToolkit(BaseToolkit, ABC):
+ name: str = "Resource Toolkit"
+ description: str = "Toolkit containing tools for Resource integration"
+
+ def get_tools(self) -> List[BaseTool]:
+ return [
+ QueryResourceTool(),
+ ]
+
+ def get_env_keys(self) -> List[str]:
+ return [
+ "RESOURCE_VECTOR_STORE",
+ "RESOURCE_VECTOR_STORE_INDEX_NAME",
+ ]
diff --git a/superagi/tools/twitter/send_tweets.py b/superagi/tools/twitter/send_tweets.py
new file mode 100644
index 000000000..79d19db2b
--- /dev/null
+++ b/superagi/tools/twitter/send_tweets.py
@@ -0,0 +1,38 @@
+import os
+import json
+import base64
+import requests
+from typing import Any, Type
+from pydantic import BaseModel, Field
+from superagi.tools.base_tool import BaseTool
+from superagi.helper.twitter_tokens import TwitterTokens
+from superagi.helper.twitter_helper import TwitterHelper
+
+class SendTweetsInput(BaseModel):
+ tweet_text: str = Field(..., description="Tweet text to be posted from twitter handle, if no value is given keep the default value as 'None'")
+ is_media: bool = Field(..., description="'True' if there is any media to be posted with Tweet else 'False'.")
+ media_files: list = Field(..., description="Name of the media files to be uploaded.")
+
+class SendTweetsTool(BaseTool):
+ name: str = "Send Tweets Tool"
+ args_schema: Type[BaseModel] = SendTweetsInput
+ description: str = "Send and Schedule Tweets for your Twitter Handle"
+ agent_id: int = None
+
+ def _execute(self, is_media: bool, tweet_text: str = 'None', media_files: list = []):
+ toolkit_id = self.toolkit_config.toolkit_id
+ session = self.toolkit_config.session
+ creds = TwitterTokens(session).get_twitter_creds(toolkit_id)
+ params = {}
+ if is_media:
+ media_ids = TwitterHelper().get_media_ids(media_files, creds, self.agent_id)
+ params["media"] = {"media_ids": media_ids}
+ if tweet_text is not None:
+ params["text"] = tweet_text
+ tweet_response = TwitterHelper().send_tweets(params, creds)
+ if tweet_response.status_code == 201:
+ return "Tweet posted successfully!!"
+ else:
+ return "Error posting tweet. (Status code: {})".format(tweet_response.status_code)
+
+
\ No newline at end of file
diff --git a/superagi/tools/twitter/twitter_toolkit.py b/superagi/tools/twitter/twitter_toolkit.py
new file mode 100644
index 000000000..75c7eea95
--- /dev/null
+++ b/superagi/tools/twitter/twitter_toolkit.py
@@ -0,0 +1,15 @@
+from abc import ABC
+from superagi.tools.base_tool import BaseToolkit, BaseTool
+from typing import Type, List
+from superagi.tools.twitter.send_tweets import SendTweetsTool
+
+
+class TwitterToolkit(BaseToolkit, ABC):
+ name: str = "Twitter Toolkit"
+ description: str = "Twitter Tool kit contains all tools related to Twitter"
+
+ def get_tools(self) -> List[BaseTool]:
+ return [SendTweetsTool()]
+
+ def get_env_keys(self) -> List[str]:
+ return ["TWITTER_API_KEY", "TWITTER_API_SECRET"]
diff --git a/superagi/types/storage_types.py b/superagi/types/storage_types.py
new file mode 100644
index 000000000..9d41c7fef
--- /dev/null
+++ b/superagi/types/storage_types.py
@@ -0,0 +1,13 @@
+from enum import Enum
+
+
+class StorageType(Enum):
+ FILE = 'FILE'
+ S3 = 'S3'
+
+ @classmethod
+ def get_storage_type(cls, store):
+ store = store.upper()
+ if store in cls.__members__:
+ return cls[store]
+ raise ValueError(f"{store} is not a valid storage name.")
diff --git a/superagi/types/vector_store_types.py b/superagi/types/vector_store_types.py
new file mode 100644
index 000000000..d10a0f5a3
--- /dev/null
+++ b/superagi/types/vector_store_types.py
@@ -0,0 +1,20 @@
+from enum import Enum
+
+
+class VectorStoreType(Enum):
+ REDIS = 'redis'
+ PINECONE = 'pinecone'
+ CHROMA = 'chroma'
+ WEAVIATE = 'weaviate'
+ QDRANT = 'qdrant'
+ LANCEDB = 'LanceDB'
+
+ @classmethod
+ def get_vector_store_type(cls, store):
+ store = store.upper()
+ if store in cls.__members__:
+ return cls[store]
+ raise ValueError(f"{store} is not a valid vector store name.")
+
+ def __str__(self):
+ return self.value
diff --git a/superagi/vector_store/chromadb.py b/superagi/vector_store/chromadb.py
new file mode 100644
index 000000000..ef2f9d8f5
--- /dev/null
+++ b/superagi/vector_store/chromadb.py
@@ -0,0 +1,102 @@
+import uuid
+from typing import Any, Optional, Iterable, List
+
+import chromadb
+from chromadb import Settings
+
+from superagi.config.config import get_config
+from superagi.vector_store.base import VectorStore
+from superagi.vector_store.document import Document
+from superagi.vector_store.embedding.openai import BaseEmbedding
+
+
+def _build_chroma_client():
+ chroma_host_name = get_config("CHROMA_HOST_NAME") or "localhost"
+ chroma_port = get_config("CHROMA_PORT") or 8000
+ return chromadb.Client(Settings(chroma_api_impl="rest", chroma_server_host=chroma_host_name,
+ chroma_server_http_port=chroma_port))
+
+
+class ChromaDB(VectorStore):
+ def __init__(
+ self,
+ collection_name: str,
+ embedding_model: BaseEmbedding,
+ text_field: str,
+ namespace: Optional[str] = "",
+ ):
+ self.client = _build_chroma_client()
+ self.collection_name = collection_name
+ self.embedding_model = embedding_model
+ self.text_field = text_field
+ self.namespace = namespace
+
+ @classmethod
+ def create_collection(cls, collection_name):
+ """Create a Chroma Collection.
+ Args:
+ collection_name: The name of the collection to create.
+ """
+ chroma_client = _build_chroma_client()
+ return chroma_client.get_or_create_collection(name=collection_name)
+
+ def add_texts(
+ self,
+ texts: Iterable[str],
+ metadatas: Optional[List[dict]] = None,
+ ids: Optional[List[str]] = None,
+ namespace: Optional[str] = None,
+ batch_size: int = 32,
+ **kwargs: Any,
+ ) -> List[str]:
+ """Add texts to the vector store."""
+ if namespace is None:
+ namespace = self.namespace
+
+ metadatas = []
+ ids = ids or [str(uuid.uuid4()) for _ in texts]
+ if len(ids) < len(texts):
+ raise ValueError("Number of ids must match number of texts.")
+
+ for text, id in zip(texts, ids):
+ metadata = metadatas.pop(0) if metadatas else {}
+ metadata[self.text_field] = text
+ metadatas.append(metadata)
+ collection = self.client.get_collection(name=self.collection_name)
+ collection.add(
+ documents=texts,
+ metadatas=metadatas,
+ ids=ids
+ )
+
+ return ids
+
+ def get_matching_text(self, query: str, top_k: int = 5, metadata: Optional[dict] = {}, **kwargs: Any) -> List[
+ Document]:
+ """Return docs most similar to query using specified search type."""
+ embedding_vector = self.embedding_model.get_embedding(query)
+ collection = self.client.get_collection(name=self.collection_name)
+ filters = {}
+ for key in metadata.keys():
+ filters[key] = metadata[key]
+ results = collection.query(
+ query_embeddings=embedding_vector,
+ include=["documents"],
+ n_results=top_k,
+ where=filters
+ )
+
+ documents = []
+
+ for node_id, text, metadata in zip(
+ results["ids"][0],
+ results["documents"][0],
+ results["metadatas"][0]):
+ documents.append(
+ Document(
+ text_content=text,
+ metadata=metadata
+ )
+ )
+
+ return documents
diff --git a/superagi/vector_store/embedding/openai.py b/superagi/vector_store/embedding/openai.py
index d6318a1de..7cef2a382 100644
--- a/superagi/vector_store/embedding/openai.py
+++ b/superagi/vector_store/embedding/openai.py
@@ -21,8 +21,9 @@ def __init__(self, api_key, model="text-embedding-ada-002"):
async def get_embedding_async(self, text):
try:
# openai.api_key = get_config("OPENAI_API_KEY")
- openai.api_key = self.api_key
+ # openai.api_key = self.api_key
response = await openai.Embedding.create(
+ api_key=self.api_key,
input=[text],
engine=self.model
)
@@ -34,6 +35,7 @@ def get_embedding(self, text):
try:
# openai.api_key = get_config("OPENAI_API_KEY")
response = openai.Embedding.create(
+ api_key=self.api_key,
input=[text],
engine=self.model
)
diff --git a/superagi/vector_store/vector_factory.py b/superagi/vector_store/vector_factory.py
index 528885694..cf6de5459 100644
--- a/superagi/vector_store/vector_factory.py
+++ b/superagi/vector_store/vector_factory.py
@@ -1,17 +1,17 @@
-import os
-
import pinecone
from pinecone import UnauthorizedException
from superagi.vector_store.pinecone import Pinecone
from superagi.vector_store import weaviate
from superagi.config.config import get_config
+from superagi.lib.logger import logger
+from superagi.types.vector_store_types import VectorStoreType
class VectorFactory:
@classmethod
- def get_vector_storage(cls, vector_store, index_name, embedding_model):
+ def get_vector_storage(cls, vector_store: VectorStoreType, index_name, embedding_model):
"""
Get the vector storage.
@@ -23,7 +23,8 @@ def get_vector_storage(cls, vector_store, index_name, embedding_model):
Returns:
The vector storage object.
"""
- if vector_store == "PineCone":
+ vector_store = VectorStoreType.get_vector_store_type(vector_store)
+ if vector_store == VectorStoreType.PINECONE:
try:
api_key = get_config("PINECONE_API_KEY")
env = get_config("PINECONE_ENVIRONMENT")
@@ -33,6 +34,8 @@ def get_vector_storage(cls, vector_store, index_name, embedding_model):
if index_name not in pinecone.list_indexes():
sample_embedding = embedding_model.get_embedding("sample")
+ if "error" in sample_embedding:
+ logger.error(f"Error in embedding model {sample_embedding}")
# if does not exist, create index
pinecone.create_index(
@@ -44,9 +47,9 @@ def get_vector_storage(cls, vector_store, index_name, embedding_model):
return Pinecone(index, embedding_model, 'text')
except UnauthorizedException:
raise ValueError("PineCone API key not found")
+
if vector_store == "Weaviate":
-
use_embedded = get_config("WEAVIATE_USE_EMBEDDED")
url = get_config("WEAVIATE_URL")
api_key = get_config("WEAVIATE_API_KEY")
@@ -58,5 +61,4 @@ def get_vector_storage(cls, vector_store, index_name, embedding_model):
)
return weaviate.Weaviate(client, embedding_model, index_name, 'text')
- else:
- raise Exception("Vector store not supported")
+ raise ValueError(f"Vector store {vector_store} not supported")
diff --git a/superagi/worker.py b/superagi/worker.py
index 243eda64d..103e8c154 100644
--- a/superagi/worker.py
+++ b/superagi/worker.py
@@ -1,15 +1,22 @@
from __future__ import absolute_import
+
+from sqlalchemy.orm import sessionmaker
+
from superagi.lib.logger import logger
from celery import Celery
from superagi.config.config import get_config
+from superagi.models.db import connect_db
+
redis_url = get_config('REDIS_URL') or 'localhost:6379'
app = Celery("superagi", include=["superagi.worker"], imports=["superagi.worker"])
app.conf.broker_url = "redis://" + redis_url + "/0"
app.conf.result_backend = "redis://" + redis_url + "/0"
app.conf.worker_concurrency = 10
+app.conf.accept_content = ['application/x-python-serialize', 'application/json']
+
@app.task(name="execute_agent", autoretry_for=(Exception,), retry_backoff=2, max_retries=5)
def execute_agent(agent_execution_id: int, time):
@@ -17,3 +24,32 @@ def execute_agent(agent_execution_id: int, time):
from superagi.jobs.agent_executor import AgentExecutor
logger.info("Execute agent:" + str(time) + "," + str(agent_execution_id))
AgentExecutor().execute_next_action(agent_execution_id=agent_execution_id)
+
+
+@app.task(name="summarize_resource", autoretry_for=(Exception,), retry_backoff=2, max_retries=5, serializer='pickle')
+def summarize_resource(agent_id: int, resource_id: int):
+ """Summarize a resource in background."""
+ from superagi.resource_manager.resource_summary import ResourceSummarizer
+ from superagi.types.storage_types import StorageType
+ from superagi.models.resource import Resource
+ from superagi.resource_manager.resource_manager import ResourceManager
+
+ engine = connect_db()
+ Session = sessionmaker(bind=engine)
+ session = Session()
+
+ resource = session.query(Resource).filter(Resource.id == resource_id).first()
+ file_path = resource.path
+
+ if resource.storage_type == StorageType.S3.value:
+ documents = ResourceManager(str(agent_id)).create_llama_document_s3(file_path)
+ else:
+ documents = ResourceManager(str(agent_id)).create_llama_document(file_path)
+
+ logger.info("Summarize resource:" + str(agent_id) + "," + str(resource_id))
+ resource_summarizer = ResourceSummarizer(session=session)
+ resource_summarizer.add_to_vector_store_and_create_summary(agent_id=agent_id,
+ resource_id=resource_id,
+ documents=documents)
+ resource_summarizer.generate_agent_summary(agent_id=agent_id)
+ session.close()
diff --git a/tests/unit_tests/controllers/test_agent_execution_config.py b/tests/unit_tests/controllers/test_agent_execution_config.py
new file mode 100644
index 000000000..0e875c9e6
--- /dev/null
+++ b/tests/unit_tests/controllers/test_agent_execution_config.py
@@ -0,0 +1,36 @@
+from unittest.mock import patch
+
+import pytest
+from fastapi.testclient import TestClient
+
+from main import app
+from superagi.models.agent_execution_config import AgentExecutionConfiguration
+
+client = TestClient(app)
+
+
+@pytest.fixture
+def mocks():
+ # Mock tool kit data for testing
+ mock_execution_config = AgentExecutionConfiguration(id=1, key="test_key", value="['test']")
+ return mock_execution_config
+
+
+def test_get_agent_execution_configuration_success(mocks):
+ with patch('superagi.controllers.agent_execution_config.db') as mock_db:
+ mock_execution_config = mocks
+ mock_db.session.query.return_value.filter.return_value.all.return_value = [mock_execution_config]
+
+ response = client.get("/agent_executions_configs/details/1")
+
+ assert response.status_code == 200
+ assert response.json() == {"test_key": ['test']}
+
+
+def test_get_agent_execution_configuration_not_found():
+ with patch('superagi.controllers.agent_execution_config.db') as mock_db:
+ mock_db.session.query.return_value.filter.return_value.all.return_value = []
+ response = client.get("/agent_executions_configs/details/1")
+
+ assert response.status_code == 404
+ assert response.json() == {"detail": "Agent Execution Configuration not found"}
diff --git a/tests/unit_tests/helper/test_resource_helper.py b/tests/unit_tests/helper/test_resource_helper.py
index 15458e1e1..6826a12ae 100644
--- a/tests/unit_tests/helper/test_resource_helper.py
+++ b/tests/unit_tests/helper/test_resource_helper.py
@@ -8,14 +8,14 @@ def test_make_written_file_resource(mocker):
mocker.patch('os.makedirs', return_value=None)
mocker.patch('os.path.getsize', return_value=1000)
mocker.patch('os.path.splitext', return_value=("", ".txt"))
- mocker.patch('superagi.helper.resource_helper.get_config', side_effect=['/', 'local', None])
+ mocker.patch('superagi.helper.resource_helper.get_config', side_effect=['/', 'FILE', None])
with patch('superagi.helper.resource_helper.logger') as logger_mock:
result = ResourceHelper.make_written_file_resource('test.txt', 1, 'INPUT')
assert result.name == 'test.txt'
assert result.path == '/test.txt'
- assert result.storage_type == 'local'
+ assert result.storage_type == 'FILE'
assert result.size == 1000
assert result.type == 'application/txt'
assert result.channel == 'OUTPUT'
diff --git a/tests/unit_tests/helper/test_twitter_helper.py b/tests/unit_tests/helper/test_twitter_helper.py
new file mode 100644
index 000000000..6bf4e0015
--- /dev/null
+++ b/tests/unit_tests/helper/test_twitter_helper.py
@@ -0,0 +1,56 @@
+import unittest
+from unittest.mock import Mock, patch
+from requests.models import Response
+from requests_oauthlib import OAuth1Session
+from superagi.helper.twitter_helper import TwitterHelper
+
+class TestSendTweets(unittest.TestCase):
+
+ @patch.object(OAuth1Session, 'post')
+ def test_send_tweets_success(self, mock_post):
+ # Prepare test data and mocks
+ test_params = {"status": "Hello, Twitter!"}
+ test_creds = Mock()
+ test_oauth = OAuth1Session(test_creds.api_key)
+
+ # Mock successful posting
+ resp = Response()
+ resp.status_code = 200
+ mock_post.return_value = resp
+
+ # Call the method under test
+ response = TwitterHelper().send_tweets(test_params, test_creds)
+
+ # Assert the post request was called correctly
+ test_oauth.post.assert_called_once_with(
+ "https://api.twitter.com/2/tweets",
+ json=test_params)
+
+ # Assert the response is correct
+ self.assertEqual(response.status_code, 200)
+
+ @patch.object(OAuth1Session, 'post')
+ def test_send_tweets_failure(self, mock_post):
+ # Prepare test data and mocks
+ test_params = {"status": "Hello, Twitter!"}
+ test_creds = Mock()
+ test_oauth = OAuth1Session(test_creds.api_key)
+
+ # Mock unsuccessful posting
+ resp = Response()
+ resp.status_code = 400
+ mock_post.return_value = resp
+
+ # Call the method under test
+ response = TwitterHelper().send_tweets(test_params, test_creds)
+
+ # Assert the post request was called correctly
+ test_oauth.post.assert_called_once_with(
+ "https://api.twitter.com/2/tweets",
+ json=test_params)
+
+ # Assert the response is correct
+ self.assertEqual(response.status_code, 400)
+
+if __name__ == '__main__':
+ unittest.main()
\ No newline at end of file
diff --git a/tests/unit_tests/helper/test_twitter_tokens.py b/tests/unit_tests/helper/test_twitter_tokens.py
new file mode 100644
index 000000000..1fc9f50bd
--- /dev/null
+++ b/tests/unit_tests/helper/test_twitter_tokens.py
@@ -0,0 +1,51 @@
+import unittest
+from unittest.mock import patch, Mock, MagicMock
+from typing import NamedTuple
+import ast
+from sqlalchemy.orm import Session
+from superagi.helper.twitter_tokens import Creds, TwitterTokens
+from superagi.models.toolkit import Toolkit
+from superagi.models.oauth_tokens import OauthTokens
+import time
+import http.client
+
+
+class TestCreds(unittest.TestCase):
+ def test_init(self):
+ creds = Creds('api_key', 'api_key_secret', 'oauth_token', 'oauth_token_secret')
+ self.assertEqual(creds.api_key, 'api_key')
+ self.assertEqual(creds.api_key_secret, 'api_key_secret')
+ self.assertEqual(creds.oauth_token, 'oauth_token')
+ self.assertEqual(creds.oauth_token_secret, 'oauth_token_secret')
+
+
+class TestTwitterTokens(unittest.TestCase):
+ twitter_tokens = TwitterTokens(Session)
+ def setUp(self):
+ self.mock_session = Mock(spec=Session)
+ self.twitter_tokens = TwitterTokens(session=self.mock_session)
+
+ def test_init(self):
+ self.assertEqual(self.twitter_tokens.session, self.mock_session)
+
+ def test_percent_encode(self):
+ self.assertEqual(self.twitter_tokens.percent_encode("#"), "%23")
+
+ def test_gen_nonce(self):
+ self.assertEqual(len(self.twitter_tokens.gen_nonce()), 32)
+
+ @patch.object(time, 'time', return_value=1234567890)
+ @patch.object(http.client, 'HTTPSConnection')
+ @patch('superagi.helper.twitter_tokens.TwitterTokens.gen_nonce', return_value=123456) # Replace '__main__' with actual module name
+ @patch('superagi.helper.twitter_tokens.TwitterTokens.percent_encode', return_value="encoded") # Replace '__main__' with actual module name
+ def test_get_request_token(self, mock_percent_encode, mock_gen_nonce, mock_https_connection, mock_time):
+ response_mock = Mock()
+ response_mock.read.return_value = b'oauth_token=test_token&oauth_token_secret=test_secret'
+ mock_https_connection.return_value.getresponse.return_value = response_mock
+
+ api_data = {"api_key": "test_key", "api_secret": "test_secret"}
+ expected_result = {'oauth_token': 'test_token', 'oauth_token_secret': 'test_secret'}
+ self.assertEqual(self.twitter_tokens.get_request_token(api_data), expected_result)
+
+if __name__ == "__main__":
+ unittest.main()
\ No newline at end of file
diff --git a/tests/unit_tests/jobs/__init__.py b/tests/unit_tests/jobs/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/tests/unit_tests/jobs/test_resource_summary.py b/tests/unit_tests/jobs/test_resource_summary.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/tests/unit_tests/models/test_agent_execution_config.py b/tests/unit_tests/models/test_agent_execution_config.py
new file mode 100644
index 000000000..4b95e9253
--- /dev/null
+++ b/tests/unit_tests/models/test_agent_execution_config.py
@@ -0,0 +1,33 @@
+import unittest
+from unittest.mock import MagicMock, patch, call
+
+from distlib.util import AND
+
+from superagi.models.agent_execution_config import AgentExecutionConfiguration
+
+
+class TestAgentExecutionConfiguration(unittest.TestCase):
+
+ def setUp(self):
+ self.session = MagicMock()
+ self.execution = MagicMock()
+ self.execution.id = 1
+
+ def test_fetch_configuration(self):
+ test_db_response = [MagicMock(key="goal", value="['test_goal']"),
+ MagicMock(key="instruction", value="['test_instruction']")]
+
+ self.session.query.return_value.filter_by.return_value.all.return_value = test_db_response
+
+ result = AgentExecutionConfiguration.fetch_configuration(self.session, self.execution)
+
+ expected_result = {"goal": ["test_goal"], "instruction": ["test_instruction"]}
+ self.assertDictEqual(result, expected_result)
+
+ def test_eval_agent_config(self):
+ key = "goal"
+ value = "['test_goal']"
+
+ result = AgentExecutionConfiguration.eval_agent_config(key, value)
+
+ self.assertEqual(result, ["test_goal"])
\ No newline at end of file
diff --git a/tests/unit_tests/resource_manager/test_llama_document_creation.py b/tests/unit_tests/resource_manager/test_llama_document_creation.py
new file mode 100644
index 000000000..84dbe7ced
--- /dev/null
+++ b/tests/unit_tests/resource_manager/test_llama_document_creation.py
@@ -0,0 +1,52 @@
+import pytest
+from unittest.mock import patch, MagicMock
+from superagi.resource_manager.resource_manager import ResourceManager
+
+
+def test_create_llama_document_s3(mocker):
+ agent_id = 'test_agent'
+ resource_manager = ResourceManager(agent_id)
+
+ mock_boto_client = MagicMock()
+ mock_s3_obj = {
+ 'Body': MagicMock(read=MagicMock(return_value='mock_file_content'))
+ }
+ mock_boto_client.get_object.return_value = mock_s3_obj
+ mocker.patch('boto3.client', return_value=mock_boto_client)
+
+ mocker.patch('superagi.resource_manager.resource_manager.get_config',
+ side_effect=['mock_access_key', 'mock_secret_key', 'mock_bucket'])
+
+ mocker.patch('builtins.open', mocker.mock_open())
+ mocker.patch('os.remove')
+
+ MockSimpleDirectoryReader = MagicMock()
+ mocker.patch('superagi.resource_manager.resource_manager.SimpleDirectoryReader',
+ return_value=MockSimpleDirectoryReader)
+
+ resource_manager.create_llama_document_s3('mock_file_path')
+
+ mock_boto_client.get_object.assert_called_once_with(
+ Bucket='mock_bucket',
+ Key='mock_file_path')
+ MockSimpleDirectoryReader.load_data.assert_called_once()
+
+
+def test_create_llama_document_s3_file_path_provided(mocker):
+ resource_manager = ResourceManager('test_agent')
+
+ mock_boto_client = MagicMock()
+ mocker.patch('boto3.client', return_value=mock_boto_client)
+
+ mocker.patch('superagi.resource_manager.resource_manager.get_config',
+ side_effect=['mock_access_key', 'mock_secret_key', 'mock_bucket'])
+
+ mocker.patch('builtins.open', mocker.mock_open())
+ mocker.patch('os.remove')
+
+ MockSimpleDirectoryReader = MagicMock()
+ mocker.patch('superagi.resource_manager.resource_manager.SimpleDirectoryReader',
+ return_value=MockSimpleDirectoryReader)
+
+ with pytest.raises(Exception, match="file_path must be provided"):
+ resource_manager.create_llama_document_s3(None)
\ No newline at end of file
diff --git a/tests/unit_tests/resource_manager/test_llama_vector_store_factory.py b/tests/unit_tests/resource_manager/test_llama_vector_store_factory.py
new file mode 100644
index 000000000..7949c2da8
--- /dev/null
+++ b/tests/unit_tests/resource_manager/test_llama_vector_store_factory.py
@@ -0,0 +1,32 @@
+import pytest
+from unittest.mock import patch
+
+from llama_index.vector_stores import PineconeVectorStore, RedisVectorStore
+
+from superagi.resource_manager.llama_vector_store_factory import LlamaVectorStoreFactory
+from superagi.types.vector_store_types import VectorStoreType
+
+
+def test_llama_vector_store_factory():
+ # Mocking method arguments
+ vector_store_name = VectorStoreType.PINECONE
+ index_name = "test_index_name"
+ factory = LlamaVectorStoreFactory(vector_store_name, index_name)
+
+ # Test case for VectorStoreType.PINECONE
+ with patch.object(PineconeVectorStore, "__init__", return_value=None):
+ vector_store = factory.get_vector_store()
+ assert isinstance(vector_store, PineconeVectorStore)
+
+ # Test case for VectorStoreType.REDIS
+ factory.vector_store_name = VectorStoreType.REDIS
+ with patch.object(RedisVectorStore, "__init__", return_value=None), \
+ patch('superagi.config.config.get_config', return_value=None):
+ vector_store = factory.get_vector_store()
+ assert isinstance(vector_store, RedisVectorStore)
+
+ # Test case for unknown VectorStoreType
+ factory.vector_store_name = "unknown"
+ with pytest.raises(ValueError) as exc_info:
+ factory.get_vector_store()
+ assert str(exc_info.value) == "unknown vector store is not supported yet."
\ No newline at end of file
diff --git a/tests/unit_tests/resource_manager/test_resource_manager.py b/tests/unit_tests/resource_manager/test_resource_manager.py
index 2c0fc624b..b382520d9 100644
--- a/tests/unit_tests/resource_manager/test_resource_manager.py
+++ b/tests/unit_tests/resource_manager/test_resource_manager.py
@@ -5,12 +5,12 @@
from superagi.helper.s3_helper import S3Helper
from superagi.lib.logger import logger
-from superagi.resource_manager.manager import ResourceManager
+from superagi.resource_manager.file_manager import FileManager
@pytest.fixture
def resource_manager():
session_mock = Mock()
- resource_manager = ResourceManager(session_mock)
+ resource_manager = FileManager(session_mock)
#resource_manager.agent_id = 1 # replace with actual value
return resource_manager
diff --git a/tests/unit_tests/resource_manager/test_save_document_to_vector_store.py b/tests/unit_tests/resource_manager/test_save_document_to_vector_store.py
new file mode 100644
index 000000000..e0edfa01c
--- /dev/null
+++ b/tests/unit_tests/resource_manager/test_save_document_to_vector_store.py
@@ -0,0 +1,30 @@
+from unittest.mock import patch, Mock
+from llama_index import VectorStoreIndex, StorageContext, Document
+from superagi.resource_manager.resource_manager import ResourceManager
+from superagi.resource_manager.llama_vector_store_factory import LlamaVectorStoreFactory
+
+
+@patch.object(LlamaVectorStoreFactory, 'get_vector_store')
+@patch.object(StorageContext, 'from_defaults')
+@patch.object(VectorStoreIndex, 'from_documents')
+def test_save_document_to_vector_store(mock_vc_from_docs, mock_sc_from_defaults, mock_get_vector_store):
+ # Prepare test resources
+ mock_vector_store = Mock()
+ mock_get_vector_store.return_value = mock_vector_store
+ mock_sc_from_defaults.return_value = "mock_storage_context"
+ mock_vc_from_docs.return_value = "mock_index"
+
+ resource_manager = ResourceManager("test_agent_id")
+ documents = [Document(text="doc1"), Document(text="doc2")]
+ resource_id = "test_resource_id"
+
+ # Run test method
+ resource_manager.save_document_to_vector_store(documents, resource_id, "test_model_api_key")
+
+ # Validate calls
+ mock_get_vector_store.assert_called_once()
+ mock_sc_from_defaults.assert_called_once_with(vector_store=mock_vector_store)
+ mock_vc_from_docs.assert_called_once_with(documents, storage_context="mock_storage_context")
+
+ # Add more assertions here if needed, e.g., to check side effects
+ mock_vector_store.persist.assert_called_once()
\ No newline at end of file
diff --git a/tests/unit_tests/test_tool_manager.py b/tests/unit_tests/test_tool_manager.py
index eb4a7bcf6..e59a874ac 100644
--- a/tests/unit_tests/test_tool_manager.py
+++ b/tests/unit_tests/test_tool_manager.py
@@ -47,5 +47,5 @@ def test_download_and_extract_tools(mock_load_tools_config, mock_download_tool):
download_and_extract_tools()
mock_load_tools_config.assert_called_once()
- mock_download_tool.assert_any_call('url1', 'tools/tool1')
- mock_download_tool.assert_any_call('url2', 'tools/tool2')
+ mock_download_tool.assert_any_call('url1', os.path.join('superagi', 'tools', 'tool1'))
+ mock_download_tool.assert_any_call('url2', os.path.join('superagi', 'tools', 'tool2'))
diff --git a/tests/unit_tests/tools/code/test_write_code.py b/tests/unit_tests/tools/code/test_write_code.py
index ca1ab6573..aefdf35a0 100644
--- a/tests/unit_tests/tools/code/test_write_code.py
+++ b/tests/unit_tests/tools/code/test_write_code.py
@@ -2,7 +2,7 @@
import pytest
-from superagi.resource_manager.manager import ResourceManager
+from superagi.resource_manager.file_manager import FileManager
from superagi.tools.code.write_code import CodingTool
from superagi.tools.tool_response_query_manager import ToolResponseQueryManager
@@ -20,7 +20,7 @@ class TestCodingTool:
def tool(self):
tool = CodingTool()
tool.llm = MockBaseLlm()
- tool.resource_manager = Mock(spec=ResourceManager)
+ tool.resource_manager = Mock(spec=FileManager)
tool.tool_response_manager = Mock(spec=ToolResponseQueryManager)
return tool
diff --git a/tests/unit_tests/tools/email/test_send_email_attachment.py b/tests/unit_tests/tools/email/test_send_email_attachment.py
index 4f9dfb8b8..d2d69e3b0 100644
--- a/tests/unit_tests/tools/email/test_send_email_attachment.py
+++ b/tests/unit_tests/tools/email/test_send_email_attachment.py
@@ -1,29 +1,30 @@
+import unittest
from unittest.mock import patch, Mock
+import os
+from superagi.tools.email.send_email_attachment import SendEmailAttachmentTool, SendEmailAttachmentInput
-from superagi.tools.email.send_email_attachment import SendEmailAttachmentTool
+class TestSendEmailAttachmentTool(unittest.TestCase):
+ @patch("superagi.tools.email.send_email_attachment.SendEmailAttachmentTool.send_email_with_attachment")
+ @patch("superagi.helper.resource_helper.ResourceHelper.get_agent_resource_path")
+ @patch("superagi.helper.resource_helper.ResourceHelper.get_root_input_dir")
+ @patch("os.path.exists")
+ def test__execute(self, mock_exists, mock_get_root_input_dir, mock_get_agent_resource_path, mock_send_email_with_attachment):
+ # Arrange
+ tool = SendEmailAttachmentTool()
+ tool.agent_id = 1
+ mock_exists.return_value = True
+ mock_get_agent_resource_path.return_value = "/test/path/test.txt"
+ mock_get_root_input_dir.return_value = "/root_dir/"
+ mock_send_email_with_attachment.return_value = "Email sent"
+ expected_result = "Email sent"
+ # Act
+ result = tool._execute("test@example.com", "test subject", "test body", "test.txt")
-def test_send_email_attachment():
- # Arrange
- with patch('superagi.tools.email.send_email_attachment.smtplib.SMTP') as mock_smtp:
- with patch('superagi.tools.email.send_email_attachment.ImapEmail') as mock_imap_email:
- with patch('superagi.tools.email.send_email_attachment.open', mock=Mock(), create=True) as mock_open:
- mock_open.return_value.__enter__.return_value.read.return_value = b"some file content"
- mock_imap_email_instance = mock_imap_email.return_value
- mock_smtp_instance = mock_smtp.return_value
- tool = SendEmailAttachmentTool()
- tool.toolkit_config.get_tool_config = Mock()
- tool.toolkit_config.get_tool_config.return_value = 'dummy_value'
- mock_smtp_instance.send_message = Mock()
- to = 'test@example.com'
- subject = 'test_subject'
- body = 'test_body'
- filename = 'test_file.txt'
+ # Assert
+ self.assertEqual(result, expected_result)
+ mock_get_agent_resource_path.assert_called_once_with("test.txt", tool.agent_id)
+ mock_send_email_with_attachment.assert_called_once_with("test@example.com", "test subject", "test body", "/test/path/test.txt", "test.txt")
- # Act
- result = tool._execute(to, subject, body, filename)
-
- # mock_smtp_instance.send_message.assert_called_once()
- assert result == f"Email was sent to {to}"
- assert 'rb' in mock_open.call_args[0]
- assert filename in mock_open.call_args[0][0]
+if __name__ == "__main__":
+ unittest.main()
\ No newline at end of file
diff --git a/tests/unit_tests/tools/image_generation/test_stable_diffusion_image_gen.py b/tests/unit_tests/tools/image_generation/test_stable_diffusion_image_gen.py
index be12f7969..aa48b70b2 100644
--- a/tests/unit_tests/tools/image_generation/test_stable_diffusion_image_gen.py
+++ b/tests/unit_tests/tools/image_generation/test_stable_diffusion_image_gen.py
@@ -27,7 +27,7 @@ def create_sample_image_base64():
def stable_diffusion_tool():
with patch('superagi.tools.image_generation.stable_diffusion_image_gen.requests.post') as post_mock, \
patch(
- 'superagi.tools.image_generation.stable_diffusion_image_gen.ResourceManager') as resource_manager_mock:
+ 'superagi.tools.image_generation.stable_diffusion_image_gen.FileManager') as resource_manager_mock:
# Create a mock response object
response_mock = Mock()
diff --git a/tests/unit_tests/tools/twitter/test_send_tweets.py b/tests/unit_tests/tools/twitter/test_send_tweets.py
new file mode 100644
index 000000000..a50fa1045
--- /dev/null
+++ b/tests/unit_tests/tools/twitter/test_send_tweets.py
@@ -0,0 +1,51 @@
+import unittest
+from unittest.mock import MagicMock, patch
+from superagi.tools.twitter.send_tweets import SendTweetsInput, SendTweetsTool
+
+
+class TestSendTweetsInput(unittest.TestCase):
+ def test_fields(self):
+ # Creating object
+ data = SendTweetsInput(tweet_text='Hello world', is_media=True, media_files=['image1.png', 'image2.png'])
+ # Testing object
+ self.assertEqual(data.tweet_text, 'Hello world')
+ self.assertEqual(data.is_media, True)
+ self.assertEqual(data.media_files, ['image1.png', 'image2.png'])
+
+
+class TestSendTweetsTool(unittest.TestCase):
+ @patch('superagi.helper.twitter_tokens.TwitterTokens.get_twitter_creds', return_value={'token': '123', 'token_secret': '456'})
+ @patch('superagi.helper.twitter_helper.TwitterHelper.get_media_ids', return_value=[789])
+ @patch('superagi.helper.twitter_helper.TwitterHelper.send_tweets')
+ def test_execute(self, mock_send_tweets, mock_get_media_ids, mock_get_twitter_creds):
+ # Mock the response from 'send_tweets'
+ responseMock = MagicMock()
+ responseMock.status_code = 201
+ mock_send_tweets.return_value = responseMock
+
+ # Creating SendTweetsTool object
+ obj = SendTweetsTool()
+ obj.toolkit_config = MagicMock()
+ obj.toolkit_config.toolkit_id = 1
+ obj.toolkit_config.session = MagicMock()
+ obj.agent_id = 99
+
+ # Testing when 'is_media' is True, 'tweet_text' is 'None' and 'media_files' is an empty list
+ self.assertEqual(obj._execute(True), "Tweet posted successfully!!")
+ mock_get_twitter_creds.assert_called_once_with(1)
+ mock_get_media_ids.assert_called_once_with([], {'token': '123', 'token_secret': '456'}, 99)
+ mock_send_tweets.assert_called_once_with({'media': {'media_ids': [789]}, 'text': 'None'}, {'token': '123', 'token_secret': '456'})
+
+ # Testing when 'is_media' is False, 'tweet_text' is 'Hello world' and 'media_files' is a list with elements
+ mock_get_twitter_creds.reset_mock()
+ mock_get_media_ids.reset_mock()
+ mock_send_tweets.reset_mock()
+ responseMock.status_code = 400
+ self.assertEqual(obj._execute(False, 'Hello world', ['image1.png']), "Error posting tweet. (Status code: 400)")
+ mock_get_twitter_creds.assert_called_once_with(1)
+ mock_get_media_ids.assert_not_called()
+ mock_send_tweets.assert_called_once_with({'text': 'Hello world'}, {'token': '123', 'token_secret': '456'})
+
+
+if __name__ == '__main__':
+ unittest.main()
\ No newline at end of file
diff --git a/tests/unit_tests/vector_store/test_chromadb.py b/tests/unit_tests/vector_store/test_chromadb.py
new file mode 100644
index 000000000..a3e41916b
--- /dev/null
+++ b/tests/unit_tests/vector_store/test_chromadb.py
@@ -0,0 +1,46 @@
+import pytest
+from unittest.mock import MagicMock, patch
+from superagi.vector_store.chromadb import ChromaDB
+from superagi.vector_store.document import Document
+from superagi.vector_store.embedding.openai import BaseEmbedding, OpenAiEmbedding
+
+
+@pytest.fixture
+def mock_embedding_model():
+ mock_model = MagicMock(spec=BaseEmbedding)
+ mock_model.get_embedding.return_value = [0.1, 0.2, 0.3] # dummy embedding vector
+ return mock_model
+
+@patch('chromadb.Client')
+def test_create_collection(mock_chromadb_client):
+ ChromaDB.create_collection('test_collection')
+ mock_chromadb_client().get_or_create_collection.assert_called_once_with(name='test_collection')
+
+@patch('chromadb.Client')
+def test_add_texts(mock_chromadb_client, mock_embedding_model):
+ chroma_db = ChromaDB('test_collection', mock_embedding_model, 'text')
+ chroma_db.add_texts(['hello world'], [{'key': 'value'}])
+ mock_chromadb_client().get_collection().add.assert_called_once()
+
+@patch('chromadb.Client')
+@patch.object(BaseEmbedding, 'get_embedding')
+def test_get_matching_text(mock_get_embedding, mock_chromadb_client):
+ # Setup
+ mock_get_embedding.return_value = [0.1, 0.2, 0.3, 0.4, 0.5] # dummy vector
+
+ mock_chromadb_client().get_collection().query.return_value = {
+ 'ids': [['id1', 'id2', 'id3']],
+ 'documents': [['doc1', 'doc2', 'doc3']],
+ 'metadatas': [[{'meta1': 'value1'}, {'meta2': 'value2'}, {'meta3': 'value3'}]]
+ }
+ chroma_db = ChromaDB('test_collection', OpenAiEmbedding(api_key="asas"), 'text')
+
+ # Execute
+ documents = chroma_db.get_matching_text('hello world')
+
+ # Validate
+ assert isinstance(documents[0], Document)
+ assert len(documents) == 3
+ for doc in documents:
+ assert 'text_content' in doc.dict().keys()
+ assert 'metadata' in doc.dict().keys()
\ No newline at end of file
diff --git a/workspace/input/testing.txt b/workspace/input/testing.txt
new file mode 100644
index 000000000..fe857725e
--- /dev/null
+++ b/workspace/input/testing.txt
@@ -0,0 +1 @@
+"Hello world"
diff --git a/workspace/output/testing.txt b/workspace/output/testing.txt
new file mode 100644
index 000000000..06ae699f2
--- /dev/null
+++ b/workspace/output/testing.txt
@@ -0,0 +1 @@
+"Hello World"