Skip to content

Commit

Permalink
Merge pull request #3 from launchdarkly-labs/optimizations
Browse files Browse the repository at this point in the history
Find types
  • Loading branch information
erinpentecost authored Feb 22, 2022
2 parents ad8f65c + 9ece900 commit 2af551c
Show file tree
Hide file tree
Showing 5 changed files with 129 additions and 58 deletions.
23 changes: 17 additions & 6 deletions files_test.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
package main

import (
"context"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestFileList(t *testing.T) {
found := []string{}
searchDir := expandPath(".")
require.NoError(t, runOnFiles([]string{searchDir}, []string{"go.mod", "go.sum"}, func(file string) error {
require.NoError(t, runOnFiles(context.TODO(), []string{searchDir}, []string{"go.mod", "go.sum"}, func(file string) error {
found = append(found, file)
return nil
}))
Expand All @@ -21,21 +23,30 @@ func TestFileList(t *testing.T) {

func TestExports(t *testing.T) {
searchDir := expandPath("./internal/dummy/")
exports, err := findExports([]string{searchDir}, []string{})
exports, err := findExports(context.TODO(), []string{searchDir}, []string{})
require.NoError(t, err)
if _, ok := exports["github.com/launchdarkly-labs/refaudit/internal/dummy.ExportedFunction"]; !ok {
require.FailNow(t, "missing export")
assert.FailNow(t, "missing exported function")
}
if _, ok := exports["github.com/launchdarkly-labs/refaudit/internal/dummy.ExportedVariable"]; !ok {
require.FailNow(t, "missing export")
assert.FailNow(t, "missing exported variable")
}
if _, ok := exports["github.com/launchdarkly-labs/refaudit/internal/dummy.ExportedStruct"]; !ok {
assert.FailNow(t, "missing exported struct")
}
if _, ok := exports["github.com/launchdarkly-labs/refaudit/internal/dummy.ExportedInterface"]; !ok {
assert.FailNow(t, "missing exported interface")
}
}

func TestImports(t *testing.T) {
searchDir := expandPath("./internal/dummy/")
imports, err := findImports([]string{searchDir}, []string{})
imports, err := findImports(context.TODO(), []string{searchDir}, []string{})
require.NoError(t, err)
if _, ok := imports["fmt.Print"]; !ok {
require.FailNow(t, "missing import")
assert.FailNow(t, "missing imported function call")
}
if _, ok := imports["fmt.Stringer"]; !ok {
assert.FailNow(t, "missing imported interface ref")
}
}
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ go 1.17

require (
github.com/stretchr/testify v1.7.0
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c
golang.org/x/tools v0.1.9
)

Expand Down
1 change: 1 addition & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c h1:5KslGYwFpkhGh+Q16bwMP3cOontH8FOep7tGV86Y7SQ=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
Expand Down
7 changes: 7 additions & 0 deletions internal/dummy/dummy.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,13 @@ import "fmt"

func ExportedFunction() {
fmt.Print("hi")

}

var ExportedVariable = 10

type ExportedStruct struct{}

type ExportedInterface interface {
fmt.Stringer
}
155 changes: 103 additions & 52 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,20 @@
package main

import (
"context"
"encoding/json"
"fmt"
"go/ast"
"go/parser"
"go/token"
"os"
"os/signal"
"path"
"path/filepath"
"sort"
"strings"

"golang.org/x/sync/errgroup"
"golang.org/x/tools/go/packages"
)

Expand All @@ -34,6 +37,9 @@ type Report struct {
}

func main() {
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt)
defer cancel()

// parse input
from := []string{}
excludeFrom := []string{}
Expand Down Expand Up @@ -74,13 +80,13 @@ func main() {
fmt.Fprintf(os.Stderr, "%s: %s\n", excludeToArg, strings.Join(excludeTo, ", "))
fmt.Fprintf(os.Stderr, "%s: %s\n", excludeFromArg, strings.Join(excludeFrom, ", "))

globals, err := findExports(from, excludeFrom)
globals, err := findExports(ctx, from, excludeFrom)
if err != nil {
fmt.Fprintf(os.Stderr, "%v", err)
os.Exit(2)
}

refs, err := findImports(to, excludeTo)
refs, err := findImports(ctx, to, excludeTo)
if err != nil {
fmt.Fprintf(os.Stderr, "%v", err)
os.Exit(2)
Expand All @@ -93,17 +99,15 @@ func main() {
UnusedExports: []string{},
}
for k := range globals {
rpt.Exported = append(rpt.Exported, k)
rpt.Exported = sortedInsert(rpt.Exported, k)
//rpt.Exported = append(rpt.Exported, k)
if _, ok := refs[k]; !ok {
rpt.UnusedExports = append(rpt.UnusedExports, k)
rpt.UnusedExports = sortedInsert(rpt.UnusedExports, k)
}
}
for k := range refs {
rpt.Imported = append(rpt.Imported, k)
rpt.Imported = sortedInsert(rpt.Imported, k)
}
sort.Strings(rpt.Exported)
sort.Strings(rpt.Imported)
sort.Strings(rpt.UnusedExports)

outB, err := json.MarshalIndent(rpt, "", " ")
if err != nil {
Expand All @@ -113,6 +117,20 @@ func main() {
fmt.Println(string(outB))
}

// sortedInsert
func sortedInsert(list []string, elem string) []string {
// find spot to insert element
i := sort.Search(len(list), func(i int) bool { return list[i] >= elem })
// handle not found case
if i == len(list) {
return append(list, elem)
}
// shift over and set
list = append(list[:i+1], list[i:]...)
list[i] = elem
return list
}

func expandPath(path string) string {
exp, err := filepath.Abs(os.ExpandEnv(path))
if err != nil {
Expand All @@ -123,51 +141,79 @@ func expandPath(path string) string {
}

// runOnFiles runs fn on every file/dir specified, recursively.
func runOnFiles(files []string, excluding []string, fn func(file string) error) error {
vendor := fmt.Sprintf("%svendor%s", fsep, fsep)
for _, file := range files {
err := filepath.Walk(file,
func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
// don't run on vendor sub-directories
if strings.Contains(path, vendor) {
return filepath.SkipDir
}
// exclude any top-level paths as needed
for _, ex := range excluding {
if strings.TrimSuffix(path, fsep) == strings.TrimSuffix(ex, fsep) {
return filepath.SkipDir
func runOnFiles(ctx context.Context, files []string, excluding []string, fn func(file string) error) error {
g, ctx := errgroup.WithContext(ctx)
filesChan := make(chan string, 4) // buffered chan since walking can take a while

// set up file consumer
g.Go(func() error {
for {
select {
case <-ctx.Done():
return ctx.Err()
case f, ok := <-filesChan:
if ok {
if err := fn(f); err != nil {
return err
}
}
// don't run on dirs
if info.IsDir() {
} else {
return nil
}
// don't run on non-go files
if !strings.HasSuffix(path, ".go") {
return nil
}
// run fn on the file
if err := fn(path); err != nil {
return err
}
}
}
})

return nil
})
if err != nil {
return fmt.Errorf("could not walk %s: %w", file, err)
// walk the dir tree, producing files
g.Go(func() error {
defer close(filesChan)
vendor := fmt.Sprintf("%svendor%s", fsep, fsep)
for _, file := range files {
err := filepath.Walk(file,
func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if ctx.Err() != nil {
return ctx.Err()
}
// don't run on vendor sub-directories
if strings.Contains(path, vendor) {
return filepath.SkipDir
}
// exclude any top-level paths as needed
for _, ex := range excluding {
if strings.TrimSuffix(path, fsep) == strings.TrimSuffix(ex, fsep) {
return filepath.SkipDir
}
}
// don't run on dirs
if info.IsDir() {
return nil
}
// don't run on non-go files
if !strings.HasSuffix(path, ".go") {
return nil
}
// send to consumer
filesChan <- path

return nil
})
if err != nil {
return fmt.Errorf("could not walk %s: %w", file, err)
}
}
}
return nil
return nil
})

return g.Wait()
}

func findExports(from []string, excludeFrom []string) (map[string]interface{}, error) {
func findExports(ctx context.Context, from []string, excludeFrom []string) (map[string]interface{}, error) {
globals := make(map[string]interface{})

fs := token.NewFileSet()
err := runOnFiles(from, excludeFrom, func(file string) error {
err := runOnFiles(ctx, from, excludeFrom, func(file string) error {
f, err := parser.ParseFile(fs, file, nil, parser.AllErrors)
if err != nil {
return fmt.Errorf("could not parse %s: %w", file, err)
Expand Down Expand Up @@ -230,13 +276,18 @@ func (v exportVisitor) Visit(n ast.Node) ast.Visitor {
case *ast.FuncDecl:
v.add(d.Name)
case *ast.GenDecl:
if d.Tok != token.VAR {
return v
}
for _, spec := range d.Specs {
if value, ok := spec.(*ast.ValueSpec); ok {
for _, name := range value.Names {
v.add(name)
if d.Tok == token.VAR {
for _, spec := range d.Specs {
if value, ok := spec.(*ast.ValueSpec); ok {
for _, name := range value.Names {
v.add(name)
}
}
}
} else if d.Tok == token.TYPE {
for _, spec := range d.Specs {
if value, ok := spec.(*ast.TypeSpec); ok {
v.add(value.Name)
}
}
}
Expand All @@ -260,11 +311,11 @@ func (v exportVisitor) add(n ast.Node) {
}
}

func findImports(to []string, excludeTo []string) (map[string]interface{}, error) {
func findImports(ctx context.Context, to []string, excludeTo []string) (map[string]interface{}, error) {
refs := make(map[string]interface{})

fs := token.NewFileSet()
err := runOnFiles(to, excludeTo, func(file string) error {
err := runOnFiles(ctx, to, excludeTo, func(file string) error {
f, err := parser.ParseFile(fs, file, nil, parser.AllErrors)

if err != nil {
Expand Down

0 comments on commit 2af551c

Please sign in to comment.