diff --git a/go.mod b/go.mod index 2e89ee8..0e50c7e 100644 --- a/go.mod +++ b/go.mod @@ -13,6 +13,7 @@ require ( github.com/stretchr/testify v1.8.4 go.uber.org/mock v0.3.0 golang.org/x/exp v0.0.0-20231219160207-73b9e39aefca + golang.org/x/net v0.20.0 ) require ( @@ -29,8 +30,8 @@ require ( go.uber.org/multierr v1.11.0 // indirect go.uber.org/zap v1.26.0 // indirect golang.org/x/arch v0.6.0 // indirect - golang.org/x/net v0.19.0 // indirect - golang.org/x/sys v0.15.0 // indirect + golang.org/x/sys v0.16.0 // indirect + golang.org/x/text v0.14.0 // indirect gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 4d9a7ba..f8c6417 100644 --- a/go.sum +++ b/go.sum @@ -79,8 +79,8 @@ golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= -golang.org/x/net v0.19.0 h1:zTwKpTd2XuCqf8huc7Fo2iSy+4RHPd10s4KzeTnVr1c= -golang.org/x/net v0.19.0/go.mod h1:CfAk/cbD4CthTvqiEl8NpboMuiuOYsAr/7NOjZJtv1U= +golang.org/x/net v0.20.0 h1:aCL9BSgETF1k+blQaYUBx9hJ9LOGP3gAVemcZlf1Kpo= +golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -93,8 +93,8 @@ golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc= -golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.16.0 h1:xWw16ngr6ZMtmxDyKyIgsE93KNKz5HKmMa3b8ALHidU= +golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= @@ -106,6 +106,8 @@ golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= +golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= +golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4= golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= diff --git a/http.go b/http.go index c581529..7aaaa17 100644 --- a/http.go +++ b/http.go @@ -3,11 +3,14 @@ package kutils import ( "bytes" "context" + "crypto/tls" + "net" "net/http" "time" "github.com/go-resty/resty/v2" "github.com/hashicorp/go-retryablehttp" + "golang.org/x/net/http2" "github.com/KyberNetwork/kutils/internal/json" ) @@ -21,6 +24,7 @@ type HttpCfg struct { MaxIdleConns int // max idle connections for all hosts, default 100 MaxIdleConnsPerHost int // max idle connections per host, default GOMAXPROCS+1 MaxConnsPerHost int // max total connections per host, default 0 (unlimited) + UseH2c bool // whether to use http2 h2c, default false RetryCount int // retry count (exponential backoff), default 0 RetryWaitTime time.Duration // first exponential backoff, default 100ms RetryMaxWaitTime time.Duration // max exponential backoff, default 2s @@ -52,6 +56,18 @@ func (h *HttpCfg) NewRestyClient() (client *resty.Client) { transport.MaxConnsPerHost = h.MaxConnsPerHost } client.SetTransport(transport) + if h.UseH2c { + if h2cTransport, err := http2.ConfigureTransports(transport); err == nil { + h2cTransport.AllowHTTP = true + } + client.SetTransport(&http2.Transport{ + AllowHTTP: true, + DialTLSContext: func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) { + var d net.Dialer + return d.DialContext(ctx, network, addr) + }, + }) + } } client.SetBaseURL(h.BaseUrl). diff --git a/http_test.go b/http_test.go index a895575..d700343 100644 --- a/http_test.go +++ b/http_test.go @@ -1,9 +1,11 @@ package kutils import ( + "context" "encoding/json" "errors" "fmt" + "net" "net/http" "net/url" "testing" @@ -11,6 +13,8 @@ import ( "github.com/go-resty/resty/v2" "github.com/stretchr/testify/assert" + "golang.org/x/net/http2" + "golang.org/x/net/http2/h2c" ) func TestHttpCfg_NewRestyClient(t *testing.T) { @@ -160,3 +164,87 @@ func Test_retryableHttpError(t *testing.T) { }) } } + +func Test_h2c(t *testing.T) { + body := "hello" + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte(body)) + assert.Nil(t, r.TLS, nil) + }) + h2s := &http2.Server{} + h2cHandler := h2c.NewHandler(handler, h2s) + tests := []struct { + name string + serverHandler http.Handler + clientCfg *HttpCfg + wantErr assert.ErrorAssertionFunc + expects func(*testing.T, *resty.Response) + }{ + { + "h1 server, h1 client", + handler, + &HttpCfg{}, + assert.NoError, + func(t *testing.T, resp *resty.Response) { + assert.Equal(t, "HTTP/1.1", resp.Proto()) + assert.Equal(t, http.StatusOK, resp.StatusCode()) + assert.Equal(t, body, resp.String()) + }, + }, + { + "h1 server, h2c client", + handler, + &HttpCfg{UseH2c: true}, + assert.Error, + nil, + }, + { + "h2 server, h1 client", + h2cHandler, + &HttpCfg{}, + assert.NoError, + func(t *testing.T, resp *resty.Response) { + assert.Equal(t, "HTTP/1.1", resp.Proto()) + assert.Equal(t, http.StatusOK, resp.StatusCode()) + assert.Equal(t, body, resp.String()) + }, + }, + { + "h2 server, h2c client", + h2cHandler, + &HttpCfg{UseH2c: true}, + assert.NoError, + func(t *testing.T, resp *resty.Response) { + assert.Equal(t, "HTTP/2.0", resp.Proto()) + assert.Equal(t, http.StatusOK, resp.StatusCode()) + assert.Equal(t, body, resp.String()) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ln, err := net.Listen("tcp", "127.0.0.1:") + assert.NoError(t, err) + srv := &http.Server{ + Addr: ln.Addr().String(), + Handler: tt.serverHandler, + } + + go func() { + _ = srv.Serve(ln) + }() + defer func(srv *http.Server, ctx context.Context) { + _ = srv.Shutdown(ctx) + }(srv, context.Background()) + + client := tt.clientCfg.NewRestyClient() + resp, err := client.R().EnableTrace().Get("http://" + srv.Addr) + + tt.wantErr(t, err) + if tt.expects != nil { + tt.expects(t, resp) + } + }) + } +}