Skip to content

Commit

Permalink
Refactor completions class
Browse files Browse the repository at this point in the history
  • Loading branch information
kladaFOX committed Sep 25, 2024
1 parent 9ad65aa commit 18428bc
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 52 deletions.
2 changes: 1 addition & 1 deletion Gemfile.lock
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
PATH
remote: .
specs:
spectre_ai (1.0.1)
spectre_ai (1.1.0)

GEM
remote: https://rubygems.org/
Expand Down
115 changes: 65 additions & 50 deletions lib/spectre/openai/completions.rb
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,23 @@ class Completions
API_URL = 'https://api.openai.com/v1/chat/completions'
DEFAULT_MODEL = 'gpt-4o-mini'

# Class method to generate a completion based on a user prompt
# Class method to generate a completion based on user messages and optional tools
#
# @param user_prompt [String] the user's input to generate a completion for
# @param system_prompt [String] an optional system prompt to guide the AI's behavior
# @param assistant_prompt [String] an optional assistant prompt to provide context for the assistant's behavior
# @param model [String] the model to be used for generating completions, defaults to DEFAULT_MODEL
# @param json_schema [Hash, nil] an optional JSON schema to enforce structured output
# @param max_tokens [Integer] the maximum number of tokens for the completion (default: 50)
# @return [String] the generated completion text
# @raise [APIKeyNotConfiguredError] if the API key is not set
# @raise [RuntimeError] for general API errors or unexpected issues
def self.create(user_prompt:, system_prompt: "You are a helpful assistant.", assistant_prompt: nil, model: DEFAULT_MODEL, json_schema: nil, max_tokens: nil)
# @param messages [Array<Hash>] The conversation messages, each with a role and content
# @param model [String] The model to be used for generating completions, defaults to DEFAULT_MODEL
# @param json_schema [Hash, nil] An optional JSON schema to enforce structured output
# @param max_tokens [Integer] The maximum number of tokens for the completion (default: 50)
# @param tools [Array<Hash>, nil] An optional array of tool definitions for function calling
# @return [Hash] The parsed response including any function calls or content
# @raise [APIKeyNotConfiguredError] If the API key is not set
# @raise [RuntimeError] For general API errors or unexpected issues
def self.create(messages:, model: DEFAULT_MODEL, json_schema: nil, max_tokens: nil, tools: nil)
api_key = Spectre.api_key
raise APIKeyNotConfiguredError, "API key is not configured" unless api_key

# Check if messages are empty or not an array
raise ArgumentError, "Messages cannot be blank or empty." if messages.blank? || !messages.is_a?(Array)

uri = URI(API_URL)
http = Net::HTTP.new(uri.host, uri.port)
http.use_ssl = true
Expand All @@ -36,7 +38,7 @@ def self.create(user_prompt:, system_prompt: "You are a helpful assistant.", ass
'Authorization' => "Bearer #{api_key}"
})

request.body = generate_body(user_prompt, system_prompt, assistant_prompt, model, json_schema, max_tokens).to_json
request.body = generate_body(messages, model, json_schema, max_tokens, tools).to_json
response = http.request(request)

unless response.is_a?(Net::HTTPSuccess)
Expand All @@ -45,18 +47,7 @@ def self.create(user_prompt:, system_prompt: "You are a helpful assistant.", ass

parsed_response = JSON.parse(response.body)

# Check if the response contains a refusal
if parsed_response.dig('choices', 0, 'message', 'refusal')
raise "Refusal: #{parsed_response.dig('choices', 0, 'message', 'refusal')}"
end

# Check if the finish reason is "length", indicating incomplete response
if parsed_response.dig('choices', 0, 'finish_reason') == "length"
raise "Incomplete response: The completion was cut off due to token limit."
end

# Return the structured output if it's included
parsed_response.dig('choices', 0, 'message', 'content')
handle_response(parsed_response)
rescue JSON::ParserError => e
raise "JSON Parse Error: #{e.message}"
rescue Net::OpenTimeout, Net::ReadTimeout => e
Expand All @@ -67,38 +58,62 @@ def self.create(user_prompt:, system_prompt: "You are a helpful assistant.", ass

# Helper method to generate the request body
#
# @param user_prompt [String] the user's input to generate a completion for
# @param system_prompt [String] an optional system prompt to guide the AI's behavior
# @param assistant_prompt [String] an optional assistant prompt to provide context for the assistant's behavior
# @param model [String] the model to be used for generating completions
# @param json_schema [Hash, nil] an optional JSON schema to enforce structured output
# @param max_tokens [Integer, nil] the maximum number of tokens for the completion
# @return [Hash] the body for the API request
def self.generate_body(user_prompt, system_prompt, assistant_prompt, model, json_schema, max_tokens)
messages = [
{ role: 'system', content: system_prompt },
{ role: 'user', content: user_prompt }
]

# Add the assistant prompt if provided
messages << { role: 'assistant', content: assistant_prompt } if assistant_prompt

# @param messages [Array<Hash>] The conversation messages, each with a role and content
# @param model [String] The model to be used for generating completions
# @param json_schema [Hash, nil] An optional JSON schema to enforce structured output
# @param max_tokens [Integer, nil] The maximum number of tokens for the completion
# @param tools [Array<Hash>, nil] An optional array of tool definitions for function calling
# @return [Hash] The body for the API request
def self.generate_body(messages, model, json_schema, max_tokens, tools)
body = {
model: model,
messages: messages,
messages: messages
}
body['max_tokens'] = max_tokens if max_tokens

# Add the JSON schema as part of response_format if provided
if json_schema
body[:response_format] = {
type: 'json_schema',
json_schema: json_schema
}
end

body[:max_tokens] = max_tokens if max_tokens
body[:response_format] = { type: 'json_schema', json_schema: json_schema } if json_schema
body[:tools] = tools if tools # Add the tools to the request body if provided

body
end

# Handles the API response, raising errors for specific cases and returning structured content otherwise
#
# @param response [Hash] The parsed API response
# @return [Hash] The relevant data based on the finish reason
def self.handle_response(response)
message = response.dig('choices', 0, 'message')
finish_reason = response.dig('choices', 0, 'finish_reason')

# Check if the response contains a refusal
if message['refusal']
raise "Refusal: #{message['refusal']}"
end

# Check if the finish reason is "length", indicating incomplete response
if finish_reason == "length"
raise "Incomplete response: The completion was cut off due to token limit."
end

# Check if the finish reason is "content_filter", indicating policy violations
if finish_reason == "content_filter"
raise "Content filtered: The model's output was blocked due to policy violations."
end

# Check if the model made a function call
if finish_reason == "function_call" || finish_reason == "tool_calls"
return { tool_calls: message['tool_calls'], content: message['content'] }
end

# If the response finished normally, return the content
if finish_reason == "stop"
return { content: message['content'] }
end

# Handle unexpected finish reasons
raise "Unexpected finish_reason: #{finish_reason}"
end

end
end
end
2 changes: 1 addition & 1 deletion lib/spectre/version.rb
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# frozen_string_literal: true

module Spectre # :nodoc:all
VERSION = "1.0.1"
VERSION = "1.1.0"
end

0 comments on commit 18428bc

Please sign in to comment.