Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

simplify proxy, move seq check to no session path #233

Merged
merged 5 commits into from
Jul 12, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 0 additions & 18 deletions __tests__/fixtures/transports.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import {
ClientTransport,
Connection,
OpaqueTransportMessageSchema,
ServerTransport,
TransportClientId,
} from '../../transport';
Expand Down Expand Up @@ -29,8 +28,6 @@ import {
ServerHandshakeOptions,
} from '../../router/handshake';
import { MessageFramer } from '../../transport/transforms/messageFraming';
import { BinaryCodec } from '../../codec';
import { Value } from '@sinclair/typebox/value';

export type ValidTransports = 'ws' | 'unix sockets' | 'ws + uds proxy';

Expand Down Expand Up @@ -205,8 +202,6 @@ export const transports: Array<TransportMatrixEntry> = [
let port: number;
let wss: NodeWs.Server;

const codec = opts?.client?.codec ?? BinaryCodec;

async function setupProxyServer() {
udsServer = net.createServer();
await onUdsServeReady(udsServer, socketPath);
Expand All @@ -222,7 +217,6 @@ export const transports: Array<TransportMatrixEntry> = [
wss = createWebSocketServer(proxyServer);

// dumb proxy
// assume that we are using the binary msgpack protocol
wss.on('connection', (ws) => {
const framer = MessageFramer.createFramedStream();
const uds = net.createConnection(socketPath);
Expand All @@ -236,23 +230,11 @@ export const transports: Array<TransportMatrixEntry> = [
// ws -> uds
ws.onmessage = (msg) => {
const data = msg.data as Uint8Array;
const res = codec.fromBuffer(data);
if (!res) return;
if (!Value.Check(OpaqueTransportMessageSchema, res)) {
return;
}

uds.write(MessageFramer.write(data));
};

// ws <- uds
uds.pipe(framer).on('data', (data: Uint8Array) => {
const res = codec.fromBuffer(data);
if (!res) return;
if (!Value.Check(OpaqueTransportMessageSchema, res)) {
return;
}

ws.send(data);
});

Expand Down
58 changes: 29 additions & 29 deletions transport/server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -282,15 +282,14 @@ export abstract class ServerTransport<
| 'unknown session'
| 'transparent reconnection'
| 'hard reconnection' = 'new session';
const clientNextExpectedSeq =
msg.payload.expectedSessionState.nextExpectedSeq;
const clientNextSentSeq = msg.payload.expectedSessionState.nextSentSeq ?? 0;

if (oldSession && oldSession.id === msg.payload.sessionId) {
connectCase = 'transparent reconnection';

// invariant: ordering must be correct
const clientNextExpectedSeq =
msg.payload.expectedSessionState.nextExpectedSeq;
// TODO: remove nullish coalescing when we're sure this is always set
const clientNextSentSeq =
msg.payload.expectedSessionState.nextSentSeq ?? 0;
const ourNextSeq = oldSession.nextSeq();
const ourAck = oldSession.ack;

Expand Down Expand Up @@ -367,33 +366,34 @@ export abstract class ServerTransport<
connectCase = 'hard reconnection';

// just nuke the old session entirely and proceed as if this was new
this.log?.info(
`client is reconnecting to a new session (${msg.payload.sessionId}) with an old session (${oldSession.id}) already existing, closing old session`,
{
...session.loggingMetadata,
connectedTo: msg.from,
sessionId: msg.payload.sessionId,
},
);
this.deleteSession(oldSession);
oldSession = undefined;
} else {
connectCase = 'unknown session';

const clientNextExpectedSeq =
msg.payload.expectedSessionState.nextExpectedSeq;
// TODO: remove nullish coalescing when we're sure this is always set
const clientNextSentSeq =
msg.payload.expectedSessionState.nextSentSeq ?? 0;
}

if (clientNextSentSeq > 0 || clientNextExpectedSeq > 0) {
// we don't have a session, but the client is trying to reconnect
// to an old session. we can't do anything about this, so we reject
this.rejectHandshakeRequest(
session,
msg.from,
`client is trying to reconnect to a session the server don't know about: ${msg.payload.sessionId}`,
'SESSION_STATE_MISMATCH',
{
...session.loggingMetadata,
connectedTo: msg.from,
transportMessage: msg,
},
);
return;
}
if (!oldSession && (clientNextSentSeq > 0 || clientNextExpectedSeq > 0)) {
// we don't have a session, but the client is trying to reconnect
// to an old session. we can't do anything about this, so we reject
connectCase = 'unknown session';
this.rejectHandshakeRequest(
session,
msg.from,
`client is trying to reconnect to a session the server don't know about: ${msg.payload.sessionId}`,
'SESSION_STATE_MISMATCH',
{
...session.loggingMetadata,
connectedTo: msg.from,
transportMessage: msg,
},
);
return;
}

// from this point on, we're committed to connecting
Expand Down
1 change: 1 addition & 0 deletions transport/sessionStateMachine/SessionConnected.ts
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ export class SessionConnected<
}

private sendHeartbeat() {
this.log?.debug('sending heartbeat', this.loggingMetadata);
this.send({
streamId: 'heartbeat',
controlFlags: ControlFlags.AckBit,
Expand Down
9 changes: 3 additions & 6 deletions transport/transport.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -733,10 +733,8 @@ describe.each(testMatrix())(
await cleanupTransports([clientTransport, serverTransport]);
});

await waitFor(() => expect(onConnect).toHaveBeenCalledTimes(1));
closeAllConnections(clientTransport);
await waitFor(() => {
expect(onConnect).toHaveBeenCalledTimes(2);
lhchavez marked this conversation as resolved.
Show resolved Hide resolved
expect(onConnect).toHaveBeenCalledTimes(1);
expect(numberOfConnections(clientTransport)).toEqual(1);
expect(numberOfConnections(serverTransport)).toEqual(1);
});
Expand All @@ -746,10 +744,9 @@ describe.each(testMatrix())(
expect(oldClientSessionId).not.toBeUndefined();
expect(oldServerSessionId).not.toBeUndefined();

// make sure our connection is still intact even after session grace elapses
await advanceFakeTimersBySessionGrace();
lhchavez marked this conversation as resolved.
Show resolved Hide resolved

closeAllConnections(clientTransport);
await waitFor(() => {
expect(onConnect).toHaveBeenCalledTimes(2);
expect(numberOfConnections(clientTransport)).toEqual(1);
expect(numberOfConnections(serverTransport)).toEqual(1);
});
Expand Down
Loading