Skip to content

Commit

Permalink
refactor: Add the output parameter, which can be markdown or raw
Browse files Browse the repository at this point in the history
  • Loading branch information
coding-hui committed Jul 14, 2024
1 parent b6e5c4c commit d1786a6
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 40 deletions.
20 changes: 12 additions & 8 deletions internal/cli/ask/ask.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,14 @@ var askExample = templates.Examples(`

// Options is a struct to support ask command.
type Options struct {
interactive, printRaw bool
prompts []string
promptFile string
pipe string
genericclioptions.IOStreams
modelOptions *options.ModelOptions

interactive bool
prompts []string
promptFile string
tempPromptFile string
pipe string
genericclioptions.IOStreams
}

// NewOptions returns initialized Options.
Expand Down Expand Up @@ -91,9 +92,9 @@ func NewCmdASK(ioStreams genericclioptions.IOStreams) *cobra.Command {

cmd.Flags().BoolVarP(&o.interactive, "interactive", "i", o.interactive, "Interactive dialogue model.")
cmd.Flags().StringVarP(&o.promptFile, "file", "f", o.promptFile, "File containing prompt.")
cmd.Flags().BoolVar(&o.printRaw, "raw", o.printRaw, "Return model raw return, no Stream UI.")

options.NewLLMFlags(false).AddFlags(cmd.Flags())
o.modelOptions = options.NewLLMFlags(false)
o.modelOptions.AddFlags(cmd.Flags())

cmd.Flags().VisitAll(func(flag *pflag.Flag) {
_ = viper.BindPFlag(flag.Name, flag)
Expand All @@ -104,6 +105,9 @@ func NewCmdASK(ioStreams genericclioptions.IOStreams) *cobra.Command {

// Validate validates the provided options.
func (o *Options) Validate() error {
if err := o.modelOptions.Validate(); err != nil {
return err
}
return nil
}

Expand All @@ -121,7 +125,7 @@ func (o *Options) Run(args []string) error {

klog.V(2).InfoS("start ask cli mode.", "args", args, "runMode", runMode, "pipe", input.GetPipe())

if o.printRaw {
if o.modelOptions.OutputFormat != nil && *o.modelOptions.OutputFormat == string(options.RawOutputFormat) {
cfg, err := options.NewConfig()
if err != nil {
display.FatalErr(err, "Failed to load ask cmd config")
Expand Down
3 changes: 3 additions & 0 deletions internal/cli/llm/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -277,5 +277,8 @@ func (e *Engine) prepareSystemPromptExecPart() string {
}

func (e *Engine) prepareSystemPromptChatPart() string {
if e.config.Ai.OutputFormat == options.RawOutputFormat {
return `You are a powerful terminal assistant. Your primary language is Chinese and you are good at answering users' questions.`
}
return `You are a powerful terminal assistant. Your primary language is Chinese and you are good at answering users' questions in markdown format.`
}
25 changes: 17 additions & 8 deletions internal/cli/options/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,24 @@ type Config struct {
}

type AiConfig struct {
SystemPrompt string `yaml:"system-prompt,omitempty" mapstructure:"system-prompt,omitempty"`
Token string `yaml:"token,omitempty" mapstructure:"token,omitempty"`
Model string `yaml:"model,omitempty" mapstructure:"model,omitempty"`
ApiBase string `yaml:"api-base,omitempty" mapstructure:"api-base,omitempty"`
Temperature float64 `yaml:"temperature,omitempty" mapstructure:"temperature,omitempty"`
TopP float64 `yaml:"top-p,omitempty" mapstructure:"top-p,omitempty"`
MaxTokens int `yaml:"max-tokens,omitempty" mapstructure:"max-tokens,omitempty"`
Proxy string `yaml:"proxy,omitempty" mapstructure:"proxy,omitempty"`
SystemPrompt string `yaml:"system-prompt,omitempty" mapstructure:"system-prompt,omitempty"`
Token string `yaml:"token,omitempty" mapstructure:"token,omitempty"`
Model string `yaml:"model,omitempty" mapstructure:"model,omitempty"`
ApiBase string `yaml:"api-base,omitempty" mapstructure:"api-base,omitempty"`
Temperature float64 `yaml:"temperature,omitempty" mapstructure:"temperature,omitempty"`
TopP float64 `yaml:"top-p,omitempty" mapstructure:"top-p,omitempty"`
MaxTokens int `yaml:"max-tokens,omitempty" mapstructure:"max-tokens,omitempty"`
Proxy string `yaml:"proxy,omitempty" mapstructure:"proxy,omitempty"`
OutputFormat OutputFormat `yaml:"output-format,omitempty" mapstructure:"output-format,omitempty"`
}

type OutputFormat string

const (
RawOutputFormat OutputFormat = "raw"
MarkdownOutputFormat OutputFormat = "markdown"
)

