diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index a3233d97..d791cc32 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -7,7 +7,7 @@ env: jobs: lint: if: github.event.pull_request.draft == false - runs-on: ubuntu-latest + runs-on: ubuntu-22.04 permissions: contents: read actions: read @@ -41,7 +41,7 @@ jobs: - run: pnpm nx affected -t typecheck unit-tests: if: github.event.pull_request.draft == false - runs-on: ubuntu-latest + runs-on: ubuntu-22.04 permissions: contents: read actions: write @@ -87,7 +87,7 @@ jobs: - run: pnpm server-output-test integration-tests-ts-server: if: github.event.pull_request.draft == false - runs-on: ubuntu-latest + runs-on: ubuntu-22.04 permissions: contents: read actions: write @@ -99,7 +99,7 @@ jobs: - name: Install Dart uses: dart-lang/setup-dart@v1 with: - sdk: 3.5.4 + sdk: 3.6.0 - name: Install Rust uses: dtolnay/rust-toolchain@stable - name: Setup Gradle @@ -131,7 +131,7 @@ jobs: - run: pnpm run integration-tests --server ts --affected integration-tests-go-server: if: github.event.pull_request.draft == false - runs-on: ubuntu-latest + runs-on: ubuntu-22.04 permissions: contents: read actions: write @@ -143,7 +143,7 @@ jobs: - name: Install Dart uses: dart-lang/setup-dart@v1 with: - sdk: 3.5.4 + sdk: 3.6.0 - name: Install Rust uses: dtolnay/rust-toolchain@stable - name: Setup Gradle diff --git a/languages/dart/dart-client/pubspec.yaml b/languages/dart/dart-client/pubspec.yaml index 3ff16268..ead8d073 100644 --- a/languages/dart/dart-client/pubspec.yaml +++ b/languages/dart/dart-client/pubspec.yaml @@ -4,7 +4,7 @@ version: "0.69.2" repository: https://github.com/modiimedia/arri issue_tracker: https://github.com/modiimedia/arri/issues environment: - sdk: ">=3.0.0 <4.0.0" + sdk: ">=3.6.0 <4.0.0" dependencies: http: ">=1.0.0 <2.0.0" web_socket_channel: ">=3.0.0 <4.0.0" diff --git a/languages/dart/dart-codegen-reference/lib/reference_client.dart b/languages/dart/dart-codegen-reference/lib/reference_client.dart index 57003e31..415f436b 100644 --- a/languages/dart/dart-codegen-reference/lib/reference_client.dart +++ b/languages/dart/dart-codegen-reference/lib/reference_client.dart @@ -1,5 +1,5 @@ // this file was autogenerated by arri -// ignore_for_file: type=lint, unused_field, unnecessary_cast +// ignore_for_file: type=lint import 'dart:async'; import 'dart:convert'; import 'package:arri_client/arri_client.dart'; @@ -115,13 +115,13 @@ class ExampleClientBooksService { onClose: onClose, onError: onError != null && _onError != null ? (err, es) { - _onError?.call(onError); + _onError.call(onError); return onError(err, es); } : onError != null ? onError : _onError != null - ? (err, _) => _onError?.call(err) + ? (err, _) => _onError.call(err) : null, ); } diff --git a/languages/dart/dart-codegen-reference/pubspec.lock b/languages/dart/dart-codegen-reference/pubspec.lock index e874d3da..ba5d9f47 100644 --- a/languages/dart/dart-codegen-reference/pubspec.lock +++ b/languages/dart/dart-codegen-reference/pubspec.lock @@ -414,4 +414,4 @@ packages: source: hosted version: "3.1.2" sdks: - dart: ">=3.6.0-0 <4.0.0" + dart: ">=3.6.0 <4.0.0" diff --git a/languages/dart/dart-codegen-reference/pubspec.yaml b/languages/dart/dart-codegen-reference/pubspec.yaml index e5b746bc..9100cf36 100644 --- a/languages/dart/dart-codegen-reference/pubspec.yaml +++ b/languages/dart/dart-codegen-reference/pubspec.yaml @@ -4,7 +4,7 @@ publish_to: none repository: https://github.com/modiimedia/arri/tree/master/packages/arri-client-dart-reference issue_tracker: https://github.com/modiimedia/arri/issues environment: - sdk: ">=3.0.0 <4.0.0" + sdk: ">=3.6.0 <4.0.0" dependencies: arri_client: path: "../dart-client" diff --git a/languages/dart/dart-codegen/src/_index.ts b/languages/dart/dart-codegen/src/_index.ts index 6e7ff903..e32ecc78 100644 --- a/languages/dart/dart-codegen/src/_index.ts +++ b/languages/dart/dart-codegen/src/_index.ts @@ -147,7 +147,7 @@ export function createDartClient( } if (rpcParts.length === 0 && subServiceParts.length === 0) { const heading = `// this file was autogenerated by arri -// ignore_for_file: type=lint, unused_field, unnecessary_cast +// ignore_for_file: type=lint import 'dart:convert'; import 'package:arri_client/arri_client.dart';`; @@ -157,7 +157,7 @@ ${typeParts.join("\n\n")}`; } const clientName = validDartClassName(context.clientName, ""); return `// this file was autogenerated by arri -// ignore_for_file: type=lint, unused_field, unnecessary_cast +// ignore_for_file: type=lint import 'dart:async'; import 'dart:convert'; import 'package:arri_client/arri_client.dart'; diff --git a/languages/dart/dart-codegen/src/procedures.ts b/languages/dart/dart-codegen/src/procedures.ts index 27fa8e16..ce770c85 100644 --- a/languages/dart/dart-codegen/src/procedures.ts +++ b/languages/dart/dart-codegen/src/procedures.ts @@ -76,13 +76,13 @@ export function dartHttpRpcFromSchema( onClose: onClose, onError: onError != null && _onError != null ? (err, es) { - _onError?.call(onError); + _onError.call(onError); return onError(err, es); } : onError != null ? onError : _onError != null - ? (err, _) => _onError?.call(err) + ? (err, _) => _onError.call(err) : null, ); }`; diff --git a/languages/go/go-server/app.go b/languages/go/go-server/app.go index 1ef5c9a2..6cb8384c 100644 --- a/languages/go/go-server/app.go +++ b/languages/go/go-server/app.go @@ -7,29 +7,30 @@ import ( "os" ) -type App[TContext Context] struct { +type App[TEvent Event] struct { Mux *http.ServeMux - CreateContext func(w http.ResponseWriter, r *http.Request) (*TContext, RpcError) - InitializationErrors []error - Options AppOptions[TContext] - Procedures *OrderedMap[RpcDef] - Definitions *OrderedMap[TypeDef] + createEvent func(w http.ResponseWriter, r *http.Request) (*TEvent, RpcError) + initializationErrors []error + options AppOptions[TEvent] + middleware []Middleware[TEvent] + procedures *OrderedMap[RpcDef] + definitions *OrderedMap[TypeDef] } -func (app *App[TContext]) GetAppDefinition() AppDef { +func (app *App[TEvent]) GetAppDefinition() AppDef { info := None[AppDefInfo]() name := None[string]() description := None[string]() version := None[string]() - if len(app.Options.AppName) > 0 { - name = Some(app.Options.AppName) + if len(app.options.AppName) > 0 { + name = Some(app.options.AppName) } - if len(app.Options.AppDescription) > 0 { - description = Some(app.Options.AppDescription) + if len(app.options.AppDescription) > 0 { + description = Some(app.options.AppDescription) } - if len(app.Options.AppVersion) > 0 { - version = Some(app.Options.AppVersion) + if len(app.options.AppVersion) > 0 { + version = Some(app.options.AppVersion) } if name.IsSome() || description.IsSome() || version.IsSome() { @@ -43,12 +44,12 @@ func (app *App[TContext]) GetAppDefinition() AppDef { return AppDef{ SchemaVersion: "0.0.7", Info: info, - Procedures: *app.Procedures, - Definitions: *app.Definitions, + Procedures: *app.procedures, + Definitions: *app.definitions, } } -func (app *App[TContext]) Run(options RunOptions) error { +func (app *App[TEvent]) Run(options RunOptions) error { defOutput := flag.String("def-out", "", "definition-out") appDefCmd := flag.NewFlagSet("def", flag.ExitOnError) appDefOutput := appDefCmd.String("output", "__definition.json", "output") @@ -56,14 +57,14 @@ func (app *App[TContext]) Run(options RunOptions) error { switch os.Args[1] { case "def", "definition": appDefCmd.Parse(os.Args[2:]) - return appDefToFile(app.GetAppDefinition(), *appDefOutput, app.Options.KeyCasing) + return appDefToFile(app.GetAppDefinition(), *appDefOutput, app.options.KeyCasing) } } if len(os.Args) > 1 { flag.Parse() } if len(*defOutput) > 0 { - err := appDefToFile(app.GetAppDefinition(), *defOutput, app.Options.KeyCasing) + err := appDefToFile(app.GetAppDefinition(), *defOutput, app.options.KeyCasing) if err != nil { return err } @@ -83,24 +84,24 @@ func appDefToFile(appDef AppDef, output string, keyCasing KeyCasing) error { return nil } -func printServerStartMessages[TContext Context](app *App[TContext], port uint32, isHttps bool) { +func printServerStartMessages[TEvent Event](app *App[TEvent], port uint32, isHttps bool) { protocol := "http" if isHttps { protocol = "https" } baseUrl := fmt.Sprintf("%v://localhost:%v", protocol, port) fmt.Printf("Starting server at %v\n", baseUrl) - if len(app.Options.RpcRoutePrefix) > 0 { - fmt.Printf("Procedures path: %v%v\n", baseUrl, app.Options.RpcRoutePrefix) + if len(app.options.RpcRoutePrefix) > 0 { + fmt.Printf("Procedures path: %v%v\n", baseUrl, app.options.RpcRoutePrefix) } - defPath := app.Options.RpcDefinitionPath + defPath := app.options.RpcDefinitionPath if len(defPath) == 0 { defPath = "/__definition" } - fmt.Printf("App Definition Path: %v%v\n\n", baseUrl, app.Options.RpcRoutePrefix+defPath) + fmt.Printf("App Definition Path: %v%v\n\n", baseUrl, app.options.RpcRoutePrefix+defPath) } -func startServer[TContext Context](app *App[TContext], options RunOptions) error { +func startServer[TEvent Event](app *App[TEvent], options RunOptions) error { port := options.Port if port == 0 { port = 3000 @@ -121,7 +122,7 @@ type RunOptions struct { KeyFile string } -type AppOptions[TContext Context] struct { +type AppOptions[TEvent Event] struct { AppName string // The current app version. Generated clients will send this in the "client-version" header AppVersion string @@ -133,61 +134,62 @@ type AppOptions[TContext Context] struct { RpcRoutePrefix string // if not set it will default to "/{RpcRoutePrefix}/__definition" RpcDefinitionPath string - OnRequest func(*http.Request, *TContext) RpcError - OnBeforeResponse func(*http.Request, *TContext, any) RpcError - OnAfterResponse func(*http.Request, *TContext, any) RpcError - OnError func(*http.Request, *TContext, error) + OnRequest func(*http.Request, *TEvent) RpcError + OnBeforeResponse func(*http.Request, *TEvent, any) RpcError + OnAfterResponse func(*http.Request, *TEvent, any) RpcError + OnError func(*http.Request, *TEvent, error) } -func NewApp[TContext Context](mux *http.ServeMux, options AppOptions[TContext], createContext func(w http.ResponseWriter, r *http.Request) (*TContext, RpcError)) App[TContext] { - app := App[TContext]{ +func NewApp[TEvent Event](mux *http.ServeMux, options AppOptions[TEvent], createEvent func(w http.ResponseWriter, r *http.Request) (*TEvent, RpcError)) App[TEvent] { + app := App[TEvent]{ Mux: mux, - CreateContext: createContext, - Options: options, - InitializationErrors: []error{}, - Procedures: &OrderedMap[RpcDef]{}, - Definitions: &OrderedMap[TypeDef]{}, + createEvent: createEvent, + options: options, + initializationErrors: []error{}, + middleware: []Middleware[TEvent]{}, + procedures: &OrderedMap[RpcDef]{}, + definitions: &OrderedMap[TypeDef]{}, } - defPath := app.Options.RpcRoutePrefix + "/__definition" - if len(app.Options.RpcDefinitionPath) > 0 { - defPath = app.Options.RpcDefinitionPath + defPath := app.options.RpcRoutePrefix + "/__definition" + if len(app.options.RpcDefinitionPath) > 0 { + defPath = app.options.RpcDefinitionPath } - onRequest := app.Options.OnRequest + onRequest := app.options.OnRequest if onRequest == nil { - onRequest = func(r *http.Request, t *TContext) RpcError { + onRequest = func(r *http.Request, t *TEvent) RpcError { return nil } } - onBeforeResponse := app.Options.OnBeforeResponse + onBeforeResponse := app.options.OnBeforeResponse if onBeforeResponse == nil { - onBeforeResponse = func(r *http.Request, t *TContext, a any) RpcError { + onBeforeResponse = func(r *http.Request, t *TEvent, a any) RpcError { return nil } } - onAfterResponse := app.Options.OnAfterResponse + onAfterResponse := app.options.OnAfterResponse if onAfterResponse == nil { - onAfterResponse = func(r *http.Request, t *TContext, a any) RpcError { + onAfterResponse = func(r *http.Request, t *TEvent, a any) RpcError { return nil } } - onError := app.Options.OnError + onError := app.options.OnError if onError == nil { - onError = func(r *http.Request, t *TContext, err error) {} + onError = func(r *http.Request, t *TEvent, err error) {} } mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { w.Header().Add("Content-Type", "application/json") - ctx, ctxErr := app.CreateContext(w, r) - if ctxErr != nil { - handleError(false, w, r, ctx, ctxErr, onError) + event, err := app.createEvent(w, r) + if err != nil { + handleError(false, w, r, event, err, onError) return } - onRequestErr := onRequest(r, ctx) - if onRequestErr != nil { - handleError(false, w, r, ctx, onRequestErr, onError) + err = onRequest(r, event) + if err != nil { + handleError(false, w, r, event, err, onError) return } if r.URL.Path != "/" { - handleError(false, w, r, ctx, Error(404, ""), onError) + handleError(false, w, r, event, Error(404, ""), onError) return } w.WriteHeader(200) @@ -197,8 +199,8 @@ func NewApp[TContext Context](mux *http.ServeMux, options AppOptions[TContext], Version Option[string] SchemaPath string }{ - Title: app.Options.AppName, - Description: app.Options.AppDescription, + Title: app.options.AppName, + Description: app.options.AppDescription, Version: None[string](), SchemaPath: defPath, } @@ -211,57 +213,57 @@ func NewApp[TContext Context](mux *http.ServeMux, options AppOptions[TContext], if len(options.AppVersion) > 0 { response.Version = Some(options.AppVersion) } - onBeforeResponseErr := onBeforeResponse(r, ctx, response) - if onBeforeResponseErr != nil { - handleError(false, w, r, ctx, onBeforeResponseErr, onError) + err = onBeforeResponse(r, event, response) + if err != nil { + handleError(false, w, r, event, err, onError) return } - jsonResult, _ := EncodeJSON(response, app.Options.KeyCasing) + jsonResult, _ := EncodeJSON(response, app.options.KeyCasing) w.Write(jsonResult) - onAfterResponseErr := onAfterResponse(r, ctx, response) + onAfterResponseErr := onAfterResponse(r, event, response) if onAfterResponseErr != nil { - handleError(true, w, r, ctx, onAfterResponseErr, onError) + handleError(true, w, r, event, onAfterResponseErr, onError) return } }) mux.HandleFunc(defPath, func(w http.ResponseWriter, r *http.Request) { w.Header().Add("Content-Type", "application/json") - ctx, ctxErr := app.CreateContext(w, r) - if ctxErr != nil { - handleError(false, w, r, ctx, ctxErr, onError) + event, err := app.createEvent(w, r) + if err != nil { + handleError(false, w, r, event, err, onError) return } - onRequestError := onRequest(r, ctx) - if onRequestError != nil { - handleError(false, w, r, ctx, onRequestError, onError) + err = onRequest(r, event) + if err != nil { + handleError(false, w, r, event, err, onError) } - jsonResult, _ := EncodeJSON(app.GetAppDefinition(), app.Options.KeyCasing) - beforeResponseErr := onBeforeResponse(r, ctx, jsonResult) + jsonResult, _ := EncodeJSON(app.GetAppDefinition(), app.options.KeyCasing) + beforeResponseErr := onBeforeResponse(r, event, jsonResult) if beforeResponseErr != nil { - handleError(false, w, r, ctx, beforeResponseErr, onError) + handleError(false, w, r, event, beforeResponseErr, onError) return } w.WriteHeader(200) w.Write(jsonResult) - afterResponseErr := onAfterResponse(r, ctx, jsonResult) - if afterResponseErr != nil { - handleError(true, w, r, ctx, afterResponseErr, onError) + err = onAfterResponse(r, event, jsonResult) + if err != nil { + handleError(true, w, r, event, err, onError) return } }) return app } -func handleError[TContext Context]( +func handleError[TEvent Event]( responseSent bool, w http.ResponseWriter, r *http.Request, - context *TContext, + event *TEvent, err RpcError, - onError func(*http.Request, *TContext, error), + onError func(*http.Request, *TEvent, error), ) { - onError(r, context, err) + onError(r, event, err) if responseSent { return } @@ -275,8 +277,8 @@ type DefOptions struct { Description string } -func RegisterDef[TContext Context](app *App[TContext], input any, options DefOptions) { - def, err := ToTypeDef(input, app.Options.KeyCasing) +func RegisterDef[TEvent Event](app *App[TEvent], input any, options DefOptions) { + def, err := ToTypeDef(input, app.options.KeyCasing) if err != nil { panic(err) } @@ -289,5 +291,5 @@ func RegisterDef[TContext Context](app *App[TContext], input any, options DefOpt if len(options.Description) > 0 { def.Metadata.Value.Description = Some(options.Description) } - app.Definitions.Set(def.Metadata.Unwrap().Id.Unwrap(), *def) + app.definitions.Set(def.Metadata.Unwrap().Id.Unwrap(), *def) } diff --git a/languages/go/go-server/app_def.go b/languages/go/go-server/app_def.go index 3a5a71d3..d65a8587 100644 --- a/languages/go/go-server/app_def.go +++ b/languages/go/go-server/app_def.go @@ -7,6 +7,7 @@ import ( "strings" "github.com/iancoleman/strcase" + "github.com/modiimedia/arri/languages/go/go-server/utils" ) type AppDef struct { @@ -64,7 +65,7 @@ func ToRpcDef(value interface{}, options ArriHttpRpcOptions) (*RpcDef, error) { } rawParams := valueType.In(0) params := Some(rawParams.Name()) - hasParam := !isEmptyMessage(rawParams) + hasParam := !utils.IsEmptyMessage(rawParams) if !hasParam { params = None[string]() } @@ -73,7 +74,7 @@ func ToRpcDef(value interface{}, options ArriHttpRpcOptions) (*RpcDef, error) { rawResponse = rawResponse.Elem() } response := Some(rawResponse.Name()) - hasResponse := !isEmptyMessage(rawResponse) + hasResponse := !utils.IsEmptyMessage(rawResponse) if !hasResponse { response = None[string]() } diff --git a/languages/go/go-server/context.go b/languages/go/go-server/context.go deleted file mode 100644 index ccf96b5c..00000000 --- a/languages/go/go-server/context.go +++ /dev/null @@ -1,31 +0,0 @@ -package arri - -import ( - "net/http" -) - -type Context interface { - Request() *http.Request - Writer() http.ResponseWriter -} - -type DefaultContext struct { - request *http.Request - writer http.ResponseWriter -} - -func (c DefaultContext) Request() *http.Request { - return c.request -} - -func (c DefaultContext) Writer() http.ResponseWriter { - return c.writer -} - -func CreateDefaultContext(w http.ResponseWriter, r *http.Request) (*DefaultContext, RpcError) { - ctx := DefaultContext{ - request: r, - writer: w, - } - return &ctx, nil -} diff --git a/languages/go/go-server/decode_json.go b/languages/go/go-server/decode_json.go index ae3209cf..4eeb85eb 100644 --- a/languages/go/go-server/decode_json.go +++ b/languages/go/go-server/decode_json.go @@ -9,6 +9,7 @@ import ( "time" "github.com/iancoleman/strcase" + "github.com/modiimedia/arri/languages/go/go-server/utils" "github.com/tidwall/gjson" ) @@ -97,7 +98,7 @@ func DecodeJSON[T any](data []byte, v *T, keyCasing KeyCasing) *DecoderError { value := reflect.ValueOf(&v) if !parsedResult.Exists() { t := value.Type() - if isNullableTypeOrPointer(t) || isOptionalType(t) { + if utils.IsNullableTypeOrPointer(t) || utils.IsOptionalType(t) { return nil } err := NewDecoderError([]ValidationError{NewValidationError("expected JSON input but received nothing", "", "")}) @@ -168,10 +169,10 @@ func typeFromJSON(data *gjson.Result, target reflect.Value, context *DecoderCont if t.Name() == "Time" { return timestampFromJSON(data, target, context) } - if isOptionalType(t) { + if utils.IsOptionalType(t) { return optionFromJson(data, target, context) } - if isNullableTypeOrPointer(t) { + if utils.IsNullableTypeOrPointer(t) { return nullableFromJson(data, target, context) } if t.Implements(reflect.TypeFor[ArriModel]()) { @@ -628,7 +629,7 @@ func structFromJSON(data *gjson.Result, target reflect.Value, c *DecoderContext) } enumValues = Some(vals) } - isOptional := isOptionalType(fieldType) + isOptional := utils.IsOptionalType(fieldType) if isOptional { ctx := c.copyWith( Some(c.CurrentDepth+1), @@ -648,7 +649,7 @@ func structFromJSON(data *gjson.Result, target reflect.Value, c *DecoderContext) Some(c.InstancePath+"/"+fieldName), Some(c.SchemaPath+"/properties/"+fieldName), ) - isNullable := isNullableTypeOrPointer(fieldType) + isNullable := utils.IsNullableTypeOrPointer(fieldType) if isNullable { success := nullableFromJson(&jsonResult, field, &ctx) if !success { diff --git a/languages/go/go-server/decode_json_test.go b/languages/go/go-server/decode_json_test.go index 3be94a18..8143170c 100644 --- a/languages/go/go-server/decode_json_test.go +++ b/languages/go/go-server/decode_json_test.go @@ -232,16 +232,10 @@ func TestDecodeStdUser(t *testing.T) { } } -func BenchmarkStdDecodeObjectWithEveryType(b *testing.B) { - for i := 0; i < b.N; i++ { - target := objectWithEveryType{} - json.Unmarshal(_objectWithEveryTypeInput, &target) - } -} -func BenchmarkArriDecodeObjectWithEveryType(b *testing.B) { +func BenchmarkArriDecodeUser(b *testing.B) { for i := 0; i < b.N; i++ { - target := objectWithEveryType{} - arri.DecodeJSON(_objectWithEveryTypeInput, &target, arri.KeyCasingCamelCase) + user := benchUser{} + arri.DecodeJSON(benchUserInput, &user, arri.KeyCasingCamelCase) } } @@ -252,10 +246,17 @@ func BenchmarkStdDecodeUser(b *testing.B) { } } -func BenchmarkArriDecodeUser(b *testing.B) { +func BenchmarkArriDecodeObjectWithEveryType(b *testing.B) { for i := 0; i < b.N; i++ { - user := benchUser{} - arri.DecodeJSON(benchUserInput, &user, arri.KeyCasingCamelCase) + target := objectWithEveryType{} + arri.DecodeJSON(_objectWithEveryTypeInput, &target, arri.KeyCasingCamelCase) + } +} + +func BenchmarkStdDecodeObjectWithEveryType(b *testing.B) { + for i := 0; i < b.N; i++ { + target := objectWithEveryType{} + json.Unmarshal(_objectWithEveryTypeInput, &target) } } diff --git a/languages/go/go-server/decode_url_query.go b/languages/go/go-server/decode_url_query.go index 8b86ab33..7bacc57c 100644 --- a/languages/go/go-server/decode_url_query.go +++ b/languages/go/go-server/decode_url_query.go @@ -9,6 +9,7 @@ import ( "time" "github.com/iancoleman/strcase" + "github.com/modiimedia/arri/languages/go/go-server/utils" ) func FromUrlQuery[T any](values url.Values, target *T, keyCasing KeyCasing) *DecoderError { @@ -49,7 +50,7 @@ func FromUrlQuery[T any](values url.Values, target *T, keyCasing KeyCasing) *Dec } urlValue := values.Get(key) - isOptional := isOptionalType(fieldType) + isOptional := utils.IsOptionalType(fieldType) if isOptional { ctx := ctx.copyWith( None[uint32](), @@ -73,7 +74,7 @@ func FromUrlQuery[T any](values url.Values, target *T, keyCasing KeyCasing) *Dec Some(ctx.InstancePath+"/"+key), Some(ctx.SchemaPath+"/optionalProperties"), ) - isNullable := isNullableTypeOrPointer(fieldType) + isNullable := utils.IsNullableTypeOrPointer(fieldType) if isNullable { nullableTypeFromUrlQuery(urlValue, &field, &ctx) continue diff --git a/languages/go/go-server/encode_json.go b/languages/go/go-server/encode_json.go index 111ecfe3..377f0668 100644 --- a/languages/go/go-server/encode_json.go +++ b/languages/go/go-server/encode_json.go @@ -6,6 +6,8 @@ import ( "sort" "strconv" "time" + + "github.com/modiimedia/arri/languages/go/go-server/utils" ) type jsonEncodingCtx struct { @@ -63,7 +65,7 @@ func encodeValueToJSON(v reflect.Value, c *jsonEncodingCtx) error { return encodeUint64ToJSON(v, c) case reflect.Struct: t := v.Type() - if isNullableType(t) { + if utils.IsNullableType(t) { return encodeNullableToJSON(v, c) } if t.Implements(reflect.TypeFor[ArriModel]()) { @@ -197,12 +199,12 @@ func encodeStructToJSON(v reflect.Value, c *jsonEncodingCtx) error { if !field.IsExported() { continue } - fieldName := getSerialKey(&field, c.keyCasing) + fieldName := utils.GetSerialKey(&field, c.keyCasing) fieldValue := v.Field(i) c.instancePath = c.instancePath + "/" + fieldName - if isOptionalType(field.Type) { + if utils.IsOptionalType(field.Type) { c.schemaPath = "/optionalProperties/" + fieldName - if !optionalHasValue(&fieldValue) { + if !utils.OptionalHasValue(&fieldValue) { c.instancePath = oldInstancePath c.schemaPath = oldSchemaPath continue diff --git a/languages/go/go-server/error.go b/languages/go/go-server/error.go index b4b4e2e1..beb53cd2 100644 --- a/languages/go/go-server/error.go +++ b/languages/go/go-server/error.go @@ -85,6 +85,7 @@ var statusMessages = map[uint32]string{ 511: "Network authentication required", } +// An Arri RPC error. If message is empty arri will replace it with a default message based on the status code. func Error(statusCode uint32, message string) errorResponse { msg := message if len(message) == 0 { @@ -101,6 +102,7 @@ func Error(statusCode uint32, message string) errorResponse { } } +// An Arri RPC error with arbitrary data. If message is empty arri will replace it with a default message based on the status code. func ErrorWithData(statusCode uint32, message string, data Option[any]) errorResponse { msg := message if len(message) == 0 { diff --git a/languages/go/go-server/event.go b/languages/go/go-server/event.go new file mode 100644 index 00000000..d560d5d7 --- /dev/null +++ b/languages/go/go-server/event.go @@ -0,0 +1,37 @@ +package arri + +import ( + "net/http" +) + +// Event interface that is available in every RPC call +type Event interface { + Request() *http.Request + Writer() http.ResponseWriter +} + +// type helper function to help you know if you've fulfilled the IsEvent interface +func IsEvent(input Event) bool { + return true +} + +type DefaultEvent struct { + request *http.Request + writer http.ResponseWriter +} + +func (c DefaultEvent) Request() *http.Request { + return c.request +} + +func (c DefaultEvent) Writer() http.ResponseWriter { + return c.writer +} + +func CreateDefaultEvent(w http.ResponseWriter, r *http.Request) (*DefaultEvent, RpcError) { + event := DefaultEvent{ + request: r, + writer: w, + } + return &event, nil +} diff --git a/languages/go/go-server/event_test.go b/languages/go/go-server/event_test.go new file mode 100644 index 00000000..3a894ee0 --- /dev/null +++ b/languages/go/go-server/event_test.go @@ -0,0 +1,30 @@ +package arri_test + +import ( + "net/http" + "testing" + + arri "github.com/modiimedia/arri/languages/go/go-server" +) + +type CustomEvent struct { + req *http.Request + res http.ResponseWriter +} + +func (e CustomEvent) Request() *http.Request { + return e.req +} + +func (e CustomEvent) Writer() http.ResponseWriter { + return e.res +} + +func TestIsEvent(t *testing.T) { + if !arri.IsEvent(arri.DefaultEvent{}) { + t.Fatal() + } + if !arri.IsEvent(CustomEvent{}) { + t.Fatal() + } +} diff --git a/languages/go/go-server/middleware.go b/languages/go/go-server/middleware.go new file mode 100644 index 00000000..21781e24 --- /dev/null +++ b/languages/go/go-server/middleware.go @@ -0,0 +1,9 @@ +package arri + +import "net/http" + +type Middleware[TEvent Event] func(r *http.Request, event TEvent, rpcName string) RpcError + +func Use[TEvent Event](app *App[TEvent], middleware Middleware[TEvent]) { + app.middleware = append(app.middleware, middleware) +} diff --git a/languages/go/go-server/procedures.go b/languages/go/go-server/procedures.go index 13533664..c7107d13 100644 --- a/languages/go/go-server/procedures.go +++ b/languages/go/go-server/procedures.go @@ -7,6 +7,7 @@ import ( "strings" "github.com/iancoleman/strcase" + "github.com/modiimedia/arri/languages/go/go-server/utils" ) type RpcOptions struct { @@ -17,7 +18,7 @@ type RpcOptions struct { IsDeprecated bool } -func rpc[TParams, TResponse any, TContext Context](app *App[TContext], serviceName string, options RpcOptions, handler func(TParams, TContext) (TResponse, RpcError)) { +func rpc[TParams, TResponse any, TEvent Event](app *App[TEvent], serviceName string, options RpcOptions, handler func(TParams, TEvent) (TResponse, RpcError)) { handlerType := reflect.TypeOf(handler) rpcSchema, rpcError := ToRpcDef(handler, ArriHttpRpcOptions{}) rpcName := rpcNameFromFunctionName(GetFunctionName(handler)) @@ -25,9 +26,9 @@ func rpc[TParams, TResponse any, TContext Context](app *App[TContext], serviceNa rpcName = serviceName + "." + rpcName } if len(serviceName) > 0 { - rpcSchema.Http.Path = app.Options.RpcRoutePrefix + "/" + strcase.ToKebab(serviceName) + rpcSchema.Http.Path + rpcSchema.Http.Path = app.options.RpcRoutePrefix + "/" + strcase.ToKebab(serviceName) + rpcSchema.Http.Path } else { - rpcSchema.Http.Path = app.Options.RpcRoutePrefix + rpcSchema.Http.Path + rpcSchema.Http.Path = app.options.RpcRoutePrefix + rpcSchema.Http.Path } if rpcError != nil { panic(rpcError) @@ -36,7 +37,7 @@ func rpc[TParams, TResponse any, TContext Context](app *App[TContext], serviceNa rpcSchema.Http.Method = strings.ToLower(options.Method) } if len(options.Path) > 0 { - rpcSchema.Http.Path = app.Options.RpcRoutePrefix + options.Path + rpcSchema.Http.Path = app.options.RpcRoutePrefix + options.Path } if len(options.Description) > 0 { rpcSchema.Http.Description.Set(options.Description) @@ -49,9 +50,9 @@ func rpc[TParams, TResponse any, TContext Context](app *App[TContext], serviceNa panic("rpc params must be a struct. pointers and other types are not allowed.") } paramsName := getModelName(rpcName, params.Name(), "Params") - hasParams := !isEmptyMessage(params) + hasParams := !utils.IsEmptyMessage(params) if hasParams { - paramsDefContext := _NewTypeDefContext(app.Options.KeyCasing) + paramsDefContext := _NewTypeDefContext(app.options.KeyCasing) paramsSchema, paramsSchemaErr := typeToTypeDef(params, paramsDefContext) if paramsSchemaErr != nil { panic(paramsSchemaErr) @@ -60,7 +61,7 @@ func rpc[TParams, TResponse any, TContext Context](app *App[TContext], serviceNa panic("Procedures cannot accept anonymous structs") } rpcSchema.Http.Params.Set(paramsName) - app.Definitions.Set(paramsName, *paramsSchema) + app.definitions.Set(paramsName, *paramsSchema) } else { rpcSchema.Http.Params.Unset() } @@ -69,9 +70,9 @@ func rpc[TParams, TResponse any, TContext Context](app *App[TContext], serviceNa response = response.Elem() } responseName := getModelName(rpcName, response.Name(), "Response") - hasResponse := !isEmptyMessage(response) + hasResponse := !utils.IsEmptyMessage(response) if hasResponse { - responseDefContext := _NewTypeDefContext(app.Options.KeyCasing) + responseDefContext := _NewTypeDefContext(app.options.KeyCasing) responseSchema, responseSchemaErr := typeToTypeDef(response, responseDefContext) if responseSchemaErr != nil { panic(responseSchemaErr) @@ -80,94 +81,110 @@ func rpc[TParams, TResponse any, TContext Context](app *App[TContext], serviceNa panic("Procedures cannot return anonymous structs") } rpcSchema.Http.Response.Set(responseName) - app.Definitions.Set(responseName, *responseSchema) + app.definitions.Set(responseName, *responseSchema) } else { rpcSchema.Http.Response.Unset() } - app.Procedures.Set(rpcName, *rpcSchema) - onRequest := app.Options.OnRequest + app.procedures.Set(rpcName, *rpcSchema) + onRequest := app.options.OnRequest if onRequest == nil { - onRequest = func(r *http.Request, t *TContext) RpcError { + onRequest = func(r *http.Request, t *TEvent) RpcError { return nil } } - onBeforeResponse := app.Options.OnBeforeResponse + onBeforeResponse := app.options.OnBeforeResponse if onBeforeResponse == nil { - onBeforeResponse = func(r *http.Request, t *TContext, a any) RpcError { + onBeforeResponse = func(r *http.Request, t *TEvent, a any) RpcError { return nil } } - onAfterResponse := app.Options.OnAfterResponse + onAfterResponse := app.options.OnAfterResponse if onAfterResponse == nil { - onAfterResponse = func(r *http.Request, t *TContext, a any) RpcError { + onAfterResponse = func(r *http.Request, t *TEvent, a any) RpcError { return nil } } - onError := app.Options.OnError + onError := app.options.OnError if onError == nil { - onError = func(r *http.Request, t *TContext, err error) {} + onError = func(r *http.Request, t *TEvent, err error) {} } paramsZero := reflect.Zero(reflect.TypeFor[TParams]()) app.Mux.HandleFunc(rpcSchema.Http.Path, func(w http.ResponseWriter, r *http.Request) { w.Header().Add("Content-Type", "application/json") - ctx, ctxErr := app.CreateContext(w, r) - if ctxErr != nil { - handleError(false, w, r, nil, ctxErr, onError) + event, err := app.createEvent(w, r) + if err != nil { + handleError(false, w, r, nil, err, onError) return } if strings.ToLower(r.Method) != rpcSchema.Http.Method { - handleError(false, w, r, ctx, Error(404, "Not found"), onError) + handleError(false, w, r, event, Error(404, "Not found"), onError) return } - onRequestErr := onRequest(r, ctx) - if onRequestErr != nil { - handleError(false, w, r, ctx, onRequestErr, onError) + + err = onRequest(r, event) + if err != nil { + handleError(false, w, r, event, err, onError) return } + + if len(app.middleware) > 0 { + for i := 0; i < len(app.middleware); i++ { + fn := app.middleware[i] + err := fn(r, *event, rpcName) + if err != nil { + handleError(false, w, r, event, err, onError) + return + } + } + } + params, paramsOk := paramsZero.Interface().(TParams) if !paramsOk { - handleError(false, w, r, ctx, Error(500, "Error initializing empty params"), onError) + handleError(false, w, r, event, Error(500, "Error initializing empty params"), onError) return } if hasParams { switch rpcSchema.Http.Method { case HttpMethodGet: urlValues := r.URL.Query() - fromUrlQueryErr := FromUrlQuery(urlValues, ¶ms, app.Options.KeyCasing) + fromUrlQueryErr := FromUrlQuery(urlValues, ¶ms, app.options.KeyCasing) if fromUrlQueryErr != nil { - handleError(false, w, r, ctx, fromUrlQueryErr, onError) + handleError(false, w, r, event, fromUrlQueryErr, onError) return } default: - b, bErr := io.ReadAll(r.Body) - if bErr != nil { - handleError(false, w, r, ctx, Error(400, bErr.Error()), onError) + b, err := io.ReadAll(r.Body) + if err != nil { + handleError(false, w, r, event, Error(400, err.Error()), onError) return } - fromJsonErr := DecodeJSON(b, ¶ms, app.Options.KeyCasing) + fromJsonErr := DecodeJSON(b, ¶ms, app.options.KeyCasing) if fromJsonErr != nil { - handleError(false, w, r, ctx, fromJsonErr, onError) + handleError(false, w, r, event, fromJsonErr, onError) return } } } - response, responseErr := handler(params, *ctx) - if responseErr != nil { - payload := responseErr - handleError(false, w, r, ctx, payload, onError) + + response, err := handler(params, *event) + if err != nil { + payload := err + handleError(false, w, r, event, payload, onError) return } - onBeforeResponseErr := onBeforeResponse(r, ctx, "") - if onBeforeResponseErr != nil { - handleError(false, w, r, ctx, onBeforeResponseErr, onError) + + err = onBeforeResponse(r, event, "") + if err != nil { + handleError(false, w, r, event, err, onError) return } + w.WriteHeader(200) var body []byte if hasResponse { - json, jsonErr := EncodeJSON(response, app.Options.KeyCasing) - if jsonErr != nil { - handleError(false, w, r, ctx, ErrorWithData(500, jsonErr.Error(), Some[any](jsonErr)), onError) + json, err := EncodeJSON(response, app.options.KeyCasing) + if err != nil { + handleError(false, w, r, event, ErrorWithData(500, err.Error(), Some[any](err)), onError) return } body = json @@ -175,9 +192,10 @@ func rpc[TParams, TResponse any, TContext Context](app *App[TContext], serviceNa body = []byte{} } w.Write([]byte(body)) - onAfterResponseErr := onAfterResponse(r, ctx, "") - if onAfterResponseErr != nil { - handleError(false, w, r, ctx, onAfterResponseErr, onError) + + err = onAfterResponse(r, event, "") + if err != nil { + handleError(true, w, r, event, err, onError) } }) } @@ -189,10 +207,10 @@ func getModelName(rpcName string, modelName string, fallbackSuffix string) strin return modelName } -func Rpc[TParams, TResponse any, TContext Context](app *App[TContext], handler func(TParams, TContext) (TResponse, RpcError), options RpcOptions) { +func Rpc[TParams, TResponse any, TEvent Event](app *App[TEvent], handler func(TParams, TEvent) (TResponse, RpcError), options RpcOptions) { rpc(app, "", options, handler) } -func ScopedRpc[TParams, TResponse any, TContext Context](app *App[TContext], serviceName string, handler func(TParams, TContext) (TResponse, RpcError), options RpcOptions) { +func ScopedRpc[TParams, TResponse any, TEvent Event](app *App[TEvent], serviceName string, handler func(TParams, TEvent) (TResponse, RpcError), options RpcOptions) { rpc(app, serviceName, options, handler) } diff --git a/languages/go/go-server/procedures_sse.go b/languages/go/go-server/procedures_sse.go index d5ee2abf..3fb86507 100644 --- a/languages/go/go-server/procedures_sse.go +++ b/languages/go/go-server/procedures_sse.go @@ -10,6 +10,7 @@ import ( "time" "github.com/iancoleman/strcase" + "github.com/modiimedia/arri/languages/go/go-server/utils" ) type SseController[T any] interface { @@ -113,7 +114,7 @@ func (controller *defaultSseController[T]) SetPingInterval(duration time.Duratio controller.pingDuration = duration } -func eventStreamRpc[TParams, TResponse any, TContext Context](app *App[TContext], serviceName string, options RpcOptions, handler func(TParams, SseController[TResponse], TContext) RpcError) { +func eventStreamRpc[TParams, TResponse any, TEvent Event](app *App[TEvent], serviceName string, options RpcOptions, handler func(TParams, SseController[TResponse], TEvent) RpcError) { handlerType := reflect.TypeOf(handler) rpcSchema, rpcError := ToRpcDef( handler, @@ -132,21 +133,21 @@ func eventStreamRpc[TParams, TResponse any, TContext Context](app *App[TContext] panic(rpcError) } if len(serviceName) > 0 { - rpcSchema.Http.Path = app.Options.RpcRoutePrefix + "/" + strcase.ToKebab(serviceName) + rpcSchema.Http.Path + rpcSchema.Http.Path = app.options.RpcRoutePrefix + "/" + strcase.ToKebab(serviceName) + rpcSchema.Http.Path } else { - rpcSchema.Http.Path = app.Options.RpcRoutePrefix + rpcSchema.Http.Path + rpcSchema.Http.Path = app.options.RpcRoutePrefix + rpcSchema.Http.Path } if len(options.Path) > 0 { - rpcSchema.Http.Path = app.Options.RpcRoutePrefix + options.Path + rpcSchema.Http.Path = app.options.RpcRoutePrefix + options.Path } params := handlerType.In(0) if params.Kind() != reflect.Struct { panic("rpc params must be a struct. pointers and other types are not allowed.") } paramName := getModelName(rpcName, params.Name(), "Params") - hasParams := !isEmptyMessage(params) + hasParams := !utils.IsEmptyMessage(params) if hasParams { - paramsDefContext := _NewTypeDefContext(app.Options.KeyCasing) + paramsDefContext := _NewTypeDefContext(app.options.KeyCasing) paramsSchema, paramsSchemaErr := typeToTypeDef(params, paramsDefContext) if paramsSchemaErr != nil { panic(paramsSchemaErr) @@ -155,7 +156,7 @@ func eventStreamRpc[TParams, TResponse any, TContext Context](app *App[TContext] panic("Procedures cannot accept anonymous structs") } rpcSchema.Http.Params.Set(paramName) - app.Definitions.Set(paramName, *paramsSchema) + app.definitions.Set(paramName, *paramsSchema) } else { rpcSchema.Http.Params.Unset() } @@ -164,9 +165,9 @@ func eventStreamRpc[TParams, TResponse any, TContext Context](app *App[TContext] response = response.Elem() } responseName := getModelName(rpcName, response.Name(), "Response") - hasResponse := !isEmptyMessage(response) + hasResponse := !utils.IsEmptyMessage(response) if hasResponse { - responseDefContext := _NewTypeDefContext(app.Options.KeyCasing) + responseDefContext := _NewTypeDefContext(app.options.KeyCasing) responseSchema, responseSchemaErr := typeToTypeDef(response, responseDefContext) if responseSchemaErr != nil { panic(responseSchemaErr) @@ -175,98 +176,111 @@ func eventStreamRpc[TParams, TResponse any, TContext Context](app *App[TContext] panic("Procedures cannot return anonymous structs") } rpcSchema.Http.Response.Set(responseName) - app.Definitions.Set(responseName, *responseSchema) + app.definitions.Set(responseName, *responseSchema) } else { rpcSchema.Http.Response.Unset() } - app.Procedures.Set(rpcName, *rpcSchema) + app.procedures.Set(rpcName, *rpcSchema) onRequest, _, onAfterResponse, onError := getHooks(app) paramsZero := reflect.Zero(reflect.TypeFor[TParams]()) app.Mux.HandleFunc(rpcSchema.Http.Path, func(w http.ResponseWriter, r *http.Request) { - ctx, ctxErr := app.CreateContext(w, r) - if ctxErr != nil { - handleError(false, w, r, nil, ctxErr, onError) + event, err := app.createEvent(w, r) + if err != nil { + handleError(false, w, r, nil, err, onError) return } if strings.ToLower(r.Method) != rpcSchema.Http.Method { - handleError(false, w, r, ctx, Error(404, "Not found"), onError) + handleError(false, w, r, event, Error(404, "Not found"), onError) return } - onRequestErr := onRequest(r, ctx) - if onRequestErr != nil { - handleError(false, w, r, ctx, onRequestErr, onError) + err = onRequest(r, event) + if err != nil { + handleError(false, w, r, event, err, onError) return } + + if len(app.middleware) > 0 { + for i := 0; i < len(app.middleware); i++ { + fn := app.middleware[i] + err := fn(r, *event, rpcName) + if err != nil { + handleError(false, w, r, event, err, onError) + return + } + } + } + params, paramsOk := paramsZero.Interface().(TParams) if !paramsOk { - handleError(false, w, r, ctx, Error(500, "Error initializing empty params"), onError) + handleError(false, w, r, event, Error(500, "Error initializing empty params"), onError) return } if hasParams { switch rpcSchema.Http.Method { case HttpMethodGet: urlValues := r.URL.Query() - fromUrlQueryErr := FromUrlQuery(urlValues, ¶ms, app.Options.KeyCasing) + fromUrlQueryErr := FromUrlQuery(urlValues, ¶ms, app.options.KeyCasing) if fromUrlQueryErr != nil { - handleError(false, w, r, ctx, fromUrlQueryErr, onError) + handleError(false, w, r, event, fromUrlQueryErr, onError) return } default: b, bErr := io.ReadAll(r.Body) if bErr != nil { - handleError(false, w, r, ctx, Error(400, bErr.Error()), onError) + handleError(false, w, r, event, Error(400, bErr.Error()), onError) return } - fromJsonErr := DecodeJSON(b, ¶ms, app.Options.KeyCasing) + fromJsonErr := DecodeJSON(b, ¶ms, app.options.KeyCasing) if fromJsonErr != nil { - handleError(false, w, r, ctx, fromJsonErr, onError) + handleError(false, w, r, event, fromJsonErr, onError) return } } } - sseController := newDefaultSseController[TResponse](w, r, app.Options.KeyCasing) - responseErr := handler(params, sseController, *ctx) - if responseErr != nil { - handleError(false, w, r, ctx, responseErr, onError) + + sseController := newDefaultSseController[TResponse](w, r, app.options.KeyCasing) + err = handler(params, sseController, *event) + if err != nil { + handleError(false, w, r, event, err, onError) return } - onAfterResponseErr := onAfterResponse(r, ctx, "") - if onAfterResponseErr != nil { - handleError(false, w, r, ctx, onAfterResponseErr, onError) + err = onAfterResponse(r, event, "") + if err != nil { + handleError(false, w, r, event, err, onError) } }) } -func EventStreamRpc[TParams, TResponse any, TContext Context](app *App[TContext], handler func(TParams, SseController[TResponse], TContext) RpcError, options RpcOptions) { +func EventStreamRpc[TParams, TResponse any, TEvent Event](app *App[TEvent], handler func(TParams, SseController[TResponse], TEvent) RpcError, options RpcOptions) { eventStreamRpc(app, "", options, handler) } -func ScopedEventStreamRpc[TParams, TResponse any, TContext Context](app *App[TContext], scope string, handler func(TParams, SseController[TResponse], TContext) RpcError, options RpcOptions) { +func ScopedEventStreamRpc[TParams, TResponse any, TEvent Event](app *App[TEvent], scope string, handler func(TParams, SseController[TResponse], TEvent) RpcError, options RpcOptions) { eventStreamRpc(app, scope, options, handler) } -func getHooks[TContext Context](app *App[TContext]) (func(*http.Request, *TContext) RpcError, func(*http.Request, *TContext, any) RpcError, func(*http.Request, *TContext, any) RpcError, func(*http.Request, *TContext, error)) { - onRequest := app.Options.OnRequest +func getHooks[TEvent Event](app *App[TEvent]) (func(*http.Request, *TEvent) RpcError, func(*http.Request, *TEvent, any) RpcError, func(*http.Request, *TEvent, any) RpcError, func(*http.Request, *TEvent, error)) { + onRequest := app.options.OnRequest if onRequest == nil { - onRequest = func(r *http.Request, t *TContext) RpcError { + onRequest = func(r *http.Request, e *TEvent) RpcError { return nil } } - onBeforeResponse := app.Options.OnBeforeResponse + onBeforeResponse := app.options.OnBeforeResponse if onBeforeResponse == nil { - onBeforeResponse = func(r *http.Request, t *TContext, a any) RpcError { + onBeforeResponse = func(r *http.Request, e *TEvent, a any) RpcError { return nil } } - onAfterResponse := app.Options.OnAfterResponse + onAfterResponse := app.options.OnAfterResponse if onAfterResponse == nil { - onAfterResponse = func(r *http.Request, t *TContext, a any) RpcError { + onAfterResponse = func(r *http.Request, e *TEvent, a any) RpcError { return nil } } - onError := app.Options.OnError + onError := app.options.OnError if onError == nil { - onError = func(r *http.Request, t *TContext, err error) {} + onError = func(r *http.Request, e *TEvent, err error) {} } return onRequest, onBeforeResponse, onAfterResponse, onError } diff --git a/languages/go/go-server/type_def.go b/languages/go/go-server/type_def.go index 2a1a0839..0bf88109 100644 --- a/languages/go/go-server/type_def.go +++ b/languages/go/go-server/type_def.go @@ -7,6 +7,7 @@ import ( "strings" "github.com/iancoleman/strcase" + "github.com/modiimedia/arri/languages/go/go-server/utils" ) const ( @@ -153,7 +154,7 @@ func typeToTypeDef(input reflect.Type, context TypeDefContext) (*TypeDef, error) if input.Implements(reflect.TypeFor[ArriModel]()) { return reflect.New(input).Interface().(ArriModel).TypeDef(context) } - if isNullableTypeOrPointer(input) { + if utils.IsNullableTypeOrPointer(input) { subType := extractNullableType(input) return typeToTypeDef( subType, @@ -332,7 +333,7 @@ func structToTypeDef(input reflect.Type, context TypeDefContext) (*TypeDef, erro s := strings.TrimSpace(field.Tag.Get("enumName")) enumName = Some(Some(s)) } - isOptional := isOptionalType(fieldType) + isOptional := utils.IsOptionalType(fieldType) if isOptional { fieldType = extractOptionalType(fieldType) } diff --git a/languages/go/go-server/utils/reflect_helper_test.go b/languages/go/go-server/utils/reflect_helper_test.go new file mode 100644 index 00000000..b2f870c6 --- /dev/null +++ b/languages/go/go-server/utils/reflect_helper_test.go @@ -0,0 +1,65 @@ +package utils_test + +import ( + "reflect" + "testing" + + arri "github.com/modiimedia/arri/languages/go/go-server" + "github.com/modiimedia/arri/languages/go/go-server/utils" +) + +type Foo struct { + Foo string + Bar int + Baz bool +} + +func TestIsOptionType(t *testing.T) { + stringInput := arri.None[string]() + structInput := arri.Some(Foo{}) + + if !utils.IsOptionalType(reflect.TypeOf(stringInput)) { + t.Fatal() + } + + if !utils.IsOptionalType(reflect.TypeOf(&stringInput)) { + t.Fatal() + } + + if !utils.IsOptionalType(reflect.TypeOf(structInput)) { + t.Fatal() + } + + if !utils.IsOptionalType(reflect.TypeOf(&structInput)) { + t.Fatal() + } + + if utils.IsOptionalType(reflect.TypeOf(Foo{})) { + t.Fatalf("foo is not a valid optional") + } + + if utils.IsOptionalType(reflect.TypeOf("Hello world")) { + t.Fatalf("string is not a valid optional") + } +} + +func TestIsNullableType(t *testing.T) { + stringInput := arri.Null[string]() + structInput := arri.NotNull(Foo{}) + + if !utils.IsNullableType(reflect.TypeOf(stringInput)) { + t.Fatal() + } + + if !utils.IsNullableType(reflect.TypeOf(structInput)) { + t.Fatal() + } + + if utils.IsNullableType(reflect.TypeOf(Foo{})) { + t.Fatal() + } + + if utils.IsNullableType(reflect.TypeOf("Hello world")) { + t.Fatal() + } +} diff --git a/languages/go/go-server/reflect_helpers.go b/languages/go/go-server/utils/reflect_helpers.go similarity index 66% rename from languages/go/go-server/reflect_helpers.go rename to languages/go/go-server/utils/reflect_helpers.go index 7b74b413..fa9a7616 100644 --- a/languages/go/go-server/reflect_helpers.go +++ b/languages/go/go-server/utils/reflect_helpers.go @@ -1,4 +1,4 @@ -package arri +package utils import ( "reflect" @@ -7,14 +7,20 @@ import ( "github.com/iancoleman/strcase" ) -func isOptionalType(t reflect.Type) bool { +const ( + KeyCasingPascalCase = "PASCAL_CASE" + KeyCasingCamelCase = "CAMEL_CASE" + KeyCasingSnakeCase = "SNAKE_CASE" +) + +func IsOptionalType(t reflect.Type) bool { if t.Kind() == reflect.Ptr { - return isOptionalType(t.Elem()) + return IsOptionalType(t.Elem()) } return t.Kind() == reflect.Struct && strings.HasPrefix(t.Name(), "Option[") } -func optionalHasValue(value *reflect.Value) bool { +func OptionalHasValue(value *reflect.Value) bool { target := value if target.Kind() == reflect.Ptr { if target.IsNil() { @@ -27,18 +33,18 @@ func optionalHasValue(value *reflect.Value) bool { return isSome.Bool() } -func isNullableType(t reflect.Type) bool { +func IsNullableType(t reflect.Type) bool { return t.Kind() == reflect.Struct && strings.HasPrefix(t.Name(), "Nullable[") } -func isNullableTypeOrPointer(t reflect.Type) bool { +func IsNullableTypeOrPointer(t reflect.Type) bool { if t.Kind() == reflect.Ptr { - return isNullableTypeOrPointer(t.Elem()) + return IsNullableTypeOrPointer(t.Elem()) } return t.Kind() == reflect.Struct && strings.HasPrefix(t.Name(), "Nullable[") } -func nullableHasValue(val *reflect.Value) bool { +func NullableHasValue(val *reflect.Value) bool { target := val if target.Kind() == reflect.Ptr { if target.IsNil() { @@ -51,7 +57,7 @@ func nullableHasValue(val *reflect.Value) bool { return isSet.Bool() } -func getSerialKey(field *reflect.StructField, keyCasing KeyCasing) string { +func GetSerialKey(field *reflect.StructField, keyCasing string) string { keyTag := field.Tag.Get("key") if len(keyTag) > 0 { return keyTag @@ -67,6 +73,6 @@ func getSerialKey(field *reflect.StructField, keyCasing KeyCasing) string { return strcase.ToLowerCamel(field.Name) } -func isEmptyMessage(t reflect.Type) bool { +func IsEmptyMessage(t reflect.Type) bool { return t.Name() == "EmptyMessage" && strings.Contains(t.PkgPath(), "arri") } diff --git a/playground/go/main.go b/playground/go/main.go index c0632d76..57a63b94 100644 --- a/playground/go/main.go +++ b/playground/go/main.go @@ -8,7 +8,7 @@ import ( ) func main() { - app := arri.NewApp(http.DefaultServeMux, arri.AppOptions[arri.DefaultContext]{}, arri.CreateDefaultContext) + app := arri.NewApp(http.DefaultServeMux, arri.AppOptions[arri.DefaultEvent]{}, arri.CreateDefaultEvent) arri.Rpc(&app, SayHello, arri.RpcOptions{}) app.Run(arri.RunOptions{}) } @@ -21,6 +21,6 @@ type SayHelloResponse struct { Message string `enum:"HELLO,WORLD" enumName:"MESSAGE"` } -func SayHello(params SayHelloParams, ctx arri.DefaultContext) (SayHelloResponse, arri.RpcError) { +func SayHello(params SayHelloParams, event arri.DefaultEvent) (SayHelloResponse, arri.RpcError) { return SayHelloResponse{Message: fmt.Sprintf("Hello %s", params.Name)}, nil } diff --git a/tests/clients/dart/lib/test_client.rpc.dart b/tests/clients/dart/lib/test_client.rpc.dart index c0d1c8d3..a84fdb0c 100644 --- a/tests/clients/dart/lib/test_client.rpc.dart +++ b/tests/clients/dart/lib/test_client.rpc.dart @@ -1,5 +1,5 @@ // this file was autogenerated by arri -// ignore_for_file: type=lint, unused_field, unnecessary_cast +// ignore_for_file: type=lint import 'dart:async'; import 'dart:convert'; import 'package:arri_client/arri_client.dart'; @@ -257,13 +257,13 @@ class TestClientTestsService { onClose: onClose, onError: onError != null && _onError != null ? (err, es) { - _onError?.call(onError); + _onError.call(onError); return onError(err, es); } : onError != null ? onError : _onError != null - ? (err, _) => _onError?.call(err) + ? (err, _) => _onError.call(err) : null, ); } @@ -302,13 +302,13 @@ class TestClientTestsService { onClose: onClose, onError: onError != null && _onError != null ? (err, es) { - _onError?.call(onError); + _onError.call(onError); return onError(err, es); } : onError != null ? onError : _onError != null - ? (err, _) => _onError?.call(err) + ? (err, _) => _onError.call(err) : null, ); } @@ -344,13 +344,13 @@ class TestClientTestsService { onClose: onClose, onError: onError != null && _onError != null ? (err, es) { - _onError?.call(onError); + _onError.call(onError); return onError(err, es); } : onError != null ? onError : _onError != null - ? (err, _) => _onError?.call(err) + ? (err, _) => _onError.call(err) : null, ); } @@ -385,13 +385,13 @@ class TestClientTestsService { onClose: onClose, onError: onError != null && _onError != null ? (err, es) { - _onError?.call(onError); + _onError.call(onError); return onError(err, es); } : onError != null ? onError : _onError != null - ? (err, _) => _onError?.call(err) + ? (err, _) => _onError.call(err) : null, ); } @@ -430,13 +430,13 @@ class TestClientTestsService { onClose: onClose, onError: onError != null && _onError != null ? (err, es) { - _onError?.call(onError); + _onError.call(onError); return onError(err, es); } : onError != null ? onError : _onError != null - ? (err, _) => _onError?.call(err) + ? (err, _) => _onError.call(err) : null, ); } @@ -470,13 +470,13 @@ class TestClientTestsService { onClose: onClose, onError: onError != null && _onError != null ? (err, es) { - _onError?.call(onError); + _onError.call(onError); return onError(err, es); } : onError != null ? onError : _onError != null - ? (err, _) => _onError?.call(err) + ? (err, _) => _onError.call(err) : null, ); } @@ -530,13 +530,13 @@ class TestClientUsersService { onClose: onClose, onError: onError != null && _onError != null ? (err, es) { - _onError?.call(onError); + _onError.call(onError); return onError(err, es); } : onError != null ? onError : _onError != null - ? (err, _) => _onError?.call(err) + ? (err, _) => _onError.call(err) : null, ); } diff --git a/tests/clients/dart/pubspec.lock b/tests/clients/dart/pubspec.lock index 26cbdf14..c6df82f2 100644 --- a/tests/clients/dart/pubspec.lock +++ b/tests/clients/dart/pubspec.lock @@ -414,4 +414,4 @@ packages: source: hosted version: "3.1.2" sdks: - dart: ">=3.6.0-0 <4.0.0" + dart: ">=3.6.0 <4.0.0" diff --git a/tests/clients/dart/pubspec.yaml b/tests/clients/dart/pubspec.yaml index 5137fbea..f4629443 100644 --- a/tests/clients/dart/pubspec.yaml +++ b/tests/clients/dart/pubspec.yaml @@ -3,7 +3,7 @@ description: repository: https://github.com/modiimedia/arri issue_tracker: https://github.com/modiimedia/arri/issues environment: - sdk: ">=3.0.0 <4.0.0" + sdk: ">=3.6.0 <4.0.0" dependencies: arri_client: path: ../../../languages/dart/dart-client diff --git a/tests/servers/go/main.go b/tests/servers/go/main.go index b819248b..e10e17c7 100644 --- a/tests/servers/go/main.go +++ b/tests/servers/go/main.go @@ -12,17 +12,17 @@ import ( "gopkg.in/loremipsum.v1" ) -type AppContext struct { +type RpcEvent struct { XTestHeader string request *http.Request writer http.ResponseWriter } -func (c AppContext) Request() *http.Request { +func (c RpcEvent) Request() *http.Request { return c.request } -func (c AppContext) Writer() http.ResponseWriter { +func (c RpcEvent) Writer() http.ResponseWriter { return c.writer } @@ -38,30 +38,33 @@ func main() { }) app := arri.NewApp( mux, - arri.AppOptions[AppContext]{ + arri.AppOptions[RpcEvent]{ AppVersion: "10", RpcRoutePrefix: "/rpcs", - OnRequest: func(r *http.Request, ac *AppContext) arri.RpcError { - ac.Writer().Header().Set("Access-Control-Allow-Origin", "*") - if len(ac.XTestHeader) == 0 && - r.URL.Path != "/" && - r.URL.Path != "/status" && - r.URL.Path != "/favicon.ico" && - !strings.HasSuffix(r.URL.Path, "__definition") { - return arri.Error(401, "Missing test auth header 'x-test-header'") - } + OnRequest: func(r *http.Request, event *RpcEvent) arri.RpcError { + event.Writer().Header().Set("Access-Control-Allow-Origin", "*") return nil }, - OnError: func(r *http.Request, ac *AppContext, err error) {}, + OnError: func(r *http.Request, ac *RpcEvent, err error) {}, }, - func(w http.ResponseWriter, r *http.Request) (*AppContext, arri.RpcError) { - return &AppContext{ + func(w http.ResponseWriter, r *http.Request) (*RpcEvent, arri.RpcError) { + return &RpcEvent{ request: r, writer: w, XTestHeader: r.Header.Get("x-test-header"), }, nil }, ) + arri.Use(&app, func(r *http.Request, event RpcEvent, rpcName string) arri.RpcError { + if len(event.XTestHeader) == 0 && + r.URL.Path != "/" && + r.URL.Path != "/status" && + r.URL.Path != "/favicon.ico" && + !strings.HasSuffix(r.URL.Path, "__definition") { + return arri.Error(401, "Missing test auth header 'x-test-header'") + } + return nil + }) arri.RegisterDef(&app, ManuallyAddedModel{}, arri.DefOptions{}) arri.ScopedRpc(&app, "tests", EmptyParamsGetRequest, arri.RpcOptions{Method: arri.HttpMethodGet}) arri.ScopedRpc(&app, "tests", EmptyParamsPostRequest, arri.RpcOptions{}) @@ -98,7 +101,7 @@ type DeprecatedRpcParams struct { DeprecatedField string `arri:"deprecated"` } -func DeprecatedRpc(_ DeprecatedRpcParams, _ AppContext) (arri.EmptyMessage, arri.RpcError) { +func DeprecatedRpc(_ DeprecatedRpcParams, _ RpcEvent) (arri.EmptyMessage, arri.RpcError) { return arri.EmptyMessage{}, nil } @@ -106,19 +109,19 @@ type DefaultPayload struct { Message string } -func EmptyParamsGetRequest(_ arri.EmptyMessage, _ AppContext) (DefaultPayload, arri.RpcError) { +func EmptyParamsGetRequest(_ arri.EmptyMessage, _ RpcEvent) (DefaultPayload, arri.RpcError) { return DefaultPayload{Message: "ok"}, nil } -func EmptyParamsPostRequest(_ arri.EmptyMessage, _ AppContext) (DefaultPayload, arri.RpcError) { +func EmptyParamsPostRequest(_ arri.EmptyMessage, _ RpcEvent) (DefaultPayload, arri.RpcError) { return DefaultPayload{Message: "ok"}, nil } -func EmptyResponseGetRequest(_ DefaultPayload, _ AppContext) (arri.EmptyMessage, arri.RpcError) { +func EmptyResponseGetRequest(_ DefaultPayload, _ RpcEvent) (arri.EmptyMessage, arri.RpcError) { return arri.EmptyMessage{}, nil } -func EmptyResponsePostRequest(_ DefaultPayload, _ AppContext) (arri.EmptyMessage, arri.RpcError) { +func EmptyResponsePostRequest(_ DefaultPayload, _ RpcEvent) (arri.EmptyMessage, arri.RpcError) { return arri.EmptyMessage{}, nil } @@ -127,7 +130,7 @@ type SendErrorParams struct { Message string } -func SendError(params SendErrorParams, _ AppContext) (arri.EmptyMessage, arri.RpcError) { +func SendError(params SendErrorParams, _ RpcEvent) (arri.EmptyMessage, arri.RpcError) { return arri.EmptyMessage{}, arri.Error(uint32(params.Code), params.Message) } @@ -181,7 +184,7 @@ type ObjectWithEveryType struct { } } -func SendObject(params ObjectWithEveryType, _ AppContext) (ObjectWithEveryType, arri.RpcError) { +func SendObject(params ObjectWithEveryType, _ RpcEvent) (ObjectWithEveryType, arri.RpcError) { return params, nil } @@ -235,7 +238,7 @@ type ObjectWithEveryNullableType struct { }]]] } -func SendObjectWithNullableFields(params ObjectWithEveryNullableType, _ AppContext) (ObjectWithEveryNullableType, arri.RpcError) { +func SendObjectWithNullableFields(params ObjectWithEveryNullableType, _ RpcEvent) (ObjectWithEveryNullableType, arri.RpcError) { return params, nil } @@ -247,7 +250,7 @@ type ObjectWithPascalCaseKeys struct { IsAdmin arri.Option[bool] `key:"IsAdmin"` } -func SendObjectWithPascalCaseKeys(params ObjectWithPascalCaseKeys, _ AppContext) (ObjectWithPascalCaseKeys, arri.RpcError) { +func SendObjectWithPascalCaseKeys(params ObjectWithPascalCaseKeys, _ RpcEvent) (ObjectWithPascalCaseKeys, arri.RpcError) { return params, nil } @@ -259,7 +262,7 @@ type ObjectWithSnakeCaseKeys struct { IsAdmin arri.Option[bool] `key:"is_admin"` } -func SendObjectWithSnakeCaseKeys(params ObjectWithSnakeCaseKeys, _ AppContext) (ObjectWithSnakeCaseKeys, arri.RpcError) { +func SendObjectWithSnakeCaseKeys(params ObjectWithSnakeCaseKeys, _ RpcEvent) (ObjectWithSnakeCaseKeys, arri.RpcError) { return params, nil } @@ -313,7 +316,7 @@ type ObjectWithEveryOptionalType struct { }] } -func SendPartialObject(params ObjectWithEveryOptionalType, _ AppContext) (ObjectWithEveryOptionalType, arri.RpcError) { +func SendPartialObject(params ObjectWithEveryOptionalType, _ RpcEvent) (ObjectWithEveryOptionalType, arri.RpcError) { return params, nil } @@ -323,7 +326,7 @@ type RecursiveObject struct { Value string } -func SendRecursiveObject(params RecursiveObject, _ AppContext) (RecursiveObject, arri.RpcError) { +func SendRecursiveObject(params RecursiveObject, _ RpcEvent) (RecursiveObject, arri.RpcError) { return params, nil } @@ -346,7 +349,7 @@ type RecursiveUnion struct { } `discriminator:"SHAPE" description:"Shape node"` } -func SendRecursiveUnion(params RecursiveUnion, _ AppContext) (RecursiveUnion, arri.RpcError) { +func SendRecursiveUnion(params RecursiveUnion, _ RpcEvent) (RecursiveUnion, arri.RpcError) { return params, nil } @@ -359,7 +362,7 @@ type AutoReconnectResponse struct { Message string } -func StreamAutoReconnect(params AutoReconnectParams, controller arri.SseController[AutoReconnectResponse], ctx AppContext) arri.RpcError { +func StreamAutoReconnect(params AutoReconnectParams, controller arri.SseController[AutoReconnectResponse], event RpcEvent) arri.RpcError { t := time.NewTicker(time.Millisecond) defer t.Stop() var msgCount uint8 = 0 @@ -393,7 +396,7 @@ type StreamConnectionErrorTestResponse struct { func StreamConnectionErrorTest( params StreamConnectionErrorTestParams, controller arri.SseController[StreamConnectionErrorTestResponse], - _ AppContext, + _ RpcEvent, ) arri.RpcError { return arri.Error(uint32(params.StatusCode), params.StatusMessage) } @@ -407,7 +410,7 @@ type StreamLargeObjectsResponse struct { } } -func StreamLargeObjects(params arri.EmptyMessage, controller arri.SseController[StreamLargeObjectsResponse], _ AppContext) arri.RpcError { +func StreamLargeObjects(params arri.EmptyMessage, controller arri.SseController[StreamLargeObjectsResponse], _ RpcEvent) arri.RpcError { t := time.NewTicker(time.Millisecond) defer t.Stop() for { @@ -480,7 +483,7 @@ type ChatMessageUrl struct { Url string } -func StreamMessages(params ChatMessageParams, controller arri.SseController[ChatMessage], context AppContext) arri.RpcError { +func StreamMessages(params ChatMessageParams, controller arri.SseController[ChatMessage], event RpcEvent) arri.RpcError { t := time.NewTicker(time.Millisecond) for { select { @@ -500,8 +503,8 @@ type TestsStreamRetryWithNewCredentialsResponse struct { var usedTokens map[string]bool = map[string]bool{} -func StreamRetryWithNewCredentials(_ arri.EmptyMessage, controller arri.SseController[TestsStreamRetryWithNewCredentialsResponse], ctx AppContext) arri.RpcError { - authToken := ctx.XTestHeader +func StreamRetryWithNewCredentials(_ arri.EmptyMessage, controller arri.SseController[TestsStreamRetryWithNewCredentialsResponse], event RpcEvent) arri.RpcError { + authToken := event.XTestHeader if len(authToken) == 0 { return arri.Error(400, "") } @@ -528,7 +531,7 @@ func StreamRetryWithNewCredentials(_ arri.EmptyMessage, controller arri.SseContr } } -func StreamTenEventsThenEnd(_ arri.EmptyMessage, controller arri.SseController[ChatMessage], ctx AppContext) arri.RpcError { +func StreamTenEventsThenEnd(_ arri.EmptyMessage, controller arri.SseController[ChatMessage], event RpcEvent) arri.RpcError { t := time.NewTicker(time.Millisecond) defer t.Stop() msgCount := 0 @@ -599,7 +602,7 @@ type UserSettings struct { PreferredTheme string `enum:"dark-mode,light-mode,system"` } -func WatchUser(params UsersWatchUserParams, stream arri.SseController[UsersWatchUserResponse], ctx AppContext) arri.RpcError { +func WatchUser(params UsersWatchUserParams, stream arri.SseController[UsersWatchUserResponse], event RpcEvent) arri.RpcError { t := time.NewTicker(time.Millisecond) defer t.Stop() msgCount := 0