diff --git a/.gitignore b/.gitignore index 66fd13c..dd552ec 100644 --- a/.gitignore +++ b/.gitignore @@ -10,6 +10,8 @@ # Output of the go coverage tool, specifically when used with LiteIDE *.out +bin/ +.vscode # Dependency directories (remove the comment below to include it) # vendor/ diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..195790c --- /dev/null +++ b/Makefile @@ -0,0 +1,9 @@ + +build: + go build -o bin/ai-shell-go main.go + +run: + go run main.go + + +all: build \ No newline at end of file diff --git a/completion/completion.go b/completion/completion.go index fe3cb16..2fa617e 100644 --- a/completion/completion.go +++ b/completion/completion.go @@ -12,7 +12,7 @@ import ( ) const ( - promptTemplate = `I will give you a prompt to create a single line bash command that one can enter in a terminal and run, based on what is asked in the prompt. + generatePromptTemplate = `I will give you a prompt to create a single line bash command that one can enter in a terminal and run, based on what is asked in the prompt. {{ .details }} @@ -20,9 +20,19 @@ const ( The prompt is: {{ .prompt }}` - details = `Please only reply with the single line bash command surrounded by 3 backticks. It should be able to be directly run in a bash terminal. Do not include any other text.` + regeneratePromptTemplate = `Please update the following bash script based on what is asked in the following prompt. + + The script: {{ .command }} + The prompt: {{ .prompt }} - explain = `Then please describe the bash script in plain english, step by step, what exactly it does. + {{ .details }} + + {{ .explain }} + ` + + promptDeteails = `Please only reply with the single line bash command surrounded by 3 backticks. It should be able to be directly run in a bash terminal. Do not include any other text.` + + promptExplain = `Then please describe the bash script in plain english, step by step, what exactly it does. Please describe succintly, use as few words as possible, do not be verbose. If there are multiple steps, please display them as a list. ` @@ -45,12 +55,17 @@ func New(openAIClient *openai.Client) *Completion { } } -func (c *Completion) Suggest(input string) (*CompletionResponse, error) { +func (c *Completion) Suggest(input, previousStep string) (*CompletionResponse, error) { if input == "" { return nil, fmt.Errorf("input is empty") } - prompt := buildPrompt(details, explain, input) + var prompt string + if previousStep == "" { + prompt = buildGenerationPrompt(input) + } else { + prompt = buildRenerationPrompt(input, previousStep) + } response, err := c.openAIClient.CreateChatCompletion( context.Background(), @@ -94,18 +109,44 @@ func (c *Completion) Suggest(input string) (*CompletionResponse, error) { // support methods // --------------- -func buildPrompt(details, explain, prompt string) string { +func buildGenerationPrompt(prompt string) string { + var output bytes.Buffer + + templ := template.Must(template.New("prompt").Parse(generatePromptTemplate)) + err := templ.Execute(&output, map[string]interface{}{ + "details": promptDeteails, + "explain": promptExplain, + "prompt": prompt, + }) + if err != nil { + panic(err) + } + + return removeInitialSpaces(output.String()) +} + +func buildRenerationPrompt(prompt, command string) string { var output bytes.Buffer - templ := template.Must(template.New("prompt").Parse(promptTemplate)) + templ := template.Must(template.New("prompt").Parse(regeneratePromptTemplate)) err := templ.Execute(&output, map[string]interface{}{ - "details": details, - "explain": explain, + "details": promptDeteails, + "explain": promptExplain, "prompt": prompt, + "command": command, }) if err != nil { panic(err) } - return output.String() + return removeInitialSpaces(output.String()) +} + +func removeInitialSpaces(input string) string { + lines := strings.Split(input, "\n") + for i, line := range lines { + lines[i] = strings.TrimLeft(line, " ") + lines[i] = strings.TrimLeft(lines[i], "\t") + } + return strings.Join(lines, "\n") } diff --git a/main.go b/main.go index 215830a..ca5d1c9 100644 --- a/main.go +++ b/main.go @@ -1,10 +1,9 @@ package main import ( - "bytes" "fmt" "os" - "text/template" + "strings" openai "github.com/sashabaranov/go-openai" @@ -12,48 +11,45 @@ import ( "github.com/henomis/ai-shell-go/shell" ) -const promptTemplate = `I will give you a prompt to create a single line bash command that one can enter in a terminal and run, based on what is asked in the prompt. - - {{ .details }} - - {{ .explain }} - - The prompt is: {{ .prompt }}` - -const details = `Please only reply with the single line bash command surrounded by 3 backticks. It should be able to be directly run in a bash terminal. Do not include any other text.` - -const explain = `Then please describe the bash script in plain english, step by step, what exactly it does. - Please describe succintly, use as few words as possible, do not be verbose. - If there are multiple steps, please display them as a list. -` - -func buildPrompt(details, explain, prompt string) string { - var output bytes.Buffer - - templ := template.Must(template.New("prompt").Parse(promptTemplate)) - err := templ.Execute(&output, map[string]interface{}{ - "details": details, - "explain": explain, - "prompt": prompt, - }) - if err != nil { - panic(err) - } - - return output.String() -} +var ( + ErrorShellAI = fmt.Errorf("something went wrong") +) func main() { openAIKey := os.Getenv("OPENAI_API_KEY") if openAIKey == "" { - fmt.Println("OPEN_AI_KEY is not set") + fmt.Println("OPEN_AI_KEY is not set.") + fmt.Println("Please set the OPENAI_API_KEY environment variable to your OpenAI API key.") return } + userInput := strings.Join(os.Args[1:], " ") + client := openai.NewClient(openAIKey) completion := completion.New(client) s := shell.New(completion) - s.Run() + shellResponse, err := s.Suggest(userInput) + if err != nil { + fmt.Printf("%s: %s\n", ErrorShellAI, err) + return + } + + for shellResponse.CommandAction == shell.CommandActionRevise { + shellResponse, err = s.Retry(shellResponse.Command) + if err != nil { + fmt.Printf("%s: %s\n", ErrorShellAI, err) + return + } + } + + if shellResponse.CommandAction == shell.CommandActionExecute { + err = s.Execute(shellResponse.Command) + if err != nil { + fmt.Printf("%s: %s\n", ErrorShellAI, err) + return + } + } + } diff --git a/shell/shell.go b/shell/shell.go index cc70dc3..c199a51 100644 --- a/shell/shell.go +++ b/shell/shell.go @@ -1,7 +1,9 @@ package shell import ( + "bufio" "fmt" + "os" "strings" "github.com/commander-cli/cmd" @@ -14,12 +16,17 @@ type Shell struct { completion *completion.Completion } -type executeResponse string +type ShellResponse struct { + CommandAction CommandAction + Command string +} + +type CommandAction string const ( - executeResponseExecute executeResponse = "e" - executeResponseRetry executeResponse = "r" - executeResponseExit executeResponse = "q" + CommandActionExecute CommandAction = "execute" + CommandActionRevise CommandAction = "retry" + CommandActionExit CommandAction = "exit" ) func New(completion *completion.Completion) *Shell { @@ -28,50 +35,101 @@ func New(completion *completion.Completion) *Shell { } } -func (s *Shell) Run(input string) error { +func (s *Shell) Suggest(input string) (*ShellResponse, error) { if input == "" { - return fmt.Errorf("input is empty") + return nil, fmt.Errorf("input is empty") } - for { - executeResponse, command := s.suggest(input) - if executeResponse == executeResponseExit { - c := cmd.NewCommand(command) - err := c.Execute() - if err != nil { - return fmt.Errorf("command: %w", err) - } - break - } - + response, err := s.completion.Suggest(input, "") + if err != nil { + return nil, fmt.Errorf("completion: %w", err) } - return nil + color.New(color.FgWhite, color.Bold).Printf("\nHere is your command line:\n\n") + color.New(color.FgCyan, color.Bold).Printf("$ %s\n", response.Command) + color.New(color.FgWhite).Printf("--\n%s\n\n", response.Explain) + + userAction := getUserActionFromStdin() + + return &ShellResponse{ + Command: response.Command, + CommandAction: getCommandActionFromUserAction(userAction), + }, nil + } -func (s *Shell) suggest(input string) (executeResponse, string) { - response, err := s.completion.Suggest(input) +func (s *Shell) Retry(previousCommand string) (*ShellResponse, error) { + + if previousCommand == "" { + return nil, fmt.Errorf("command is empty") + } + var userAction string + + color.New(color.FgWhite, color.Bold).Printf("\nEnter your revision:\n\n") + reader := bufio.NewReader(os.Stdin) + userAction, _ = reader.ReadString('\n') + + response, err := s.completion.Suggest(userAction, previousCommand) if err != nil { - return executeResponseExit, "" + return nil, fmt.Errorf("completion: %w", err) } + color.New(color.FgWhite, color.Bold).Printf("\nHere is your command line:\n\n") color.New(color.FgCyan, color.Bold).Printf("$ %s\n", response.Command) - color.New(color.FgWhite).Printf("--\n%s", response.Explain) + color.New(color.FgWhite).Printf("--\n%s\n\n", response.Explain) + + userAction = getUserActionFromStdin() - var userInput string - fmt.Println("[E]xecute, [R]etry, [Q]uit? > ") - fmt.Scanf("%s", &userInput) + return &ShellResponse{ + Command: response.Command, + CommandAction: getCommandActionFromUserAction(userAction), + }, nil - switch strings.ToLower(userInput) { +} + +func (s *Shell) Execute(command string) error { + + c := cmd.NewCommand(command, cmd.WithStandardStreams, cmd.WithInheritedEnvironment(cmd.EnvVars{})) + + err := c.Execute() + if err != nil { + return fmt.Errorf("command: %w", err) + } + + return nil +} + +// --------------- +// support methods +// --------------- + +func getUserActionFromStdin() string { + + var userAction string + color.New(color.FgWhite).Printf("[") + color.New(color.FgGreen).Printf("E") + color.New(color.FgWhite).Printf("]xecute, [") + color.New(color.FgYellow).Printf("R") + color.New(color.FgWhite).Printf("]evise, [") + color.New(color.FgRed).Printf("Q") + color.New(color.FgWhite).Printf("]uit? > ") + reader := bufio.NewReader(os.Stdin) + userAction, _ = reader.ReadString('\n') + userAction = strings.TrimSpace(userAction) + + return userAction +} + +func getCommandActionFromUserAction(userAction string) CommandAction { + switch strings.ToLower(userAction) { case "e": - return executeResponseExecute, response.Command + return CommandActionExecute case "r": - return executeResponseRetry, response.Command + return CommandActionRevise case "q": - return executeResponseExit, response.Command + return CommandActionExit default: - return executeResponseExit, response.Command + return CommandActionExit } - }