diff --git a/src/core/packet_builder.c b/src/core/packet_builder.c index 84d1a27052..52704b8369 100644 --- a/src/core/packet_builder.c +++ b/src/core/packet_builder.c @@ -479,8 +479,22 @@ QuicPacketBuilderGetPacketTypeAndKeyForControlFrames( CXPLAT_DBG_ASSERT(SendFlags != 0); QuicSendValidate(&Builder->Connection->Send); + QUIC_PACKET_KEY_TYPE MaxKeyType = Connection->Crypto.TlsState.WriteKey; + + if (QuicConnIsClient(Connection) && + !Connection->State.HandshakeConfirmed && + MaxKeyType == QUIC_PACKET_KEY_1_RTT && + (SendFlags & QUIC_CONN_SEND_FLAG_CONNECTION_CLOSE)) { + // + // Server is not allowed to process 1-RTT packets until the handshake is confirmed and since we are + // closing the connection, the handshake is unlikely to complete. Ensure the CONNECTION_CLOSE is sent + // in a packet which server can process. + // + MaxKeyType = QUIC_PACKET_KEY_HANDSHAKE; + } + for (QUIC_PACKET_KEY_TYPE KeyType = 0; - KeyType <= Connection->Crypto.TlsState.WriteKey; + KeyType <= MaxKeyType; ++KeyType) { if (KeyType == QUIC_PACKET_KEY_0_RTT) { @@ -538,7 +552,7 @@ QuicPacketBuilderGetPacketTypeAndKeyForControlFrames( if (Connection->Crypto.TlsState.WriteKey == QUIC_PACKET_KEY_0_RTT) { *PacketKeyType = QUIC_PACKET_KEY_INITIAL; } else { - *PacketKeyType = Connection->Crypto.TlsState.WriteKey; + *PacketKeyType = MaxKeyType; } return TRUE; } diff --git a/src/test/lib/HandshakeTest.cpp b/src/test/lib/HandshakeTest.cpp index a1bbdaca0f..974cdc4e14 100644 --- a/src/test/lib/HandshakeTest.cpp +++ b/src/test/lib/HandshakeTest.cpp @@ -885,13 +885,11 @@ QuicTestCustomServerCertificateValidation( } TEST_EQUAL(AcceptCert, Client.GetIsConnected()); - if (AcceptCert) { // Server will be deleted on reject case, so can't validate. - TEST_NOT_EQUAL(nullptr, Server); - if (!Server->WaitForConnectionComplete()) { - return; - } - TEST_TRUE(Server->GetIsConnected()); + TEST_NOT_EQUAL(nullptr, Server); + if (!Server->WaitForConnectionComplete()) { + return; } + TEST_EQUAL(AcceptCert, Server->GetIsConnected()); } } } @@ -998,16 +996,15 @@ QuicTestCustomClientCertificateValidation( if (!Client.WaitForConnectionComplete()) { return; } - - if (AcceptCert) { // Server will be deleted on reject case, so can't validate. - TEST_NOT_EQUAL(nullptr, Server); - if (!Server->WaitForConnectionComplete()) { - return; - } - TEST_TRUE(Server->GetIsConnected()); - } // In all cases, the client "connects", but in the rejection case, it gets disconnected. TEST_TRUE(Client.GetIsConnected()); + + TEST_NOT_EQUAL(nullptr, Server); + if (!Server->WaitForConnectionComplete()) { + return; + } + + TEST_EQUAL(AcceptCert, Server->GetIsConnected()); } } }