Skip to content

Commit

Permalink
input sanitizing (#28)
Browse files Browse the repository at this point in the history
  • Loading branch information
efriis authored Nov 3, 2023
1 parent c93bae5 commit 81c73b7
Showing 1 changed file with 53 additions and 30 deletions.
83 changes: 53 additions & 30 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,25 @@
import asyncio
import os
from operator import itemgetter
from typing import List, Optional, Sequence, Tuple
from typing import List, Optional, Sequence, Tuple, Union

import langsmith
from fastapi import FastAPI, Request
from fastapi import FastAPI, Request, Depends
from fastapi.middleware.cors import CORSMiddleware
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
from langchain.chat_models import ChatAnthropic, ChatOpenAI, ChatVertexAI
from langchain.document_loaders import AsyncHtmlLoader
from langchain.document_transformers import Html2TextTransformer
from langchain.embeddings import OpenAIEmbeddings
from langchain.prompts import (ChatPromptTemplate, MessagesPlaceholder,
PromptTemplate)
from langchain.retrievers import (ContextualCompressionRetriever,
TavilySearchAPIRetriever)
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder, PromptTemplate
from langchain.retrievers import (
ContextualCompressionRetriever,
TavilySearchAPIRetriever,
)
from langchain.retrievers.document_compressors import (
DocumentCompressorPipeline, EmbeddingsFilter)
DocumentCompressorPipeline,
EmbeddingsFilter,
)
from langchain.retrievers.kay import KayAiRetriever
from langchain.retrievers.you import YouRetriever
from langchain.schema import Document
Expand All @@ -26,15 +29,21 @@
from langchain.schema.messages import AIMessage, HumanMessage
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.retriever import BaseRetriever
from langchain.schema.runnable import (ConfigurableField, Runnable,
RunnableBranch, RunnableLambda,
RunnableMap)
from langchain.schema.runnable import (
ConfigurableField,
Runnable,
RunnableBranch,
RunnableLambda,
RunnableMap,
)
from langchain.text_splitter import RecursiveCharacterTextSplitter

# Backup
from langchain.utilities import GoogleSearchAPIWrapper
from langserve import add_routes
from langsmith import Client
from pydantic import BaseModel, Field
from uuid import UUID

RESPONSE_TEMPLATE = """\
You are an expert researcher and writer, tasked with answering any question.
Expand Down Expand Up @@ -362,34 +371,45 @@ def create_chain(
)


class SendFeedbackBody(BaseModel):
run_id: UUID
key: str = "user_score"

score: Union[float, int, bool, None] = None
feedback_id: Optional[UUID] = None
comment: Optional[str] = None


@app.post("/feedback")
async def send_feedback(request: Request):
data = await request.json()
run_id = data.get("run_id")
if run_id is None:
return {
"result": "No LangSmith run ID provided",
"code": 400,
}
key = data.get("key", "user_score")
vals = {**data, "key": key}
client.create_feedback(**vals)
async def send_feedback(body: SendFeedbackBody):
client.create_feedback(
body.run_id,
body.key,
score=body.score,
comment=body.comment,
feedback_id=body.feedback_id,
)
return {"result": "posted feedback successfully", "code": 200}


class UpdateFeedbackBody(BaseModel):
feedback_id: UUID
score: Union[float, int, bool, None] = None
comment: Optional[str] = None


@app.patch("/feedback")
async def update_feedback(request: Request):
data = await request.json()
feedback_id = data.get("feedback_id")
async def update_feedback(body: UpdateFeedbackBody):
feedback_id = body.feedback_id
if feedback_id is None:
return {
"result": "No feedback ID provided",
"code": 400,
}
client.update_feedback(
feedback_id,
score=data.get("score"),
comment=data.get("comment"),
score=body.score,
comment=body.comment,
)
return {"result": "patched feedback successfully", "code": 200}

Expand All @@ -412,16 +432,19 @@ async def aget_trace_url(run_id: str) -> str:
return await _arun(client.share_run, run_id)


class GetTraceBody(BaseModel):
run_id: UUID


@app.post("/get_trace")
async def get_trace(request: Request):
data = await request.json()
run_id = data.get("run_id")
async def get_trace(body: GetTraceBody):
run_id = body.run_id
if run_id is None:
return {
"result": "No LangSmith run ID provided",
"code": 400,
}
return await aget_trace_url(run_id)
return await aget_trace_url(str(run_id))


if __name__ == "__main__":
Expand Down

0 comments on commit 81c73b7

Please sign in to comment.