Skip to content

Commit

Permalink
Llm timeout and retry; verbose logging; max_concurrent config option
Browse files Browse the repository at this point in the history
  • Loading branch information
petrgazarov committed Oct 27, 2023
1 parent cdd2f5d commit baf8bed
Show file tree
Hide file tree
Showing 12 changed files with 107 additions and 28 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ compiler:
target_dir: terraform
```
Set `source_dir` to the directory where your Salami files are, and `target_dir` to the directory where you want the Terraform files to be written. The config file supports environment variables, which is useful to avoid storing secrets in version control. To inject an env variable at runtime, use the `${ENV_VAR}` delimeter.
Set `compiler.source_dir` to the directory where your Salami files are, and `compiler.target_dir` to the directory where the Terraform files should be written. The config file supports environment variables, which is useful to avoid storing secrets in version control. To inject an env variable at runtime, use the `${ENV_VAR}` delimeter. Use `compiler.llm.max_concurrent` config to control how many concurrent API calls are made to OpenAI API. The default is 5.

## ✅ VS Code Extension

Expand Down
50 changes: 39 additions & 11 deletions backend/llm/openai/gpt4/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@ import (
"encoding/json"
"math"
backendTypes "salami/backend/types"
"salami/common/constants"
"salami/common/errors"
"salami/common/logger"
commonTypes "salami/common/types"
"strings"
"time"

"github.com/sashabaranov/go-openai"
)
Expand All @@ -21,9 +23,10 @@ const LlmMessageAssistantRole = "assistant"
const functionCallName = "save_code"

type OpenaiGpt4 struct {
slug string
model string
client *openai.Client
slug string
model string
client *openai.Client
maxConcurrentExecutions int
}

type LlmMessage struct {
Expand All @@ -35,9 +38,10 @@ type LlmMessage struct {

func NewLlm(llmConfig commonTypes.LlmConfig) backendTypes.Llm {
return &OpenaiGpt4{
client: getClient(llmConfig),
model: openai.GPT4,
slug: commonTypes.LlmOpenaiGpt4,
client: getClient(llmConfig),
model: openai.GPT4,
slug: commonTypes.LlmOpenaiGpt4,
maxConcurrentExecutions: llmConfig.MaxConcurrentExecutions,
}
}

Expand Down Expand Up @@ -66,7 +70,7 @@ func (o *OpenaiGpt4) GetSlug() string {
}

func (o *OpenaiGpt4) GetMaxConcurrentExecutions() int {
return 5
return o.maxConcurrentExecutions
}

func (o *OpenaiGpt4) CreateCompletion(messages []interface{}) (string, error) {
Expand All @@ -78,29 +82,53 @@ func (o *OpenaiGpt4) CreateCompletion(messages []interface{}) (string, error) {

logMessages(llmMessages)

response, err := o.client.CreateChatCompletion(
context.Background(),
o.getChatCompletionRequest(llmMessages),
)
response, err := o.callApi(llmMessages)
if err != nil {
return "", err
}

functionCall := response.Choices[0].Message.FunctionCall
if functionCall == nil {
return "", &errors.LlmError{Message: "Function call is nil"}
}

var parsedArguments map[string]interface{}
err = json.Unmarshal([]byte(functionCall.Arguments), &parsedArguments)
if err != nil {
return "", err
}

code, ok := parsedArguments["code"].(string)
if !ok {
return "", &errors.LlmError{Message: "Code is not a string"}
}

return strings.TrimSpace(code), nil
}

func (o *OpenaiGpt4) callApi(llmMessages []*LlmMessage) (openai.ChatCompletionResponse, error) {
var response openai.ChatCompletionResponse
var err error

for i := 0; i < 2; i++ {
ctx, cancel := context.WithTimeout(
context.Background(),
time.Duration(constants.LlmTimeoutDurationSeconds)*time.Second,
)
defer cancel()

response, err = o.client.CreateChatCompletion(
ctx,
o.getChatCompletionRequest(llmMessages),
)
if err == nil {
return response, err
}
}

return response, err
}

func getClient(llmConfig commonTypes.LlmConfig) *openai.Client {
return openai.NewClient(llmConfig.ApiKey)
}
Expand Down
20 changes: 20 additions & 0 deletions backend/target/terraform/generate_code.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package terraform

import (
"fmt"
"salami/backend/prompts/terraform/openai_gpt4"
backendTypes "salami/backend/types"
"salami/common/change_set"
"salami/common/logger"
"salami/common/symbol_table"
commonTypes "salami/common/types"

Expand Down Expand Up @@ -33,6 +35,8 @@ func (t *Terraform) GenerateCode(
semaphoreChannel <- struct{}{}
defer func() { <-semaphoreChannel }()

logDiffProgress(diff)

messages, err := getGenerateCodeLlmMessages(symbolTable, diff, llm)
if err != nil {
return err
Expand Down Expand Up @@ -73,3 +77,19 @@ func getGenerateCodeLlmMessages(
}
return messages, nil
}

func logDiffProgress(diff *commonTypes.ChangeSetDiff) {
var objectType string
var objectId string

if diff.NewObject.IsResource() {
objectType = "resource"
objectId = string(diff.NewObject.ParsedResource.LogicalName)
} else if diff.NewObject.IsVariable() {
objectType = "variable"
objectId = diff.NewObject.ParsedVariable.Name
}

message := fmt.Sprintf("🖋 Generating code for %s '%s' (diff type: %s)...", objectType, objectId, diff.DiffType)
logger.Verbose(message)
}
12 changes: 11 additions & 1 deletion backend/target/terraform/validate_code.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,23 @@ func (t *Terraform) ValidateCode(
if len(changeSetRepository.Diffs) == 0 {
return nil
}
if retryCount == 0 {
logger.Verbose("🔍 Validating code...")
}
validationResults, err := generateValidationResults(newObjects)
if err != nil {
return err
}
if len(validationResults) == 0 {
errorCount := len(validationResults)
if errorCount == 0 {
logger.Verbose("🙌 All code is valid")
return nil
}
errorWord := "errors"
if errorCount == 1 {
errorWord = "error"
}
logger.Verbose(fmt.Sprintf("🔧 Found %d validation %s, fixing...", errorCount, errorWord))

if err := processValidationResults(validationResults, symbolTable, changeSetRepository, llm); err != nil {
return err
Expand Down
14 changes: 11 additions & 3 deletions common/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"log"
"os"
"regexp"
"salami/common/constants"
"salami/common/types"

"gopkg.in/yaml.v3"
Expand Down Expand Up @@ -46,10 +47,17 @@ func GetTargetConfig() types.TargetConfig {

func GetLlmConfig() types.LlmConfig {
compilerLlmConfig := getConfig().Compiler.Llm

maxConcurrentExecutions := compilerLlmConfig.MaxConcurrentExecutions
if maxConcurrentExecutions == 0 {
maxConcurrentExecutions = constants.DefaultMaxConcurrentLlmExecutions
}

return types.LlmConfig{
Provider: compilerLlmConfig.Provider,
Model: compilerLlmConfig.Model,
ApiKey: compilerLlmConfig.ApiKey,
Provider: compilerLlmConfig.Provider,
Model: compilerLlmConfig.Model,
ApiKey: compilerLlmConfig.ApiKey,
MaxConcurrentExecutions: maxConcurrentExecutions,
}
}

Expand Down
9 changes: 6 additions & 3 deletions common/config/validate.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,10 @@ type CompilerTargetConfig struct {
}

type CompilerLlmConfig struct {
Provider string `yaml:"provider"`
Model string `yaml:"model"`
ApiKey string `yaml:"api_key"`
Provider string `yaml:"provider"`
Model string `yaml:"model"`
ApiKey string `yaml:"api_key"`
MaxConcurrentExecutions int `yaml:"max_concurrent"`
}

func validateTarget(fl validator.FieldLevel) bool {
Expand All @@ -80,9 +81,11 @@ func validateLlm(fl validator.FieldLevel) bool {
if !ok {
return false
}

validLlmProvider := llmConfig.Provider == types.LlmOpenaiProvider
validLlmModel := llmConfig.Model == types.LlmGpt4Model
apiKeyExists := llmConfig.ApiKey != ""

return validLlmProvider && validLlmModel && apiKeyExists
}

Expand Down
6 changes: 5 additions & 1 deletion common/constants/constants.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
package constants

const SalamiVersion = "0.0.2"
const SalamiVersion = "0.0.3"

const SalamiFileExtension = ".sami"

const TerraformFileExtension = ".tf"

const MaxFixValidationErrorRetries = 2

const DefaultMaxConcurrentLlmExecutions = 3

const LlmTimeoutDurationSeconds = 90
10 changes: 7 additions & 3 deletions common/logger/logger.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ var logger *salamiLogger
func InitializeLogger(verbose bool) {
zapConfig := zap.NewDevelopmentConfig()
zapConfig.EncoderConfig.EncodeCaller = nil
zapConfig.EncoderConfig.LevelKey = ""

zapLogger, err := zapConfig.Build()
if err != nil {
Expand All @@ -28,7 +29,7 @@ func InitializeLogger(verbose bool) {
}
}

// Log logs a message always
// Log logs the message always
func Log(message string) {
if logger == nil {
return
Expand All @@ -38,17 +39,20 @@ func Log(message string) {
logger.instance.Info(message)
}

// Verbose logs a message if the verbose flag is set to true
// Verbose logs the message if the verbose flag is set to true
func Verbose(message string) {
if logger == nil {
return
}
if !logger.verbose {
return
}

defer logger.instance.Sync()
logger.instance.Info(message)
}

// Debug logs a message if the DEBUG environment variable is set to true
// Debug logs the message if the DEBUG environment variable is set to true
func Debug(message string) {
if logger == nil {
return
Expand Down
7 changes: 4 additions & 3 deletions common/types/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,10 @@ type TargetConfig struct {
}

type LlmConfig struct {
Provider string
Model string
ApiKey string
Provider string
Model string
ApiKey string
MaxConcurrentExecutions int
}

const TerraformPlatform = "terraform"
Expand Down
2 changes: 1 addition & 1 deletion examples/public_and_private_ecs_services/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,4 @@ To run this example, you need:

1. Note that `salami compile` will examine the `salami-lock.toml` file and the source `.sami` files, and determine which Salami objects have changed since the last compilation. To force a complete recompilation, delete the `salami-lock.toml` file. Or, you can change source `.sami` files and `salami compile` will recompile only the changed objects.

2. Occassionally, OpenAI API delays responses significantly. If `salami compile` is stuck for a long time, try again later. The total compilation time varies significantly depending on the number of objects and their complexity. This project takes me about 3 minutes to compile from scratch, and a lot quicker for partial changes.
2. Occassionally, OpenAI API delays responses significantly. If `salami compile` is stuck, try setting `compiler.llm.max_concurrent` config to a lower value.
1 change: 1 addition & 0 deletions examples/public_and_private_ecs_services/salami.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,6 @@ compiler:
provider: openai
model: gpt4
api_key: ${OPENAI_API_KEY}
max_concurrent: 3
source_dir: salami
target_dir: terraform
2 changes: 1 addition & 1 deletion examples/simple_s3_bucket/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,4 @@ To run this example, you need:

1. Note that `salami compile` will examine the `salami-lock.toml` file and the source `.sami` files, and determine which Salami objects have changed since the last compilation. To force a complete recompilation, delete the `salami-lock.toml` file. Or, you can change source `.sami` files and `salami compile` will recompile only the changed objects.

2. Occassionally, OpenAI API delays responses significantly. If `salami compile` is stuck for a long time, try again later. The total compilation time varies significantly depending on the number of objects and their complexity. This project takes me about 10 seconds to compile from scratch.
2. Occassionally, OpenAI API delays responses significantly. If `salami compile` is stuck, try setting `compiler.llm.max_concurrent` config to a lower value.

0 comments on commit baf8bed

Please sign in to comment.