Skip to content

Commit

Permalink
add connector pooling for #51
Browse files Browse the repository at this point in the history
  • Loading branch information
mathieuancelin committed Jan 15, 2025
1 parent ded7a83 commit 64b11f5
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 80 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ object LlmFunctions {
} yield wasmFunctionsR ++ mcpConnectorsR
}

def tools(wasmFunctions: Seq[String], mcpConnectors: Seq[String])(implicit env: Env): JsObject = {
def tools(wasmFunctions: Seq[String], mcpConnectors: Seq[String])(implicit ec: ExecutionContext, env: Env): JsObject = {
val tools: Seq[JsObject] = LlmToolFunction._tools(wasmFunctions) ++ McpSupport.tools(mcpConnectors)
Json.obj(
"tools" -> tools
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@ import otoroshi_plugins.com.cloud.apim.extensions.aigateway.{AiExtension, AiGate
import play.api.libs.json._

import java.util.UUID
import java.util.concurrent.ConcurrentLinkedQueue
import java.util.concurrent.atomic.AtomicInteger
import scala.collection.concurrent.TrieMap
import scala.concurrent.duration.{Duration, DurationLong, FiniteDuration}
import scala.concurrent.{ExecutionContext, Future}
import scala.concurrent.{Await, ExecutionContext, Future, Promise}
import scala.jdk.CollectionConverters._
import scala.util.{Failure, Success, Try}

Expand Down Expand Up @@ -76,13 +78,18 @@ object McpConnectorTransport {
}
}

case class McpConnectorPoolSettings(size: Int = 1) {
def json: JsValue = Json.obj("size" -> size)
}

case class McpConnector(
location: EntityLocation,
id: String,
name: String,
description: String,
tags: Seq[String],
metadata: Map[String, String],
pool: McpConnectorPoolSettings,
transport: McpConnectorTransport,
) extends EntityLocationSupport {
override def internalId: String = id
Expand All @@ -98,21 +105,74 @@ case class McpConnector(

def hasClientChanged(): Boolean = {
McpConnector.connectorsCache.get(id) match {
case Some((client, hash, time)) if hash != json.stringify.sha256 => true
case Some((_, _, hash, _)) if hash != json.stringify.sha256 => true
case _ => false
}
}

def client(): DefaultMcpClient = {
def restartIfNeeded(): Unit = {
clientPool()
}

private def clientPool(): ConcurrentLinkedQueue[DefaultMcpClient] = synchronized {
McpConnector.connectorsCache.get(id) match {
case Some((cli, hash, _)) if hash == json.stringify.sha256 => cli
case Some((cli, _, hash, _)) if hash == json.stringify.sha256 => cli
case e => {
val cli = buildClient()
McpConnector.connectorsCache.put(id, (cli, json.stringify.sha256, System.currentTimeMillis()))
e.foreach(_._1.close())
cli
val pool = new ConcurrentLinkedQueue[DefaultMcpClient]()
pool.add(cli)
McpConnector.connectorsCache.put(id, (pool, new AtomicInteger(1), json.stringify.sha256, System.currentTimeMillis()))
e.foreach(_._1.asScala.foreach(_.close()))
pool
}
}
}

private def withClient[T](f: DefaultMcpClient => T)(implicit ec: ExecutionContext, env: Env): Future[T] = {
val promise = Promise.apply[T]()
McpConnector.connectorsCache.get(id) match {
case None => {
clientPool()
withClient(f).andThen {
case Failure(e) => promise.tryFailure(e)
case Success(e) => promise.trySuccess(e)
}
}
case Some((queue, counter, _, _)) => {
val item = queue.poll()
if (item == null) {
if (counter.get() < pool.size) {
counter.incrementAndGet()
val cli = buildClient()
try {
val r = f(cli)
promise.trySuccess(r)
} catch {
case e: Throwable => promise.tryFailure(e)
} finally {
queue.add(cli)
}
} else {
env.otoroshiScheduler.scheduleOnce(100.millis) {
withClient(f).andThen {
case Failure(e) => promise.tryFailure(e)
case Success(e) => promise.trySuccess(e)
}
}
}
} else {
try {
val r = f(item)
promise.trySuccess(r)
} catch {
case e: Throwable => promise.tryFailure(e)
} finally {
queue.add(item)
}
}
}
}
promise.future
}

private def buildClient(): DefaultMcpClient = {
Expand Down Expand Up @@ -143,47 +203,26 @@ case class McpConnector(
.build()
}

// def withClient[T](f: DefaultMcpClient => T): T = {
// val stdioTransport = new StdioMcpTransport.Builder()
// .command(java.util.List.of("/Users/mathieuancelin/.nvm/versions/node/v18.19.0/bin/node", "/Users/mathieuancelin/projects/clever-ai/mpc-test/mcp-otoroshi-proxy/bin/proxy.js"))
// .logEvents(true) // only if you want to see the traffic in the log
// .build()
// val sseTransport = new HttpMcpTransport.Builder()
// .sseUrl("http://localhost:3001/sse")
// //.postUrl("http://localhost:3001/message")
// .logRequests(true) // if you want to see the traffic in the log
// .logResponses(true)
// .build()
// val mcpClient = new DefaultMcpClient.Builder()
// .transport(stdioTransport)
// .build()
// // mcpClient.listTools()
// // val res = mcpClient.executeTool()
// val c = client()
// try {
// f(c)
// } finally {
// c.close()
// }
// }
def listTools()(implicit ec: ExecutionContext, env: Env): Future[Seq[ToolSpecification]] = withClient(_.listTools().asScala)

def listTools(): Seq[ToolSpecification] = client().listTools().asScala
def listToolsBlocking()(implicit ec: ExecutionContext, env: Env): Seq[ToolSpecification] = Await.result(listTools(), 10.seconds)

def call(name: String = "get-rooms", args: String = "{}"): String = {
def call(name: String, args: String)(implicit ec: ExecutionContext, env: Env): Future[String] = {
val request = ToolExecutionRequest.builder().id(UUID.randomUUID().toString()).name(name).arguments(args).build()
client().executeTool(request)
withClient(_.executeTool(request))
}
}

object McpConnector {
val connectorsCache = new TrieMap[String, (DefaultMcpClient, String, Long)]()
val connectorsCache = new TrieMap[String, (ConcurrentLinkedQueue[DefaultMcpClient], AtomicInteger, String, Long)]()
val format = new Format[McpConnector] {
override def writes(o: McpConnector): JsValue = o.location.jsonWithKey ++ Json.obj(
"id" -> o.id,
"name" -> o.name,
"description" -> o.description,
"metadata" -> o.metadata,
"tags" -> JsArray(o.tags.map(JsString.apply)),
"id" -> o.id,
"name" -> o.name,
"description" -> o.description,
"metadata" -> o.metadata,
"tags" -> JsArray(o.tags.map(JsString.apply)),
"pool" -> o.pool.json,
"transport" -> o.transport.json,
)
override def reads(json: JsValue): JsResult[McpConnector] = Try {
Expand All @@ -194,6 +233,7 @@ object McpConnector {
description = (json \ "description").as[String],
metadata = (json \ "metadata").asOpt[Map[String, String]].getOrElse(Map.empty),
tags = (json \ "tags").asOpt[Seq[String]].getOrElse(Seq.empty[String]),
pool = McpConnectorPoolSettings((json \ "pool" \ "size").asOpt[Int].filter(_ > 0).getOrElse(1)),
transport = (json \ "transport").asOpt(McpConnectorTransport.format).getOrElse(McpConnectorTransport()),
)
} match {
Expand Down Expand Up @@ -223,6 +263,7 @@ object McpConnector {
metadata = Map.empty,
tags = Seq.empty,
location = EntityLocation.default,
pool = McpConnectorPoolSettings(),
transport = McpConnectorTransport(
kind = Stdio,
options = Json.obj(
Expand Down Expand Up @@ -266,15 +307,15 @@ object McpSupport {
def restartConnectorsIfNeeded()(implicit env: Env): Unit = {
val ext = env.adminExtensions.extension[AiExtension].get
ext.states.allMcpConnectors().foreach { connector =>
connector.client()
connector.restartIfNeeded()
}
}

def stopConnectorsIfNeeded()(implicit env: Env): Unit = {
val ext = env.adminExtensions.extension[AiExtension].get
McpConnector.connectorsCache.keySet.foreach { key =>
ext.states.mcpConnector(key) match {
case None => McpConnector.connectorsCache.remove(key).foreach(_._1.close())
case None => McpConnector.connectorsCache.remove(key).foreach(_._1.asScala.foreach(_.close()))
case Some(_) => ()
}
}
Expand Down Expand Up @@ -312,11 +353,11 @@ object McpSupport {
}
}

def tools(connectors: Seq[String])(implicit env: Env): Seq[JsObject] = {
def tools(connectors: Seq[String])(implicit env: Env, ec: ExecutionContext): Seq[JsObject] = {
val ext = env.adminExtensions.extension[AiExtension].get
connectors.zipWithIndex.flatMap(tuple => ext.states.mcpConnector(tuple._1).map(v => (v, tuple._2))).flatMap {
case (connector, idx) =>
connector.listTools().map { function =>
connector.listToolsBlocking().map { function =>
val additionalProperties: scala.Boolean = Option(function.parameters().additionalProperties()).map(_.booleanValue()).getOrElse(false)
val required: Seq[String] = Option(function.parameters().required()).map(_.asScala.toSeq).getOrElse(Seq.empty)
val properties: JsObject = JsObject(Option(function.parameters().properties()).map(_.asScala).getOrElse(Map.empty[String, JsonSchemaElement]).mapValues { el =>
Expand Down Expand Up @@ -356,7 +397,7 @@ object McpSupport {
case None => (s"undefined mcp connector ${connectorId}", toolCall).some.vfuture
case Some(function) => {
println(s"calling mcp function '${functionName}' with args: '${toolCall.function.arguments}'")
function.call(functionName, toolCall.function.arguments).vfuture.map { r =>
function.call(functionName, toolCall.function.arguments).map { r =>
(r, toolCall).some
}
}
Expand Down
Loading

0 comments on commit 64b11f5

Please sign in to comment.