Skip to content

Commit

Permalink
♻️ refactored schemas to module
Browse files Browse the repository at this point in the history
  • Loading branch information
chriamue committed Dec 20, 2023
1 parent c7d6f8d commit fdd9e8e
Show file tree
Hide file tree
Showing 10 changed files with 193 additions and 207 deletions.
19 changes: 3 additions & 16 deletions src/api/generate.rs
Original file line number Diff line number Diff line change
@@ -1,26 +1,13 @@
use axum::{extract::State, http::StatusCode, response::IntoResponse, Json};
use serde::{Deserialize, Serialize};
use utoipa::ToSchema;

use crate::config::Config;

use super::{
generate_parameters::GenerateParameters, generate_stream::generate_stream_handler,
generate_text::GenerateRequest, ErrorResponse,
generate_stream::generate_stream_handler,
model::ErrorResponse,
model::{CompatGenerateRequest, GenerateRequest},
};

#[derive(Debug, Serialize, Deserialize, ToSchema)]
pub struct CompatGenerateRequest {
#[schema(example = "My name is Olivier and I")]
pub inputs: String,

#[serde(skip_serializing_if = "Option::is_none")]
pub parameters: Option<GenerateParameters>,

#[serde(default)]
pub stream: bool,
}

/// Generate tokens
#[utoipa::path(
post,
Expand Down
68 changes: 0 additions & 68 deletions src/api/generate_parameters.rs

This file was deleted.

9 changes: 3 additions & 6 deletions src/api/generate_stream.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
use std::vec;

use super::model::GenerateRequest;
use crate::{config::Config, llm::create_text_generation};
use axum::{
extract::State,
response::{sse::Event, IntoResponse, Sse},
Json,
};
use futures::stream::StreamExt;
use log::debug;

use crate::{config::Config, llm::create_text_generation};

use super::generate_text::GenerateRequest;
use std::vec;

/// Generate tokens
#[utoipa::path(
Expand Down
23 changes: 2 additions & 21 deletions src/api/generate_text.rs
Original file line number Diff line number Diff line change
@@ -1,25 +1,6 @@
use axum::{extract::State, http::StatusCode, response::IntoResponse, Json};
use serde::{Deserialize, Serialize};
use utoipa::ToSchema;

use super::model::{ErrorResponse, GenerateRequest, GenerateResponse};
use crate::{config::Config, llm::create_text_generation};

use super::{generate_parameters::GenerateParameters, ErrorResponse};

#[derive(Deserialize, ToSchema, Debug)]
pub struct GenerateRequest {
#[schema(example = "My name is John")]
pub inputs: String,

#[serde(default, skip_serializing_if = "Option::is_none")]
pub parameters: Option<GenerateParameters>,
}

#[derive(Serialize, ToSchema)]
pub struct GenerateResponse {
pub generated_text: String,
// Add other fields as necessary
}
use axum::{extract::State, http::StatusCode, response::IntoResponse, Json};

/// Generate tokens
#[utoipa::path(
Expand Down
2 changes: 1 addition & 1 deletion src/api/health.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use axum::{
Json,
};

use super::ErrorResponse;
use super::model::ErrorResponse;

/// Health check endpoint
#[utoipa::path(
Expand Down
12 changes: 1 addition & 11 deletions src/api/mod.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,6 @@
pub mod generate;
pub mod generate_parameters;
pub mod generate_stream;
pub mod generate_text;
pub mod health;
pub mod model;
pub mod openapi;
pub mod stream_response;

use serde::{Deserialize, Serialize};
use utoipa::ToSchema;

#[derive(Serialize, Deserialize, ToSchema)]
pub struct ErrorResponse {
error: String,
error_type: Option<String>,
}
168 changes: 168 additions & 0 deletions src/api/model/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
use serde::{Deserialize, Serialize};
use utoipa::ToSchema;

#[derive(Debug, Serialize, Deserialize, ToSchema, Clone)]
#[serde(rename_all = "snake_case")]
pub enum FinishReason {
Length,
EosToken,
StopSequence,
}

#[derive(Debug, Serialize, Deserialize, ToSchema, Clone)]
pub struct Token {
pub id: i32,
pub text: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub logprob: Option<f64>,
pub special: bool,
}

#[derive(Debug, Serialize, Deserialize, ToSchema)]
pub struct PrefillToken {
pub id: i32,
pub text: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub logprob: Option<f64>,
}

#[derive(Debug, Serialize, Deserialize, ToSchema)]
pub struct BestOfSequence {
pub finish_reason: FinishReason,
pub generated_text: String,
pub generated_tokens: i32,
pub prefill: Vec<PrefillToken>,
#[serde(skip_serializing_if = "Option::is_none")]
pub seed: Option<i64>,
pub tokens: Vec<Token>,
pub top_tokens: Vec<Vec<Token>>,
}

