Skip to content

Commit

Permalink
add ws and rest plugins for #51
Browse files Browse the repository at this point in the history
  • Loading branch information
mathieuancelin committed Jan 14, 2025
1 parent 69845c0 commit ecdff17
Showing 1 changed file with 329 additions and 1 deletion.
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package otoroshi_plugins.com.cloud.apim.otoroshi.extensions.aigateway.plugins

import akka.stream.scaladsl.{Sink, Source, SourceQueueWithComplete}
import akka.actor.{Actor, ActorRef, PoisonPill, Props}
import akka.stream.scaladsl.{Flow, Sink, Source, SourceQueueWithComplete}
import akka.stream.{Materializer, OverflowStrategy}
import akka.util.ByteString
import otoroshi.env.Env
Expand All @@ -9,7 +10,9 @@ import otoroshi.next.proxy.NgProxyEngineError
import otoroshi.security.IdGenerator
import otoroshi.utils.syntax.implicits._
import otoroshi_plugins.com.cloud.apim.extensions.aigateway.AiExtension
import play.api.http.websocket.Message
import play.api.libs.json._
import play.api.libs.streams.ActorFlow
import play.api.mvc.Results

import java.util.concurrent.atomic.{AtomicBoolean, AtomicReference}
Expand Down Expand Up @@ -383,3 +386,328 @@ class McpSseEndpoint extends NgBackendCall {
}
}

class McpWebsocketEndpoint extends NgWebsocketBackendPlugin {

override def name: String = "Cloud APIM - MCP Websocket Endpoint"
override def description: Option[String] = "Exposes tool functions as an MCP server using the (non-official) Websocket Transport".some

override def core: Boolean = false
override def visibility: NgPluginVisibility = NgPluginVisibility.NgUserLand
override def categories: Seq[NgPluginCategory] = Seq(NgPluginCategory.Custom("Cloud APIM"), NgPluginCategory.Custom("AI - LLM"))
override def steps: Seq[NgStep] = Seq(NgStep.CallBackend)
override def defaultConfigObject: Option[NgPluginConfig] = Some(McpProxyEndpointConfig.default)

override def noJsForm: Boolean = true
override def configFlow: Seq[String] = McpProxyEndpointConfig.configFlow
override def configSchema: Option[JsObject] = McpProxyEndpointConfig.configSchema

override def start(env: Env): Future[Unit] = {
env.adminExtensions.extension[AiExtension].foreach { ext =>
ext.logger.info("the 'MCP SSE Endpoint' plugin is available !")
}
().vfuture
}

override def callBackend(ctx: NgWebsocketPluginContext)(implicit env: Env, ec: ExecutionContext): Flow[Message, Message, _] = {
val config = ctx.cachedConfig(internalName)(McpProxyEndpointConfig.format).getOrElse(McpProxyEndpointConfig.default)
ActorFlow
.actorRef(out => McpActor.props(out, config, env))(env.otoroshiActorSystem, env.otoroshiMaterializer)
}
}

object McpActor {
def props(out: ActorRef, config: McpProxyEndpointConfig, env: Env) = Props(new McpActor(out, config, env))
}

