diff --git a/cmd/repl/main.go b/cmd/shell/main.go similarity index 59% rename from cmd/repl/main.go rename to cmd/shell/main.go index 0c5ec81..1e1cfdc 100644 --- a/cmd/repl/main.go +++ b/cmd/shell/main.go @@ -2,13 +2,11 @@ package main import ( "context" - "errors" - "log" "os" "os/signal" "syscall" - "github.com/i-sevostyanov/NanoDB/internal/repl" + "github.com/i-sevostyanov/NanoDB/internal/shell" "github.com/i-sevostyanov/NanoDB/internal/sql/engine" "github.com/i-sevostyanov/NanoDB/internal/sql/parsing/ast" "github.com/i-sevostyanov/NanoDB/internal/sql/parsing/lexer" @@ -21,22 +19,17 @@ func main() { ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) defer cancel() - parseFn := engine.ParseFn(func(sql string) (ast.Node, error) { + sqlParser := engine.ParseFn(func(sql string) (ast.Node, error) { lx := lexer.New(sql) pr := parser.New(lx) return pr.Parse() }) - catalog := memory.NewCatalog() - aPlanner := planner.New(catalog) - anEngine := engine.New(parseFn, aPlanner) - aRepl := repl.New(os.Stdin, os.Stdout, catalog, anEngine) + sqlCatalog := memory.NewCatalog() + sqlPlanner := planner.New(sqlCatalog) + sqlEngine := engine.New(sqlParser, sqlPlanner) + tableWriter := shell.NewTableWriterFactory() - if err := aRepl.Run(ctx); err != nil { - switch { - case errors.Is(err, repl.ErrQuit): - default: - log.Printf("repl: %v\n", err) - } - } + sh := shell.New(os.Stdin, os.Stdout, sqlCatalog, sqlEngine, tableWriter) + sh.Run(ctx) } diff --git a/go.mod b/go.mod index ca3d82c..109ffc4 100644 --- a/go.mod +++ b/go.mod @@ -3,16 +3,13 @@ module github.com/i-sevostyanov/NanoDB go 1.22 require ( - github.com/olekukonko/tablewriter v0.0.5 github.com/stretchr/testify v1.9.0 go.uber.org/mock v0.4.0 ) require ( github.com/davecgh/go-spew v1.1.1 // indirect - github.com/mattn/go-runewidth v0.0.15 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - github.com/rivo/uniseg v0.4.7 // indirect golang.org/x/mod v0.17.0 // indirect golang.org/x/tools v0.21.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/go.sum b/go.sum index db0c905..ba778e8 100644 --- a/go.sum +++ b/go.sum @@ -1,15 +1,7 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI= -github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U= -github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= -github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec= -github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= -github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= -github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU= diff --git a/internal/repl/repl.go b/internal/shell/shell.go similarity index 62% rename from internal/repl/repl.go rename to internal/shell/shell.go index f4284f4..935a122 100644 --- a/internal/repl/repl.go +++ b/internal/shell/shell.go @@ -1,4 +1,4 @@ -package repl +package shell import ( "bufio" @@ -8,51 +8,54 @@ import ( "fmt" "io" "os" + "strconv" "strings" "github.com/i-sevostyanov/NanoDB/internal/sql" - - "github.com/olekukonko/tablewriter" ) const prompt = "#> " -var ErrQuit = errors.New("quit") +type TableWriter interface { + WriteTable(w io.Writer, headers []string, data [][]string, showRowsCount bool) +} type Engine interface { Exec(database, sql string) (columns []string, iter sql.RowIter, err error) } -// Repl is terminal-based front-end to NanoDB. -type Repl struct { +// Shell is terminal-based front-end to NanoDB. +type Shell struct { input io.Reader output io.Writer catalog sql.Catalog database sql.Database engine Engine + tw TableWriter prompt string closeCh chan struct{} } -func New(in io.Reader, out io.Writer, catalog sql.Catalog, engine Engine) *Repl { - return &Repl{ +func New(in io.Reader, out io.Writer, catalog sql.Catalog, engine Engine, tw TableWriter) *Shell { + return &Shell{ input: in, output: out, catalog: catalog, engine: engine, + tw: tw, prompt: prompt, closeCh: make(chan struct{}), } } -func (r *Repl) Run(ctx context.Context) error { - r.write("repl is the NanoDB interactive terminal.\n") +func (s *Shell) Run(ctx context.Context) { + s.write("shell is the NanoDB interactive terminal.\n") go func() { for { - r.write(r.prompt) + s.write(s.prompt) - scanner := bufio.NewScanner(r.input) + scanner := bufio.NewScanner(s.input) scanner.Scan() input := scanner.Text() @@ -60,77 +63,75 @@ func (r *Repl) Run(ctx context.Context) error { continue } - if repl, err := r.exec(input); err != nil { - r.write(err.Error() + "\n") + if repl, err := s.exec(input); err != nil { + s.write(err.Error() + "\n") } else if repl != "" { - r.write(repl) + s.write(repl) } } }() select { case <-ctx.Done(): - return nil - case <-r.closeCh: - return ErrQuit + case <-s.closeCh: } } -func (r *Repl) write(s string) { - _, _ = r.output.Write([]byte(s)) +func (s *Shell) write(line string) { + _, _ = s.output.Write([]byte(line)) } -func (r *Repl) exec(input string) (string, error) { +func (s *Shell) exec(input string) (string, error) { switch input[0] { case '\\': - return r.execCommand(input) + return s.execCommand(input) default: - return r.execQuery(input) + return s.execQuery(input) } } -func (r *Repl) execCommand(input string) (string, error) { +func (s *Shell) execCommand(input string) (string, error) { cmd := strings.TrimSpace(input) params := strings.Fields(cmd) switch params[0] { case `\use`: - return r.useDatabase(params) + return s.useDatabase(params) case `\databases`: - return r.listDatabases() + return s.listDatabases() case `\tables`: - return r.listTables() + return s.listTables() case `\describe`: - return r.describeTable(params) + return s.describeTable(params) case `\import`: - return r.importFile(params) + return s.importFile(params) case `\help`: - return r.showHelp(), nil + return s.showHelp(), nil case `\quit`: - return r.quit(), nil + return s.quit(), nil default: return "", fmt.Errorf("unknown command: %v", params[0]) } } -func (r *Repl) useDatabase(params []string) (string, error) { +func (s *Shell) useDatabase(params []string) (string, error) { if len(params) < 2 { return "", fmt.Errorf("database name not specified") } - db, err := r.catalog.GetDatabase(params[1]) + db, err := s.catalog.GetDatabase(params[1]) if err != nil { return "", err } - r.database = db - r.prompt = fmt.Sprintf("%s %s", db.Name(), prompt) + s.database = db + s.prompt = fmt.Sprintf("%s %s", db.Name(), prompt) return "database changed\n", nil } -func (r *Repl) listDatabases() (string, error) { - databases, err := r.catalog.ListDatabases() +func (s *Shell) listDatabases() (string, error) { + databases, err := s.catalog.ListDatabases() if err != nil { return "", err } @@ -142,33 +143,31 @@ func (r *Repl) listDatabases() (string, error) { data = append(data, []string{databases[i].Name()}) } - r.drawTable(buf, []string{"Database"}, data) - buf.WriteString(fmt.Sprintf("(%d rows)\n\n", len(data))) + s.tw.WriteTable(buf, []string{"Database"}, data, true) return buf.String(), nil } -func (r *Repl) listTables() (string, error) { - if r.database == nil { +func (s *Shell) listTables() (string, error) { + if s.database == nil { return "", fmt.Errorf("connect to database first") } buf := bytes.NewBuffer(nil) - tables := r.database.ListTables() + tables := s.database.ListTables() data := make([][]string, 0, len(tables)) for i := range tables { data = append(data, []string{tables[i].Name()}) } - r.drawTable(buf, []string{"Table"}, data) - buf.WriteString(fmt.Sprintf("(%d rows)\n\n", len(data))) + s.tw.WriteTable(buf, []string{"Table"}, data, true) return buf.String(), nil } -func (r *Repl) describeTable(params []string) (string, error) { - if r.database == nil { +func (s *Shell) describeTable(params []string) (string, error) { + if s.database == nil { return "", fmt.Errorf("connect to database first") } @@ -176,7 +175,7 @@ func (r *Repl) describeTable(params []string) (string, error) { return "", fmt.Errorf("table name not specified") } - table, err := r.database.GetTable(params[1]) + table, err := s.database.GetTable(params[1]) if err != nil { return "", err } @@ -209,21 +208,21 @@ func (r *Repl) describeTable(params []string) (string, error) { row := []string{ columns[i].Name, columns[i].DataType.String(), - fmt.Sprintf("%t", columns[i].Nullable), + strconv.FormatBool(columns[i].Nullable), defaultValue, } data = append(data, row) } - r.drawTable(buf, []string{"Column", "Type", "Nullable", "Default"}, data) + s.tw.WriteTable(buf, []string{"Column", "Type", "Nullable", "Default"}, data, false) buf.WriteString("Indexes:\n") buf.WriteString(fmt.Sprintf(" PRIMARY KEY (%s) autoincrement\n\n", primaryKey.Name)) return buf.String(), nil } -func (r *Repl) importFile(params []string) (string, error) { +func (s *Shell) importFile(params []string) (string, error) { if len(params) < 2 { return "", fmt.Errorf("filename not specified") } @@ -241,7 +240,7 @@ func (r *Repl) importFile(params []string) (string, error) { continue } - if _, err = r.exec(stmt); err != nil { + if _, err = s.exec(stmt); err != nil { return "", err } } @@ -249,7 +248,7 @@ func (r *Repl) importFile(params []string) (string, error) { return "OK\n", nil } -func (r *Repl) showHelp() string { +func (s *Shell) showHelp() string { help := `repl is the NanoDB interactive terminal. Commands: @@ -264,19 +263,19 @@ Commands: return help } -func (r *Repl) quit() string { - close(r.closeCh) +func (s *Shell) quit() string { + close(s.closeCh) return "Bye!\n" } -func (r *Repl) execQuery(input string) (string, error) { +func (s *Shell) execQuery(input string) (string, error) { var database string - if r.database != nil { - database = r.database.Name() + if s.database != nil { + database = s.database.Name() } - columns, rowIter, err := r.engine.Exec(database, input) + columns, rowIter, err := s.engine.Exec(database, input) if err != nil { return "", fmt.Errorf("failed to execute query: %w", err) } @@ -311,19 +310,8 @@ loop: buf := bytes.NewBuffer(nil) if len(data) > 0 { - r.drawTable(buf, columns, data) - buf.WriteString(fmt.Sprintf("(%d rows)\n\n", len(data))) + s.tw.WriteTable(buf, columns, data, true) } return buf.String(), nil } - -func (r *Repl) drawTable(w io.Writer, headers []string, data [][]string) { - tw := tablewriter.NewWriter(w) - tw.SetColWidth(75) - tw.AppendBulk(data) - tw.SetAutoFormatHeaders(false) - tw.SetAlignment(tablewriter.ALIGN_LEFT) - tw.SetHeader(headers) - tw.Render() -} diff --git a/internal/shell/table.go b/internal/shell/table.go new file mode 100644 index 0000000..81cc9e7 --- /dev/null +++ b/internal/shell/table.go @@ -0,0 +1,93 @@ +package shell + +import ( + "fmt" + "io" + "strings" + "text/tabwriter" +) + +type Table struct { + w *tabwriter.Writer +} + +func NewTableWriter(w io.Writer) *Table { + return &Table{ + w: tabwriter.NewWriter(w, 0, 0, 0, ' ', tabwriter.TabIndent), + } +} + +func (t *Table) WriteTable(headers []string, data [][]string, showRowsCount bool) { + divLine := t.formatDividerLine(headers, data) + + t.write(divLine) + t.write(t.formatRow(headers)) + t.write(divLine) + + for _, row := range data { + t.write(t.formatRow(row)) + } + + t.write(divLine) + + if showRowsCount { + t.write(t.formatRowsCount(data)) + } + + t.flush() +} + +func (t *Table) formatDividerLine(headers []string, data [][]string) string { + columnsWidth := t.columnsWidth(headers, data) + columns := make([]string, len(headers)) + + for i, size := range columnsWidth { + columns[i] = strings.Repeat("-", size+2) + } + + return fmt.Sprintf("+%s\t+\n", strings.Join(columns, "\t+")) +} + +func (t *Table) formatRow(columns []string) string { + return fmt.Sprintf("| %s\t|\n", strings.Join(columns, "\t| ")) +} + +func (t *Table) formatRowsCount(data [][]string) string { + return fmt.Sprintf("(%d rows)\n\n", len(data)) +} + +func (t *Table) write(line string) { + _, _ = t.w.Write([]byte(line)) +} + +func (t *Table) flush() { + _ = t.w.Flush() +} + +func (t *Table) columnsWidth(headers []string, data [][]string) []int { + columns := make([]int, len(headers)) + + for i := range data { + for j := range data[i] { + if len(headers[j])+2 > columns[j] { + columns[j] = len(headers[j]) + } + + if len(data[i][j]) > columns[j] { + columns[j] = len(data[i][j]) + } + } + } + + return columns +} + +type TableWriterFactory struct{} + +func NewTableWriterFactory() TableWriterFactory { + return TableWriterFactory{} +} + +func (f TableWriterFactory) WriteTable(w io.Writer, headers []string, data [][]string, showRowsCount bool) { + NewTableWriter(w).WriteTable(headers, data, showRowsCount) +}