Skip to content

Commit

Permalink
improve go imports
Browse files Browse the repository at this point in the history
  • Loading branch information
sxwebdev authored Feb 14, 2023
1 parent 806524e commit d997784
Show file tree
Hide file tree
Showing 14 changed files with 117 additions and 41 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ go 1.19

require (
github.com/antlr/antlr4/runtime/Go/antlr v1.4.10
github.com/bytecodealliance/wasmtime-go/v3 v3.0.2
github.com/cubicdaiya/gonp v1.0.4
github.com/davecgh/go-spew v1.1.1
github.com/go-sql-driver/mysql v1.7.0
Expand All @@ -28,6 +27,7 @@ require (

require (
github.com/benbjohnson/clock v1.3.0 // indirect
github.com/bytecodealliance/wasmtime-go/v5 v5.0.0
github.com/cznic/mathutil v0.0.0-20181122101859-297441e03548 // indirect
github.com/go-ozzo/ozzo-validation/v4 v4.3.0 // indirect
github.com/goccy/go-json v0.10.0 // indirect
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ github.com/asaskevich/govalidator v0.0.0-20210307081110-f21760c49a8d h1:Byv0BzEl
github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA=
github.com/benbjohnson/clock v1.3.0 h1:ip6w0uFQkncKQ979AypyG0ER7mqUSBdKLOgAle/AT8A=
github.com/benbjohnson/clock v1.3.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA=
github.com/bytecodealliance/wasmtime-go/v3 v3.0.2 h1:3uZCA/BLTIu+DqCfguByNMJa2HVHpXvjfy0Dy7g6fuA=
github.com/bytecodealliance/wasmtime-go/v3 v3.0.2/go.mod h1:RnUjnIXxEJcL6BgCvNyzCCRzZcxCgsZCi+RNlvYor5Q=
github.com/bytecodealliance/wasmtime-go/v5 v5.0.0 h1:Ue3eBDElMrdzWoUtr7uPr7NeDZriuR5oIivp5EHknQU=
github.com/bytecodealliance/wasmtime-go/v5 v5.0.0/go.mod h1:KcecyOqumZrvLnlaEIMFRbBaQeUYNvsbPjAEVho1Fcs=
github.com/cockroachdb/apd v1.1.0 h1:3LFP3629v+1aKXU5Q37mxmRxX/pIu1nijXydLShEq5I=
github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ=
github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4=
Expand Down
2 changes: 1 addition & 1 deletion internal/pgxgen/pgxgen.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import (
"gopkg.in/yaml.v3"
)

var version = "v0.0.21"
var version = "v0.0.22"

func Start(args []string) error {
if len(args) == 0 {
Expand Down
23 changes: 22 additions & 1 deletion internal/sqlc/movemodels.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"strings"

"github.com/tkcrm/pgxgen/internal/config"
"github.com/tkcrm/pgxgen/internal/structs"
)

// replacePackageName - replace package name for golang file
Expand Down Expand Up @@ -48,7 +49,7 @@ func replacePackageName(c config.Config, str string) (res string) {
return res
}

func replaceImports(c config.Config, str string) (res string) {
func replaceImports(c config.Config, str string, modelFileStructs structs.Structs) (res string) {
if c.Pgxgen.SqlcModels.OutputDir == "" {
return str
}
Expand All @@ -62,6 +63,26 @@ func replaceImports(c config.Config, str string) (res string) {
log.Fatal("empty package path")
}

var existsSomeModelStruct bool
for _, item := range modelFileStructs {
re := regexp.MustCompile(fmt.Sprintf(`(?sm)\([\[\]\*]+%s[\,\){]+`, item.Name))
if re.MatchString(str) {
existsSomeModelStruct = true
break
}

for _, field := range item.Fields {
re := regexp.MustCompile(fmt.Sprintf(`(?sm)\s+\w+\s+%s\s+`, field.Name))
if re.MatchString(str) {
existsSomeModelStruct = true
break
}
}
}
if !existsSomeModelStruct {
return str
}

r := regexp.MustCompile(`import (\"\w+\")`)
r2 := regexp.MustCompile(`(?sm)^import \(\s(([^\)]+)\s)+\)`)

Expand Down
5 changes: 3 additions & 2 deletions internal/sqlc/movemodels_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"testing"

"github.com/tkcrm/pgxgen/internal/config"
"github.com/tkcrm/pgxgen/internal/structs"
)

func TestReplaceImports(t *testing.T) {
Expand Down Expand Up @@ -48,6 +49,6 @@ const a = "b"`
},
}

replaceImports(cfg, str1)
replaceImports(cfg, str2)
replaceImports(cfg, str1, structs.Structs{})
replaceImports(cfg, str2, structs.Structs{})
}
67 changes: 50 additions & 17 deletions internal/sqlc/sqlc.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ import (
"github.com/tkcrm/pgxgen/internal/config"
"github.com/tkcrm/pgxgen/internal/crud"
"github.com/tkcrm/pgxgen/internal/generator"
"github.com/tkcrm/pgxgen/internal/structs"
sqlcpkg "github.com/tkcrm/pgxgen/pkg/sqlc"
"github.com/tkcrm/pgxgen/utils"
"golang.org/x/tools/imports"
)

Expand Down Expand Up @@ -52,11 +54,15 @@ func (s *sqlc) process(args []string) error {
return fmt.Errorf("unsupported sqlc version: %d", s.config.Sqlc.Version)
}

var modelFileStructs structs.Structs
modelPaths := s.config.Sqlc.GetPaths().ModelsPaths

for _, path := range modelPaths {
if err := s.processFile(path); err != nil {
return errors.Wrapf(err, "failed to process file \"%s\"", path)
if err := s.processModelFilePaths(path, &modelFileStructs); err != nil {
return errors.Wrapf(err, "processModelFilePaths error: failed to process file \"%s\"", path)
}

if err := s.processGoFilePaths(path, modelFileStructs); err != nil {
return errors.Wrapf(err, "processGoFilePaths error: failed to process file \"%s\"", path)
}
}

Expand All @@ -67,7 +73,7 @@ func (s *sqlc) process(args []string) error {
return nil
}

func (s *sqlc) processFile(modelFilePath string) error {
func (s *sqlc) processModelFilePaths(modelFilePath string, modelFileStructs *structs.Structs) error {
modelFileDir := filepath.Dir(modelFilePath)
modelFileName := filepath.Base(modelFilePath)

Expand All @@ -91,17 +97,17 @@ func (s *sqlc) processFile(modelFilePath string) error {
}
}

// replace imports
if strings.HasSuffix(file.Name(), ".sql.go") || file.Name() == "querier.go" {
if err := s.replace(filepath.Join(modelFileDir, file.Name()), replaceImports); err != nil {
return err
}
}

// process models file
if file.Name() == modelFileName {
modelFilePath := filepath.Join(modelFileDir, file.Name())

modelFile, err := utils.ReadFile(modelFilePath)
if err != nil {
return err
}

*modelFileStructs = structs.GetStructs(string(modelFile))

// replace json tags
if err := s.replace(modelFilePath, replaceJsonTags); err != nil {
return err
Expand All @@ -122,10 +128,8 @@ func (s *sqlc) processFile(modelFilePath string) error {
oldPathDir := filepath.Join(currentDir, modelFilePath)

// create dir if new path not exists
if _, err := os.Stat(newPathDir); errors.Is(err, os.ErrNotExist) {
if err := os.MkdirAll(newPathDir, os.ModePerm); err != nil {
return err
}
if err := utils.CreatePath(newPathDir); err != nil {
return err
}

// move file to new directory
Expand All @@ -149,6 +153,35 @@ func (s *sqlc) processFile(modelFilePath string) error {
return nil
}

func (s *sqlc) processGoFilePaths(path string, modelFileStructs structs.Structs) error {
modelFileDir := filepath.Dir(path)

files, err := os.ReadDir(modelFileDir)
if err != nil {
return errors.Wrapf(err, "failed to read path \"%s\"", modelFileDir)
}

for _, file := range files {
goFileRegexp := regexp.MustCompile(`(\.go)`)

// skip not golang files
if !goFileRegexp.MatchString(file.Name()) {
continue
}

// replace imports
if strings.HasSuffix(file.Name(), ".sql.go") || file.Name() == "querier.go" {
if err := s.replace(filepath.Join(modelFileDir, file.Name()), func(c config.Config, str string) string {
return replaceImports(c, str, modelFileStructs)
}); err != nil {
return err
}
}
}

return nil
}

func replaceStructTypes(c config.Config, str string) string {
res := str
for old, new := range types {
Expand All @@ -172,9 +205,9 @@ func replaceJsonTags(c config.Config, str string) string {
}

func (s *sqlc) replace(path string, fn func(c config.Config, str string) string) error {
file, err := os.ReadFile(path)
file, err := utils.ReadFile(path)
if err != nil {
return errors.Wrapf(err, "failed to read path \"%s\"", path)
return err
}

result := fn(s.config, string(file))
Expand Down
16 changes: 11 additions & 5 deletions pkg/sqlc/compiler/resolve.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,12 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar,
_, ok := node.(*ast.ColumnRef)
return ok
})
if len(list.Items) == 0 {
list = astutils.Search(n.Rexpr, func(node ast.Node) bool {
_, ok := node.(*ast.ColumnRef)
return ok
})
}

if len(list.Items) == 0 {
// TODO: Move this to database-specific engine package
Expand All @@ -135,9 +141,9 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar,
continue
}

switch left := list.Items[0].(type) {
switch node := list.Items[0].(type) {
case *ast.ColumnRef:
items := stringSlice(left.Fields)
items := stringSlice(node.Fields)
var key, alias string
switch len(items) {
case 1:
Expand Down Expand Up @@ -165,7 +171,7 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar,
return nil, &sqlerr.Error{
Code: "42703",
Message: fmt.Sprintf("table alias \"%s\" does not exist", alias),
Location: left.Location,
Location: node.Location,
}
}
}
Expand Down Expand Up @@ -204,14 +210,14 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar,
return nil, &sqlerr.Error{
Code: "42703",
Message: fmt.Sprintf("column \"%s\" does not exist", key),
Location: left.Location,
Location: node.Location,
}
}
if found > 1 {
return nil, &sqlerr.Error{
Code: "42703",
Message: fmt.Sprintf("column reference \"%s\" is ambiguous", key),
Location: left.Location,
Location: node.Location,
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/sqlc/engine/postgresql/convert.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
//go:build !windows
// +build !windows
//go:build !windows && cgo
// +build !windows,cgo

package postgresql

Expand Down
4 changes: 2 additions & 2 deletions pkg/sqlc/engine/postgresql/parse.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
//go:build !windows
// +build !windows
//go:build !windows && cgo
// +build !windows,cgo

package postgresql

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
//go:build windows
// +build windows
//go:build windows || !cgo
// +build windows !cgo

package postgresql

Expand Down
4 changes: 2 additions & 2 deletions pkg/sqlc/engine/postgresql/utils.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
//go:build !windows
// +build !windows
//go:build !windows && cgo
// +build !windows,cgo

package postgresql

Expand Down
4 changes: 2 additions & 2 deletions pkg/sqlc/ext/wasm/wasm.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,15 @@ import (
"runtime/trace"
"strings"

wasmtime "github.com/bytecodealliance/wasmtime-go/v3"
wasmtime "github.com/bytecodealliance/wasmtime-go/v5"
"golang.org/x/sync/singleflight"

"github.com/tkcrm/pgxgen/pkg/sqlc/info"
"github.com/tkcrm/pgxgen/pkg/sqlc/plugin"
)

// This version must be updated whenever the wasmtime-go dependency is updated
const wasmtimeVersion = `v3.0.2`
const wasmtimeVersion = `v5.0.0`

func cacheDir() (string, error) {
cache := os.Getenv("SQLCCACHE")
Expand Down
2 changes: 1 addition & 1 deletion pkg/sqlc/info/facts.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@ package info

// When no version is set, return the next bug fix version
// after the most recent tag
const Version = "v1.16.0"
const Version = "v1.17.0"
17 changes: 16 additions & 1 deletion utils/paths.go → utils/files.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,22 @@ import (
)

func CreatePath(path string) error {
return os.MkdirAll(path, 0755)
if _, err := os.Stat(path); errors.Is(err, os.ErrNotExist) {
if err := os.MkdirAll(path, 0755); err != nil {
return err
}
}

return nil
}

func ReadFile(filePath string) ([]byte, error) {
file, err := os.ReadFile(filePath)
if err != nil {
return nil, errors.Wrapf(err, "failed to read file by path \"%s\"", filePath)
}

return file, nil
}

func SaveFile(path, fileName string, data []byte) error {
Expand Down

0 comments on commit d997784

Please sign in to comment.