#[derive(Debug, Serialize, Deserialize, ToSchema)]
pub struct Details {
#[serde(skip_serializing_if = "Option::is_none")]
pub best_of_sequences: Option<Vec<BestOfSequence>>,
pub finish_reason: FinishReason,
pub generated_tokens: i32,
pub prefill: Vec<PrefillToken>,
#[serde(skip_serializing_if = "Option::is_none")]
pub seed: Option<i64>,
pub tokens: Vec<Token>,
pub top_tokens: Vec<Vec<Token>>,
}

#[derive(Debug, Serialize, Deserialize, ToSchema, Clone)]
pub struct StreamDetails {
pub finish_reason: FinishReason,
pub generated_tokens: i32,
#[serde(skip_serializing_if = "Option::is_none")]
pub seed: Option<i64>,
}

#[derive(Debug, Serialize, Deserialize, ToSchema, Clone)]
pub struct StreamResponse {
#[serde(skip_serializing_if = "Option::is_none")]
pub details: Option<StreamDetails>,
#[serde(skip_serializing_if = "Option::is_none")]
pub generated_text: Option<String>,
pub token: Token,
pub top_tokens: Vec<Token>,
}

#[derive(Serialize, Deserialize, ToSchema)]
pub struct ErrorResponse {
pub error: String,
pub error_type: Option<String>,
}

#[derive(Debug, Serialize, Deserialize, ToSchema)]
pub struct GenerateParameters {
#[serde(default, skip_serializing_if = "Option::is_none")]
#[schema(example = json!(1))]
pub best_of: Option<i32>,

#[serde(default = "default_true")]
pub decoder_input_details: bool,

#[serde(default = "default_true")]
pub details: bool,

#[serde(default)]
pub do_sample: bool,

#[serde(default, skip_serializing_if = "Option::is_none")]
#[schema(example = json!(20))]
pub max_new_tokens: Option<i32>,

#[serde(default, skip_serializing_if = "Option::is_none")]
#[schema(example = json!(1.03))]
pub repetition_penalty: Option<f32>,

#[serde(default, skip_serializing_if = "Option::is_none")]
#[schema(example = json!(false))]
pub return_full_text: Option<bool>,

#[serde(default, skip_serializing_if = "Option::is_none")]
#[schema(example = json!(299792458))]
pub seed: Option<i64>,

#[serde(default, skip_serializing_if = "Vec::is_empty")]
#[schema(example = json!(vec!["photographer"]))]
pub stop: Vec<String>,

#[serde(default, skip_serializing_if = "Option::is_none")]
#[schema(example = json!(0.5))]
pub temperature: Option<f64>,

#[serde(default, skip_serializing_if = "Option::is_none")]
#[schema(example = json!(10))]
pub top_k: Option<i32>,

#[serde(default, skip_serializing_if = "Option::is_none")]
#[schema(example = json!(5))]
pub top_n_tokens: Option<i32>,

#[serde(default, skip_serializing_if = "Option::is_none")]
#[schema(example = json!(0.95))]
pub top_p: Option<f64>,

#[serde(default, skip_serializing_if = "Option::is_none")]
pub truncate: Option<i32>,

#[serde(default, skip_serializing_if = "Option::is_none")]
#[schema(example = json!(0.95))]
pub typical_p: Option<f32>,

#[serde(default)]
pub watermark: bool,
}

#[derive(Deserialize, ToSchema, Debug)]
pub struct GenerateRequest {
#[schema(example = "My name is John")]
pub inputs: String,

#[serde(default, skip_serializing_if = "Option::is_none")]
pub parameters: Option<GenerateParameters>,
}

#[derive(Serialize, ToSchema)]
pub struct GenerateResponse {
pub generated_text: String,
}

#[derive(Debug, Serialize, Deserialize, ToSchema)]
pub struct CompatGenerateRequest {
#[schema(example = "My name is Olivier and I")]
pub inputs: String,

#[serde(skip_serializing_if = "Option::is_none")]
pub parameters: Option<GenerateParameters>,

#[serde(default)]
pub stream: bool,
}

fn default_true() -> bool {
true
}
12 changes: 4 additions & 8 deletions src/api/openapi.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
use utoipa::OpenApi;

use crate::api::{
generate_text::{GenerateRequest, GenerateResponse},
ErrorResponse,
};
use crate::api::model::ErrorResponse;

use super::{
generate::CompatGenerateRequest,
generate_parameters::GenerateParameters,
stream_response::{FinishReason, StreamDetails, StreamResponse, Token},
use super::model::{
CompatGenerateRequest, FinishReason, GenerateParameters, GenerateRequest, GenerateResponse,
StreamDetails, StreamResponse, Token,
};

#[derive(OpenApi)]
Expand Down
Loading

0 comments on commit fdd9e8e

Please sign in to comment.