diff --git a/internal/server/util_test.go b/internal/server/util_test.go index 32aa6886..d8e306e0 100644 --- a/internal/server/util_test.go +++ b/internal/server/util_test.go @@ -5,7 +5,9 @@ package server import ( "net" + "os" "testing" + "time" "github.com/pion/stun/v3" "github.com/pion/turn/v4/internal/proto" @@ -41,7 +43,7 @@ func TestAuthenticateRequest(t *testing.T) { r = &Request{ Conn: conn, SrcAddr: srcAddr, - AuthHandler: func(username, realm string, srcAddr net.Addr) (key []byte, ok bool) { + AuthHandler: func(username, realm string, _ net.Addr) (key []byte, ok bool) { return testMsgIntegrity, username == testUsername && realm == testRealm }, NonceHash: nonceHash, @@ -58,6 +60,37 @@ func TestAuthenticateRequest(t *testing.T) { } } + checkSTUNAllocateErrorResponse := func(t *testing.T) stun.ErrorCode { + // Set read deadline to avoid blocking for a long time + err := conn.SetReadDeadline(time.Now().Add(10 * time.Millisecond)) + require.NoError(t, err) + + // Check the error response + buf := make([]byte, 1024) + n, _, err := conn.ReadFrom(buf) + require.NoError(t, err) + + resp := &stun.Message{} + err = resp.UnmarshalBinary(buf[:n]) + require.NoError(t, err) + require.Equal(t, stun.NewType(stun.MethodAllocate, stun.ClassErrorResponse), resp.Type) + var attrErrorCode stun.ErrorCodeAttribute + err = attrErrorCode.GetFrom(resp) + require.NoError(t, err) + return attrErrorCode.Code + } + + checkNoSTUNResponse := func(t *testing.T) { + // Set read deadline to avoid blocking for a long time + err := conn.SetReadDeadline(time.Now().Add(time.Millisecond)) + require.NoError(t, err) + + // Check the error response + buf := make([]byte, 1024) + _, _, err = conn.ReadFrom(buf) + require.ErrorIs(t, err, os.ErrDeadlineExceeded) + } + t.Run("auth success", func(t *testing.T) { tearDown := setUp(t, options{}) defer tearDown() @@ -78,6 +111,8 @@ func TestAuthenticateRequest(t *testing.T) { require.True(t, authResult.hasAuth) require.Equal(t, testRealm, authResult.realm, "Realm value should be present in the result") require.Equal(t, testUsername, authResult.username, "Username value should be present in the result") + + checkNoSTUNResponse(t) }) t.Run("no message integrity", func(t *testing.T) { @@ -100,18 +135,8 @@ func TestAuthenticateRequest(t *testing.T) { require.False(t, authResult.hasAuth) // Check the error response - buf := make([]byte, 1024) - n, _, err := conn.ReadFrom(buf) - require.NoError(t, err) - - var resp stun.Message - err = resp.UnmarshalBinary(buf[:n]) - require.NoError(t, err) - require.Equal(t, stun.NewType(stun.MethodAllocate, stun.ClassErrorResponse), resp.Type) - var attrErrorCode stun.ErrorCodeAttribute - err = attrErrorCode.GetFrom(&resp) - require.NoError(t, err) - require.Equal(t, stun.CodeUnauthorised, attrErrorCode.Code) + errorCode := checkSTUNAllocateErrorResponse(t) + require.Equal(t, stun.CodeUnauthorized, errorCode) }) t.Run("no auth handler", func(t *testing.T) { @@ -134,18 +159,8 @@ func TestAuthenticateRequest(t *testing.T) { require.False(t, authResult.hasAuth) // Check the error response - buf := make([]byte, 1024) - n, _, err := conn.ReadFrom(buf) - require.NoError(t, err) - - var resp stun.Message - err = resp.UnmarshalBinary(buf[:n]) - require.NoError(t, err) - require.Equal(t, stun.NewType(stun.MethodAllocate, stun.ClassErrorResponse), resp.Type) - var attrErrorCode stun.ErrorCodeAttribute - err = attrErrorCode.GetFrom(&resp) - require.NoError(t, err) - require.Equal(t, stun.CodeBadRequest, attrErrorCode.Code) + errorCode := checkSTUNAllocateErrorResponse(t) + require.Equal(t, stun.CodeBadRequest, errorCode) }) t.Run("no nonce", func(t *testing.T) { @@ -168,18 +183,8 @@ func TestAuthenticateRequest(t *testing.T) { require.False(t, authResult.hasAuth) // Check the error response - buf := make([]byte, 1024) - n, _, err := conn.ReadFrom(buf) - require.NoError(t, err) - - var resp stun.Message - err = resp.UnmarshalBinary(buf[:n]) - require.NoError(t, err) - require.Equal(t, stun.NewType(stun.MethodAllocate, stun.ClassErrorResponse), resp.Type) - var attrErrorCode stun.ErrorCodeAttribute - err = attrErrorCode.GetFrom(&resp) - require.NoError(t, err) - require.Equal(t, stun.CodeBadRequest, attrErrorCode.Code) + errorCode := checkSTUNAllocateErrorResponse(t) + require.Equal(t, stun.CodeBadRequest, errorCode) }) t.Run("invalid nonce", func(t *testing.T) { @@ -202,18 +207,8 @@ func TestAuthenticateRequest(t *testing.T) { require.False(t, authResult.hasAuth) // Check the error response - buf := make([]byte, 1024) - n, _, err := conn.ReadFrom(buf) - require.NoError(t, err) - - var resp stun.Message - err = resp.UnmarshalBinary(buf[:n]) - require.NoError(t, err) - require.Equal(t, stun.NewType(stun.MethodAllocate, stun.ClassErrorResponse), resp.Type) - var attrErrorCode stun.ErrorCodeAttribute - err = attrErrorCode.GetFrom(&resp) - require.NoError(t, err) - require.Equal(t, stun.CodeStaleNonce, attrErrorCode.Code) + errorCode := checkSTUNAllocateErrorResponse(t) + require.Equal(t, stun.CodeStaleNonce, errorCode) }) t.Run("no realm", func(t *testing.T) { @@ -236,18 +231,8 @@ func TestAuthenticateRequest(t *testing.T) { require.False(t, authResult.hasAuth) // Check the error response - buf := make([]byte, 1024) - n, _, err := conn.ReadFrom(buf) - require.NoError(t, err) - - var resp stun.Message - err = resp.UnmarshalBinary(buf[:n]) - require.NoError(t, err) - require.Equal(t, stun.NewType(stun.MethodAllocate, stun.ClassErrorResponse), resp.Type) - var attrErrorCode stun.ErrorCodeAttribute - err = attrErrorCode.GetFrom(&resp) - require.NoError(t, err) - require.Equal(t, stun.CodeBadRequest, attrErrorCode.Code) + errorCode := checkSTUNAllocateErrorResponse(t) + require.Equal(t, stun.CodeBadRequest, errorCode) }) t.Run("no username", func(t *testing.T) { @@ -270,18 +255,8 @@ func TestAuthenticateRequest(t *testing.T) { require.False(t, authResult.hasAuth) // Check the error response - buf := make([]byte, 1024) - n, _, err := conn.ReadFrom(buf) - require.NoError(t, err) - - var resp stun.Message - err = resp.UnmarshalBinary(buf[:n]) - require.NoError(t, err) - require.Equal(t, stun.NewType(stun.MethodAllocate, stun.ClassErrorResponse), resp.Type) - var attrErrorCode stun.ErrorCodeAttribute - err = attrErrorCode.GetFrom(&resp) - require.NoError(t, err) - require.Equal(t, stun.CodeBadRequest, attrErrorCode.Code) + errorCode := checkSTUNAllocateErrorResponse(t) + require.Equal(t, stun.CodeBadRequest, errorCode) }) t.Run("unknown username", func(t *testing.T) { @@ -304,18 +279,8 @@ func TestAuthenticateRequest(t *testing.T) { require.False(t, authResult.hasAuth) // Check the error response - buf := make([]byte, 1024) - n, _, err := conn.ReadFrom(buf) - require.NoError(t, err) - - var resp stun.Message - err = resp.UnmarshalBinary(buf[:n]) - require.NoError(t, err) - require.Equal(t, stun.NewType(stun.MethodAllocate, stun.ClassErrorResponse), resp.Type) - var attrErrorCode stun.ErrorCodeAttribute - err = attrErrorCode.GetFrom(&resp) - require.NoError(t, err) - require.Equal(t, stun.CodeBadRequest, attrErrorCode.Code) + errorCode := checkSTUNAllocateErrorResponse(t) + require.Equal(t, stun.CodeBadRequest, errorCode) }) t.Run("invalid message integrity", func(t *testing.T) { @@ -338,17 +303,7 @@ func TestAuthenticateRequest(t *testing.T) { require.False(t, authResult.hasAuth) // Check the error response - buf := make([]byte, 1024) - n, _, err := conn.ReadFrom(buf) - require.NoError(t, err) - - var resp stun.Message - err = resp.UnmarshalBinary(buf[:n]) - require.NoError(t, err) - require.Equal(t, stun.NewType(stun.MethodAllocate, stun.ClassErrorResponse), resp.Type) - var attrErrorCode stun.ErrorCodeAttribute - err = attrErrorCode.GetFrom(&resp) - require.NoError(t, err) - require.Equal(t, stun.CodeBadRequest, attrErrorCode.Code) + errorCode := checkSTUNAllocateErrorResponse(t) + require.Equal(t, stun.CodeBadRequest, errorCode) }) }