// NewConfig returns a Config struct with the default values.
func NewConfig() (*Config, error) {
return &Config{
Expand All @@ -54,6 +62,7 @@ func NewConfig() (*Config, error) {
Temperature: viper.GetFloat64(FlagAiTemperature),
TopP: viper.GetFloat64(FlagAiTopP),
MaxTokens: viper.GetInt(FlagAiMaxTokens),
OutputFormat: OutputFormat(viper.GetString(FlagOutputFormat)),
Proxy: "",
},
System: system.Analyse(),
Expand Down
1 change: 1 addition & 0 deletions internal/cli/options/config_flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ const (
FlagAiTemperature = "temperature"
FlagAiTopP = "top-p"
FlagAiMaxTokens = "max-tokens"
FlagOutputFormat = "output-format"

FlagLogFlushFrequency = "log-flush-frequency"
)
Expand Down
65 changes: 41 additions & 24 deletions internal/cli/options/llm_flags.go
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
package options

import (
"fmt"

"github.com/AlekSi/pointer"
flag "github.com/spf13/pflag"
)

type LLMFlags struct {
Token *string
Model *string
ApiBase *string
Temperature *float64
TopP *float64
MaxTokens *int
Proxy *string
type ModelOptions struct {
Token *string
Model *string
ApiBase *string
Temperature *float64
TopP *float64
MaxTokens *int
Proxy *string
OutputFormat *string

// If set to true, will use persistent client config and
// propagate the config to the places that need it, rather than
Expand All @@ -21,36 +24,50 @@ type LLMFlags struct {
}

// AddFlags binds client configuration flags to a given flagset.
func (f *LLMFlags) AddFlags(flags *flag.FlagSet) {
if f.Token != nil {
flags.StringVar(f.Token, FlagAiToken, *f.Token, "Api token to use for CLI requests")
func (m *ModelOptions) AddFlags(flags *flag.FlagSet) {
if m.Token != nil {
flags.StringVar(m.Token, FlagAiToken, *m.Token, "Api token to use for CLI requests")
}
if m.Model != nil {
flags.StringVar(m.Model, FlagAiModel, *m.Model, "The encoding of the model to be called.")
}
if f.Model != nil {
flags.StringVar(f.Model, FlagAiModel, *f.Model, "The encoding of the model to be called.")
if m.ApiBase != nil {
flags.StringVar(m.ApiBase, FlagAiApiBase, *m.ApiBase, "Interface for the API.")
}
if f.ApiBase != nil {
flags.StringVar(f.ApiBase, FlagAiApiBase, *f.ApiBase, "Interface for the API.")
if m.Temperature != nil {
flags.Float64Var(m.Temperature, FlagAiTemperature, *m.Temperature, "Sampling temperature to control the randomness of the output.")
}
if f.Temperature != nil {
flags.Float64Var(f.Temperature, FlagAiTemperature, *f.Temperature, "Sampling temperature to control the randomness of the output.")
if m.TopP != nil {
flags.Float64Var(m.TopP, FlagAiTopP, *m.TopP, "Nucleus sampling method to control the probability mass of the output.")
}
if f.TopP != nil {
flags.Float64Var(f.TopP, FlagAiTopP, *f.TopP, "Nucleus sampling method to control the probability mass of the output.")
if m.MaxTokens != nil {
flags.IntVar(m.MaxTokens, FlagAiMaxTokens, *m.MaxTokens, "The maximum number of tokens the model can output.")
}
if f.MaxTokens != nil {
flags.IntVar(f.MaxTokens, FlagAiMaxTokens, *f.MaxTokens, "The maximum number of tokens the model can output.")
if m.OutputFormat != nil {
flags.StringVarP(m.OutputFormat, FlagOutputFormat, "o", *m.OutputFormat, "Output format. One of: (markdown, raw).")
}
}

// NewLLMFlags returns LLMFlags with default values set.
func NewLLMFlags(usePersistentConfig bool) *LLMFlags {
return &LLMFlags{
// NewLLMFlags returns ModelOptions with default values set.
func NewLLMFlags(usePersistentConfig bool) *ModelOptions {
return &ModelOptions{
Token: pointer.ToString(""),
Model: pointer.ToString(""),
ApiBase: pointer.ToString(""),
Temperature: pointer.ToFloat64(0.5),
TopP: pointer.ToFloat64(0.5),
MaxTokens: pointer.ToInt(1024),
OutputFormat: pointer.ToString(string(MarkdownOutputFormat)),
usePersistentConfig: usePersistentConfig,
}
}

func (m *ModelOptions) Validate() error {
if m.OutputFormat != nil {
output := *m.OutputFormat
if output != string(MarkdownOutputFormat) && output != string(RawOutputFormat) {
return fmt.Errorf("invalid output format: %s", output)
}
}
return nil
}

0 comments on commit d1786a6

Please sign in to comment.