From 4438847df9f80a6f6a374299dda5a03dbbb4fa9e Mon Sep 17 00:00:00 2001 From: Michael Li Date: Tue, 19 Dec 2023 10:26:40 +0800 Subject: [PATCH] add Context primitive logic for hertz engine --- assert/assert.go | 12 ++- assert/assert_any.go | 13 ++++ core/core.go | 7 ++ engine/hertz/assert.go | 6 +- engine/hertz/go.mod | 4 +- engine/hertz/go.sum | 12 +-- internal/generator/templates/hertz_iface.tmpl | 76 ++++++++++--------- 7 files changed, 83 insertions(+), 47 deletions(-) diff --git a/assert/assert.go b/assert/assert.go index c34b72a..db7a557 100644 --- a/assert/assert.go +++ b/assert/assert.go @@ -20,7 +20,7 @@ type Binding[T any] interface { Bind(T) mir.Error } -// Binding2[R, P] binding interface for custom T context +// Binding2[R, P] binding interface for custom R, P context type Binding2[R, P any] interface { Bind(R, P) mir.Error } @@ -30,6 +30,11 @@ type Render[T any] interface { Render(T) } +// Render[C, T] render interface for custom C, T context +type Render2[C, T any] interface { + Render(C, T) +} + // TypeAssertor type assert for Binding and Render interface type TypeAssertor interface { AssertBinding(any) bool @@ -56,6 +61,11 @@ func RegisterType3[B, P, R any]() { _typeAssertor = anyTypeAssertor3[B, P, R]{} } +// RegisterType4[C, T] register custom TypeAssertor to assert C, T interface +func RegisterType4[C, T any]() { + _typeAssertor = anyTypeAssertor4[C, T]{} +} + // AssertBinding assert Binding interface for obj func AssertBinding(obj any) bool { return _typeAssertor.AssertBinding(obj) diff --git a/assert/assert_any.go b/assert/assert_any.go index c260245..79c7f18 100644 --- a/assert/assert_any.go +++ b/assert/assert_any.go @@ -13,6 +13,9 @@ type anyTypeAssertor2[B, R any] struct{} // anyTypeAssertor3 a common type assert for type B(Binding)/P(Params) and R(Render) type anyTypeAssertor3[B, P, R any] struct{} +// anyTypeAssertor4 a common type assert for type C and T +type anyTypeAssertor4[C, T any] struct{} + func (anyTypeAssertor[T]) AssertBinding(obj any) bool { _, ok := obj.(Binding[T]) return ok @@ -42,3 +45,13 @@ func (anyTypeAssertor3[B, P, R]) AssertRender(obj any) bool { _, ok := obj.(Render[R]) return ok } + +func (anyTypeAssertor4[C, T]) AssertBinding(obj any) bool { + _, ok := obj.(Binding2[C, T]) + return ok +} + +func (anyTypeAssertor4[C, T]) AssertRender(obj any) bool { + _, ok := obj.(Render2[C, T]) + return ok +} diff --git a/core/core.go b/core/core.go index b148194..0bc8885 100644 --- a/core/core.go +++ b/core/core.go @@ -277,6 +277,13 @@ func AssertType3[B, P, R any]() Option { }) } +// AssertType4[C, T] register assert.TypeAssertor for custom C, T type +func AssertType4[C, T any]() Option { + return optFunc(func(_opts *InitOpts) { + assert.RegisterType4[C, T]() + }) +} + // WatchCtxDone set generator whether watch context done when Register Servants in // generated code. default watch context done. func WatchCtxDone(enable bool) Option { diff --git a/engine/hertz/assert.go b/engine/hertz/assert.go index 14e55a0..afb3950 100644 --- a/engine/hertz/assert.go +++ b/engine/hertz/assert.go @@ -5,6 +5,8 @@ package engine_hertz import ( + "context" + "github.com/alimy/mir/v4" "github.com/alimy/mir/v4/assert" "github.com/cloudwego/hertz/pkg/app" @@ -15,11 +17,11 @@ func init() { } type Binding interface { - Bind(*app.RequestContext) mir.Error + Bind(context.Context, *app.RequestContext) mir.Error } type Render interface { - Render(*app.RequestContext) + Render(context.Context, *app.RequestContext) } type typeAssertor struct{} diff --git a/engine/hertz/go.mod b/engine/hertz/go.mod index 04a4cdb..517bd1e 100644 --- a/engine/hertz/go.mod +++ b/engine/hertz/go.mod @@ -4,7 +4,7 @@ go 1.19 require ( github.com/alimy/mir/v4 v4.0.0 - github.com/cloudwego/hertz v0.6.5 + github.com/cloudwego/hertz v0.7.3 ) require ( @@ -18,7 +18,7 @@ require ( github.com/henrylee2cn/goutil v0.0.0-20210127050712-89660552f6f8 // indirect github.com/klauspost/cpuid/v2 v2.0.9 // indirect github.com/nyaruka/phonenumbers v1.0.55 // indirect - github.com/tidwall/gjson v1.13.0 // indirect + github.com/tidwall/gjson v1.14.4 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.0 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect diff --git a/engine/hertz/go.sum b/engine/hertz/go.sum index 1ece0de..0f38c05 100644 --- a/engine/hertz/go.sum +++ b/engine/hertz/go.sum @@ -11,10 +11,10 @@ github.com/bytedance/sonic v1.8.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZX github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY= github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams= github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk= -github.com/cloudwego/hertz v0.6.5 h1:2yZY8Tn4YIX7CF75oaGw3NymXgewPzGcxdwfVKfNBCo= -github.com/cloudwego/hertz v0.6.5/go.mod h1:KhztQcZtMQ46gOjZcmCy557AKD29cbumGEV0BzwevwA= -github.com/cloudwego/netpoll v0.3.2 h1:/998ICrNMVBo4mlul4j7qcIeY7QnEfuCCPPwck9S3X4= -github.com/cloudwego/netpoll v0.3.2/go.mod h1:xVefXptcyheopwNDZjDPcfU6kIjZXZ4nY550k1yH9eQ= +github.com/cloudwego/hertz v0.7.3 h1:VM1DxditA6vxI97rG5SBu4hHB24xdzDbKBQfUy7sfVE= +github.com/cloudwego/hertz v0.7.3/go.mod h1:WliNtVbwihWHHgAaIQEbVXl0O3aWj0ks1eoPrcEAnjs= +github.com/cloudwego/netpoll v0.5.0 h1:oRrOp58cPCvK2QbMozZNDESvrxQaEHW2dCimmwH1lcU= +github.com/cloudwego/netpoll v0.5.0/go.mod h1:xVefXptcyheopwNDZjDPcfU6kIjZXZ4nY550k1yH9eQ= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -51,8 +51,8 @@ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/tidwall/gjson v1.9.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= -github.com/tidwall/gjson v1.13.0 h1:3TFY9yxOQShrvmjdM76K+jc66zJeT6D3/VFFYCGQf7M= -github.com/tidwall/gjson v1.13.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.14.4 h1:uo0p8EbA09J7RQaflQ1aBRffTR7xedD2bcIVSYxLnkM= +github.com/tidwall/gjson v1.14.4/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= diff --git a/internal/generator/templates/hertz_iface.tmpl b/internal/generator/templates/hertz_iface.tmpl index 7a0c2c0..7c22b55 100644 --- a/internal/generator/templates/hertz_iface.tmpl +++ b/internal/generator/templates/hertz_iface.tmpl @@ -15,16 +15,16 @@ import ( ) {{- if .DeclareCoreInterface }} type _binding_ interface { - Bind(*app.RequestContext) mir.Error + Bind(context.Context, *app.RequestContext) mir.Error } type _render_ interface { - Render(*app.RequestContext) + Render(context.Context, *app.RequestContext) } type _default_ interface { - Bind(*app.RequestContext, any) mir.Error - Render(*app.RequestContext, any, mir.Error) + Bind(context.Context, *app.RequestContext, any) mir.Error + Render(context.Context, *app.RequestContext, any, mir.Error) } {{- end }} @@ -37,7 +37,7 @@ type {{.TypeName}} interface { {{if notEmptyStr .Chain }}// Chain provide handlers chain for hertz {{.Chain}}() []app.HandlerFunc {{end}} -{{range .Fields}} {{.MethodName}}({{if notEmptyStr .InName }}*{{ .InName }}{{end}}) {{if notEmptyStr .OutName }}(*{{ .OutName}}, mir.Error){{else}}mir.Error{{end}} +{{range .Fields}} {{if .JustUseContext }}{{ .MethodName}}(context.Context, *app.RequestContext){{else}}{{.MethodName}}({{if .IsUseContext }}context.Context, *app.RequestContext{{if notEmptyStr .InName }}, {{end}}{{end}}{{if notEmptyStr .InName }}*{{ .InName }}{{end}}) {{if notEmptyStr .OutName }}(*{{ .OutName}}, mir.Error){{else}}mir.Error{{end}}{{end}} {{end}} mustEmbedUnimplemented{{.TypeName}}Servant() @@ -67,7 +67,7 @@ func Register{{.TypeName}}Servant(e *route.Engine, s {{.TypeName}}{{if .IsUseFie router.Use(middlewares...) {{end}} // register routes info to router -{{range .Fields}}{{if .NotHttpAny }} router.Handle("{{.HttpMethod}}", "{{.Path}}", {{if .IsFieldChain }}append(cc.Chain{{.MethodName}}(), {{end}}func(c context.Context, ctx *app.RequestContext) { +{{range .Fields}}{{if .NotHttpAny }} router.Handle("{{.HttpMethod}}", "{{.Path}}", {{if .IsFieldChain }}append(cc.Chain{{.MethodName}}(), {{end}}{{if .JustUseContext}}s.{{ .MethodName}}{{else}}func(c context.Context, ctx *app.RequestContext) { {{- if $.WatchCtxDone }} select { case <- c.Done(): @@ -79,31 +79,31 @@ func Register{{.TypeName}}Servant(e *route.Engine, s {{.TypeName}}{{if .IsUseFie req := new({{.InName}}) {{if .IsBindIn -}} var bv _binding_ = req - if err := bv.Bind(ctx); err != nil { + if err := bv.Bind(c, ctx); err != nil { {{- else -}} - if err := s.Bind(ctx, req); err != nil { + if err := s.Bind(c, ctx, req); err != nil { {{- end }} - s.Render(ctx, nil, err) + s.Render(c, ctx, nil, err) return } {{- end }} {{if notEmptyStr .OutName -}} - resp, err := s.{{ .MethodName}}({{if notEmptyStr .InName}}req{{end}}) + resp, err := s.{{ .MethodName}}({{if .IsUseContext }}c, ctx{{if notEmptyStr .InName }}, {{end}}{{end}}{{if notEmptyStr .InName}}req{{end}}) {{if .IsRenderOut -}} if err != nil { - s.Render(ctx, nil, err) + s.Render(c, ctx, nil, err) return } var rv _render_ = resp - rv.Render(ctx) + rv.Render(c, ctx) {{- else -}} - s.Render(ctx, resp, err) + s.Render(c, ctx, resp, err) {{- end }} {{- else -}} - s.Render(ctx, nil, s.{{.MethodName}}({{if notEmptyStr .InName}}req{{end}})) + s.Render(c, ctx, nil, s.{{.MethodName}}({{if .IsUseContext }}c, ctx{{if notEmptyStr .InName }}, {{end}}{{end}}{{if notEmptyStr .InName}}req{{end}})) {{- end }} - }{{if .IsFieldChain }})...{{end}}) - {{else if .JustHttpAny}} router.Any("{{.Path}}", {{if .IsFieldChain }}append(cc.Chain{{.MethodName}}(), {{end}}func(c context.Context, ctx *app.RequestContext) { + }{{end}}{{if .IsFieldChain }})...{{end}}) + {{else if .JustHttpAny}} router.Any("{{.Path}}", {{if .IsFieldChain }}append(cc.Chain{{.MethodName}}(), {{end}}{{if .JustUseContext}}s.{{ .MethodName}}{{else}}func(c context.Context, ctx *app.RequestContext) { {{- if $.WatchCtxDone }} select { case <- c.Done(): @@ -115,32 +115,32 @@ func Register{{.TypeName}}Servant(e *route.Engine, s {{.TypeName}}{{if .IsUseFie req := new({{.InName}}) {{if .IsBindIn -}} var bv _binding_ = req - if err := bv.Bind(ctx); err != nil { + if err := bv.Bind(c, ctx); err != nil { {{- else -}} - if err := s.Bind(ctx, req); err != nil { + if err := s.Bind(c, ctx, req); err != nil { {{- end }} - s.Render(ctx, nil, err) + s.Render(c, ctx, nil, err) return } {{- end }} {{if notEmptyStr .OutName -}} - resp, err := s.{{ .MethodName}}({{if notEmptyStr .InName}}req{{end}}) + resp, err := s.{{ .MethodName}}({{if .IsUseContext }}c, ctx{{if notEmptyStr .InName }}, {{end}}{{end}}{{if notEmptyStr .InName}}req{{end}}) {{if .IsRenderOut -}} if err != nil { - s.Render(ctx, nil, err) + s.Render(c, ctx, nil, err) return } var rv _render_ = resp - rv.Render(ctx) + rv.Render(c, ctx) {{- else -}} - s.Render(ctx, resp, err) + s.Render(c, ctx, resp, err) {{- end }} {{- else -}} - s.Render(ctx, nil, s.{{.MethodName}}({{if notEmptyStr .InName}}req{{end}})) + s.Render(c, ctx, nil, s.{{.MethodName}}({{if .IsUseContext }}c, ctx{{if notEmptyStr .InName }}, {{end}}{{end}}{{if notEmptyStr .InName}}req{{end}})) {{- end }} - }{{if .IsFieldChain }})...{{end}}) + }{{end}}{{if .IsFieldChain }})...{{end}}) {{else}}{{$field := .}} { - h := {{if .IsFieldChain }}append(cc.Chain{{.MethodName}}(), {{end}}func(c context.Context, ctx *app.RequestContext) { + h := {{if .IsFieldChain }}append(cc.Chain{{.MethodName}}(), {{end}}{{if .JustUseContext}}s.{{ .MethodName}}{{else}}func(c context.Context, ctx *app.RequestContext) { {{- if $.WatchCtxDone }} select { case <- c.Done(): @@ -152,30 +152,30 @@ func Register{{.TypeName}}Servant(e *route.Engine, s {{.TypeName}}{{if .IsUseFie req := new({{.InName}}) {{if .IsBindIn -}} var bv _binding_ = req - if err := bv.Bind(ctx); err != nil { + if err := bv.Bind(c, ctx); err != nil { {{- else -}} - if err := s.Bind(ctx, req); err != nil { + if err := s.Bind(c, ctx, req); err != nil { {{- end }} - s.Render(ctx, nil, err) + s.Render(c, ctx, nil, err) return } {{- end }} {{if notEmptyStr .OutName -}} - resp, err := s.{{ .MethodName}}({{if notEmptyStr .InName}}req{{end}}) + resp, err := s.{{ .MethodName}}({{if .IsUseContext }}c, ctx{{if notEmptyStr .InName }}, {{end}}{{end}}{{if notEmptyStr .InName}}req{{end}}) {{if .IsRenderOut -}} if err != nil { - s.Render(ctx, nil, err) + s.Render(c, ctx, nil, err) return } var rv _render_ = resp - rv.Render(ctx) + rv.Render(c, ctx) {{- else -}} - s.Render(ctx, resp, err) + s.Render(c, ctx, resp, err) {{- end }} {{- else -}} - s.Render(ctx, nil, s.{{.MethodName}}({{if notEmptyStr .InName}}req{{end}})) + s.Render(c, ctx, nil, s.{{.MethodName}}({{if .IsUseContext }}c, ctx{{if notEmptyStr .InName }}, {{end}}{{end}}{{if notEmptyStr .InName}}req{{end}})) {{- end }} - }{{if .IsFieldChain }}){{end}} + }{{end}}{{if .IsFieldChain }}){{end}} {{range .AnyHttpMethods}} router.Handle("{{.}}", "{{$field.Path}}", h{{if $field.IsFieldChain }}...{{end}}) {{end}} } {{end}} @@ -193,8 +193,12 @@ func ({{$unimplementedServant}}){{.Chain}}() []app.HandlerFunc { {{end}} {{range .Fields}} -func ({{$unimplementedServant}}){{.MethodName}}({{if notEmptyStr .InName }}req *{{ .InName }}{{end}}) {{if notEmptyStr .OutName }}(*{{ .OutName}}, mir.Error){{else}}mir.Error{{end}} { +func ({{$unimplementedServant}}){{if .JustUseContext}}{{ .MethodName}}(c context.Context, ctx *app.RequestContext){{else}}{{.MethodName}}({{if .IsUseContext }}c context.Context, ctx *app.RequestContext{{if notEmptyStr .InName }}, {{end}}{{end}}{{if notEmptyStr .InName }}req *{{ .InName }}{{end}}) {{if notEmptyStr .OutName }}(*{{ .OutName}}, mir.Error){{else}}mir.Error{{end}}{{end}} { + {{if .JustUseContext -}} + ctx.String(http.StatusNotImplemented, http.StatusText(http.StatusNotImplemented)) + {{else -}} return {{if notEmptyStr .OutName }}nil, {{end}}mir.Errorln(http.StatusNotImplemented, http.StatusText(http.StatusNotImplemented)) + {{end -}} } {{end}}