From 0f9cfb4801a43c4edd81f359c92eb9d7c7b0191a Mon Sep 17 00:00:00 2001 From: winlin Date: Wed, 28 Aug 2024 16:26:28 +0800 Subject: [PATCH] Support HTTP API server proxy. --- proxy/api.go | 90 ++++++++++++++++++++++++ proxy/http.go | 31 +++++++-- proxy/main.go | 13 +++- proxy/rtmp.go | 9 ++- proxy/rtmp/amf0.go | 104 +++++++++++++-------------- proxy/rtmp/rtmp.go | 170 ++++++++++++++++++++++----------------------- 6 files changed, 273 insertions(+), 144 deletions(-) create mode 100644 proxy/api.go diff --git a/proxy/api.go b/proxy/api.go new file mode 100644 index 00000000000..4f5c3ec607f --- /dev/null +++ b/proxy/api.go @@ -0,0 +1,90 @@ +// Copyright (c) 2024 Winlin +// +// SPDX-License-Identifier: MIT +package main + +import ( + "context" + "net/http" + "srs-proxy/logger" + "strings" + "sync" + "time" +) + +type httpAPI struct { + // The underlayer HTTP server. + server *http.Server + // The gracefully quit timeout, wait server to quit. + gracefulQuitTimeout time.Duration + // The wait group for all goroutines. + wg sync.WaitGroup +} + +func NewHttpAPI(opts ...func(*httpAPI)) *httpAPI { + v := &httpAPI{} + for _, opt := range opts { + opt(v) + } + return v +} + +func (v *httpAPI) Close() error { + ctx, cancel := context.WithTimeout(context.Background(), v.gracefulQuitTimeout) + defer cancel() + v.server.Shutdown(ctx) + + v.wg.Wait() + return nil +} + +func (v *httpAPI) Run(ctx context.Context) error { + // Parse address to listen. + addr := envHttpAPI() + if !strings.Contains(addr, ":") { + addr = ":" + addr + } + + // Create server and handler. + mux := http.NewServeMux() + v.server = &http.Server{Addr: addr, Handler: mux} + logger.Df(ctx, "HTTP API server listen at %v", addr) + + // Shutdown the server gracefully when quiting. + go func() { + ctxParent := ctx + <-ctxParent.Done() + + ctx, cancel := context.WithTimeout(context.Background(), v.gracefulQuitTimeout) + defer cancel() + + v.server.Shutdown(ctx) + }() + + // The basic version handler, also can be used as health check API. + logger.Df(ctx, "Handle /api/v1/versions by %v", addr) + mux.HandleFunc("/api/v1/versions", func(w http.ResponseWriter, r *http.Request) { + apiResponse(ctx, w, r, map[string]string{ + "signature": Signature(), + "version": Version(), + }) + }) + + // Run HTTP API server. + v.wg.Add(1) + go func() { + defer v.wg.Done() + + err := v.server.ListenAndServe() + if err != nil { + if ctx.Err() != context.Canceled { + // TODO: If HTTP API server closed unexpectedly, we should notice the main loop to quit. + logger.Wf(ctx, "HTTP API accept err %+v", err) + } else { + logger.Df(ctx, "HTTP API server done") + } + } + }() + + return nil +} diff --git a/proxy/http.go b/proxy/http.go index 7d7881f0427..6c437ad176a 100644 --- a/proxy/http.go +++ b/proxy/http.go @@ -10,6 +10,7 @@ import ( "os" "srs-proxy/logger" "strings" + "sync" "time" ) @@ -18,6 +19,8 @@ type httpServer struct { server *http.Server // The gracefully quit timeout, wait server to quit. gracefulQuitTimeout time.Duration + // The wait group for all goroutines. + wg sync.WaitGroup } func NewHttpServer(opts ...func(*httpServer)) *httpServer { @@ -29,10 +32,15 @@ func NewHttpServer(opts ...func(*httpServer)) *httpServer { } func (v *httpServer) Close() error { - return v.server.Close() + ctx, cancel := context.WithTimeout(context.Background(), v.gracefulQuitTimeout) + defer cancel() + v.server.Shutdown(ctx) + + v.wg.Wait() + return nil } -func (v *httpServer) ListenAndServe(ctx context.Context) error { +func (v *httpServer) Run(ctx context.Context) error { // Parse address to listen. addr := envHttpServer() if !strings.Contains(addr, ":") { @@ -42,7 +50,7 @@ func (v *httpServer) ListenAndServe(ctx context.Context) error { // Create server and handler. mux := http.NewServeMux() v.server = &http.Server{Addr: addr, Handler: mux} - logger.Df(ctx, "HTTP stream server listen at %v", addr) + logger.Df(ctx, "HTTP Stream server listen at %v", addr) // Shutdown the server gracefully when quiting. go func() { @@ -79,5 +87,20 @@ func (v *httpServer) ListenAndServe(ctx context.Context) error { }) // Run HTTP server. - return v.server.ListenAndServe() + v.wg.Add(1) + go func() { + defer v.wg.Done() + + err := v.server.ListenAndServe() + if err != nil { + if ctx.Err() != context.Canceled { + // TODO: If HTTP Stream server closed unexpectedly, we should notice the main loop to quit. + logger.Wf(ctx, "HTTP Stream accept err %+v", err) + } else { + logger.Df(ctx, "HTTP Stream server done") + } + } + }() + + return nil } diff --git a/proxy/main.go b/proxy/main.go index 1f57060a606..616c95f631e 100644 --- a/proxy/main.go +++ b/proxy/main.go @@ -59,14 +59,25 @@ func doMain(ctx context.Context) error { return errors.Wrapf(err, "rtmp server") } + // Start the HTTP API server. + httpAPI := NewHttpAPI(func(server *httpAPI) { + server.gracefulQuitTimeout = gracefulQuitTimeout + }) + defer httpAPI.Close() + if err := httpAPI.Run(ctx); err != nil { + return errors.Wrapf(err, "http api server") + } + // Start the HTTP web server. httpServer := NewHttpServer(func(server *httpServer) { server.gracefulQuitTimeout = gracefulQuitTimeout }) defer httpServer.Close() - if err := httpServer.ListenAndServe(ctx); err != nil { + if err := httpServer.Run(ctx); err != nil { return errors.Wrapf(err, "http server") } + // Wait for the main loop to quit. + <-ctx.Done() return nil } diff --git a/proxy/rtmp.go b/proxy/rtmp.go index a4d95c9a721..fbfad95ac76 100644 --- a/proxy/rtmp.go +++ b/proxy/rtmp.go @@ -5,6 +5,7 @@ package main import ( "context" + "io" "math/rand" "net" "os" @@ -72,7 +73,7 @@ func (v *rtmpServer) Run(ctx context.Context) error { if err != nil { if ctx.Err() != context.Canceled { // TODO: If RTMP server closed unexpectedly, we should notice the main loop to quit. - logger.Wf(ctx, "accept rtmp err %+v", err) + logger.Wf(ctx, "RTMP server accept err %+v", err) } else { logger.Df(ctx, "RTMP server done") } @@ -82,7 +83,11 @@ func (v *rtmpServer) Run(ctx context.Context) error { go func(ctx context.Context, conn *net.TCPConn) { defer conn.Close() if err := v.serve(ctx, conn); err != nil { - logger.Wf(ctx, "serve conn %v err %+v", conn.RemoteAddr(), err) + if errors.Cause(err) == io.EOF { + logger.Df(ctx, "RTMP client peer closed") + } else { + logger.Wf(ctx, "serve conn %v err %+v", conn.RemoteAddr(), err) + } } else { logger.Df(ctx, "RTMP client done") } diff --git a/proxy/rtmp/amf0.go b/proxy/rtmp/amf0.go index 4a94457f023..f61a0b98e3c 100644 --- a/proxy/rtmp/amf0.go +++ b/proxy/rtmp/amf0.go @@ -11,7 +11,7 @@ import ( "math" "sync" - oe "srs-proxy/errors" + "srs-proxy/errors" ) // Please read @doc amf0_spec_121207.pdf, @page 4, @section 2.1 Types Overview @@ -109,7 +109,7 @@ type amf0Any interface { // Discovery the amf0 object from the bytes b. func Amf0Discovery(p []byte) (a amf0Any, err error) { if len(p) < 1 { - return nil, oe.Errorf("require 1 bytes only %v", len(p)) + return nil, errors.Errorf("require 1 bytes only %v", len(p)) } m := amf0Marker(p[0]) @@ -136,9 +136,9 @@ func Amf0Discovery(p []byte) (a amf0Any, err error) { case amf0MarkerDate, amf0MarkerLongString, amf0MarkerUnsupported, amf0MarkerXmlDocument, amf0MarkerTypedObject, amf0MarkerAvmPlusObject, amf0MarkerForbidden, amf0MarkerMovieClip, amf0MarkerRecordSet: - return nil, oe.Errorf("Marker %v is not supported", m) + return nil, errors.Errorf("Marker %v is not supported", m) } - return nil, oe.Errorf("Marker %v is invalid", m) + return nil, errors.Errorf("Marker %v is invalid", m) } // The UTF8 string, please read @doc amf0_spec_121207.pdf, @page 3, @section 1.3.1 Strings and UTF-8 @@ -151,12 +151,12 @@ func (v *amf0UTF8) Size() int { func (v *amf0UTF8) UnmarshalBinary(data []byte) (err error) { var p []byte if p = data; len(p) < 2 { - return oe.Errorf("require 2 bytes only %v", len(p)) + return errors.Errorf("require 2 bytes only %v", len(p)) } size := uint16(p[0])<<8 | uint16(p[1]) if p = data[2:]; len(p) < int(size) { - return oe.Errorf("require %v bytes only %v", int(size), len(p)) + return errors.Errorf("require %v bytes only %v", int(size), len(p)) } *v = amf0UTF8(string(p[:size])) @@ -196,10 +196,10 @@ func (v *amf0Number) Size() int { func (v *amf0Number) UnmarshalBinary(data []byte) (err error) { var p []byte if p = data; len(p) < 9 { - return oe.Errorf("require 9 bytes only %v", len(p)) + return errors.Errorf("require 9 bytes only %v", len(p)) } if m := amf0Marker(p[0]); m != amf0MarkerNumber { - return oe.Errorf("Amf0Number amf0Marker %v is illegal", m) + return errors.Errorf("Amf0Number amf0Marker %v is illegal", m) } f := binary.BigEndian.Uint64(p[1:]) @@ -235,15 +235,15 @@ func (v *amf0String) Size() int { func (v *amf0String) UnmarshalBinary(data []byte) (err error) { var p []byte if p = data; len(p) < 1 { - return oe.Errorf("require 1 bytes only %v", len(p)) + return errors.Errorf("require 1 bytes only %v", len(p)) } if m := amf0Marker(p[0]); m != amf0MarkerString { - return oe.Errorf("Amf0String amf0Marker %v is illegal", m) + return errors.Errorf("Amf0String amf0Marker %v is illegal", m) } var sv amf0UTF8 if err = sv.UnmarshalBinary(p[1:]); err != nil { - return oe.WithMessage(err, "utf8") + return errors.WithMessage(err, "utf8") } *v = amf0String(string(sv)) return @@ -254,7 +254,7 @@ func (v *amf0String) MarshalBinary() (data []byte, err error) { var pb []byte if pb, err = u.MarshalBinary(); err != nil { - return nil, oe.WithMessage(err, "utf8") + return nil, errors.WithMessage(err, "utf8") } data = append([]byte{byte(amf0MarkerString)}, pb...) @@ -277,11 +277,11 @@ func (v *amf0ObjectEOF) UnmarshalBinary(data []byte) (err error) { p := data if len(p) < 3 { - return oe.Errorf("require 3 bytes only %v", len(p)) + return errors.Errorf("require 3 bytes only %v", len(p)) } if p[0] != 0 || p[1] != 0 || p[2] != 9 { - return oe.Errorf("EOF amf0Marker %v is illegal", p[0:3]) + return errors.Errorf("EOF amf0Marker %v is illegal", p[0:3]) } return } @@ -353,23 +353,23 @@ func (v *amf0ObjectBase) Set(key string, value amf0Any) *amf0ObjectBase { func (v *amf0ObjectBase) unmarshal(p []byte, eof bool, maxElems int) (err error) { // if no eof, elems specified by maxElems. if !eof && maxElems < 0 { - return oe.Errorf("maxElems=%v without eof", maxElems) + return errors.Errorf("maxElems=%v without eof", maxElems) } // if eof, maxElems must be -1. if eof && maxElems != -1 { - return oe.Errorf("maxElems=%v with eof", maxElems) + return errors.Errorf("maxElems=%v with eof", maxElems) } readOne := func() (amf0UTF8, amf0Any, error) { var u amf0UTF8 if err = u.UnmarshalBinary(p); err != nil { - return "", nil, oe.WithMessage(err, "prop name") + return "", nil, errors.WithMessage(err, "prop name") } p = p[u.Size():] var a amf0Any if a, err = Amf0Discovery(p); err != nil { - return "", nil, oe.WithMessage(err, fmt.Sprintf("discover prop %v", string(u))) + return "", nil, errors.WithMessage(err, fmt.Sprintf("discover prop %v", string(u))) } return u, a, nil } @@ -377,7 +377,7 @@ func (v *amf0ObjectBase) unmarshal(p []byte, eof bool, maxElems int) (err error) pushOne := func(u amf0UTF8, a amf0Any) error { // For object property, consume the whole bytes. if err = a.UnmarshalBinary(p); err != nil { - return oe.WithMessage(err, fmt.Sprintf("unmarshal prop %v", string(u))) + return errors.WithMessage(err, fmt.Sprintf("unmarshal prop %v", string(u))) } v.Set(string(u), a) @@ -388,7 +388,7 @@ func (v *amf0ObjectBase) unmarshal(p []byte, eof bool, maxElems int) (err error) for eof { u, a, err := readOne() if err != nil { - return oe.WithMessage(err, "read") + return errors.WithMessage(err, "read") } // For object EOF, we should only consume total 3bytes. @@ -399,18 +399,18 @@ func (v *amf0ObjectBase) unmarshal(p []byte, eof bool, maxElems int) (err error) } if err := pushOne(u, a); err != nil { - return oe.WithMessage(err, "push") + return errors.WithMessage(err, "push") } } for len(v.properties) < maxElems { u, a, err := readOne() if err != nil { - return oe.WithMessage(err, "read") + return errors.WithMessage(err, "read") } if err := pushOne(u, a); err != nil { - return oe.WithMessage(err, "push") + return errors.WithMessage(err, "push") } } @@ -426,17 +426,17 @@ func (v *amf0ObjectBase) marshal(b amf0Buffer) (err error) { key, value := p.key, p.value if pb, err = key.MarshalBinary(); err != nil { - return oe.WithMessage(err, fmt.Sprintf("marshal %v", string(key))) + return errors.WithMessage(err, fmt.Sprintf("marshal %v", string(key))) } if _, err = b.Write(pb); err != nil { - return oe.Wrapf(err, "write %v", string(key)) + return errors.Wrapf(err, "write %v", string(key)) } if pb, err = value.MarshalBinary(); err != nil { - return oe.WithMessage(err, fmt.Sprintf("marshal value for %v", string(key))) + return errors.WithMessage(err, fmt.Sprintf("marshal value for %v", string(key))) } if _, err = b.Write(pb); err != nil { - return oe.Wrapf(err, "marshal value for %v", string(key)) + return errors.Wrapf(err, "marshal value for %v", string(key)) } } @@ -466,15 +466,15 @@ func (v *amf0Object) Size() int { func (v *amf0Object) UnmarshalBinary(data []byte) (err error) { var p []byte if p = data; len(p) < 1 { - return oe.Errorf("require 1 byte only %v", len(p)) + return errors.Errorf("require 1 byte only %v", len(p)) } if m := amf0Marker(p[0]); m != amf0MarkerObject { - return oe.Errorf("Amf0Object amf0Marker %v is illegal", m) + return errors.Errorf("Amf0Object amf0Marker %v is illegal", m) } p = p[1:] if err = v.unmarshal(p, true, -1); err != nil { - return oe.WithMessage(err, "unmarshal") + return errors.WithMessage(err, "unmarshal") } return @@ -484,19 +484,19 @@ func (v *amf0Object) MarshalBinary() (data []byte, err error) { b := createBuffer() if err = b.WriteByte(byte(amf0MarkerObject)); err != nil { - return nil, oe.Wrap(err, "marshal") + return nil, errors.Wrap(err, "marshal") } if err = v.marshal(b); err != nil { - return nil, oe.WithMessage(err, "marshal") + return nil, errors.WithMessage(err, "marshal") } var pb []byte if pb, err = v.eof.MarshalBinary(); err != nil { - return nil, oe.WithMessage(err, "marshal") + return nil, errors.WithMessage(err, "marshal") } if _, err = b.Write(pb); err != nil { - return nil, oe.Wrap(err, "marshal") + return nil, errors.Wrap(err, "marshal") } return b.Bytes(), nil @@ -526,16 +526,16 @@ func (v *amf0EcmaArray) Size() int { func (v *amf0EcmaArray) UnmarshalBinary(data []byte) (err error) { var p []byte if p = data; len(p) < 5 { - return oe.Errorf("require 5 bytes only %v", len(p)) + return errors.Errorf("require 5 bytes only %v", len(p)) } if m := amf0Marker(p[0]); m != amf0MarkerEcmaArray { - return oe.Errorf("EcmaArray amf0Marker %v is illegal", m) + return errors.Errorf("EcmaArray amf0Marker %v is illegal", m) } v.count = binary.BigEndian.Uint32(p[1:]) p = p[5:] if err = v.unmarshal(p, true, -1); err != nil { - return oe.WithMessage(err, "unmarshal") + return errors.WithMessage(err, "unmarshal") } return } @@ -544,23 +544,23 @@ func (v *amf0EcmaArray) MarshalBinary() (data []byte, err error) { b := createBuffer() if err = b.WriteByte(byte(amf0MarkerEcmaArray)); err != nil { - return nil, oe.Wrap(err, "marshal") + return nil, errors.Wrap(err, "marshal") } if err = binary.Write(b, binary.BigEndian, v.count); err != nil { - return nil, oe.Wrap(err, "marshal") + return nil, errors.Wrap(err, "marshal") } if err = v.marshal(b); err != nil { - return nil, oe.WithMessage(err, "marshal") + return nil, errors.WithMessage(err, "marshal") } var pb []byte if pb, err = v.eof.MarshalBinary(); err != nil { - return nil, oe.WithMessage(err, "marshal") + return nil, errors.WithMessage(err, "marshal") } if _, err = b.Write(pb); err != nil { - return nil, oe.Wrap(err, "marshal") + return nil, errors.Wrap(err, "marshal") } return b.Bytes(), nil @@ -589,10 +589,10 @@ func (v *amf0StrictArray) Size() int { func (v *amf0StrictArray) UnmarshalBinary(data []byte) (err error) { var p []byte if p = data; len(p) < 5 { - return oe.Errorf("require 5 bytes only %v", len(p)) + return errors.Errorf("require 5 bytes only %v", len(p)) } if m := amf0Marker(p[0]); m != amf0MarkerStrictArray { - return oe.Errorf("StrictArray amf0Marker %v is illegal", m) + return errors.Errorf("StrictArray amf0Marker %v is illegal", m) } v.count = binary.BigEndian.Uint32(p[1:]) p = p[5:] @@ -602,7 +602,7 @@ func (v *amf0StrictArray) UnmarshalBinary(data []byte) (err error) { } if err = v.unmarshal(p, false, int(v.count)); err != nil { - return oe.WithMessage(err, "unmarshal") + return errors.WithMessage(err, "unmarshal") } return } @@ -611,15 +611,15 @@ func (v *amf0StrictArray) MarshalBinary() (data []byte, err error) { b := createBuffer() if err = b.WriteByte(byte(amf0MarkerStrictArray)); err != nil { - return nil, oe.Wrap(err, "marshal") + return nil, errors.Wrap(err, "marshal") } if err = binary.Write(b, binary.BigEndian, v.count); err != nil { - return nil, oe.Wrap(err, "marshal") + return nil, errors.Wrap(err, "marshal") } if err = v.marshal(b); err != nil { - return nil, oe.WithMessage(err, "marshal") + return nil, errors.WithMessage(err, "marshal") } return b.Bytes(), nil @@ -645,10 +645,10 @@ func (v *amf0SingleMarkerObject) Size() int { func (v *amf0SingleMarkerObject) UnmarshalBinary(data []byte) (err error) { var p []byte if p = data; len(p) < 1 { - return oe.Errorf("require 1 byte only %v", len(p)) + return errors.Errorf("require 1 byte only %v", len(p)) } if m := amf0Marker(p[0]); m != v.target { - return oe.Errorf("%v amf0Marker %v is illegal", v.target, m) + return errors.Errorf("%v amf0Marker %v is illegal", v.target, m) } return } @@ -698,10 +698,10 @@ func (v *amf0Boolean) Size() int { func (v *amf0Boolean) UnmarshalBinary(data []byte) (err error) { var p []byte if p = data; len(p) < 2 { - return oe.Errorf("require 2 bytes only %v", len(p)) + return errors.Errorf("require 2 bytes only %v", len(p)) } if m := amf0Marker(p[0]); m != amf0MarkerBoolean { - return oe.Errorf("BOOL amf0Marker %v is illegal", m) + return errors.Errorf("BOOL amf0Marker %v is illegal", m) } if p[1] == 0 { *v = false diff --git a/proxy/rtmp/rtmp.go b/proxy/rtmp/rtmp.go index 10a99ec3cbc..cc1f6611a55 100644 --- a/proxy/rtmp/rtmp.go +++ b/proxy/rtmp/rtmp.go @@ -14,7 +14,7 @@ import ( "math/rand" "sync" - oe "srs-proxy/errors" + "srs-proxy/errors" ) // The handshake implements the RTMP handshake protocol. @@ -36,7 +36,7 @@ func (v *Handshake) C1S1() []byte { func (v *Handshake) WriteC0S0(w io.Writer) (err error) { r := bytes.NewReader([]byte{0x03}) if _, err = io.Copy(w, r); err != nil { - return oe.Wrap(err, "write c0s0") + return errors.Wrap(err, "write c0s0") } return @@ -45,7 +45,7 @@ func (v *Handshake) WriteC0S0(w io.Writer) (err error) { func (v *Handshake) ReadC0S0(r io.Reader) (c0 []byte, err error) { b := &bytes.Buffer{} if _, err = io.CopyN(b, r, 1); err != nil { - return nil, oe.Wrap(err, "read c0s0") + return nil, errors.Wrap(err, "read c0s0") } c0 = b.Bytes() @@ -62,7 +62,7 @@ func (v *Handshake) WriteC1S1(w io.Writer) (err error) { r := bytes.NewReader(p) if _, err = io.Copy(w, r); err != nil { - return oe.Wrap(err, "write c0s1") + return errors.Wrap(err, "write c0s1") } return @@ -71,7 +71,7 @@ func (v *Handshake) WriteC1S1(w io.Writer) (err error) { func (v *Handshake) ReadC1S1(r io.Reader) (c1s1 []byte, err error) { b := &bytes.Buffer{} if _, err = io.CopyN(b, r, 1536); err != nil { - return nil, oe.Wrap(err, "read c1s1") + return nil, errors.Wrap(err, "read c1s1") } c1s1 = b.Bytes() @@ -83,7 +83,7 @@ func (v *Handshake) ReadC1S1(r io.Reader) (c1s1 []byte, err error) { func (v *Handshake) WriteC2S2(w io.Writer, s1c1 []byte) (err error) { r := bytes.NewReader(s1c1[:]) if _, err = io.Copy(w, r); err != nil { - return oe.Wrap(err, "write c2s2") + return errors.Wrap(err, "write c2s2") } return @@ -92,7 +92,7 @@ func (v *Handshake) WriteC2S2(w io.Writer, s1c1 []byte) (err error) { func (v *Handshake) ReadC2S2(r io.Reader) (c2 []byte, err error) { b := &bytes.Buffer{} if _, err = io.CopyN(b, r, 1536); err != nil { - return nil, oe.Wrap(err, "read c2s2") + return nil, errors.Wrap(err, "read c2s2") } c2 = b.Bytes() @@ -173,12 +173,12 @@ func NewProtocol(rw io.ReadWriter) *Protocol { func ExpectPacket[T Packet](ctx context.Context, v *Protocol, ppkt *T) (m *Message, err error) { for { if m, err = v.ReadMessage(ctx); err != nil { - return nil, oe.WithMessage(err, "read message") + return nil, errors.WithMessage(err, "read message") } var pkt Packet if pkt, err = v.DecodeMessage(m); err != nil { - return nil, oe.WithMessage(err, "decode message") + return nil, errors.WithMessage(err, "decode message") } if p, ok := pkt.(T); ok { @@ -198,7 +198,7 @@ func (v *Protocol) ExpectPacket(ctx context.Context, ppkt any) (m *Message, err func (v *Protocol) ExpectMessage(ctx context.Context, types ...MessageType) (m *Message, err error) { for { if m, err = v.ReadMessage(ctx); err != nil { - return nil, oe.WithMessage(err, "read message") + return nil, errors.WithMessage(err, "read message") } if len(types) == 0 { @@ -218,14 +218,14 @@ func (v *Protocol) ExpectMessage(ctx context.Context, types ...MessageType) (m * func (v *Protocol) parseAMFObject(p []byte) (pkt Packet, err error) { var commandName amf0String if err = commandName.UnmarshalBinary(p); err != nil { - return nil, oe.WithMessage(err, "unmarshal command name") + return nil, errors.WithMessage(err, "unmarshal command name") } switch commandName { case commandResult, commandError: var transactionID amf0Number if err = transactionID.UnmarshalBinary(p[commandName.Size():]); err != nil { - return nil, oe.WithMessage(err, "unmarshal tid") + return nil, errors.WithMessage(err, "unmarshal tid") } var requestName amf0String @@ -235,13 +235,13 @@ func (v *Protocol) parseAMFObject(p []byte) (pkt Packet, err error) { var ok bool if requestName, ok = v.input.transactions[transactionID]; !ok { - return oe.Errorf("No matched request for tid=%v", transactionID) + return errors.Errorf("No matched request for tid=%v", transactionID) } delete(v.input.transactions, transactionID) return nil }(); err != nil { - return nil, oe.WithMessage(err, "discovery request name") + return nil, errors.WithMessage(err, "discovery request name") } switch requestName { @@ -250,7 +250,7 @@ func (v *Protocol) parseAMFObject(p []byte) (pkt Packet, err error) { case commandCreateStream: return NewCreateStreamResPacket(transactionID), nil default: - return nil, oe.Errorf("No request for %v", string(requestName)) + return nil, errors.Errorf("No request for %v", string(requestName)) } case commandConnect: return NewConnectAppPacket(), nil @@ -264,7 +264,7 @@ func (v *Protocol) parseAMFObject(p []byte) (pkt Packet, err error) { func (v *Protocol) DecodeMessage(m *Message) (pkt Packet, err error) { p := m.Payload[:] if len(p) == 0 { - return nil, oe.New("Empty packet") + return nil, errors.New("Empty packet") } switch m.MessageType { @@ -281,16 +281,16 @@ func (v *Protocol) DecodeMessage(m *Message) (pkt Packet, err error) { pkt = NewSetPeerBandwidth() case MessageTypeAMF0Command, MessageTypeAMF3Command, MessageTypeAMF0Data, MessageTypeAMF3Data: if pkt, err = v.parseAMFObject(p); err != nil { - return nil, oe.WithMessage(err, fmt.Sprintf("Parse AMF %v", m.MessageType)) + return nil, errors.WithMessage(err, fmt.Sprintf("Parse AMF %v", m.MessageType)) } case MessageTypeUserControl: pkt = NewUserControl() default: - return nil, oe.Errorf("Unknown message %v", m.MessageType) + return nil, errors.Errorf("Unknown message %v", m.MessageType) } if err = pkt.UnmarshalBinary(p); err != nil { - return nil, oe.WithMessage(err, fmt.Sprintf("Unmarshal %v", m.MessageType)) + return nil, errors.WithMessage(err, fmt.Sprintf("Unmarshal %v", m.MessageType)) } return @@ -307,7 +307,7 @@ func (v *Protocol) ReadMessage(ctx context.Context) (m *Message, err error) { var cid chunkID var format formatType if format, cid, err = v.readBasicHeader(ctx); err != nil { - return nil, oe.WithMessage(err, "read basic header") + return nil, errors.WithMessage(err, "read basic header") } var ok bool @@ -319,15 +319,15 @@ func (v *Protocol) ReadMessage(ctx context.Context) (m *Message, err error) { } if err = v.readMessageHeader(ctx, chunk, format); err != nil { - return nil, oe.WithMessage(err, "read message header") + return nil, errors.WithMessage(err, "read message header") } if m, err = v.readMessagePayload(ctx, chunk); err != nil { - return nil, oe.WithMessage(err, "read message payload") + return nil, errors.WithMessage(err, "read message payload") } if err = v.onMessageArrivated(m); err != nil { - return nil, oe.WithMessage(err, "on message") + return nil, errors.WithMessage(err, "on message") } } @@ -350,7 +350,7 @@ func (v *Protocol) readMessagePayload(ctx context.Context, chunk *chunkStream) ( b := make([]byte, chunkedPayloadSize) if _, err = io.ReadFull(v.r, b); err != nil { - return nil, oe.Wrapf(err, "read chunk %vB", chunkedPayloadSize) + return nil, errors.Wrapf(err, "read chunk %vB", chunkedPayloadSize) } chunk.message.Payload = append(chunk.message.Payload, b...) @@ -460,14 +460,14 @@ func (v *Protocol) readMessageHeader(ctx context.Context, chunk *chunkStream, fo if chunk.cid == chunkIDProtocolControl && format == formatType1 { // We accept cid=2, fmt=1 to make librtmp happy. } else { - return oe.Errorf("For fresh chunk, fmt %v != %v(required), cid is %v", format, formatType0, chunk.cid) + return errors.Errorf("For fresh chunk, fmt %v != %v(required), cid is %v", format, formatType0, chunk.cid) } } // When exists cache msg, means got an partial message, // the fmt must not be type0 which means new message. if chunk.message != nil && format == formatType0 { - return oe.Errorf("For exists chunk, fmt is %v, cid is %v", format, chunk.cid) + return errors.Errorf("For exists chunk, fmt is %v, cid is %v", format, chunk.cid) } // Create msg when new chunk stream start @@ -478,7 +478,7 @@ func (v *Protocol) readMessageHeader(ctx context.Context, chunk *chunkStream, fo // Read the message header. p := make([]byte, messageHeaderSizes[format]) if _, err = io.ReadFull(v.r, p); err != nil { - return oe.Wrapf(err, "read %vB message header", len(p)) + return errors.Wrapf(err, "read %vB message header", len(p)) } // Prse the message header. @@ -543,7 +543,7 @@ func (v *Protocol) readMessageHeader(ctx context.Context, chunk *chunkStream, fo // for the fmt type1(stream_id not changed), user can change the payload // length(it's not allowed in the continue chunks). if !isFirstChunkOfMsg && chunk.header.payloadLength != payloadLength { - return oe.Errorf("Chunk message size %v != %v(required)", payloadLength, chunk.header.payloadLength) + return errors.Errorf("Chunk message size %v != %v(required)", payloadLength, chunk.header.payloadLength) } chunk.header.payloadLength = payloadLength @@ -566,7 +566,7 @@ func (v *Protocol) readMessageHeader(ctx context.Context, chunk *chunkStream, fo if chunk.extendedTimestamp { var timestamp uint32 if err = binary.Read(v.r, binary.BigEndian, ×tamp); err != nil { - return oe.Wrapf(err, "read ext-ts, pkt-ts=%v", chunk.header.Timestamp) + return errors.Wrapf(err, "read ext-ts, pkt-ts=%v", chunk.header.Timestamp) } // We always use 31bits timestamp, for some server may use 32bits extended timestamp. @@ -655,7 +655,7 @@ func (v *Protocol) readBasicHeader(ctx context.Context) (format formatType, cid // 2-63, 1B chunk header var t uint8 if err = binary.Read(v.r, binary.BigEndian, &t); err != nil { - return format, cid, oe.Wrap(err, "read basic header") + return format, cid, errors.Wrap(err, "read basic header") } cid = chunkID(t & 0x3f) format = formatType((t >> 6) & 0x03) @@ -666,14 +666,14 @@ func (v *Protocol) readBasicHeader(ctx context.Context) (format formatType, cid // 64-319, 2B chunk header if err = binary.Read(v.r, binary.BigEndian, &t); err != nil { - return format, cid, oe.Wrapf(err, "read basic header for cid=%v", cid) + return format, cid, errors.Wrapf(err, "read basic header for cid=%v", cid) } cid = chunkID(64 + uint32(t)) // 64-65599, 3B chunk header if cid == 1 { if err = binary.Read(v.r, binary.BigEndian, &t); err != nil { - return format, cid, oe.Wrapf(err, "read basic header for cid=%v", cid) + return format, cid, errors.Wrapf(err, "read basic header for cid=%v", cid) } cid += chunkID(uint32(t) * 256) } @@ -685,7 +685,7 @@ func (v *Protocol) WritePacket(ctx context.Context, pkt Packet, streamID int) (e m := NewMessage() if m.Payload, err = pkt.MarshalBinary(); err != nil { - return oe.WithMessage(err, "marshal payload") + return errors.WithMessage(err, "marshal payload") } m.MessageType = pkt.Type() @@ -693,11 +693,11 @@ func (v *Protocol) WritePacket(ctx context.Context, pkt Packet, streamID int) (e m.betterCid = pkt.BetterCid() if err = v.WriteMessage(ctx, m); err != nil { - return oe.WithMessage(err, "write message") + return errors.WithMessage(err, "write message") } if err = v.onPacketWriten(m, pkt); err != nil { - return oe.WithMessage(err, "on write packet") + return errors.WithMessage(err, "on write packet") } return @@ -733,7 +733,7 @@ func (v *Protocol) onMessageArrivated(m *Message) (err error) { switch m.MessageType { case MessageTypeSetChunkSize, MessageTypeUserControl, MessageTypeWindowAcknowledgementSize: if pkt, err = v.DecodeMessage(m); err != nil { - return oe.Errorf("decode message %v", m.MessageType) + return errors.Errorf("decode message %v", m.MessageType) } } @@ -750,10 +750,10 @@ func (v *Protocol) WriteMessage(ctx context.Context, m *Message) (err error) { var c0h, c3h []byte if c0h, err = m.generateC0Header(); err != nil { - return oe.WithMessage(err, "generate c0 header") + return errors.WithMessage(err, "generate c0 header") } if c3h, err = m.generateC3Header(); err != nil { - return oe.WithMessage(err, "generate c3 header") + return errors.WithMessage(err, "generate c3 header") } var h []byte @@ -772,7 +772,7 @@ func (v *Protocol) WriteMessage(ctx context.Context, m *Message) (err error) { } if _, err = io.Copy(v.w, bytes.NewReader(h)); err != nil { - return oe.Wrapf(err, "write c0c3 header %x", h) + return errors.Wrapf(err, "write c0c3 header %x", h) } size := len(p) @@ -781,7 +781,7 @@ func (v *Protocol) WriteMessage(ctx context.Context, m *Message) (err error) { } if _, err = io.Copy(v.w, bytes.NewReader(p[:size])); err != nil { - return oe.Wrapf(err, "write chunk payload %vB", size) + return errors.Wrapf(err, "write chunk payload %vB", size) } p = p[size:] } @@ -794,7 +794,7 @@ func (v *Protocol) WriteMessage(ctx context.Context, m *Message) (err error) { // TODO: FIXME: Use writev to write for high performance. if err = v.w.Flush(); err != nil { - return oe.Wrapf(err, "flush writer") + return errors.Wrapf(err, "flush writer") } return @@ -1053,17 +1053,17 @@ func (v *objectCallPacket) UnmarshalBinary(data []byte) (err error) { p := data if err = v.CommandName.UnmarshalBinary(p); err != nil { - return oe.WithMessage(err, "unmarshal command name") + return errors.WithMessage(err, "unmarshal command name") } p = p[v.CommandName.Size():] if err = v.TransactionID.UnmarshalBinary(p); err != nil { - return oe.WithMessage(err, "unmarshal tid") + return errors.WithMessage(err, "unmarshal tid") } p = p[v.TransactionID.Size():] if err = v.CommandObject.UnmarshalBinary(p); err != nil { - return oe.WithMessage(err, "unmarshal command") + return errors.WithMessage(err, "unmarshal command") } p = p[v.CommandObject.Size():] @@ -1073,7 +1073,7 @@ func (v *objectCallPacket) UnmarshalBinary(data []byte) (err error) { v.Args = NewAmf0Object() if err = v.Args.UnmarshalBinary(p); err != nil { - return oe.WithMessage(err, "unmarshal args") + return errors.WithMessage(err, "unmarshal args") } return @@ -1082,23 +1082,23 @@ func (v *objectCallPacket) UnmarshalBinary(data []byte) (err error) { func (v *objectCallPacket) MarshalBinary() (data []byte, err error) { var pb []byte if pb, err = v.CommandName.MarshalBinary(); err != nil { - return nil, oe.WithMessage(err, "marshal command name") + return nil, errors.WithMessage(err, "marshal command name") } data = append(data, pb...) if pb, err = v.TransactionID.MarshalBinary(); err != nil { - return nil, oe.WithMessage(err, "marshal tid") + return nil, errors.WithMessage(err, "marshal tid") } data = append(data, pb...) if pb, err = v.CommandObject.MarshalBinary(); err != nil { - return nil, oe.WithMessage(err, "marshal command object") + return nil, errors.WithMessage(err, "marshal command object") } data = append(data, pb...) if v.Args != nil { if pb, err = v.Args.MarshalBinary(); err != nil { - return nil, oe.WithMessage(err, "marshal args") + return nil, errors.WithMessage(err, "marshal args") } data = append(data, pb...) } @@ -1123,15 +1123,15 @@ func NewConnectAppPacket() *ConnectAppPacket { func (v *ConnectAppPacket) UnmarshalBinary(data []byte) (err error) { if err = v.objectCallPacket.UnmarshalBinary(data); err != nil { - return oe.WithMessage(err, "unmarshal call") + return errors.WithMessage(err, "unmarshal call") } if v.CommandName != commandConnect { - return oe.Errorf("Invalid command name %v", string(v.CommandName)) + return errors.Errorf("Invalid command name %v", string(v.CommandName)) } if v.TransactionID != 1.0 { - return oe.Errorf("Invalid transaction ID %v", float64(v.TransactionID)) + return errors.Errorf("Invalid transaction ID %v", float64(v.TransactionID)) } return @@ -1165,11 +1165,11 @@ func NewConnectAppResPacket(tid amf0Number) *ConnectAppResPacket { func (v *ConnectAppResPacket) UnmarshalBinary(data []byte) (err error) { if err = v.objectCallPacket.UnmarshalBinary(data); err != nil { - return oe.WithMessage(err, "unmarshal call") + return errors.WithMessage(err, "unmarshal call") } if v.CommandName != commandResult { - return oe.Errorf("Invalid command name %v", string(v.CommandName)) + return errors.Errorf("Invalid command name %v", string(v.CommandName)) } return @@ -1204,21 +1204,21 @@ func (v *variantCallPacket) UnmarshalBinary(data []byte) (err error) { p := data if err = v.CommandName.UnmarshalBinary(p); err != nil { - return oe.WithMessage(err, "unmarshal command name") + return errors.WithMessage(err, "unmarshal command name") } p = p[v.CommandName.Size():] if err = v.TransactionID.UnmarshalBinary(p); err != nil { - return oe.WithMessage(err, "unmarshal tid") + return errors.WithMessage(err, "unmarshal tid") } p = p[v.TransactionID.Size():] if len(p) > 0 { if v.CommandObject, err = Amf0Discovery(p); err != nil { - return oe.WithMessage(err, "discovery command object") + return errors.WithMessage(err, "discovery command object") } if err = v.CommandObject.UnmarshalBinary(p); err != nil { - return oe.WithMessage(err, "unmarshal command object") + return errors.WithMessage(err, "unmarshal command object") } p = p[v.CommandObject.Size():] } @@ -1229,18 +1229,18 @@ func (v *variantCallPacket) UnmarshalBinary(data []byte) (err error) { func (v *variantCallPacket) MarshalBinary() (data []byte, err error) { var pb []byte if pb, err = v.CommandName.MarshalBinary(); err != nil { - return nil, oe.WithMessage(err, "marshal command name") + return nil, errors.WithMessage(err, "marshal command name") } data = append(data, pb...) if pb, err = v.TransactionID.MarshalBinary(); err != nil { - return nil, oe.WithMessage(err, "marshal tid") + return nil, errors.WithMessage(err, "marshal tid") } data = append(data, pb...) if v.CommandObject != nil { if pb, err = v.CommandObject.MarshalBinary(); err != nil { - return nil, oe.WithMessage(err, "marshal command object") + return nil, errors.WithMessage(err, "marshal command object") } data = append(data, pb...) } @@ -1283,16 +1283,16 @@ func (v *CallPacket) UnmarshalBinary(data []byte) (err error) { p := data if err = v.variantCallPacket.UnmarshalBinary(p); err != nil { - return oe.WithMessage(err, "unmarshal call") + return errors.WithMessage(err, "unmarshal call") } p = p[v.variantCallPacket.Size():] if len(p) > 0 { if v.Args, err = Amf0Discovery(p); err != nil { - return oe.WithMessage(err, "discovery args") + return errors.WithMessage(err, "discovery args") } if err = v.Args.UnmarshalBinary(p); err != nil { - return oe.WithMessage(err, "unmarshal args") + return errors.WithMessage(err, "unmarshal args") } } @@ -1302,13 +1302,13 @@ func (v *CallPacket) UnmarshalBinary(data []byte) (err error) { func (v *CallPacket) MarshalBinary() (data []byte, err error) { var pb []byte if pb, err = v.variantCallPacket.MarshalBinary(); err != nil { - return nil, oe.WithMessage(err, "marshal call") + return nil, errors.WithMessage(err, "marshal call") } data = append(data, pb...) if v.Args != nil { if pb, err = v.Args.MarshalBinary(); err != nil { - return nil, oe.WithMessage(err, "marshal args") + return nil, errors.WithMessage(err, "marshal args") } data = append(data, pb...) } @@ -1355,12 +1355,12 @@ func (v *CreateStreamResPacket) UnmarshalBinary(data []byte) (err error) { p := data if err = v.variantCallPacket.UnmarshalBinary(p); err != nil { - return oe.WithMessage(err, "unmarshal call") + return errors.WithMessage(err, "unmarshal call") } p = p[v.variantCallPacket.Size():] if err = v.StreamID.UnmarshalBinary(p); err != nil { - return oe.WithMessage(err, "unmarshal sid") + return errors.WithMessage(err, "unmarshal sid") } return @@ -1369,12 +1369,12 @@ func (v *CreateStreamResPacket) UnmarshalBinary(data []byte) (err error) { func (v *CreateStreamResPacket) MarshalBinary() (data []byte, err error) { var pb []byte if pb, err = v.variantCallPacket.MarshalBinary(); err != nil { - return nil, oe.WithMessage(err, "marshal call") + return nil, errors.WithMessage(err, "marshal call") } data = append(data, pb...) if pb, err = v.StreamID.MarshalBinary(); err != nil { - return nil, oe.WithMessage(err, "marshal sid") + return nil, errors.WithMessage(err, "marshal sid") } data = append(data, pb...) @@ -1404,17 +1404,17 @@ func (v *PublishPacket) UnmarshalBinary(data []byte) (err error) { p := data if err = v.variantCallPacket.UnmarshalBinary(p); err != nil { - return oe.WithMessage(err, "unmarshal call") + return errors.WithMessage(err, "unmarshal call") } p = p[v.variantCallPacket.Size():] if err = v.StreamName.UnmarshalBinary(p); err != nil { - return oe.WithMessage(err, "unmarshal stream name") + return errors.WithMessage(err, "unmarshal stream name") } p = p[v.StreamName.Size():] if err = v.StreamType.UnmarshalBinary(p); err != nil { - return oe.WithMessage(err, "unmarshal stream type") + return errors.WithMessage(err, "unmarshal stream type") } return @@ -1423,17 +1423,17 @@ func (v *PublishPacket) UnmarshalBinary(data []byte) (err error) { func (v *PublishPacket) MarshalBinary() (data []byte, err error) { var pb []byte if pb, err = v.variantCallPacket.MarshalBinary(); err != nil { - return nil, oe.WithMessage(err, "marshal call") + return nil, errors.WithMessage(err, "marshal call") } data = append(data, pb...) if pb, err = v.StreamName.MarshalBinary(); err != nil { - return nil, oe.WithMessage(err, "marshal stream name") + return nil, errors.WithMessage(err, "marshal stream name") } data = append(data, pb...) if pb, err = v.StreamType.MarshalBinary(); err != nil { - return nil, oe.WithMessage(err, "marshal stream type") + return nil, errors.WithMessage(err, "marshal stream type") } data = append(data, pb...) @@ -1461,12 +1461,12 @@ func (v *PlayPacket) UnmarshalBinary(data []byte) (err error) { p := data if err = v.variantCallPacket.UnmarshalBinary(p); err != nil { - return oe.WithMessage(err, "unmarshal call") + return errors.WithMessage(err, "unmarshal call") } p = p[v.variantCallPacket.Size():] if err = v.StreamName.UnmarshalBinary(p); err != nil { - return oe.WithMessage(err, "unmarshal stream name") + return errors.WithMessage(err, "unmarshal stream name") } p = p[v.StreamName.Size():] @@ -1476,12 +1476,12 @@ func (v *PlayPacket) UnmarshalBinary(data []byte) (err error) { func (v *PlayPacket) MarshalBinary() (data []byte, err error) { var pb []byte if pb, err = v.variantCallPacket.MarshalBinary(); err != nil { - return nil, oe.WithMessage(err, "marshal call") + return nil, errors.WithMessage(err, "marshal call") } data = append(data, pb...) if pb, err = v.StreamName.MarshalBinary(); err != nil { - return nil, oe.WithMessage(err, "marshal stream name") + return nil, errors.WithMessage(err, "marshal stream name") } data = append(data, pb...) @@ -1515,7 +1515,7 @@ func (v *SetChunkSize) Size() int { func (v *SetChunkSize) UnmarshalBinary(data []byte) (err error) { if len(data) < 4 { - return oe.Errorf("requires 4 only %v bytes, %x", len(data), data) + return errors.Errorf("requires 4 only %v bytes, %x", len(data), data) } v.ChunkSize = binary.BigEndian.Uint32(data) @@ -1554,7 +1554,7 @@ func (v *WindowAcknowledgementSize) Size() int { func (v *WindowAcknowledgementSize) UnmarshalBinary(data []byte) (err error) { if len(data) < 4 { - return oe.Errorf("requires 4 only %v bytes, %x", len(data), data) + return errors.Errorf("requires 4 only %v bytes, %x", len(data), data) } v.AckSize = binary.BigEndian.Uint32(data) @@ -1605,7 +1605,7 @@ func (v *SetPeerBandwidth) Size() int { func (v *SetPeerBandwidth) UnmarshalBinary(data []byte) (err error) { if len(data) < 5 { - return oe.Errorf("requires 5 only %v bytes, %x", len(data), data) + return errors.Errorf("requires 5 only %v bytes, %x", len(data), data) } v.Bandwidth = binary.BigEndian.Uint32(data) v.LimitType = LimitType(data[4]) @@ -1734,12 +1734,12 @@ func (v *UserControl) Size() int { func (v *UserControl) UnmarshalBinary(data []byte) (err error) { if len(data) < 3 { - return oe.Errorf("requires 5 only %v bytes, %x", len(data), data) + return errors.Errorf("requires 5 only %v bytes, %x", len(data), data) } v.EventType = EventType(binary.BigEndian.Uint16(data)) if len(data) < v.Size() { - return oe.Errorf("requires %v only %v bytes, %x", v.Size(), len(data), data) + return errors.Errorf("requires %v only %v bytes, %x", v.Size(), len(data), data) } if v.EventType == EventTypeFmsEvent0 {