Skip to content

Commit

Permalink
refactor: get rid of olekukonko/tablewriter package (#20)
Browse files Browse the repository at this point in the history
  • Loading branch information
i-sevostyanov authored May 27, 2024
1 parent e59d45d commit e2baa05
Show file tree
Hide file tree
Showing 5 changed files with 159 additions and 96 deletions.
23 changes: 8 additions & 15 deletions cmd/repl/main.go → cmd/shell/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,11 @@ package main

import (
"context"
"errors"
"log"
"os"
"os/signal"
"syscall"

"github.com/i-sevostyanov/NanoDB/internal/repl"
"github.com/i-sevostyanov/NanoDB/internal/shell"
"github.com/i-sevostyanov/NanoDB/internal/sql/engine"
"github.com/i-sevostyanov/NanoDB/internal/sql/parsing/ast"
"github.com/i-sevostyanov/NanoDB/internal/sql/parsing/lexer"
Expand All @@ -21,22 +19,17 @@ func main() {
ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
defer cancel()

parseFn := engine.ParseFn(func(sql string) (ast.Node, error) {
sqlParser := engine.ParseFn(func(sql string) (ast.Node, error) {
lx := lexer.New(sql)
pr := parser.New(lx)
return pr.Parse()
})

catalog := memory.NewCatalog()
aPlanner := planner.New(catalog)
anEngine := engine.New(parseFn, aPlanner)
aRepl := repl.New(os.Stdin, os.Stdout, catalog, anEngine)
sqlCatalog := memory.NewCatalog()
sqlPlanner := planner.New(sqlCatalog)
sqlEngine := engine.New(sqlParser, sqlPlanner)
tableWriter := shell.NewTableWriterFactory()

if err := aRepl.Run(ctx); err != nil {
switch {
case errors.Is(err, repl.ErrQuit):
default:
log.Printf("repl: %v\n", err)
}
}
sh := shell.New(os.Stdin, os.Stdout, sqlCatalog, sqlEngine, tableWriter)
sh.Run(ctx)
}
3 changes: 0 additions & 3 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,13 @@ module github.com/i-sevostyanov/NanoDB
go 1.22

require (
github.com/olekukonko/tablewriter v0.0.5
github.com/stretchr/testify v1.9.0
go.uber.org/mock v0.4.0
)

require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/mattn/go-runewidth v0.0.15 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/rivo/uniseg v0.4.7 // indirect
golang.org/x/mod v0.17.0 // indirect
golang.org/x/tools v0.21.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
Expand Down
8 changes: 0 additions & 8 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,15 +1,7 @@
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI=
github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U=
github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU=
Expand Down
128 changes: 58 additions & 70 deletions internal/repl/repl.go → internal/shell/shell.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package repl
package shell

import (
"bufio"
Expand All @@ -8,129 +8,130 @@ import (
"fmt"
"io"
"os"
"strconv"
"strings"

"github.com/i-sevostyanov/NanoDB/internal/sql"

"github.com/olekukonko/tablewriter"
)

const prompt = "#> "

var ErrQuit = errors.New("quit")
type TableWriter interface {
WriteTable(w io.Writer, headers []string, data [][]string, showRowsCount bool)
}

type Engine interface {
Exec(database, sql string) (columns []string, iter sql.RowIter, err error)
}

// Repl is terminal-based front-end to NanoDB.
type Repl struct {
// Shell is terminal-based front-end to NanoDB.
type Shell struct {
input io.Reader
output io.Writer
catalog sql.Catalog
database sql.Database
engine Engine
tw TableWriter
prompt string
closeCh chan struct{}
}

func New(in io.Reader, out io.Writer, catalog sql.Catalog, engine Engine) *Repl {
return &Repl{
func New(in io.Reader, out io.Writer, catalog sql.Catalog, engine Engine, tw TableWriter) *Shell {
return &Shell{
input: in,
output: out,
catalog: catalog,
engine: engine,
tw: tw,
prompt: prompt,
closeCh: make(chan struct{}),
}
}

func (r *Repl) Run(ctx context.Context) error {
r.write("repl is the NanoDB interactive terminal.\n")
func (s *Shell) Run(ctx context.Context) {
s.write("shell is the NanoDB interactive terminal.\n")

go func() {
for {
r.write(r.prompt)
s.write(s.prompt)

scanner := bufio.NewScanner(r.input)
scanner := bufio.NewScanner(s.input)
scanner.Scan()
input := scanner.Text()

if input == "" {
continue
}

if repl, err := r.exec(input); err != nil {
r.write(err.Error() + "\n")
if repl, err := s.exec(input); err != nil {
s.write(err.Error() + "\n")
} else if repl != "" {
r.write(repl)
s.write(repl)
}
}
}()

select {
case <-ctx.Done():
return nil
case <-r.closeCh:
return ErrQuit
case <-s.closeCh:
}
}

func (r *Repl) write(s string) {
_, _ = r.output.Write([]byte(s))
func (s *Shell) write(line string) {
_, _ = s.output.Write([]byte(line))
}

func (r *Repl) exec(input string) (string, error) {
func (s *Shell) exec(input string) (string, error) {
switch input[0] {
case '\\':
return r.execCommand(input)
return s.execCommand(input)
default:
return r.execQuery(input)
return s.execQuery(input)
}
}

func (r *Repl) execCommand(input string) (string, error) {
func (s *Shell) execCommand(input string) (string, error) {
cmd := strings.TrimSpace(input)
params := strings.Fields(cmd)

switch params[0] {
case `\use`:
return r.useDatabase(params)
return s.useDatabase(params)
case `\databases`:
return r.listDatabases()
return s.listDatabases()
case `\tables`:
return r.listTables()
return s.listTables()
case `\describe`:
return r.describeTable(params)
return s.describeTable(params)
case `\import`:
return r.importFile(params)
return s.importFile(params)
case `\help`:
return r.showHelp(), nil
return s.showHelp(), nil
case `\quit`:
return r.quit(), nil
return s.quit(), nil
default:
return "", fmt.Errorf("unknown command: %v", params[0])
}
}

func (r *Repl) useDatabase(params []string) (string, error) {
func (s *Shell) useDatabase(params []string) (string, error) {
if len(params) < 2 {
return "", fmt.Errorf("database name not specified")
}

db, err := r.catalog.GetDatabase(params[1])
db, err := s.catalog.GetDatabase(params[1])
if err != nil {
return "", err
}

r.database = db
r.prompt = fmt.Sprintf("%s %s", db.Name(), prompt)
s.database = db
s.prompt = fmt.Sprintf("%s %s", db.Name(), prompt)

return "database changed\n", nil
}

func (r *Repl) listDatabases() (string, error) {
databases, err := r.catalog.ListDatabases()
func (s *Shell) listDatabases() (string, error) {
databases, err := s.catalog.ListDatabases()
if err != nil {
return "", err
}
Expand All @@ -142,41 +143,39 @@ func (r *Repl) listDatabases() (string, error) {
data = append(data, []string{databases[i].Name()})
}

r.drawTable(buf, []string{"Database"}, data)
buf.WriteString(fmt.Sprintf("(%d rows)\n\n", len(data)))
s.tw.WriteTable(buf, []string{"Database"}, data, true)

return buf.String(), nil
}

func (r *Repl) listTables() (string, error) {
if r.database == nil {
func (s *Shell) listTables() (string, error) {
if s.database == nil {
return "", fmt.Errorf("connect to database first")
}

buf := bytes.NewBuffer(nil)
tables := r.database.ListTables()
tables := s.database.ListTables()
data := make([][]string, 0, len(tables))

for i := range tables {
data = append(data, []string{tables[i].Name()})
}

r.drawTable(buf, []string{"Table"}, data)
buf.WriteString(fmt.Sprintf("(%d rows)\n\n", len(data)))
s.tw.WriteTable(buf, []string{"Table"}, data, true)

return buf.String(), nil
}

func (r *Repl) describeTable(params []string) (string, error) {
if r.database == nil {
func (s *Shell) describeTable(params []string) (string, error) {
if s.database == nil {
return "", fmt.Errorf("connect to database first")
}

if len(params) < 2 {
return "", fmt.Errorf("table name not specified")
}

table, err := r.database.GetTable(params[1])
table, err := s.database.GetTable(params[1])
if err != nil {
return "", err
}
Expand Down Expand Up @@ -209,21 +208,21 @@ func (r *Repl) describeTable(params []string) (string, error) {
row := []string{
columns[i].Name,
columns[i].DataType.String(),
fmt.Sprintf("%t", columns[i].Nullable),
strconv.FormatBool(columns[i].Nullable),
defaultValue,
}

data = append(data, row)
}

r.drawTable(buf, []string{"Column", "Type", "Nullable", "Default"}, data)
s.tw.WriteTable(buf, []string{"Column", "Type", "Nullable", "Default"}, data, false)
buf.WriteString("Indexes:\n")
buf.WriteString(fmt.Sprintf(" PRIMARY KEY (%s) autoincrement\n\n", primaryKey.Name))

return buf.String(), nil
}

func (r *Repl) importFile(params []string) (string, error) {
func (s *Shell) importFile(params []string) (string, error) {
if len(params) < 2 {
return "", fmt.Errorf("filename not specified")
}
Expand All @@ -241,15 +240,15 @@ func (r *Repl) importFile(params []string) (string, error) {
continue
}

if _, err = r.exec(stmt); err != nil {
if _, err = s.exec(stmt); err != nil {
return "", err
}
}

return "OK\n", nil
}

func (r *Repl) showHelp() string {
func (s *Shell) showHelp() string {
help := `repl is the NanoDB interactive terminal.
Commands:
Expand All @@ -264,19 +263,19 @@ Commands:
return help
}

func (r *Repl) quit() string {
close(r.closeCh)
func (s *Shell) quit() string {
close(s.closeCh)
return "Bye!\n"
}

func (r *Repl) execQuery(input string) (string, error) {
func (s *Shell) execQuery(input string) (string, error) {
var database string

if r.database != nil {
database = r.database.Name()
if s.database != nil {
database = s.database.Name()
}

columns, rowIter, err := r.engine.Exec(database, input)
columns, rowIter, err := s.engine.Exec(database, input)
if err != nil {
return "", fmt.Errorf("failed to execute query: %w", err)
}
Expand Down Expand Up @@ -311,19 +310,8 @@ loop:
buf := bytes.NewBuffer(nil)

if len(data) > 0 {
r.drawTable(buf, columns, data)
buf.WriteString(fmt.Sprintf("(%d rows)\n\n", len(data)))
s.tw.WriteTable(buf, columns, data, true)
}

return buf.String(), nil
}

func (r *Repl) drawTable(w io.Writer, headers []string, data [][]string) {
tw := tablewriter.NewWriter(w)
tw.SetColWidth(75)
tw.AppendBulk(data)
tw.SetAutoFormatHeaders(false)
tw.SetAlignment(tablewriter.ALIGN_LEFT)
tw.SetHeader(headers)
tw.Render()
}
Loading

0 comments on commit e2baa05

Please sign in to comment.