Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: get rid of olekukonko/tablewriter package #20

Merged
merged 1 commit into from
May 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 8 additions & 15 deletions cmd/repl/main.go → cmd/shell/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,11 @@ package main

import (
"context"
"errors"
"log"
"os"
"os/signal"
"syscall"

"github.com/i-sevostyanov/NanoDB/internal/repl"
"github.com/i-sevostyanov/NanoDB/internal/shell"
"github.com/i-sevostyanov/NanoDB/internal/sql/engine"
"github.com/i-sevostyanov/NanoDB/internal/sql/parsing/ast"
"github.com/i-sevostyanov/NanoDB/internal/sql/parsing/lexer"
Expand All @@ -21,22 +19,17 @@ func main() {
ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
defer cancel()

parseFn := engine.ParseFn(func(sql string) (ast.Node, error) {
sqlParser := engine.ParseFn(func(sql string) (ast.Node, error) {
lx := lexer.New(sql)
pr := parser.New(lx)
return pr.Parse()
})

catalog := memory.NewCatalog()
aPlanner := planner.New(catalog)
anEngine := engine.New(parseFn, aPlanner)
aRepl := repl.New(os.Stdin, os.Stdout, catalog, anEngine)
sqlCatalog := memory.NewCatalog()
sqlPlanner := planner.New(sqlCatalog)
sqlEngine := engine.New(sqlParser, sqlPlanner)
tableWriter := shell.NewTableWriterFactory()

if err := aRepl.Run(ctx); err != nil {
switch {
case errors.Is(err, repl.ErrQuit):
default:
log.Printf("repl: %v\n", err)
}
}
sh := shell.New(os.Stdin, os.Stdout, sqlCatalog, sqlEngine, tableWriter)
sh.Run(ctx)
}
3 changes: 0 additions & 3 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,13 @@ module github.com/i-sevostyanov/NanoDB
go 1.22

require (
github.com/olekukonko/tablewriter v0.0.5
github.com/stretchr/testify v1.9.0
go.uber.org/mock v0.4.0
)

require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/mattn/go-runewidth v0.0.15 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/rivo/uniseg v0.4.7 // indirect
golang.org/x/mod v0.17.0 // indirect
golang.org/x/tools v0.21.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
Expand Down
8 changes: 0 additions & 8 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,15 +1,7 @@
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI=
github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U=
github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU=
Expand Down
128 changes: 58 additions & 70 deletions internal/repl/repl.go → internal/shell/shell.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package repl
package shell

import (
"bufio"
Expand All @@ -8,129 +8,130 @@ import (
"fmt"
"io"
"os"
"strconv"
"strings"

"github.com/i-sevostyanov/NanoDB/internal/sql"

"github.com/olekukonko/tablewriter"
)

const prompt = "#> "

var ErrQuit = errors.New("quit")
type TableWriter interface {
WriteTable(w io.Writer, headers []string, data [][]string, showRowsCount bool)
}

type Engine interface {
Exec(database, sql string) (columns []string, iter sql.RowIter, err error)
}

// Repl is terminal-based front-end to NanoDB.
type Repl struct {
// Shell is terminal-based front-end to NanoDB.
type Shell struct {
input io.Reader
output io.Writer
catalog sql.Catalog
database sql.Database
engine Engine
tw TableWriter
prompt string
closeCh chan struct{}
}

func New(in io.Reader, out io.Writer, catalog sql.Catalog, engine Engine) *Repl {
return &Repl{
func New(in io.Reader, out io.Writer, catalog sql.Catalog, engine Engine, tw TableWriter) *Shell {
return &Shell{
input: in,
output: out,
catalog: catalog,
engine: engine,
tw: tw,
prompt: prompt,
closeCh: make(chan struct{}),
}
}

func (r *Repl) Run(ctx context.Context) error {
r.write("repl is the NanoDB interactive terminal.\n")
func (s *Shell) Run(ctx context.Context) {
s.write("shell is the NanoDB interactive terminal.\n")

go func() {
for {
r.write(r.prompt)
s.write(s.prompt)

scanner := bufio.NewScanner(r.input)
scanner := bufio.NewScanner(s.input)
scanner.Scan()
input := scanner.Text()

if input == "" {
continue
}

if repl, err := r.exec(input); err != nil {
r.write(err.Error() + "\n")
if repl, err := s.exec(input); err != nil {
s.write(err.Error() + "\n")
} else if repl != "" {
r.write(repl)
s.write(repl)
}
}
}()

select {
case <-ctx.Done():
return nil
case <-r.closeCh:
return ErrQuit
case <-s.closeCh:
}
}

func (r *Repl) write(s string) {
_, _ = r.output.Write([]byte(s))
func (s *Shell) write(line string) {
_, _ = s.output.Write([]byte(line))
}

func (r *Repl) exec(input string) (string, error) {
func (s *Shell) exec(input string) (string, error) {
switch input[0] {
case '\\':
return r.execCommand(input)
return s.execCommand(input)
default:
return r.execQuery(input)
return s.execQuery(input)
}
}

func (r *Repl) execCommand(input string) (string, error) {
func (s *Shell) execCommand(input string) (string, error) {
cmd := strings.TrimSpace(input)
params := strings.Fields(cmd)

switch params[0] {
case `\use`:
return r.useDatabase(params)
return s.useDatabase(params)
case `\databases`:
return r.listDatabases()
return s.listDatabases()
case `\tables`:
return r.listTables()
return s.listTables()
case `\describe`:
return r.describeTable(params)
return s.describeTable(params)
case `\import`:
return r.importFile(params)
return s.importFile(params)
case `\help`:
return r.showHelp(), nil
return s.showHelp(), nil
case `\quit`:
return r.quit(), nil
return s.quit(), nil
default:
return "", fmt.Errorf("unknown command: %v", params[0])
}
}

func (r *Repl) useDatabase(params []string) (string, error) {
func (s *Shell) useDatabase(params []string) (string, error) {
if len(params) < 2 {
return "", fmt.Errorf("database name not specified")
}

db, err := r.catalog.GetDatabase(params[1])
db, err := s.catalog.GetDatabase(params[1])
if err != nil {
return "", err
}

r.database = db
r.prompt = fmt.Sprintf("%s %s", db.Name(), prompt)
s.database = db
s.prompt = fmt.Sprintf("%s %s", db.Name(), prompt)

return "database changed\n", nil
}

func (r *Repl) listDatabases() (string, error) {
databases, err := r.catalog.ListDatabases()
func (s *Shell) listDatabases() (string, error) {
databases, err := s.catalog.ListDatabases()
if err != nil {
return "", err
}
Expand All @@ -142,41 +143,39 @@ func (r *Repl) listDatabases() (string, error) {
data = append(data, []string{databases[i].Name()})
}

r.drawTable(buf, []string{"Database"}, data)
buf.WriteString(fmt.Sprintf("(%d rows)\n\n", len(data)))
s.tw.WriteTable(buf, []string{"Database"}, data, true)

return buf.String(), nil
}

func (r *Repl) listTables() (string, error) {
if r.database == nil {
func (s *Shell) listTables() (string, error) {
if s.database == nil {
return "", fmt.Errorf("connect to database first")
}

buf := bytes.NewBuffer(nil)
tables := r.database.ListTables()
tables := s.database.ListTables()
data := make([][]string, 0, len(tables))

for i := range tables {
data = append(data, []string{tables[i].Name()})
}

r.drawTable(buf, []string{"Table"}, data)
buf.WriteString(fmt.Sprintf("(%d rows)\n\n", len(data)))
s.tw.WriteTable(buf, []string{"Table"}, data, true)

return buf.String(), nil
}

func (r *Repl) describeTable(params []string) (string, error) {
if r.database == nil {
func (s *Shell) describeTable(params []string) (string, error) {
if s.database == nil {
return "", fmt.Errorf("connect to database first")
}

if len(params) < 2 {
return "", fmt.Errorf("table name not specified")
}

table, err := r.database.GetTable(params[1])
table, err := s.database.GetTable(params[1])
if err != nil {
return "", err
}
Expand Down Expand Up @@ -209,21 +208,21 @@ func (r *Repl) describeTable(params []string) (string, error) {
row := []string{
columns[i].Name,
columns[i].DataType.String(),
fmt.Sprintf("%t", columns[i].Nullable),
strconv.FormatBool(columns[i].Nullable),
defaultValue,
}

data = append(data, row)
}

r.drawTable(buf, []string{"Column", "Type", "Nullable", "Default"}, data)
s.tw.WriteTable(buf, []string{"Column", "Type", "Nullable", "Default"}, data, false)
buf.WriteString("Indexes:\n")
buf.WriteString(fmt.Sprintf(" PRIMARY KEY (%s) autoincrement\n\n", primaryKey.Name))

return buf.String(), nil
}

func (r *Repl) importFile(params []string) (string, error) {
func (s *Shell) importFile(params []string) (string, error) {
if len(params) < 2 {
return "", fmt.Errorf("filename not specified")
}
Expand All @@ -241,15 +240,15 @@ func (r *Repl) importFile(params []string) (string, error) {
continue
}

if _, err = r.exec(stmt); err != nil {
if _, err = s.exec(stmt); err != nil {
return "", err
}
}

return "OK\n", nil
}

func (r *Repl) showHelp() string {
func (s *Shell) showHelp() string {
help := `repl is the NanoDB interactive terminal.

Commands:
Expand All @@ -264,19 +263,19 @@ Commands:
return help
}

func (r *Repl) quit() string {
close(r.closeCh)
func (s *Shell) quit() string {
close(s.closeCh)
return "Bye!\n"
}

func (r *Repl) execQuery(input string) (string, error) {
func (s *Shell) execQuery(input string) (string, error) {
var database string

if r.database != nil {
database = r.database.Name()
if s.database != nil {
database = s.database.Name()
}

columns, rowIter, err := r.engine.Exec(database, input)
columns, rowIter, err := s.engine.Exec(database, input)
if err != nil {
return "", fmt.Errorf("failed to execute query: %w", err)
}
Expand Down Expand Up @@ -311,19 +310,8 @@ loop:
buf := bytes.NewBuffer(nil)

if len(data) > 0 {
r.drawTable(buf, columns, data)
buf.WriteString(fmt.Sprintf("(%d rows)\n\n", len(data)))
s.tw.WriteTable(buf, columns, data, true)
}

return buf.String(), nil
}

func (r *Repl) drawTable(w io.Writer, headers []string, data [][]string) {
tw := tablewriter.NewWriter(w)
tw.SetColWidth(75)
tw.AppendBulk(data)
tw.SetAutoFormatHeaders(false)
tw.SetAlignment(tablewriter.ALIGN_LEFT)
tw.SetHeader(headers)
tw.Render()
}
Loading