Skip to content

Commit

Permalink
BC-6711 - Disable asset sync (#4803)
Browse files Browse the repository at this point in the history
* add feature flag to disable asset sync, refactor some loops

* refactor to improve readability and maintainability

---------

Co-authored-by: Cedric Evers <12080057+CeEv@users.noreply.github.com>
  • Loading branch information
2 people authored and virgilchiriac committed Mar 4, 2024
1 parent 75f29bb commit 94dc691
Show file tree
Hide file tree
Showing 26 changed files with 482 additions and 490 deletions.
2 changes: 2 additions & 0 deletions apps/server/src/modules/tldraw/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ export interface TldrawConfig {
TLDRAW_GC_ENABLED: number;
REDIS_URI: string;
TLDRAW_ASSETS_ENABLED: boolean;
TLDRAW_ASSETS_SYNC_ENABLED: boolean;
TLDRAW_ASSETS_MAX_SIZE: number;
ASSETS_ALLOWED_MIME_TYPES_LIST: string;
API_HOST: number;
Expand All @@ -31,6 +32,7 @@ const tldrawConfig = {
TLDRAW_GC_ENABLED: Configuration.get('TLDRAW__GC_ENABLED') as boolean,
REDIS_URI: Configuration.has('REDIS_URI') ? (Configuration.get('REDIS_URI') as string) : null,
TLDRAW_ASSETS_ENABLED: Configuration.get('TLDRAW__ASSETS_ENABLED') as boolean,
TLDRAW_ASSETS_SYNC_ENABLED: Configuration.get('TLDRAW__ASSETS_SYNC_ENABLED') as boolean,
TLDRAW_ASSETS_MAX_SIZE: Configuration.get('TLDRAW__ASSETS_MAX_SIZE') as number,
ASSETS_ALLOWED_MIME_TYPES_LIST: Configuration.get('TLDRAW__ASSETS_ALLOWED_MIME_TYPES_LIST') as string,
API_HOST: Configuration.get('API_HOST') as string,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import { createConfigModuleOptions } from '@src/config';
import { Logger } from '@src/core/logger';
import { of, throwError } from 'rxjs';
import { createMock, DeepMocked } from '@golevelup/ts-jest';
import { ConfigModule } from '@nestjs/config';
import { ConfigModule, ConfigService } from '@nestjs/config';
import { HttpService } from '@nestjs/axios';
import { AxiosError, AxiosHeaders, AxiosResponse } from 'axios';
import { axiosResponseFactory } from '@shared/testing';
Expand All @@ -19,14 +19,16 @@ import { TldrawBoardRepo, TldrawRepo, YMongodb } from '../../repo';
import { TestConnection, tldrawTestConfig } from '../../testing';
import { MetricsService } from '../../metrics';
import { TldrawWs } from '..';
import { WsCloseCodeEnum, WsCloseMessageEnum } from '../../types';
import { WsCloseCode, WsCloseMessage } from '../../types';
import { TldrawConfig } from '../../config';

describe('WebSocketController (WsAdapter)', () => {
let app: INestApplication;
let gateway: TldrawWs;
let ws: WebSocket;
let wsService: TldrawWsService;
let httpService: DeepMocked<HttpService>;
let configService: ConfigService<TldrawConfig, true>;

const gatewayPort = 3346;
const wsUrl = TestConnection.getWsUrl(gatewayPort);
Expand Down Expand Up @@ -69,6 +71,7 @@ describe('WebSocketController (WsAdapter)', () => {
gateway = testingModule.get(TldrawWs);
wsService = testingModule.get(TldrawWsService);
httpService = testingModule.get(HttpService);
configService = testingModule.get(ConfigService);
app = testingModule.createNestApplication();
app.useWebSocketAdapter(new WsAdapter(app));
await app.init();
Expand Down Expand Up @@ -163,10 +166,7 @@ describe('WebSocketController (WsAdapter)', () => {

ws = await TestConnection.setupWs(wsUrl, 'TEST', {});

expect(wsCloseSpy).toHaveBeenCalledWith(
WsCloseCodeEnum.WS_CLIENT_UNAUTHORISED_CONNECTION_CODE,
Buffer.from(WsCloseMessageEnum.WS_CLIENT_UNAUTHORISED_CONNECTION_MESSAGE)
);
expect(wsCloseSpy).toHaveBeenCalledWith(WsCloseCode.UNAUTHORIZED, Buffer.from(WsCloseMessage.UNAUTHORIZED));

httpGetCallSpy.mockRestore();
wsCloseSpy.mockRestore();
Expand All @@ -180,22 +180,42 @@ describe('WebSocketController (WsAdapter)', () => {
httpGetCallSpy.mockReturnValueOnce(throwError(() => error));
ws = await TestConnection.setupWs(wsUrl, 'TEST', { cookie: 'jwt=jwt-mocked' });

expect(wsCloseSpy).toHaveBeenCalledWith(
WsCloseCodeEnum.WS_CLIENT_UNAUTHORISED_CONNECTION_CODE,
Buffer.from(WsCloseMessageEnum.WS_CLIENT_UNAUTHORISED_CONNECTION_MESSAGE)
);
expect(wsCloseSpy).toHaveBeenCalledWith(WsCloseCode.UNAUTHORIZED, Buffer.from(WsCloseMessage.UNAUTHORIZED));

httpGetCallSpy.mockRestore();
wsCloseSpy.mockRestore();
ws.close();
});
});

describe('when tldraw feature is disabled', () => {
const setup = () => {
const wsCloseSpy = jest.spyOn(WebSocket.prototype, 'close');
const configSpy = jest.spyOn(configService, 'get').mockReturnValueOnce(false);

return {
wsCloseSpy,
configSpy,
};
};

it('should close', async () => {
const { wsCloseSpy } = setup();

ws = await TestConnection.setupWs(wsUrl, 'test-doc');

expect(wsCloseSpy).toHaveBeenCalledWith(WsCloseCode.BAD_REQUEST, Buffer.from(WsCloseMessage.FEATURE_DISABLED));

wsCloseSpy.mockRestore();
ws.close();
});
});

describe('when checking docName and cookie', () => {
const setup = () => {
const setupConnectionSpy = jest.spyOn(wsService, 'setupWSConnection');
const setupConnectionSpy = jest.spyOn(wsService, 'setupWsConnection');
const wsCloseSpy = jest.spyOn(WebSocket.prototype, 'close');
const closeConnSpy = jest.spyOn(wsService, 'closeConn').mockRejectedValue(new Error('error'));
const closeConnSpy = jest.spyOn(wsService, 'closeConnection').mockRejectedValue(new Error('error'));

return {
setupConnectionSpy,
Expand All @@ -211,10 +231,7 @@ describe('WebSocketController (WsAdapter)', () => {
ws = await TestConnection.setupWs(wsUrl, '', { cookie: 'jwt=jwt-mocked' });
ws.send(buffer);

expect(wsCloseSpy).toHaveBeenCalledWith(
WsCloseCodeEnum.WS_CLIENT_BAD_REQUEST_CODE,
Buffer.from(WsCloseMessageEnum.WS_CLIENT_BAD_REQUEST_MESSAGE)
);
expect(wsCloseSpy).toHaveBeenCalledWith(WsCloseCode.BAD_REQUEST, Buffer.from(WsCloseMessage.BAD_REQUEST));

wsCloseSpy.mockRestore();
setupConnectionSpy.mockRestore();
Expand All @@ -236,30 +253,58 @@ describe('WebSocketController (WsAdapter)', () => {

ws = await TestConnection.setupWs(wsUrl, 'GLOBAL', { cookie: 'jwt=jwt-mocked' });

expect(wsCloseSpy).toHaveBeenCalledWith(
WsCloseCodeEnum.WS_CLIENT_NOT_FOUND_CODE,
Buffer.from(WsCloseMessageEnum.WS_CLIENT_NOT_FOUND_MESSAGE)
);
expect(wsCloseSpy).toHaveBeenCalledWith(WsCloseCode.NOT_FOUND, Buffer.from(WsCloseMessage.NOT_FOUND));

wsCloseSpy.mockRestore();
setupConnectionSpy.mockRestore();
ws.close();
});

it(`should close for not authorizing connection`, async () => {
it(`should close for not authorized connection`, async () => {
const { setupConnectionSpy, wsCloseSpy } = setup();
const { buffer } = getMessage();

const httpGetCallSpy = jest.spyOn(httpService, 'get');
const error = new Error('unknown error');
const error = new AxiosError('unknown error', '401', undefined, undefined, {
config: { headers: new AxiosHeaders() },
data: undefined,
headers: {},
statusText: '401',
status: 401,
});
httpGetCallSpy.mockReturnValueOnce(throwError(() => error));

ws = await TestConnection.setupWs(wsUrl, 'TEST', { cookie: 'jwt=jwt-mocked' });
ws.send(buffer);

expect(wsCloseSpy).toHaveBeenCalledWith(WsCloseCode.UNAUTHORIZED, Buffer.from(WsCloseMessage.UNAUTHORIZED));

wsCloseSpy.mockRestore();
setupConnectionSpy.mockRestore();
httpGetCallSpy.mockRestore();
ws.close();
});

it(`should close on unexpected error code`, async () => {
const { setupConnectionSpy, wsCloseSpy } = setup();
const { buffer } = getMessage();

const httpGetCallSpy = jest.spyOn(httpService, 'get');
const error = new AxiosError('unknown error', '418', undefined, undefined, {
config: { headers: new AxiosHeaders() },
data: undefined,
headers: {},
statusText: '418',
status: 418,
});
httpGetCallSpy.mockReturnValueOnce(throwError(() => error));

ws = await TestConnection.setupWs(wsUrl, 'TEST', { cookie: 'jwt=jwt-mocked' });
ws.send(buffer);

expect(wsCloseSpy).toHaveBeenCalledWith(
WsCloseCodeEnum.WS_CLIENT_UNAUTHORISED_CONNECTION_CODE,
Buffer.from(WsCloseMessageEnum.WS_CLIENT_UNAUTHORISED_CONNECTION_MESSAGE)
WsCloseCode.INTERNAL_SERVER_ERROR,
Buffer.from(WsCloseMessage.INTERNAL_SERVER_ERROR)
);

wsCloseSpy.mockRestore();
Expand Down Expand Up @@ -307,8 +352,8 @@ describe('WebSocketController (WsAdapter)', () => {

expect(setupConnectionSpy).toHaveBeenCalledWith(expect.anything(), 'TEST');
expect(wsCloseSpy).toHaveBeenCalledWith(
WsCloseCodeEnum.WS_CLIENT_FAILED_CONNECTION_CODE,
Buffer.from(WsCloseMessageEnum.WS_CLIENT_FAILED_CONNECTION_MESSAGE)
WsCloseCode.INTERNAL_SERVER_ERROR,
Buffer.from(WsCloseMessage.INTERNAL_SERVER_ERROR)
);

wsCloseSpy.mockRestore();
Expand Down
131 changes: 71 additions & 60 deletions apps/server/src/modules/tldraw/controller/tldraw.ws.ts
Original file line number Diff line number Diff line change
@@ -1,81 +1,48 @@
import { WebSocketGateway, WebSocketServer, OnGatewayInit, OnGatewayConnection } from '@nestjs/websockets';
import { Server, WebSocket } from 'ws';
import WebSocket, { Server } from 'ws';
import { Request } from 'express';
import { ConfigService } from '@nestjs/config';
import cookie from 'cookie';
import { BadRequestException, UnauthorizedException } from '@nestjs/common';
import { InternalServerErrorException, UnauthorizedException, NotFoundException } from '@nestjs/common';
import { Logger } from '@src/core/logger';
import { AxiosError } from 'axios';
import { isAxiosError } from 'axios';
import { firstValueFrom } from 'rxjs';
import { HttpService } from '@nestjs/axios';
import { WebsocketCloseErrorLoggable } from '../loggable';
import { WebsocketInitErrorLoggable } from '../loggable';
import { TldrawConfig, TLDRAW_SOCKET_PORT } from '../config';
import { WsCloseCodeEnum, WsCloseMessageEnum } from '../types';
import { WsCloseCode, WsCloseMessage } from '../types';
import { TldrawWsService } from '../service';

@WebSocketGateway(TLDRAW_SOCKET_PORT)
export class TldrawWs implements OnGatewayInit, OnGatewayConnection {
@WebSocketServer()
server!: Server;

private readonly apiHostUrl: string;

private readonly isTldrawEnabled: boolean;

constructor(
private readonly configService: ConfigService<TldrawConfig, true>,
private readonly tldrawWsService: TldrawWsService,
private readonly httpService: HttpService,
private readonly logger: Logger
) {
this.isTldrawEnabled = this.configService.get<boolean>('FEATURE_TLDRAW_ENABLED');
this.apiHostUrl = this.configService.get<string>('API_HOST');
}
) {}

public async handleConnection(client: WebSocket, request: Request): Promise<void> {
const docName = this.getDocNameFromRequest(request);

if (!this.isTldrawEnabled || !docName) {
this.closeClientAndLogError(
client,
WsCloseCodeEnum.WS_CLIENT_BAD_REQUEST_CODE,
WsCloseMessageEnum.WS_CLIENT_BAD_REQUEST_MESSAGE,
new BadRequestException()
);
if (!this.configService.get<boolean>('FEATURE_TLDRAW_ENABLED')) {
client.close(WsCloseCode.BAD_REQUEST, WsCloseMessage.FEATURE_DISABLED);
return;
}

try {
const cookies = this.parseCookiesFromHeader(request);
await this.authorizeConnection(docName, cookies?.jwt);
} catch (err) {
if (err instanceof AxiosError && (err.response?.status === 400 || err.response?.status === 404)) {
this.closeClientAndLogError(
client,
WsCloseCodeEnum.WS_CLIENT_NOT_FOUND_CODE,
WsCloseMessageEnum.WS_CLIENT_NOT_FOUND_MESSAGE,
err
);
} else {
this.closeClientAndLogError(
client,
WsCloseCodeEnum.WS_CLIENT_UNAUTHORISED_CONNECTION_CODE,
WsCloseMessageEnum.WS_CLIENT_UNAUTHORISED_CONNECTION_MESSAGE,
err
);
}
const docName = this.getDocNameFromRequest(request);
if (!docName) {
client.close(WsCloseCode.BAD_REQUEST, WsCloseMessage.BAD_REQUEST);
return;
}

try {
await this.tldrawWsService.setupWSConnection(client, docName);
const cookies = this.parseCookiesFromHeader(request);
await this.authorizeConnection(docName, cookies?.jwt);
await this.tldrawWsService.setupWsConnection(client, docName);
} catch (err) {
this.closeClientAndLogError(
client,
WsCloseCodeEnum.WS_CLIENT_FAILED_CONNECTION_CODE,
WsCloseMessageEnum.WS_CLIENT_FAILED_CONNECTION_MESSAGE,
err
);
this.handleError(err, client, docName);
}
}

Expand All @@ -93,23 +60,67 @@ export class TldrawWs implements OnGatewayInit, OnGatewayConnection {
return parsedCookies;
}

private closeClientAndLogError(client: WebSocket, code: WsCloseCodeEnum, data: string, err: unknown): void {
client.close(code, data);
this.logger.warning(new WebsocketCloseErrorLoggable(err, `(${code}) ${data}`));
}

private async authorizeConnection(drawingName: string, token: string): Promise<void> {
if (!token) {
throw new UnauthorizedException('Token was not given');
}

await firstValueFrom(
this.httpService.get(`${this.apiHostUrl}/v3/elements/${drawingName}/permission`, {
headers: {
Accept: 'Application/json',
Authorization: `Bearer ${token}`,
},
})
try {
const apiHostUrl = this.configService.get<string>('API_HOST');
await firstValueFrom(
this.httpService.get(`${apiHostUrl}/v3/elements/${drawingName}/permission`, {
headers: {
Accept: 'Application/json',
Authorization: `Bearer ${token}`,
},
})
);
} catch (err) {
if (isAxiosError(err)) {
switch (err.response?.status) {
case 400:
case 404:
throw new NotFoundException();
case 401:
case 403:
throw new UnauthorizedException();
default:
throw new InternalServerErrorException();
}
}

throw new InternalServerErrorException();
}
}

private closeClientAndLog(
client: WebSocket,
code: WsCloseCode,
message: WsCloseMessage,
docName: string,
err?: unknown
): void {
client.close(code, message);
this.logger.warning(new WebsocketInitErrorLoggable(code, message, docName, err));
}

private handleError(err: unknown, client: WebSocket, docName: string): void {
if (err instanceof NotFoundException) {
this.closeClientAndLog(client, WsCloseCode.NOT_FOUND, WsCloseMessage.NOT_FOUND, docName);
return;
}

if (err instanceof UnauthorizedException) {
this.closeClientAndLog(client, WsCloseCode.UNAUTHORIZED, WsCloseMessage.UNAUTHORIZED, docName);
return;
}

this.closeClientAndLog(
client,
WsCloseCode.INTERNAL_SERVER_ERROR,
WsCloseMessage.INTERNAL_SERVER_ERROR,
docName,
err
);
}
}
2 changes: 1 addition & 1 deletion apps/server/src/modules/tldraw/loggable/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ export * from './mongo-transaction-error.loggable';
export * from './redis-error.loggable';
export * from './redis-publish-error.loggable';
export * from './websocket-error.loggable';
export * from './websocket-close-error.loggable';
export * from './websocket-init-error.loggable';
export * from './websocket-message-error.loggable';
export * from './ws-shared-doc-error.loggable';
export * from './close-connection.loggable';
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import { RedisPublishErrorLoggable } from './redis-publish-error.loggable';
import { UpdateType } from '../types';

describe('RedisPublishErrorLoggable', () => {
describe('getLogMessage', () => {
const setup = () => {
const type = 'document';
const type = UpdateType.DOCUMENT;
const error = new Error('test');
const loggable = new RedisPublishErrorLoggable(type, error);

Expand Down
Loading

0 comments on commit 94dc691

Please sign in to comment.