class McpActor(out: ActorRef, config: McpProxyEndpointConfig, env: Env) extends Actor {

val ready = new AtomicBoolean(false)
val canceledRequests = new TrieMap[Long, Unit]()

def send(msg: JsValue): Unit = {
val id = msg.select("id").asLong
if (!canceledRequests.contains(id)) {
out ! play.api.http.websocket.TextMessage(msg.stringify)
}
}

def jsonRpcResponse(id: Long, payload: JsValue): JsValue = {
Json.obj(
"jsonrpc" -> "2.0",
"id" -> id,
"result" -> payload
)
}

def jsonRpcError(id: Long, code: Int, message: String, data: JsValue): JsValue = {
Json.obj(
"jsonrpc" -> "2.0",
"id" -> id,
"result" -> Json.obj(
"code" -> code,
"message" -> message,
"data" -> data
)
)
}

def emptyResp(id: Long): JsValue = {
jsonRpcResponse(id, Json.obj())
}

def initialize(id: Long): JsValue = {
val response = Json.obj(
"protocolVersion" -> "2024-11-05",
"capabilities" -> Json.obj("tools" -> Json.obj(), "logging" -> Json.obj()),
"serverInfo" -> Json.obj("name" -> "otoroshi-ws-endpoint", "version" -> "1.0.0"), // TODO: custom valies
)
jsonRpcResponse(id, response)
}

def getToolList(id: Long, config: McpProxyEndpointConfig): JsValue = {
val ext = env.adminExtensions.extension[AiExtension].get
val functions = config.refs.flatMap(r => ext.states.toolFunction(r))
val response = Json.obj("tools" -> JsArray(functions.map { wf =>
val required: JsArray = wf.required.map(v => JsArray(v.map(_.json))).getOrElse(JsArray(wf.parameters.value.keySet.toSeq.map(_.json)))
Json.obj(
"name" -> wf.name,
"description" -> wf.description,
"inputSchema" -> Json.obj(
"type" -> "object",
"properties" -> wf.parameters,
"required" -> required,
),
)
}))
jsonRpcResponse(id, response)
}

def toolsCall(id: Long, request: JsValue, config: McpProxyEndpointConfig): Future[JsValue] = {
implicit val ec = env.otoroshiExecutionContext
implicit val ev = env
val params = request.select("params").asOpt[JsObject].getOrElse(Json.obj())
val ext = env.adminExtensions.extension[AiExtension].get
val functions = config.refs.flatMap(r => ext.states.toolFunction(r))
val functionsMap = functions.map(f => (f.name, f)).toMap
val name = params.select("name").asString
val arguments = params.select("arguments").asOpt[JsObject].getOrElse(Json.obj())
functionsMap.get(name) match {
case None => {
jsonRpcError(id, 400, s"unknown function ${name}", Json.obj()).vfuture
}
case Some(function) => {
function.call(arguments.stringify).flatMap { res =>
val payload = Json.obj("content" -> Json.arr(Json.obj(
"type" -> "text",
"text" -> res
)))
jsonRpcResponse(id, payload).vfuture
}
}
}
}

def getResourcesList(id: Long): JsValue = {
jsonRpcResponse(id, Json.obj("resources" -> Json.arr()))
}

def getPromptsList(id: Long): JsValue = {
jsonRpcResponse(id, Json.obj("prompts" -> Json.arr()))
}

def getTemplatesList(id: Long): JsValue = {
jsonRpcResponse(id, Json.obj("templates" -> Json.arr()))
}

def handle(data: String): Unit = {
Try(data.parseJson) match {
case Failure(e) => send(jsonRpcError(0, 400, "error while parsing json-rpc payload", Json.obj()))
case Success(json) => {
val id = json.select("id").asOpt[Long].getOrElse(0L)
val resp: Future[JsValue] = json.select("method").asOpt[String] match {
case Some("initialize") => initialize(id).vfuture
case Some("shutdown") => {
self ! PoisonPill
emptyResp(id).vfuture
}
case Some("exit") => {
self ! PoisonPill
emptyResp(id).vfuture
}
case Some("ping") => jsonRpcResponse(id, Json.obj()).vfuture
case Some("cancelled") => {
canceledRequests.put(id, ())
emptyResp(id).vfuture
}
case Some("notifications/cancelled") => {
canceledRequests.put(id, ())
emptyResp(id).vfuture
}
case Some("notifications/initialized") => {
ready.set(true)
emptyResp(id).vfuture
}
case Some("tools/list") if ready.get() => getToolList(id, config).vfuture
case Some("resources/list") if ready.get() => getResourcesList(id).vfuture
case Some("resources/read") if ready.get() => emptyResp(id).vfuture // TODO: support ?
case Some("resources/templates/list") if ready.get() => getTemplatesList(id).vfuture
case Some("prompts/list") if ready.get() => getPromptsList(id).vfuture
case Some("prompts/get") if ready.get() => emptyResp(id).vfuture // TODO: support ?
case Some("tools/call") if ready.get() => toolsCall(id, json, config)
case _ => {
val method = json.select("method").asOpt[String].getOrElse("--")
jsonRpcResponse(id, Json.obj("error" -> "method unsupported", "error_details" -> Json.obj("method" -> method, "ready" -> ready.get()))).vfuture
}
}
resp.map(r => send(r))(env.otoroshiExecutionContext)
}
}
}

override def receive: Receive = {
case play.api.http.websocket.TextMessage(data) => handle(data)
case play.api.http.websocket.BinaryMessage(data) => handle(data.utf8String)
case play.api.http.websocket.PingMessage(_) => out ! play.api.http.websocket.PongMessage(ByteString.empty)
case play.api.http.websocket.CloseMessage(statusCode, reason) => self ! PoisonPill
case play.api.http.websocket.PongMessage(_) => ()
case _ => ()
}
}

