-
Notifications
You must be signed in to change notification settings - Fork 17
/
generator.go
125 lines (104 loc) · 3.01 KB
/
generator.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
package main
import (
"fmt"
"github.com/martinxsliu/protoc-gen-graphql/parameters"
"strings"
"google.golang.org/protobuf/compiler/protogen"
"google.golang.org/protobuf/types/pluginpb"
"github.com/martinxsliu/protoc-gen-graphql/graphql"
"github.com/martinxsliu/protoc-gen-graphql/mapper"
)
var header = []byte(`# DO NOT EDIT! Generated by protoc-gen-graphql.`)
type Generator struct {
req *pluginpb.CodeGeneratorRequest
gen *protogen.Plugin
mapper *mapper.Mapper
}
func New(gen *protogen.Plugin) *Generator {
return &Generator{
req: gen.Request,
gen: gen,
}
}
func (g *Generator) Generate() (err error) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("%v", r)
}
}()
params, err := parameters.NewParameters(g.req.GetParameter())
if err != nil {
return err
}
g.mapper = mapper.New(g.req.GetProtoFile(), params)
g.generateFiles(params)
return nil
}
func (g *Generator) generateFiles(params *parameters.Parameters) {
for _, fileName := range g.req.GetFileToGenerate() {
fileResp := &pluginpb.CodeGeneratorResponse_File{}
fileResp.Name = stringPtr(graphqlFileName(fileName))
var gqlTypes []graphql.Type
file := g.mapper.Files[fileName]
for _, service := range file.Services {
m, ok := g.mapper.ServiceMappers[service.FullName]
if !ok {
continue // Service was skipped
}
if m.Queries != nil {
if m.Queries.ExtendRootObject != nil {
gqlTypes = append(gqlTypes, m.Queries.ExtendRootObject)
}
gqlTypes = append(gqlTypes, m.Queries.Object)
}
if m.Mutations != nil {
if m.Mutations.ExtendRootObject != nil {
gqlTypes = append(gqlTypes, m.Mutations.ExtendRootObject)
}
gqlTypes = append(gqlTypes, m.Mutations.Object)
}
if m.Subscriptions != nil {
if m.Subscriptions.ExtendRootObject != nil {
gqlTypes = append(gqlTypes, m.Subscriptions.ExtendRootObject)
}
gqlTypes = append(gqlTypes, m.Subscriptions.Object)
}
}
for _, message := range file.Messages {
m := g.mapper.MessageMappers[message.FullName]
if m.Object != nil {
gqlTypes = append(gqlTypes, m.Object)
}
for _, oneof := range m.Oneofs {
gqlTypes = append(gqlTypes, oneof.Union)
for _, object := range oneof.Objects {
gqlTypes = append(gqlTypes, object)
}
}
if m.Input != nil {
gqlTypes = append(gqlTypes, m.Input)
}
for _, oneof := range m.Oneofs {
if oneof.Input != nil {
gqlTypes = append(gqlTypes, oneof.Input)
}
}
}
for _, enum := range file.Enums {
gqlTypes = append(gqlTypes, g.mapper.EnumMappers[enum.FullName].Enum)
}
genFile := g.gen.NewGeneratedFile(graphqlFileName(fileName), "github.com/not-a-real-import")
_, _ = genFile.Write(header)
for _, gqlType := range gqlTypes {
_, _ = genFile.Write([]byte("\n\n"))
_, _ = genFile.Write([]byte(graphql.TypeDef(gqlType, params)))
}
_, _ = genFile.Write([]byte("\n"))
}
}
func graphqlFileName(name string) string {
return strings.TrimSuffix(name, ".proto") + "_pb.graphql"
}
func stringPtr(v string) *string {
return &v
}