diff --git a/internal/cli/ask/ask.go b/internal/cli/ask/ask.go index 2a6f3ce..0ce6dda 100644 --- a/internal/cli/ask/ask.go +++ b/internal/cli/ask/ask.go @@ -40,13 +40,14 @@ var askExample = templates.Examples(` // Options is a struct to support ask command. type Options struct { - interactive, printRaw bool - prompts []string - promptFile string - pipe string - genericclioptions.IOStreams + modelOptions *options.ModelOptions + interactive bool + prompts []string + promptFile string tempPromptFile string + pipe string + genericclioptions.IOStreams } // NewOptions returns initialized Options. @@ -91,9 +92,9 @@ func NewCmdASK(ioStreams genericclioptions.IOStreams) *cobra.Command { cmd.Flags().BoolVarP(&o.interactive, "interactive", "i", o.interactive, "Interactive dialogue model.") cmd.Flags().StringVarP(&o.promptFile, "file", "f", o.promptFile, "File containing prompt.") - cmd.Flags().BoolVar(&o.printRaw, "raw", o.printRaw, "Return model raw return, no Stream UI.") - options.NewLLMFlags(false).AddFlags(cmd.Flags()) + o.modelOptions = options.NewLLMFlags(false) + o.modelOptions.AddFlags(cmd.Flags()) cmd.Flags().VisitAll(func(flag *pflag.Flag) { _ = viper.BindPFlag(flag.Name, flag) @@ -104,6 +105,9 @@ func NewCmdASK(ioStreams genericclioptions.IOStreams) *cobra.Command { // Validate validates the provided options. func (o *Options) Validate() error { + if err := o.modelOptions.Validate(); err != nil { + return err + } return nil } @@ -121,7 +125,7 @@ func (o *Options) Run(args []string) error { klog.V(2).InfoS("start ask cli mode.", "args", args, "runMode", runMode, "pipe", input.GetPipe()) - if o.printRaw { + if o.modelOptions.OutputFormat != nil && *o.modelOptions.OutputFormat == string(options.RawOutputFormat) { cfg, err := options.NewConfig() if err != nil { display.FatalErr(err, "Failed to load ask cmd config") diff --git a/internal/cli/llm/engine.go b/internal/cli/llm/engine.go index 56d5f31..36a0bda 100644 --- a/internal/cli/llm/engine.go +++ b/internal/cli/llm/engine.go @@ -277,5 +277,8 @@ func (e *Engine) prepareSystemPromptExecPart() string { } func (e *Engine) prepareSystemPromptChatPart() string { + if e.config.Ai.OutputFormat == options.RawOutputFormat { + return `You are a powerful terminal assistant. Your primary language is Chinese and you are good at answering users' questions.` + } return `You are a powerful terminal assistant. Your primary language is Chinese and you are good at answering users' questions in markdown format.` } diff --git a/internal/cli/options/config.go b/internal/cli/options/config.go index 30aec61..aa37d9c 100644 --- a/internal/cli/options/config.go +++ b/internal/cli/options/config.go @@ -33,16 +33,24 @@ type Config struct { } type AiConfig struct { - SystemPrompt string `yaml:"system-prompt,omitempty" mapstructure:"system-prompt,omitempty"` - Token string `yaml:"token,omitempty" mapstructure:"token,omitempty"` - Model string `yaml:"model,omitempty" mapstructure:"model,omitempty"` - ApiBase string `yaml:"api-base,omitempty" mapstructure:"api-base,omitempty"` - Temperature float64 `yaml:"temperature,omitempty" mapstructure:"temperature,omitempty"` - TopP float64 `yaml:"top-p,omitempty" mapstructure:"top-p,omitempty"` - MaxTokens int `yaml:"max-tokens,omitempty" mapstructure:"max-tokens,omitempty"` - Proxy string `yaml:"proxy,omitempty" mapstructure:"proxy,omitempty"` + SystemPrompt string `yaml:"system-prompt,omitempty" mapstructure:"system-prompt,omitempty"` + Token string `yaml:"token,omitempty" mapstructure:"token,omitempty"` + Model string `yaml:"model,omitempty" mapstructure:"model,omitempty"` + ApiBase string `yaml:"api-base,omitempty" mapstructure:"api-base,omitempty"` + Temperature float64 `yaml:"temperature,omitempty" mapstructure:"temperature,omitempty"` + TopP float64 `yaml:"top-p,omitempty" mapstructure:"top-p,omitempty"` + MaxTokens int `yaml:"max-tokens,omitempty" mapstructure:"max-tokens,omitempty"` + Proxy string `yaml:"proxy,omitempty" mapstructure:"proxy,omitempty"` + OutputFormat OutputFormat `yaml:"output-format,omitempty" mapstructure:"output-format,omitempty"` } +type OutputFormat string + +const ( + RawOutputFormat OutputFormat = "raw" + MarkdownOutputFormat OutputFormat = "markdown" +) + // NewConfig returns a Config struct with the default values. func NewConfig() (*Config, error) { return &Config{ @@ -54,6 +62,7 @@ func NewConfig() (*Config, error) { Temperature: viper.GetFloat64(FlagAiTemperature), TopP: viper.GetFloat64(FlagAiTopP), MaxTokens: viper.GetInt(FlagAiMaxTokens), + OutputFormat: OutputFormat(viper.GetString(FlagOutputFormat)), Proxy: "", }, System: system.Analyse(), diff --git a/internal/cli/options/config_flags.go b/internal/cli/options/config_flags.go index 24b80a1..7a9dd27 100644 --- a/internal/cli/options/config_flags.go +++ b/internal/cli/options/config_flags.go @@ -21,6 +21,7 @@ const ( FlagAiTemperature = "temperature" FlagAiTopP = "top-p" FlagAiMaxTokens = "max-tokens" + FlagOutputFormat = "output-format" FlagLogFlushFrequency = "log-flush-frequency" ) diff --git a/internal/cli/options/llm_flags.go b/internal/cli/options/llm_flags.go index 0b18add..93cc735 100644 --- a/internal/cli/options/llm_flags.go +++ b/internal/cli/options/llm_flags.go @@ -1,18 +1,21 @@ package options import ( + "fmt" + "github.com/AlekSi/pointer" flag "github.com/spf13/pflag" ) -type LLMFlags struct { - Token *string - Model *string - ApiBase *string - Temperature *float64 - TopP *float64 - MaxTokens *int - Proxy *string +type ModelOptions struct { + Token *string + Model *string + ApiBase *string + Temperature *float64 + TopP *float64 + MaxTokens *int + Proxy *string + OutputFormat *string // If set to true, will use persistent client config and // propagate the config to the places that need it, rather than @@ -21,36 +24,50 @@ type LLMFlags struct { } // AddFlags binds client configuration flags to a given flagset. -func (f *LLMFlags) AddFlags(flags *flag.FlagSet) { - if f.Token != nil { - flags.StringVar(f.Token, FlagAiToken, *f.Token, "Api token to use for CLI requests") +func (m *ModelOptions) AddFlags(flags *flag.FlagSet) { + if m.Token != nil { + flags.StringVar(m.Token, FlagAiToken, *m.Token, "Api token to use for CLI requests") + } + if m.Model != nil { + flags.StringVar(m.Model, FlagAiModel, *m.Model, "The encoding of the model to be called.") } - if f.Model != nil { - flags.StringVar(f.Model, FlagAiModel, *f.Model, "The encoding of the model to be called.") + if m.ApiBase != nil { + flags.StringVar(m.ApiBase, FlagAiApiBase, *m.ApiBase, "Interface for the API.") } - if f.ApiBase != nil { - flags.StringVar(f.ApiBase, FlagAiApiBase, *f.ApiBase, "Interface for the API.") + if m.Temperature != nil { + flags.Float64Var(m.Temperature, FlagAiTemperature, *m.Temperature, "Sampling temperature to control the randomness of the output.") } - if f.Temperature != nil { - flags.Float64Var(f.Temperature, FlagAiTemperature, *f.Temperature, "Sampling temperature to control the randomness of the output.") + if m.TopP != nil { + flags.Float64Var(m.TopP, FlagAiTopP, *m.TopP, "Nucleus sampling method to control the probability mass of the output.") } - if f.TopP != nil { - flags.Float64Var(f.TopP, FlagAiTopP, *f.TopP, "Nucleus sampling method to control the probability mass of the output.") + if m.MaxTokens != nil { + flags.IntVar(m.MaxTokens, FlagAiMaxTokens, *m.MaxTokens, "The maximum number of tokens the model can output.") } - if f.MaxTokens != nil { - flags.IntVar(f.MaxTokens, FlagAiMaxTokens, *f.MaxTokens, "The maximum number of tokens the model can output.") + if m.OutputFormat != nil { + flags.StringVarP(m.OutputFormat, FlagOutputFormat, "o", *m.OutputFormat, "Output format. One of: (markdown, raw).") } } -// NewLLMFlags returns LLMFlags with default values set. -func NewLLMFlags(usePersistentConfig bool) *LLMFlags { - return &LLMFlags{ +// NewLLMFlags returns ModelOptions with default values set. +func NewLLMFlags(usePersistentConfig bool) *ModelOptions { + return &ModelOptions{ Token: pointer.ToString(""), Model: pointer.ToString(""), ApiBase: pointer.ToString(""), Temperature: pointer.ToFloat64(0.5), TopP: pointer.ToFloat64(0.5), MaxTokens: pointer.ToInt(1024), + OutputFormat: pointer.ToString(string(MarkdownOutputFormat)), usePersistentConfig: usePersistentConfig, } } + +func (m *ModelOptions) Validate() error { + if m.OutputFormat != nil { + output := *m.OutputFormat + if output != string(MarkdownOutputFormat) && output != string(RawOutputFormat) { + return fmt.Errorf("invalid output format: %s", output) + } + } + return nil +}