class McpRespEndpoint extends NgBackendCall {

override def name: String = "Cloud APIM - MCP Rest Endpoint"
override def description: Option[String] = "Exposes tool functions as an MCP server using the (non-official) Rest Transport".some

override def core: Boolean = false
override def visibility: NgPluginVisibility = NgPluginVisibility.NgUserLand
override def categories: Seq[NgPluginCategory] = Seq(NgPluginCategory.Custom("Cloud APIM"), NgPluginCategory.Custom("AI - LLM"))
override def steps: Seq[NgStep] = Seq(NgStep.CallBackend)
override def useDelegates: Boolean = false
override def defaultConfigObject: Option[NgPluginConfig] = Some(McpProxyEndpointConfig.default)

override def noJsForm: Boolean = true
override def configFlow: Seq[String] = McpProxyEndpointConfig.configFlow
override def configSchema: Option[JsObject] = McpProxyEndpointConfig.configSchema

override def start(env: Env): Future[Unit] = {
env.adminExtensions.extension[AiExtension].foreach { ext =>
ext.logger.info("the 'MCP Rest Endpoint' plugin is available !")
}
().vfuture
}

def error(status: Int, msg: String): Future[Either[NgProxyEngineError, BackendCallResponse]] = {
// println(s"http error: ${status} - ${msg}")
NgProxyEngineError.NgResultProxyEngineError(Results.Status(status)(Json.obj("error" -> msg))).leftf
}

def jsonRpcResponse(id: Long, payload: JsValue): Future[Either[NgProxyEngineError, BackendCallResponse]] = {
BackendCallResponse(NgPluginHttpResponse.fromResult(Results.Ok(Json.obj(
"jsonrpc" -> "2.0",
"id" -> id,
"result" -> payload
))), None).rightf
}

def emptyResp(id: Long): Future[Either[NgProxyEngineError, BackendCallResponse]] = {
jsonRpcResponse(id, Json.obj())
}

def initialize(id: Long)(implicit env: Env, ec: ExecutionContext): Future[Either[NgProxyEngineError, BackendCallResponse]] = {
val response = Json.obj(
"protocolVersion" -> "2024-11-05",
"capabilities" -> Json.obj("tools" -> Json.obj(), "logging" -> Json.obj()),
"serverInfo" -> Json.obj("name" -> "otoroshi-sse-endpoint", "version" -> "1.0.0"), // TODO: custom valies
)
jsonRpcResponse(id, response)
}

def getToolList(id: Long, config: McpProxyEndpointConfig)(implicit env: Env, ec: ExecutionContext): Future[Either[NgProxyEngineError, BackendCallResponse]] = {
val ext = env.adminExtensions.extension[AiExtension].get
val functions = config.refs.flatMap(r => ext.states.toolFunction(r))
val response = Json.obj("tools" -> JsArray(functions.map { wf =>
val required: JsArray = wf.required.map(v => JsArray(v.map(_.json))).getOrElse(JsArray(wf.parameters.value.keySet.toSeq.map(_.json)))
Json.obj(
"name" -> wf.name,
"description" -> wf.description,
"inputSchema" -> Json.obj(
"type" -> "object",
"properties" -> wf.parameters,
"required" -> required,
),
)
}))
jsonRpcResponse(id, response)
}

def toolsCall(id: Long, request: JsValue, config: McpProxyEndpointConfig)(implicit env: Env, ec: ExecutionContext): Future[Either[NgProxyEngineError, BackendCallResponse]] = {
val params = request.select("params").asOpt[JsObject].getOrElse(Json.obj())
val ext = env.adminExtensions.extension[AiExtension].get
val functions = config.refs.flatMap(r => ext.states.toolFunction(r))
val functionsMap = functions.map(f => (f.name, f)).toMap
val name = params.select("name").asString
val arguments = params.select("arguments").asOpt[JsObject].getOrElse(Json.obj())
functionsMap.get(name) match {
case None => error(400, s"unknown function ${name}")
case Some(function) => {
function.call(arguments.stringify).flatMap { res =>
val payload = Json.obj("content" -> Json.arr(Json.obj(
"type" -> "text",
"text" -> res
)))
jsonRpcResponse(id, payload)
}
}
}
}

def getResourcesList(id: Long)(implicit env: Env, ec: ExecutionContext): Future[Either[NgProxyEngineError, BackendCallResponse]] = {
jsonRpcResponse(id, Json.obj("resources" -> Json.arr()))
}

def getPromptsList(id: Long)(implicit env: Env, ec: ExecutionContext): Future[Either[NgProxyEngineError, BackendCallResponse]] = {
jsonRpcResponse(id, Json.obj("prompts" -> Json.arr()))
}

def getTemplatesList(id: Long)(implicit env: Env, ec: ExecutionContext): Future[Either[NgProxyEngineError, BackendCallResponse]] = {
jsonRpcResponse(id, Json.obj("templates" -> Json.arr()))
}

override def callBackend(ctx: NgbBackendCallContext, delegates: () => Future[Either[NgProxyEngineError, BackendCallResponse]])(implicit env: Env, ec: ExecutionContext, mat: Materializer): Future[Either[NgProxyEngineError, BackendCallResponse]] = {
val config = ctx.cachedConfig(internalName)(McpProxyEndpointConfig.format).getOrElse(McpProxyEndpointConfig.default)
if (ctx.request.hasBody && ctx.request.method.toLowerCase() == "post") {
ctx.request.body.runFold(ByteString.empty)(_ ++ _).flatMap { bodyRaw =>
Try(bodyRaw.utf8String.parseJson) match {
case Failure(e) => error(400,"error while parsing json-rpc payload")
case Success(json) => {
val id = json.select("id").asOpt[Long].getOrElse(0L)
json.select("method").asOpt[String] match {
case Some("initialize") => initialize(id)
case Some("shutdown") => emptyResp(id)
case Some("exit") => emptyResp(id)
case Some("ping") => jsonRpcResponse(id, Json.obj())
case Some("cancelled") => emptyResp(id)
case Some("notifications/cancelled") => emptyResp(id)
case Some("notifications/initialized") => emptyResp(id)
case Some("tools/list") => getToolList(id, config)
case Some("resources/list") => getResourcesList(id)
case Some("resources/read") => emptyResp(id)
case Some("resources/templates/list") => getTemplatesList(id)
case Some("prompts/list") => getPromptsList(id)
case Some("prompts/get") => emptyResp(id)
case Some("tools/call") => toolsCall(id, json, config)
case _ => {
val method = json.select("method").asOpt[String].getOrElse("--")
jsonRpcResponse(id, Json.obj("error" -> "method unsupported", "error_details" -> Json.obj("method" -> method)))
}
}
}
}
}
} else {
NgProxyEngineError.NgResultProxyEngineError(Results.BadRequest(Json.obj("error" -> "bad request"))).leftf
}
}
}

0 comments on commit ecdff17

Please sign in to comment.