diff --git a/apps/server/src/modules/tldraw/config.ts b/apps/server/src/modules/tldraw/config.ts index d1b421463eb..4292185c91d 100644 --- a/apps/server/src/modules/tldraw/config.ts +++ b/apps/server/src/modules/tldraw/config.ts @@ -11,6 +11,7 @@ export interface TldrawConfig { TLDRAW_GC_ENABLED: number; REDIS_URI: string; TLDRAW_ASSETS_ENABLED: boolean; + TLDRAW_ASSETS_SYNC_ENABLED: boolean; TLDRAW_ASSETS_MAX_SIZE: number; ASSETS_ALLOWED_MIME_TYPES_LIST: string; API_HOST: number; @@ -31,6 +32,7 @@ const tldrawConfig = { TLDRAW_GC_ENABLED: Configuration.get('TLDRAW__GC_ENABLED') as boolean, REDIS_URI: Configuration.has('REDIS_URI') ? (Configuration.get('REDIS_URI') as string) : null, TLDRAW_ASSETS_ENABLED: Configuration.get('TLDRAW__ASSETS_ENABLED') as boolean, + TLDRAW_ASSETS_SYNC_ENABLED: Configuration.get('TLDRAW__ASSETS_SYNC_ENABLED') as boolean, TLDRAW_ASSETS_MAX_SIZE: Configuration.get('TLDRAW__ASSETS_MAX_SIZE') as number, ASSETS_ALLOWED_MIME_TYPES_LIST: Configuration.get('TLDRAW__ASSETS_ALLOWED_MIME_TYPES_LIST') as string, API_HOST: Configuration.get('API_HOST') as string, diff --git a/apps/server/src/modules/tldraw/controller/api-test/tldraw.ws.api.spec.ts b/apps/server/src/modules/tldraw/controller/api-test/tldraw.ws.api.spec.ts index 811f6324584..fd02a8097da 100644 --- a/apps/server/src/modules/tldraw/controller/api-test/tldraw.ws.api.spec.ts +++ b/apps/server/src/modules/tldraw/controller/api-test/tldraw.ws.api.spec.ts @@ -8,7 +8,7 @@ import { createConfigModuleOptions } from '@src/config'; import { Logger } from '@src/core/logger'; import { of, throwError } from 'rxjs'; import { createMock, DeepMocked } from '@golevelup/ts-jest'; -import { ConfigModule } from '@nestjs/config'; +import { ConfigModule, ConfigService } from '@nestjs/config'; import { HttpService } from '@nestjs/axios'; import { AxiosError, AxiosHeaders, AxiosResponse } from 'axios'; import { axiosResponseFactory } from '@shared/testing'; @@ -19,7 +19,8 @@ import { TldrawBoardRepo, TldrawRepo, YMongodb } from '../../repo'; import { TestConnection, tldrawTestConfig } from '../../testing'; import { MetricsService } from '../../metrics'; import { TldrawWs } from '..'; -import { WsCloseCodeEnum, WsCloseMessageEnum } from '../../types'; +import { WsCloseCode, WsCloseMessage } from '../../types'; +import { TldrawConfig } from '../../config'; describe('WebSocketController (WsAdapter)', () => { let app: INestApplication; @@ -27,6 +28,7 @@ describe('WebSocketController (WsAdapter)', () => { let ws: WebSocket; let wsService: TldrawWsService; let httpService: DeepMocked; + let configService: ConfigService; const gatewayPort = 3346; const wsUrl = TestConnection.getWsUrl(gatewayPort); @@ -69,6 +71,7 @@ describe('WebSocketController (WsAdapter)', () => { gateway = testingModule.get(TldrawWs); wsService = testingModule.get(TldrawWsService); httpService = testingModule.get(HttpService); + configService = testingModule.get(ConfigService); app = testingModule.createNestApplication(); app.useWebSocketAdapter(new WsAdapter(app)); await app.init(); @@ -163,10 +166,7 @@ describe('WebSocketController (WsAdapter)', () => { ws = await TestConnection.setupWs(wsUrl, 'TEST', {}); - expect(wsCloseSpy).toHaveBeenCalledWith( - WsCloseCodeEnum.WS_CLIENT_UNAUTHORISED_CONNECTION_CODE, - Buffer.from(WsCloseMessageEnum.WS_CLIENT_UNAUTHORISED_CONNECTION_MESSAGE) - ); + expect(wsCloseSpy).toHaveBeenCalledWith(WsCloseCode.UNAUTHORIZED, Buffer.from(WsCloseMessage.UNAUTHORIZED)); httpGetCallSpy.mockRestore(); wsCloseSpy.mockRestore(); @@ -180,10 +180,7 @@ describe('WebSocketController (WsAdapter)', () => { httpGetCallSpy.mockReturnValueOnce(throwError(() => error)); ws = await TestConnection.setupWs(wsUrl, 'TEST', { cookie: 'jwt=jwt-mocked' }); - expect(wsCloseSpy).toHaveBeenCalledWith( - WsCloseCodeEnum.WS_CLIENT_UNAUTHORISED_CONNECTION_CODE, - Buffer.from(WsCloseMessageEnum.WS_CLIENT_UNAUTHORISED_CONNECTION_MESSAGE) - ); + expect(wsCloseSpy).toHaveBeenCalledWith(WsCloseCode.UNAUTHORIZED, Buffer.from(WsCloseMessage.UNAUTHORIZED)); httpGetCallSpy.mockRestore(); wsCloseSpy.mockRestore(); @@ -191,11 +188,34 @@ describe('WebSocketController (WsAdapter)', () => { }); }); + describe('when tldraw feature is disabled', () => { + const setup = () => { + const wsCloseSpy = jest.spyOn(WebSocket.prototype, 'close'); + const configSpy = jest.spyOn(configService, 'get').mockReturnValueOnce(false); + + return { + wsCloseSpy, + configSpy, + }; + }; + + it('should close', async () => { + const { wsCloseSpy } = setup(); + + ws = await TestConnection.setupWs(wsUrl, 'test-doc'); + + expect(wsCloseSpy).toHaveBeenCalledWith(WsCloseCode.BAD_REQUEST, Buffer.from(WsCloseMessage.FEATURE_DISABLED)); + + wsCloseSpy.mockRestore(); + ws.close(); + }); + }); + describe('when checking docName and cookie', () => { const setup = () => { - const setupConnectionSpy = jest.spyOn(wsService, 'setupWSConnection'); + const setupConnectionSpy = jest.spyOn(wsService, 'setupWsConnection'); const wsCloseSpy = jest.spyOn(WebSocket.prototype, 'close'); - const closeConnSpy = jest.spyOn(wsService, 'closeConn').mockRejectedValue(new Error('error')); + const closeConnSpy = jest.spyOn(wsService, 'closeConnection').mockRejectedValue(new Error('error')); return { setupConnectionSpy, @@ -211,10 +231,7 @@ describe('WebSocketController (WsAdapter)', () => { ws = await TestConnection.setupWs(wsUrl, '', { cookie: 'jwt=jwt-mocked' }); ws.send(buffer); - expect(wsCloseSpy).toHaveBeenCalledWith( - WsCloseCodeEnum.WS_CLIENT_BAD_REQUEST_CODE, - Buffer.from(WsCloseMessageEnum.WS_CLIENT_BAD_REQUEST_MESSAGE) - ); + expect(wsCloseSpy).toHaveBeenCalledWith(WsCloseCode.BAD_REQUEST, Buffer.from(WsCloseMessage.BAD_REQUEST)); wsCloseSpy.mockRestore(); setupConnectionSpy.mockRestore(); @@ -236,30 +253,58 @@ describe('WebSocketController (WsAdapter)', () => { ws = await TestConnection.setupWs(wsUrl, 'GLOBAL', { cookie: 'jwt=jwt-mocked' }); - expect(wsCloseSpy).toHaveBeenCalledWith( - WsCloseCodeEnum.WS_CLIENT_NOT_FOUND_CODE, - Buffer.from(WsCloseMessageEnum.WS_CLIENT_NOT_FOUND_MESSAGE) - ); + expect(wsCloseSpy).toHaveBeenCalledWith(WsCloseCode.NOT_FOUND, Buffer.from(WsCloseMessage.NOT_FOUND)); wsCloseSpy.mockRestore(); setupConnectionSpy.mockRestore(); ws.close(); }); - it(`should close for not authorizing connection`, async () => { + it(`should close for not authorized connection`, async () => { const { setupConnectionSpy, wsCloseSpy } = setup(); const { buffer } = getMessage(); const httpGetCallSpy = jest.spyOn(httpService, 'get'); - const error = new Error('unknown error'); + const error = new AxiosError('unknown error', '401', undefined, undefined, { + config: { headers: new AxiosHeaders() }, + data: undefined, + headers: {}, + statusText: '401', + status: 401, + }); + httpGetCallSpy.mockReturnValueOnce(throwError(() => error)); + + ws = await TestConnection.setupWs(wsUrl, 'TEST', { cookie: 'jwt=jwt-mocked' }); + ws.send(buffer); + + expect(wsCloseSpy).toHaveBeenCalledWith(WsCloseCode.UNAUTHORIZED, Buffer.from(WsCloseMessage.UNAUTHORIZED)); + + wsCloseSpy.mockRestore(); + setupConnectionSpy.mockRestore(); + httpGetCallSpy.mockRestore(); + ws.close(); + }); + + it(`should close on unexpected error code`, async () => { + const { setupConnectionSpy, wsCloseSpy } = setup(); + const { buffer } = getMessage(); + + const httpGetCallSpy = jest.spyOn(httpService, 'get'); + const error = new AxiosError('unknown error', '418', undefined, undefined, { + config: { headers: new AxiosHeaders() }, + data: undefined, + headers: {}, + statusText: '418', + status: 418, + }); httpGetCallSpy.mockReturnValueOnce(throwError(() => error)); ws = await TestConnection.setupWs(wsUrl, 'TEST', { cookie: 'jwt=jwt-mocked' }); ws.send(buffer); expect(wsCloseSpy).toHaveBeenCalledWith( - WsCloseCodeEnum.WS_CLIENT_UNAUTHORISED_CONNECTION_CODE, - Buffer.from(WsCloseMessageEnum.WS_CLIENT_UNAUTHORISED_CONNECTION_MESSAGE) + WsCloseCode.INTERNAL_SERVER_ERROR, + Buffer.from(WsCloseMessage.INTERNAL_SERVER_ERROR) ); wsCloseSpy.mockRestore(); @@ -307,8 +352,8 @@ describe('WebSocketController (WsAdapter)', () => { expect(setupConnectionSpy).toHaveBeenCalledWith(expect.anything(), 'TEST'); expect(wsCloseSpy).toHaveBeenCalledWith( - WsCloseCodeEnum.WS_CLIENT_FAILED_CONNECTION_CODE, - Buffer.from(WsCloseMessageEnum.WS_CLIENT_FAILED_CONNECTION_MESSAGE) + WsCloseCode.INTERNAL_SERVER_ERROR, + Buffer.from(WsCloseMessage.INTERNAL_SERVER_ERROR) ); wsCloseSpy.mockRestore(); diff --git a/apps/server/src/modules/tldraw/controller/tldraw.ws.ts b/apps/server/src/modules/tldraw/controller/tldraw.ws.ts index c0d36b7fc69..61161001fe4 100644 --- a/apps/server/src/modules/tldraw/controller/tldraw.ws.ts +++ b/apps/server/src/modules/tldraw/controller/tldraw.ws.ts @@ -1,16 +1,16 @@ import { WebSocketGateway, WebSocketServer, OnGatewayInit, OnGatewayConnection } from '@nestjs/websockets'; -import { Server, WebSocket } from 'ws'; +import WebSocket, { Server } from 'ws'; import { Request } from 'express'; import { ConfigService } from '@nestjs/config'; import cookie from 'cookie'; -import { BadRequestException, UnauthorizedException } from '@nestjs/common'; +import { InternalServerErrorException, UnauthorizedException, NotFoundException } from '@nestjs/common'; import { Logger } from '@src/core/logger'; -import { AxiosError } from 'axios'; +import { isAxiosError } from 'axios'; import { firstValueFrom } from 'rxjs'; import { HttpService } from '@nestjs/axios'; -import { WebsocketCloseErrorLoggable } from '../loggable'; +import { WebsocketInitErrorLoggable } from '../loggable'; import { TldrawConfig, TLDRAW_SOCKET_PORT } from '../config'; -import { WsCloseCodeEnum, WsCloseMessageEnum } from '../types'; +import { WsCloseCode, WsCloseMessage } from '../types'; import { TldrawWsService } from '../service'; @WebSocketGateway(TLDRAW_SOCKET_PORT) @@ -18,64 +18,31 @@ export class TldrawWs implements OnGatewayInit, OnGatewayConnection { @WebSocketServer() server!: Server; - private readonly apiHostUrl: string; - - private readonly isTldrawEnabled: boolean; - constructor( private readonly configService: ConfigService, private readonly tldrawWsService: TldrawWsService, private readonly httpService: HttpService, private readonly logger: Logger - ) { - this.isTldrawEnabled = this.configService.get('FEATURE_TLDRAW_ENABLED'); - this.apiHostUrl = this.configService.get('API_HOST'); - } + ) {} public async handleConnection(client: WebSocket, request: Request): Promise { - const docName = this.getDocNameFromRequest(request); - - if (!this.isTldrawEnabled || !docName) { - this.closeClientAndLogError( - client, - WsCloseCodeEnum.WS_CLIENT_BAD_REQUEST_CODE, - WsCloseMessageEnum.WS_CLIENT_BAD_REQUEST_MESSAGE, - new BadRequestException() - ); + if (!this.configService.get('FEATURE_TLDRAW_ENABLED')) { + client.close(WsCloseCode.BAD_REQUEST, WsCloseMessage.FEATURE_DISABLED); return; } - try { - const cookies = this.parseCookiesFromHeader(request); - await this.authorizeConnection(docName, cookies?.jwt); - } catch (err) { - if (err instanceof AxiosError && (err.response?.status === 400 || err.response?.status === 404)) { - this.closeClientAndLogError( - client, - WsCloseCodeEnum.WS_CLIENT_NOT_FOUND_CODE, - WsCloseMessageEnum.WS_CLIENT_NOT_FOUND_MESSAGE, - err - ); - } else { - this.closeClientAndLogError( - client, - WsCloseCodeEnum.WS_CLIENT_UNAUTHORISED_CONNECTION_CODE, - WsCloseMessageEnum.WS_CLIENT_UNAUTHORISED_CONNECTION_MESSAGE, - err - ); - } + const docName = this.getDocNameFromRequest(request); + if (!docName) { + client.close(WsCloseCode.BAD_REQUEST, WsCloseMessage.BAD_REQUEST); return; } try { - await this.tldrawWsService.setupWSConnection(client, docName); + const cookies = this.parseCookiesFromHeader(request); + await this.authorizeConnection(docName, cookies?.jwt); + await this.tldrawWsService.setupWsConnection(client, docName); } catch (err) { - this.closeClientAndLogError( - client, - WsCloseCodeEnum.WS_CLIENT_FAILED_CONNECTION_CODE, - WsCloseMessageEnum.WS_CLIENT_FAILED_CONNECTION_MESSAGE, - err - ); + this.handleError(err, client, docName); } } @@ -93,23 +60,67 @@ export class TldrawWs implements OnGatewayInit, OnGatewayConnection { return parsedCookies; } - private closeClientAndLogError(client: WebSocket, code: WsCloseCodeEnum, data: string, err: unknown): void { - client.close(code, data); - this.logger.warning(new WebsocketCloseErrorLoggable(err, `(${code}) ${data}`)); - } - private async authorizeConnection(drawingName: string, token: string): Promise { if (!token) { throw new UnauthorizedException('Token was not given'); } - await firstValueFrom( - this.httpService.get(`${this.apiHostUrl}/v3/elements/${drawingName}/permission`, { - headers: { - Accept: 'Application/json', - Authorization: `Bearer ${token}`, - }, - }) + try { + const apiHostUrl = this.configService.get('API_HOST'); + await firstValueFrom( + this.httpService.get(`${apiHostUrl}/v3/elements/${drawingName}/permission`, { + headers: { + Accept: 'Application/json', + Authorization: `Bearer ${token}`, + }, + }) + ); + } catch (err) { + if (isAxiosError(err)) { + switch (err.response?.status) { + case 400: + case 404: + throw new NotFoundException(); + case 401: + case 403: + throw new UnauthorizedException(); + default: + throw new InternalServerErrorException(); + } + } + + throw new InternalServerErrorException(); + } + } + + private closeClientAndLog( + client: WebSocket, + code: WsCloseCode, + message: WsCloseMessage, + docName: string, + err?: unknown + ): void { + client.close(code, message); + this.logger.warning(new WebsocketInitErrorLoggable(code, message, docName, err)); + } + + private handleError(err: unknown, client: WebSocket, docName: string): void { + if (err instanceof NotFoundException) { + this.closeClientAndLog(client, WsCloseCode.NOT_FOUND, WsCloseMessage.NOT_FOUND, docName); + return; + } + + if (err instanceof UnauthorizedException) { + this.closeClientAndLog(client, WsCloseCode.UNAUTHORIZED, WsCloseMessage.UNAUTHORIZED, docName); + return; + } + + this.closeClientAndLog( + client, + WsCloseCode.INTERNAL_SERVER_ERROR, + WsCloseMessage.INTERNAL_SERVER_ERROR, + docName, + err ); } } diff --git a/apps/server/src/modules/tldraw/loggable/index.ts b/apps/server/src/modules/tldraw/loggable/index.ts index 286e877131c..00bfbc2fa7b 100644 --- a/apps/server/src/modules/tldraw/loggable/index.ts +++ b/apps/server/src/modules/tldraw/loggable/index.ts @@ -2,7 +2,7 @@ export * from './mongo-transaction-error.loggable'; export * from './redis-error.loggable'; export * from './redis-publish-error.loggable'; export * from './websocket-error.loggable'; -export * from './websocket-close-error.loggable'; +export * from './websocket-init-error.loggable'; export * from './websocket-message-error.loggable'; export * from './ws-shared-doc-error.loggable'; export * from './close-connection.loggable'; diff --git a/apps/server/src/modules/tldraw/loggable/redis-publish-error.loggable.spec.ts b/apps/server/src/modules/tldraw/loggable/redis-publish-error.loggable.spec.ts index f942d4d4091..915b1596dd5 100644 --- a/apps/server/src/modules/tldraw/loggable/redis-publish-error.loggable.spec.ts +++ b/apps/server/src/modules/tldraw/loggable/redis-publish-error.loggable.spec.ts @@ -1,9 +1,10 @@ import { RedisPublishErrorLoggable } from './redis-publish-error.loggable'; +import { UpdateType } from '../types'; describe('RedisPublishErrorLoggable', () => { describe('getLogMessage', () => { const setup = () => { - const type = 'document'; + const type = UpdateType.DOCUMENT; const error = new Error('test'); const loggable = new RedisPublishErrorLoggable(type, error); diff --git a/apps/server/src/modules/tldraw/loggable/redis-publish-error.loggable.ts b/apps/server/src/modules/tldraw/loggable/redis-publish-error.loggable.ts index f96a21bfba0..2e3d6b1559e 100644 --- a/apps/server/src/modules/tldraw/loggable/redis-publish-error.loggable.ts +++ b/apps/server/src/modules/tldraw/loggable/redis-publish-error.loggable.ts @@ -1,9 +1,10 @@ import { ErrorLogMessage, Loggable, LogMessage, ValidationErrorLogMessage } from '@src/core/logger'; +import { UpdateType } from '../types'; export class RedisPublishErrorLoggable implements Loggable { private error: Error | undefined; - constructor(private readonly type: 'document' | 'awareness', private readonly err: unknown) { + constructor(private readonly type: UpdateType, private readonly err: unknown) { if (err instanceof Error) { this.error = err; } diff --git a/apps/server/src/modules/tldraw/loggable/websocket-close-error.loggable.spec.ts b/apps/server/src/modules/tldraw/loggable/websocket-close-error.loggable.spec.ts deleted file mode 100644 index b14fb64c1e7..00000000000 --- a/apps/server/src/modules/tldraw/loggable/websocket-close-error.loggable.spec.ts +++ /dev/null @@ -1,21 +0,0 @@ -import { WebsocketCloseErrorLoggable } from './websocket-close-error.loggable'; - -describe('WebsocketCloseErrorLoggable', () => { - describe('getLogMessage', () => { - const setup = () => { - const error = new Error('test'); - const errorMessage = 'message'; - - const loggable = new WebsocketCloseErrorLoggable(error, errorMessage); - return { loggable, error, errorMessage }; - }; - - it('should return a loggable message', () => { - const { loggable, error, errorMessage } = setup(); - - const message = loggable.getLogMessage(); - - expect(message).toEqual({ message: errorMessage, error, type: 'WEBSOCKET_CLOSE_ERROR' }); - }); - }); -}); diff --git a/apps/server/src/modules/tldraw/loggable/websocket-close-error.loggable.ts b/apps/server/src/modules/tldraw/loggable/websocket-close-error.loggable.ts deleted file mode 100644 index f4a77a7b4ad..00000000000 --- a/apps/server/src/modules/tldraw/loggable/websocket-close-error.loggable.ts +++ /dev/null @@ -1,19 +0,0 @@ -import { ErrorLogMessage, Loggable, LogMessage, ValidationErrorLogMessage } from '@src/core/logger'; - -export class WebsocketCloseErrorLoggable implements Loggable { - private readonly error: Error | undefined; - - constructor(private readonly err: unknown, private readonly message: string) { - if (err instanceof Error) { - this.error = err; - } - } - - getLogMessage(): LogMessage | ErrorLogMessage | ValidationErrorLogMessage { - return { - message: this.message, - type: 'WEBSOCKET_CLOSE_ERROR', - error: this.error, - }; - } -} diff --git a/apps/server/src/modules/tldraw/loggable/websocket-init-error.loggable.spec.ts b/apps/server/src/modules/tldraw/loggable/websocket-init-error.loggable.spec.ts new file mode 100644 index 00000000000..faada42a29d --- /dev/null +++ b/apps/server/src/modules/tldraw/loggable/websocket-init-error.loggable.spec.ts @@ -0,0 +1,28 @@ +import { WebsocketInitErrorLoggable } from './websocket-init-error.loggable'; +import { WsCloseCode, WsCloseMessage } from '../types'; + +describe('WebsocketInitErrorLoggable', () => { + describe('getLogMessage', () => { + const setup = () => { + const error = new Error('test'); + const errorCode = WsCloseCode.BAD_REQUEST; + const errorMessage = WsCloseMessage.BAD_REQUEST; + const docName = 'test'; + + const loggable = new WebsocketInitErrorLoggable(errorCode, errorMessage, docName, error); + return { loggable, error, errorCode, errorMessage, docName }; + }; + + it('should return a loggable message', () => { + const { loggable, error, errorMessage, errorCode, docName } = setup(); + + const message = loggable.getLogMessage(); + + expect(message).toEqual({ + message: `[${docName}] [${errorCode}] ${errorMessage}`, + type: 'WEBSOCKET_CONNECTION_INIT_ERROR', + error, + }); + }); + }); +}); diff --git a/apps/server/src/modules/tldraw/loggable/websocket-init-error.loggable.ts b/apps/server/src/modules/tldraw/loggable/websocket-init-error.loggable.ts new file mode 100644 index 00000000000..d82760290b8 --- /dev/null +++ b/apps/server/src/modules/tldraw/loggable/websocket-init-error.loggable.ts @@ -0,0 +1,25 @@ +import { ErrorLogMessage, Loggable, LogMessage, ValidationErrorLogMessage } from '@src/core/logger'; +import { WsCloseCode, WsCloseMessage } from '../types'; + +export class WebsocketInitErrorLoggable implements Loggable { + private readonly error: Error | undefined; + + constructor( + private readonly code: WsCloseCode, + private readonly message: WsCloseMessage, + private readonly docName: string, + private readonly err?: unknown + ) { + if (err instanceof Error) { + this.error = err; + } + } + + getLogMessage(): LogMessage | ErrorLogMessage | ValidationErrorLogMessage { + return { + message: `[${this.docName}] [${this.code}] ${this.message}`, + type: 'WEBSOCKET_CONNECTION_INIT_ERROR', + error: this.error, + }; + } +} diff --git a/apps/server/src/modules/tldraw/redis/tldraw-redis.factory.spec.ts b/apps/server/src/modules/tldraw/redis/tldraw-redis.factory.spec.ts index 7517093b787..c24fec60514 100644 --- a/apps/server/src/modules/tldraw/redis/tldraw-redis.factory.spec.ts +++ b/apps/server/src/modules/tldraw/redis/tldraw-redis.factory.spec.ts @@ -3,8 +3,9 @@ import { Test } from '@nestjs/testing'; import { createConfigModuleOptions } from '@src/config'; import { INestApplication } from '@nestjs/common'; import { WsAdapter } from '@nestjs/platform-ws'; -import { createMock, DeepMocked } from '@golevelup/ts-jest'; +import { createMock } from '@golevelup/ts-jest'; import { Logger } from '@src/core/logger'; +import { RedisConnectionTypeEnum } from '../types'; import { TldrawConfig } from '../config'; import { tldrawTestConfig } from '../testing'; import { TldrawRedisFactory } from './tldraw-redis.factory'; @@ -12,8 +13,7 @@ import { TldrawRedisFactory } from './tldraw-redis.factory'; describe('TldrawRedisFactory', () => { let app: INestApplication; let configService: ConfigService; - let logger: DeepMocked; - let redisFactory: DeepMocked; + let redisFactory: TldrawRedisFactory; beforeAll(async () => { const testingModule = await Test.createTestingModule({ @@ -28,7 +28,6 @@ describe('TldrawRedisFactory', () => { }).compile(); configService = testingModule.get(ConfigService); - logger = testingModule.get(Logger); redisFactory = testingModule.get(TldrawRedisFactory); app = testingModule.createNestApplication(); app.useWebSocketAdapter(new WsAdapter(app)); @@ -43,11 +42,19 @@ describe('TldrawRedisFactory', () => { expect(redisFactory).toBeDefined(); }); - describe('constructor', () => { + describe('build', () => { it('should throw if REDIS_URI is not set', () => { - const configSpy = jest.spyOn(configService, 'get').mockReturnValue(null); + const configSpy = jest.spyOn(configService, 'get').mockReturnValueOnce(null); - expect(() => new TldrawRedisFactory(configService, logger)).toThrow('REDIS_URI is not set'); + expect(() => redisFactory.build(RedisConnectionTypeEnum.PUBLISH)).toThrow('REDIS_URI is not set'); + configSpy.mockRestore(); + }); + + it('should return redis connection', () => { + const configSpy = jest.spyOn(configService, 'get').mockReturnValueOnce('redis://localhost:6379'); + const redis = redisFactory.build(RedisConnectionTypeEnum.PUBLISH); + + expect(redis).toBeDefined(); configSpy.mockRestore(); }); }); diff --git a/apps/server/src/modules/tldraw/redis/tldraw-redis.factory.ts b/apps/server/src/modules/tldraw/redis/tldraw-redis.factory.ts index b5e6ad8c65b..b71a6b401f8 100644 --- a/apps/server/src/modules/tldraw/redis/tldraw-redis.factory.ts +++ b/apps/server/src/modules/tldraw/redis/tldraw-redis.factory.ts @@ -8,19 +8,17 @@ import { RedisConnectionTypeEnum } from '../types'; @Injectable() export class TldrawRedisFactory { - private readonly redisUri: string; - constructor(private readonly configService: ConfigService, private readonly logger: Logger) { this.logger.setContext(TldrawRedisFactory.name); - this.redisUri = this.configService.get('REDIS_URI'); + } - if (!this.redisUri) { + public build(connectionType: RedisConnectionTypeEnum) { + const redisUri = this.configService.get('REDIS_URI'); + if (!redisUri) { throw new Error('REDIS_URI is not set'); } - } - public build(connectionType: RedisConnectionTypeEnum) { - const redis = new Redis(this.redisUri, { + const redis = new Redis(redisUri, { maxRetriesPerRequest: null, }); diff --git a/apps/server/src/modules/tldraw/repo/tldraw-board.repo.spec.ts b/apps/server/src/modules/tldraw/repo/tldraw-board.repo.spec.ts index a7c9b115d7e..eed655808ce 100644 --- a/apps/server/src/modules/tldraw/repo/tldraw-board.repo.spec.ts +++ b/apps/server/src/modules/tldraw/repo/tldraw-board.repo.spec.ts @@ -79,7 +79,7 @@ describe('TldrawBoardRepo', () => { describe('getYDocFromMdb', () => { describe('when taking doc data from db', () => { const setup = () => { - const storeGetYDocSpy = jest.spyOn(repo.mdb, 'getYDoc').mockResolvedValueOnce(new WsSharedDocDo('TEST')); + const storeGetYDocSpy = jest.spyOn(repo.mdb, 'getDocument').mockResolvedValueOnce(new WsSharedDocDo('TEST')); return { storeGetYDocSpy, @@ -89,7 +89,7 @@ describe('TldrawBoardRepo', () => { it('should return ydoc', async () => { const { storeGetYDocSpy } = setup(); - const result = await repo.getYDocFromMdb('test'); + const result = await repo.getDocumentFromDb('test'); expect(result).toBeInstanceOf(Doc); storeGetYDocSpy.mockRestore(); diff --git a/apps/server/src/modules/tldraw/repo/tldraw-board.repo.ts b/apps/server/src/modules/tldraw/repo/tldraw-board.repo.ts index 57ea2d408dd..7d3887feb68 100644 --- a/apps/server/src/modules/tldraw/repo/tldraw-board.repo.ts +++ b/apps/server/src/modules/tldraw/repo/tldraw-board.repo.ts @@ -19,8 +19,8 @@ export class TldrawBoardRepo { await this.mdb.createIndex(); } - public async getYDocFromMdb(docName: string): Promise { - const yDoc = await this.mdb.getYDoc(docName); + public async getDocumentFromDb(docName: string): Promise { + const yDoc = await this.mdb.getDocument(docName); return yDoc; } diff --git a/apps/server/src/modules/tldraw/repo/y-mongodb.spec.ts b/apps/server/src/modules/tldraw/repo/y-mongodb.spec.ts index 669dc930096..7fa11273b2b 100644 --- a/apps/server/src/modules/tldraw/repo/y-mongodb.spec.ts +++ b/apps/server/src/modules/tldraw/repo/y-mongodb.spec.ts @@ -81,15 +81,13 @@ describe('YMongoDb', () => { await em.persistAndFlush(drawing); em.clear(); - const update = new Uint8Array([2, 2]); - - return { drawing, update }; + return { drawing }; }; it('should create new document with updates in the database', async () => { - const { drawing, update } = await setup(); + const { drawing } = await setup(); - await mdb.storeUpdateTransactional(drawing.docName, update); + await mdb.storeUpdateTransactional(drawing.docName, new Uint8Array([])); const docs = await em.findAndCount(TldrawDrawing, { docName: drawing.docName }); expect(docs.length).toEqual(2); @@ -196,7 +194,7 @@ describe('YMongoDb', () => { it('should return ydoc', async () => { const { applyUpdateSpy } = await setup(); - const doc = await mdb.getYDoc('test-name'); + const doc = await mdb.getDocument('test-name'); expect(doc).toBeDefined(); applyUpdateSpy.mockRestore(); @@ -222,7 +220,7 @@ describe('YMongoDb', () => { it('should not return ydoc', async () => { const { applyUpdateSpy } = await setup(); - const doc = await mdb.getYDoc('test-name'); + const doc = await mdb.getDocument('test-name'); expect(doc).toBeUndefined(); applyUpdateSpy.mockRestore(); @@ -247,7 +245,7 @@ describe('YMongoDb', () => { it('should return ydoc from the database', async () => { const { applyUpdateSpy } = await setup(); - const doc = await mdb.getYDoc('test-name'); + const doc = await mdb.getDocument('test-name'); expect(doc).toBeDefined(); applyUpdateSpy.mockRestore(); @@ -255,11 +253,9 @@ describe('YMongoDb', () => { describe('when single entity size is greater than MAX_DOCUMENT_SIZE', () => { it('should return ydoc from the database', async () => { - // @ts-expect-error test-case - mdb.maxDocumentSize = 1; const { applyUpdateSpy } = await setup(); - const doc = await mdb.getYDoc('test-name'); + const doc = await mdb.getDocument('test-name'); expect(doc).toBeDefined(); applyUpdateSpy.mockRestore(); diff --git a/apps/server/src/modules/tldraw/repo/y-mongodb.ts b/apps/server/src/modules/tldraw/repo/y-mongodb.ts index faad33396f7..94f24bb0b7c 100644 --- a/apps/server/src/modules/tldraw/repo/y-mongodb.ts +++ b/apps/server/src/modules/tldraw/repo/y-mongodb.ts @@ -17,10 +17,6 @@ import { KeyFactory } from './key.factory'; @Injectable() export class YMongodb { - private readonly maxDocumentSize: number; - - private readonly gcEnabled: boolean; - private readonly _transact: >(docName: string, fn: () => T) => T; // scope the queue of the transaction to each docName @@ -34,9 +30,6 @@ export class YMongodb { ) { this.logger.setContext(YMongodb.name); - this.gcEnabled = this.configService.get('TLDRAW_GC_ENABLED'); - this.maxDocumentSize = this.configService.get('TLDRAW_MAX_DOCUMENT_SIZE'); - // execute a transaction on a database // this will ensure that other processes are currently not writing this._transact = >(docName: string, fn: () => T): T => { @@ -75,12 +68,13 @@ export class YMongodb { await this.repo.ensureIndexes(); } - public getYDoc(docName: string): Promise { + public getDocument(docName: string): Promise { return this._transact(docName, async (): Promise => { const updates = await this.getMongoUpdates(docName); const mergedUpdates = mergeUpdates(updates); - const ydoc = new WsSharedDocDo(docName, this.gcEnabled); + const gcEnabled = this.configService.get('TLDRAW_GC_ENABLED'); + const ydoc = new WsSharedDocDo(docName, gcEnabled); applyUpdate(ydoc, mergedUpdates); return ydoc; @@ -215,21 +209,22 @@ export class YMongodb { await this.writeStateVector(docName, sv, 0); } + const maxDocumentSize = this.configService.get('TLDRAW_MAX_DOCUMENT_SIZE'); const value = Buffer.from(update); // if our buffer exceeds maxDocumentSize, we store the update in multiple documents - if (value.length <= this.maxDocumentSize) { + if (value.length <= maxDocumentSize) { const uniqueKey = KeyFactory.createForUpdate(docName, clock + 1); await this.repo.put(uniqueKey, { value, }); } else { - const totalChunks = Math.ceil(value.length / this.maxDocumentSize); + const totalChunks = Math.ceil(value.length / maxDocumentSize); const putPromises: Promise[] = []; for (let i = 0; i < totalChunks; i += 1) { - const start = i * this.maxDocumentSize; - const end = Math.min(start + this.maxDocumentSize, value.length); + const start = i * maxDocumentSize; + const end = Math.min(start + maxDocumentSize, value.length); const chunk = value.subarray(start, end); putPromises.push( diff --git a/apps/server/src/modules/tldraw/service/tldraw.ws.service.spec.ts b/apps/server/src/modules/tldraw/service/tldraw.ws.service.spec.ts index f83b1f1c9b8..724b0c1022d 100644 --- a/apps/server/src/modules/tldraw/service/tldraw.ws.service.spec.ts +++ b/apps/server/src/modules/tldraw/service/tldraw.ws.service.spec.ts @@ -142,7 +142,7 @@ describe('TldrawWSService', () => { ws = await TestConnection.setupWs(wsUrl, 'TEST'); const clientMessageMock = 'test-message'; - const closeConSpy = jest.spyOn(service, 'closeConn').mockResolvedValueOnce(); + const closeConSpy = jest.spyOn(service, 'closeConnection').mockResolvedValueOnce(); const sendSpy = jest.spyOn(service, 'send'); const doc = TldrawWsFactory.createWsSharedDocDo(); const byteArray = new TextEncoder().encode(clientMessageMock); @@ -173,7 +173,7 @@ describe('TldrawWSService', () => { const socketMock = TldrawWsFactory.createWebsocket(WebSocketReadyStateEnum.OPEN); const clientMessageMock = 'test-message'; - const closeConSpy = jest.spyOn(service, 'closeConn').mockRejectedValue(new Error('error')); + const closeConSpy = jest.spyOn(service, 'closeConnection').mockRejectedValue(new Error('error')); jest.spyOn(socketMock, 'send').mockImplementation((...args: unknown[]) => { args.forEach((arg) => { if (typeof arg === 'function') { @@ -216,7 +216,7 @@ describe('TldrawWSService', () => { const socketMock = TldrawWsFactory.createWebsocket(WebSocketReadyStateEnum.CLOSED); const clientMessageMock = 'test-message'; - const closeConSpy = jest.spyOn(service, 'closeConn').mockRejectedValue(new Error('error')); + const closeConSpy = jest.spyOn(service, 'closeConnection').mockRejectedValue(new Error('error')); const sendSpy = jest.spyOn(service, 'send'); const errorLogSpy = jest.spyOn(logger, 'warning'); const doc = TldrawWsFactory.createWsSharedDocDo(); @@ -250,7 +250,7 @@ describe('TldrawWSService', () => { describe('when websocket has ready state different than Open (1) or Connecting (0)', () => { const setup = () => { const clientMessageMock = 'test-message'; - const closeConSpy = jest.spyOn(service, 'closeConn'); + const closeConSpy = jest.spyOn(service, 'closeConnection'); const sendSpy = jest.spyOn(service, 'send'); const doc = TldrawWsFactory.createWsSharedDocDo(); const socketMock = TldrawWsFactory.createWebsocket(WebSocketReadyStateEnum.CLOSED); @@ -470,8 +470,8 @@ describe('TldrawWSService', () => { const messageHandlerSpy = jest.spyOn(service, 'messageHandler').mockReturnValueOnce(); const sendSpy = jest.spyOn(service, 'send').mockImplementation(() => {}); - const getYDocSpy = jest.spyOn(service, 'getYDoc').mockResolvedValueOnce(doc); - const closeConnSpy = jest.spyOn(service, 'closeConn').mockResolvedValue(); + const getYDocSpy = jest.spyOn(service, 'getDocument').mockResolvedValueOnce(doc); + const closeConnSpy = jest.spyOn(service, 'closeConnection').mockResolvedValue(); const { msg } = createMessage([0]); jest.spyOn(AwarenessProtocol, 'encodeAwarenessUpdate').mockReturnValueOnce(msg); @@ -486,7 +486,7 @@ describe('TldrawWSService', () => { it('should send to every client', async () => { const { messageHandlerSpy, sendSpy, getYDocSpy, closeConnSpy } = await setup(); - await service.setupWSConnection(ws, 'TEST'); + await service.setupWsConnection(ws, 'TEST'); await delay(20); ws.emit('pong'); @@ -504,7 +504,7 @@ describe('TldrawWSService', () => { describe('on websocket error', () => { const setup = async () => { - boardRepo.getYDocFromMdb.mockResolvedValueOnce(new WsSharedDocDo('TEST')); + boardRepo.getDocumentFromDb.mockResolvedValueOnce(new WsSharedDocDo('TEST')); ws = await TestConnection.setupWs(wsUrl, 'TEST'); const errorLogSpy = jest.spyOn(logger, 'warning'); @@ -515,7 +515,7 @@ describe('TldrawWSService', () => { it('should log error', async () => { const { errorLogSpy } = await setup(); - await service.setupWSConnection(ws, 'TEST'); + await service.setupWsConnection(ws, 'TEST'); ws.emit('error', new Error('error')); expect(errorLogSpy).toHaveBeenCalled(); @@ -526,12 +526,12 @@ describe('TldrawWSService', () => { describe('closeConn', () => { describe('when there is no error', () => { const setup = async () => { - boardRepo.getYDocFromMdb.mockResolvedValueOnce(new WsSharedDocDo('TEST')); + boardRepo.getDocumentFromDb.mockResolvedValueOnce(new WsSharedDocDo('TEST')); ws = await TestConnection.setupWs(wsUrl); boardRepo.compressDocument.mockResolvedValueOnce(); const redisUnsubscribeSpy = jest.spyOn(Ioredis.Redis.prototype, 'unsubscribe').mockResolvedValueOnce(1); - const closeConnSpy = jest.spyOn(service, 'closeConn'); + const closeConnSpy = jest.spyOn(service, 'closeConnection'); jest.spyOn(Ioredis.Redis.prototype, 'subscribe').mockResolvedValueOnce({}); return { @@ -543,7 +543,7 @@ describe('TldrawWSService', () => { it('should close connection', async () => { const { redisUnsubscribeSpy, closeConnSpy } = await setup(); - await service.setupWSConnection(ws, 'TEST'); + await service.setupWsConnection(ws, 'TEST'); expect(closeConnSpy).toHaveBeenCalled(); ws.close(); @@ -552,10 +552,34 @@ describe('TldrawWSService', () => { }); }); + describe('when there are active connections', () => { + const setup = async () => { + const doc = new WsSharedDocDo('TEST'); + ws = await TestConnection.setupWs(wsUrl); + const ws2 = await TestConnection.setupWs(wsUrl); + doc.connections.set(ws, new Set()); + doc.connections.set(ws2, new Set()); + + return { + doc, + }; + }; + + it('should not call compressDocument', async () => { + const { doc } = await setup(); + + await service.closeConnection(doc, ws); + + expect(boardRepo.compressDocument).not.toHaveBeenCalled(); + ws.close(); + }); + }); + describe('when deleteUnusedFilesForDocument fails', () => { const setup = async () => { ws = await TestConnection.setupWs(wsUrl); const doc = TldrawWsFactory.createWsSharedDocDo(); + doc.connections.set(ws, new Set()); const errorLogSpy = jest.spyOn(logger, 'warning'); const storageSpy = jest @@ -572,7 +596,7 @@ describe('TldrawWSService', () => { it('should log error', async () => { const { doc, errorLogSpy } = await setup(); - await service.closeConn(doc, ws); + await service.closeConnection(doc, ws); await delay(100); expect(errorLogSpy).toHaveBeenCalled(); @@ -582,12 +606,12 @@ describe('TldrawWSService', () => { describe('when close connection fails', () => { const setup = async () => { - boardRepo.getYDocFromMdb.mockResolvedValueOnce(new WsSharedDocDo('TEST')); + boardRepo.getDocumentFromDb.mockResolvedValueOnce(new WsSharedDocDo('TEST')); ws = await TestConnection.setupWs(wsUrl); boardRepo.compressDocument.mockResolvedValueOnce(); const redisUnsubscribeSpy = jest.spyOn(Ioredis.Redis.prototype, 'unsubscribe').mockResolvedValueOnce(1); - const closeConnSpy = jest.spyOn(service, 'closeConn').mockRejectedValueOnce(new Error('error')); + const closeConnSpy = jest.spyOn(service, 'closeConnection').mockRejectedValueOnce(new Error('error')); const errorLogSpy = jest.spyOn(logger, 'warning'); const sendSpyError = jest.spyOn(service, 'send').mockReturnValue(); jest.spyOn(Ioredis.Redis.prototype, 'subscribe').mockResolvedValueOnce({}); @@ -603,7 +627,7 @@ describe('TldrawWSService', () => { it('should log error', async () => { const { redisUnsubscribeSpy, closeConnSpy, errorLogSpy, sendSpyError } = await setup(); - await service.setupWSConnection(ws, 'TEST'); + await service.setupWsConnection(ws, 'TEST'); await delay(100); @@ -618,45 +642,6 @@ describe('TldrawWSService', () => { }); }); - describe('when unsubscribing from Redis fails', () => { - const setup = async () => { - ws = await TestConnection.setupWs(wsUrl); - const doc = TldrawWsFactory.createWsSharedDocDo(); - doc.connections.set(ws, new Set()); - - boardRepo.compressDocument.mockResolvedValueOnce(); - const redisUnsubscribeSpy = jest - .spyOn(Ioredis.Redis.prototype, 'unsubscribe') - .mockImplementationOnce((...args: unknown[]) => { - args.forEach((arg) => { - if (typeof arg === 'function') { - arg(new Error('error')); - } - }); - return Promise.resolve(0); - }); - const errorLogSpy = jest.spyOn(logger, 'warning'); - - return { - doc, - redisUnsubscribeSpy, - errorLogSpy, - }; - }; - - it('should log error', async () => { - const { doc, errorLogSpy, redisUnsubscribeSpy } = await setup(); - - await service.closeConn(doc, ws); - - await delay(100); - - expect(redisUnsubscribeSpy).toHaveBeenCalled(); - expect(errorLogSpy).toHaveBeenCalled(); - redisUnsubscribeSpy.mockRestore(); - }); - }); - describe('when unsubscribing from Redis throw error', () => { const setup = async () => { ws = await TestConnection.setupWs(wsUrl); @@ -667,7 +652,7 @@ describe('TldrawWSService', () => { const redisUnsubscribeSpy = jest .spyOn(Ioredis.Redis.prototype, 'unsubscribe') .mockRejectedValue(new Error('error')); - const closeConnSpy = jest.spyOn(service, 'closeConn'); + const closeConnSpy = jest.spyOn(service, 'closeConnection'); const errorLogSpy = jest.spyOn(logger, 'warning'); jest.spyOn(Ioredis.Redis.prototype, 'subscribe').mockResolvedValueOnce({}); @@ -682,7 +667,7 @@ describe('TldrawWSService', () => { it('should log error', async () => { const { doc, errorLogSpy, redisUnsubscribeSpy, closeConnSpy } = await setup(); - await service.closeConn(doc, ws); + await service.closeConnection(doc, ws); await delay(200); expect(redisUnsubscribeSpy).toHaveBeenCalled(); @@ -693,43 +678,13 @@ describe('TldrawWSService', () => { }); }); - describe('when updating new document fails', () => { - const setup = async () => { - boardRepo.getYDocFromMdb.mockResolvedValueOnce(new WsSharedDocDo('test-update-fail')); - ws = await TestConnection.setupWs(wsUrl); - - const closeConnSpy = jest.spyOn(service, 'closeConn'); - const errorLogSpy = jest.spyOn(logger, 'warning'); - const sendSpy = jest.spyOn(service, 'send').mockImplementation(() => {}); - jest.spyOn(Ioredis.Redis.prototype, 'subscribe').mockResolvedValueOnce({}); - - return { - closeConnSpy, - errorLogSpy, - sendSpy, - }; - }; - - it('should log error', async () => { - const { sendSpy, errorLogSpy, closeConnSpy } = await setup(); - - await service.setupWSConnection(ws, 'test-update-fail'); - ws.close(); - - expect(errorLogSpy).toHaveBeenCalled(); - closeConnSpy.mockRestore(); - sendSpy.mockRestore(); - ws.close(); - }); - }); - describe('when pong not received', () => { const setup = async () => { - boardRepo.getYDocFromMdb.mockResolvedValueOnce(new WsSharedDocDo('TEST')); + boardRepo.getDocumentFromDb.mockResolvedValueOnce(new WsSharedDocDo('TEST')); ws = await TestConnection.setupWs(wsUrl, 'TEST'); const messageHandlerSpy = jest.spyOn(service, 'messageHandler').mockReturnValueOnce(); - const closeConnSpy = jest.spyOn(service, 'closeConn').mockImplementation(() => Promise.resolve()); + const closeConnSpy = jest.spyOn(service, 'closeConnection').mockImplementation(() => Promise.resolve()); const pingSpy = jest.spyOn(ws, 'ping').mockImplementationOnce(() => {}); const sendSpy = jest.spyOn(service, 'send').mockImplementation(() => {}); const clearIntervalSpy = jest.spyOn(global, 'clearInterval'); @@ -747,7 +702,7 @@ describe('TldrawWSService', () => { it('should close connection', async () => { const { messageHandlerSpy, closeConnSpy, pingSpy, sendSpy, clearIntervalSpy } = await setup(); - await service.setupWSConnection(ws, 'TEST'); + await service.setupWsConnection(ws, 'TEST'); await delay(20); @@ -764,11 +719,11 @@ describe('TldrawWSService', () => { describe('when pong not received and close connection fails', () => { const setup = async () => { - boardRepo.getYDocFromMdb.mockResolvedValueOnce(new WsSharedDocDo('TEST')); + boardRepo.getDocumentFromDb.mockResolvedValueOnce(new WsSharedDocDo('TEST')); ws = await TestConnection.setupWs(wsUrl, 'TEST'); const messageHandlerSpy = jest.spyOn(service, 'messageHandler').mockReturnValueOnce(); - const closeConnSpy = jest.spyOn(service, 'closeConn').mockRejectedValue(new Error('error')); + const closeConnSpy = jest.spyOn(service, 'closeConnection').mockRejectedValue(new Error('error')); const pingSpy = jest.spyOn(ws, 'ping').mockImplementation(() => {}); const sendSpy = jest.spyOn(service, 'send').mockImplementation(() => {}); const clearIntervalSpy = jest.spyOn(global, 'clearInterval'); @@ -788,7 +743,7 @@ describe('TldrawWSService', () => { it('should log error', async () => { const { messageHandlerSpy, closeConnSpy, pingSpy, sendSpy, clearIntervalSpy, errorLogSpy } = await setup(); - await service.setupWSConnection(ws, 'TEST'); + await service.setupWsConnection(ws, 'TEST'); await delay(200); @@ -836,7 +791,7 @@ describe('TldrawWSService', () => { const { doc, assets } = await setup(); const initialSize = assets.size; - await service.closeConn(doc, ws); + await service.closeConnection(doc, ws); const finalSize = assets.size; expect(initialSize).toBe(2); @@ -845,7 +800,7 @@ describe('TldrawWSService', () => { }); }); - describe('when flushDocument failed', () => { + describe('when compressDocument failed', () => { const setup = async () => { ws = await TestConnection.setupWs(wsUrl, 'TEST'); const doc = TldrawWsFactory.createWsSharedDocDo(); @@ -863,7 +818,7 @@ describe('TldrawWSService', () => { it('should log error', async () => { const { doc, errorLogSpy } = await setup(); - await expect(service.closeConn(doc, ws)).rejects.toThrow('error'); + await service.closeConnection(doc, ws); expect(boardRepo.compressDocument).toHaveBeenCalled(); expect(errorLogSpy).toHaveBeenCalled(); @@ -930,47 +885,6 @@ describe('TldrawWSService', () => { }); }); - describe('when publish to Redis has errors', () => { - const setup = async () => { - ws = await TestConnection.setupWs(wsUrl); - - const sendSpy = jest.spyOn(service, 'send').mockReturnValueOnce(); - const errorLogSpy = jest.spyOn(logger, 'warning'); - const publishSpy = jest - .spyOn(Ioredis.Redis.prototype, 'publish') - .mockImplementationOnce((_channel, _message, cb) => { - if (cb) { - cb(new Error('error')); - } - return Promise.resolve(0); - }); - - const doc = TldrawWsFactory.createWsSharedDocDo(); - const socketMock = TldrawWsFactory.createWebsocket(WebSocketReadyStateEnum.OPEN); - doc.connections.set(socketMock, new Set()); - const msg = new Uint8Array([0]); - - return { - doc, - sendSpy, - socketMock, - msg, - errorLogSpy, - publishSpy, - }; - }; - - it('should log error', async () => { - const { doc, socketMock, msg, errorLogSpy, publishSpy } = await setup(); - - service.updateHandler(msg, socketMock, doc); - - expect(errorLogSpy).toHaveBeenCalled(); - ws.close(); - publishSpy.mockRestore(); - }); - }); - describe('when publish to Redis throws errors', () => { const setup = async () => { ws = await TestConnection.setupWs(wsUrl); @@ -980,14 +894,12 @@ describe('TldrawWSService', () => { const publishSpy = jest.spyOn(Ioredis.Redis.prototype, 'publish').mockRejectedValueOnce(new Error('error')); const doc = TldrawWsFactory.createWsSharedDocDo(); - const socketMock = TldrawWsFactory.createWebsocket(WebSocketReadyStateEnum.OPEN); - doc.connections.set(socketMock, new Set()); + doc.connections.set(ws, new Set()); const msg = new Uint8Array([0]); return { doc, sendSpy, - socketMock, msg, errorLogSpy, publishSpy, @@ -995,9 +907,9 @@ describe('TldrawWSService', () => { }; it('should log error', async () => { - const { doc, socketMock, msg, errorLogSpy, publishSpy } = await setup(); + const { doc, msg, errorLogSpy, publishSpy } = await setup(); - service.updateHandler(msg, socketMock, doc); + service.updateHandler(msg, ws, doc); await delay(20); @@ -1010,7 +922,7 @@ describe('TldrawWSService', () => { describe('messageHandler', () => { describe('when message is received', () => { const setup = async (messageValues: number[]) => { - boardRepo.getYDocFromMdb.mockResolvedValueOnce(new WsSharedDocDo('TEST')); + boardRepo.getDocumentFromDb.mockResolvedValueOnce(new WsSharedDocDo('TEST')); ws = await TestConnection.setupWs(wsUrl, 'TEST'); const errorLogSpy = jest.spyOn(logger, 'warning'); @@ -1036,7 +948,7 @@ describe('TldrawWSService', () => { const { messageHandlerSpy, msg, readSyncMessageSpy, publishSpy } = await setup([0, 1]); publishSpy.mockResolvedValueOnce(1); - await service.setupWSConnection(ws, 'TEST'); + await service.setupWsConnection(ws, 'TEST'); ws.emit('message', msg); await delay(20); @@ -1054,7 +966,7 @@ describe('TldrawWSService', () => { throw new Error('error'); }); - await service.setupWSConnection(ws, 'TEST'); + await service.setupWsConnection(ws, 'TEST'); ws.emit('message', msg); await delay(20); @@ -1069,7 +981,7 @@ describe('TldrawWSService', () => { const { errorLogSpy, publishSpy } = await setup([1, 1]); publishSpy.mockRejectedValueOnce(new Error('error')); - await service.setupWSConnection(ws, 'TEST'); + await service.setupWsConnection(ws, 'TEST'); expect(errorLogSpy).toHaveBeenCalled(); ws.close(); @@ -1080,10 +992,10 @@ describe('TldrawWSService', () => { describe('getYDoc', () => { describe('when getting yDoc by name', () => { it('should assign to service docs map and return instance', async () => { - boardRepo.getYDocFromMdb.mockResolvedValueOnce(new WsSharedDocDo('get-test')); + boardRepo.getDocumentFromDb.mockResolvedValueOnce(new WsSharedDocDo('get-test')); jest.spyOn(Ioredis.Redis.prototype, 'subscribe').mockResolvedValueOnce({}); const docName = 'get-test'; - const doc = await service.getYDoc(docName); + const doc = await service.getDocument(docName); expect(doc).toBeInstanceOf(WsSharedDocDo); expect(service.docs.get(docName)).not.toBeUndefined(); @@ -1091,13 +1003,13 @@ describe('TldrawWSService', () => { describe('when subscribing to redis channel', () => { const setup = () => { - boardRepo.getYDocFromMdb.mockResolvedValueOnce(new WsSharedDocDo('test-redis')); + boardRepo.getDocumentFromDb.mockResolvedValueOnce(new WsSharedDocDo('test-redis')); const doc = new WsSharedDocDo('test-redis'); const redisSubscribeSpy = jest.spyOn(Ioredis.Redis.prototype, 'subscribe').mockResolvedValueOnce(1); const redisOnSpy = jest.spyOn(Ioredis.Redis.prototype, 'on'); const errorLogSpy = jest.spyOn(logger, 'warning'); - boardRepo.getYDocFromMdb.mockResolvedValueOnce(doc); + boardRepo.getDocumentFromDb.mockResolvedValueOnce(doc); return { redisOnSpy, @@ -1106,47 +1018,13 @@ describe('TldrawWSService', () => { }; }; - it('should register new listener', async () => { + it('should subscribe', async () => { const { redisOnSpy, redisSubscribeSpy } = setup(); - const doc = await service.getYDoc('test-redis'); + const doc = await service.getDocument('test-redis'); expect(doc).toBeDefined(); - expect(redisOnSpy).toHaveBeenCalled(); - redisSubscribeSpy.mockRestore(); - redisOnSpy.mockRestore(); - }); - }); - - describe('when subscribing to redis channel fails', () => { - const setup = () => { - const redisSubscribeSpy = jest - .spyOn(Ioredis.Redis.prototype, 'subscribe') - .mockImplementationOnce((...args: unknown[]) => { - args.forEach((arg) => { - if (typeof arg === 'function') { - arg(new Error('error')); - } - }); - return Promise.resolve(0); - }); - const redisOnSpy = jest.spyOn(Ioredis.Redis.prototype, 'on'); - const errorLogSpy = jest.spyOn(logger, 'warning'); - - return { - redisOnSpy, - redisSubscribeSpy, - errorLogSpy, - }; - }; - - it('should log error', async () => { - const { errorLogSpy, redisSubscribeSpy, redisOnSpy } = setup(); - - await service.getYDoc('test-redis-fail'); - expect(redisSubscribeSpy).toHaveBeenCalled(); - expect(errorLogSpy).toHaveBeenCalled(); redisSubscribeSpy.mockRestore(); redisOnSpy.mockRestore(); }); @@ -1155,7 +1033,7 @@ describe('TldrawWSService', () => { describe('when subscribing to redis channel throws error', () => { const setup = () => { - boardRepo.getYDocFromMdb.mockResolvedValueOnce(new WsSharedDocDo('test-redis-fail-2')); + boardRepo.getDocumentFromDb.mockResolvedValueOnce(new WsSharedDocDo('test-redis-fail-2')); const redisSubscribeSpy = jest .spyOn(Ioredis.Redis.prototype, 'subscribe') .mockRejectedValue(new Error('error')); @@ -1172,7 +1050,7 @@ describe('TldrawWSService', () => { it('should log error', async () => { const { errorLogSpy, redisSubscribeSpy, redisOnSpy } = setup(); - await service.getYDoc('test-redis-fail-2'); + await service.getDocument('test-redis-fail-2'); await delay(500); @@ -1190,7 +1068,7 @@ describe('TldrawWSService', () => { const applyAwarenessUpdateSpy = jest.spyOn(AwarenessProtocol, 'applyAwarenessUpdate').mockReturnValueOnce(); const doc = new WsSharedDocDo('TEST'); - doc.awarenessChannel = 'TEST-AWARENESS'; + doc.awarenessChannel = 'TEST-awareness'; return { doc, @@ -1202,8 +1080,8 @@ describe('TldrawWSService', () => { describe('when channel name is the same as docName', () => { it('should call applyUpdate', () => { const { doc, applyUpdateSpy } = setup(); - - service.redisMessageHandler(Buffer.from('TEST'), Buffer.from('message'), doc); + service.docs.set('TEST', doc); + service.redisMessageHandler(Buffer.from('TEST'), Buffer.from('message')); expect(applyUpdateSpy).toHaveBeenCalled(); }); @@ -1212,12 +1090,61 @@ describe('TldrawWSService', () => { describe('when channel name is the same as docAwarenessChannel name', () => { it('should call applyAwarenessUpdate', () => { const { doc, applyAwarenessUpdateSpy } = setup(); - - service.redisMessageHandler(Buffer.from('TEST-AWARENESS'), Buffer.from('message'), doc); + service.docs.set('TEST', doc); + service.redisMessageHandler(Buffer.from('TEST-awareness'), Buffer.from('message')); expect(applyAwarenessUpdateSpy).toHaveBeenCalled(); }); }); + + describe('when channel name is not found as document name', () => { + it('should not call applyUpdate or applyAwarenessUpdate', () => { + const { doc, applyUpdateSpy, applyAwarenessUpdateSpy } = setup(); + service.docs.set('TEST', doc); + service.redisMessageHandler(Buffer.from('NOTFOUND'), Buffer.from('message')); + + expect(applyUpdateSpy).not.toHaveBeenCalled(); + expect(applyAwarenessUpdateSpy).not.toHaveBeenCalled(); + }); + }); + }); + + describe('updateHandler', () => { + describe('when update comes from connected websocket', () => { + const setup = async () => { + ws = await TestConnection.setupWs(wsUrl, 'TEST'); + + const doc = new WsSharedDocDo('TEST'); + doc.connections.set(ws, new Set()); + const publishSpy = jest.spyOn(Ioredis.Redis.prototype, 'publish'); + const errorLogSpy = jest.spyOn(logger, 'warning'); + + return { + doc, + publishSpy, + errorLogSpy, + }; + }; + + it('should publish update to redis', async () => { + const { doc, publishSpy } = await setup(); + + service.updateHandler(new Uint8Array([]), ws, doc); + + expect(publishSpy).toHaveBeenCalled(); + ws.close(); + }); + + it('should log error on failed publish', async () => { + const { doc, publishSpy, errorLogSpy } = await setup(); + publishSpy.mockRejectedValueOnce(new Error('error')); + + service.updateHandler(new Uint8Array([]), ws, doc); + + expect(errorLogSpy).toHaveBeenCalled(); + ws.close(); + }); + }); }); describe('awarenessUpdateHandler', () => { diff --git a/apps/server/src/modules/tldraw/service/tldraw.ws.service.ts b/apps/server/src/modules/tldraw/service/tldraw.ws.service.ts index cc11442842a..69cca7c267f 100644 --- a/apps/server/src/modules/tldraw/service/tldraw.ws.service.ts +++ b/apps/server/src/modules/tldraw/service/tldraw.ws.service.ts @@ -2,7 +2,7 @@ import { Injectable } from '@nestjs/common'; import { ConfigService } from '@nestjs/config'; import WebSocket from 'ws'; import { applyAwarenessUpdate, encodeAwarenessUpdate, removeAwarenessStates } from 'y-protocols/awareness'; -import { decoding, encoding, map } from 'lib0'; +import { decoding, encoding } from 'lib0'; import { readSyncMessage, writeSyncStep1, writeUpdate } from 'y-protocols/sync'; import { applyUpdate, encodeStateAsUpdate } from 'yjs'; import { Buffer } from 'node:buffer'; @@ -25,6 +25,7 @@ import { TldrawAsset, TldrawShape, UpdateOrigin, + UpdateType, WSMessageType, } from '../types'; import { WsSharedDocDo } from '../domain'; @@ -34,9 +35,7 @@ import { TldrawFilesStorageAdapterService } from './tldraw-files-storage.service @Injectable() export class TldrawWsService { - public docs = new Map(); - - private readonly pingTimeout: number; + public docs = new Map(); public readonly sub: Redis; @@ -51,35 +50,36 @@ export class TldrawWsService { private readonly filesStorageTldrawAdapterService: TldrawFilesStorageAdapterService ) { this.logger.setContext(TldrawWsService.name); - this.pingTimeout = this.configService.get('TLDRAW_PING_TIMEOUT'); this.sub = this.tldrawRedisFactory.build(RedisConnectionTypeEnum.SUBSCRIBE); this.pub = this.tldrawRedisFactory.build(RedisConnectionTypeEnum.PUBLISH); + + this.sub.on('messageBuffer', (channel, message) => this.redisMessageHandler(channel, message)); } - public async closeConn(doc: WsSharedDocDo, ws: WebSocket): Promise { + public async closeConnection(doc: WsSharedDocDo, ws: WebSocket): Promise { if (doc.connections.has(ws)) { const controlledIds = doc.connections.get(ws); doc.connections.delete(ws); removeAwarenessStates(doc.awareness, this.forceToArray(controlledIds), null); - await this.storeStateAndDestroyYDocIfPersisted(doc); + await this.finalizeIfNoConnections(doc); this.metricsService.decrementNumberOfUsersOnServerCounter(); } ws.close(); } - public send(doc: WsSharedDocDo, conn: WebSocket, message: Uint8Array): void { - if (this.isClosedOrClosing(conn)) { - this.closeConn(doc, conn).catch((err) => { + public send(doc: WsSharedDocDo, ws: WebSocket, message: Uint8Array): void { + if (this.isClosedOrClosing(ws)) { + this.closeConnection(doc, ws).catch((err) => { this.logger.warning(new CloseConnectionLoggable('send | isClosedOrClosing', err)); }); } - conn.send(message, (err) => { + ws.send(message, (err) => { if (err) { - this.closeConn(doc, conn).catch((e) => { + this.closeConnection(doc, ws).catch((e) => { this.logger.warning(new CloseConnectionLoggable('send', e)); }); } @@ -88,17 +88,16 @@ export class TldrawWsService { public updateHandler(update: Uint8Array, origin, doc: WsSharedDocDo): void { if (this.isFromConnectedWebSocket(doc, origin)) { - this.publishUpdateToRedis(doc, update, 'document'); + this.publishUpdateToRedis(doc, update, UpdateType.DOCUMENT); } - this.propagateUpdate(update, doc); + this.sendUpdateToConnectedClients(update, doc); } public async databaseUpdateHandler(docName: string, update: Uint8Array, origin) { if (this.isFromRedis(origin)) { return; } - await this.tldrawBoardRepo.storeUpdate(docName, update); } @@ -112,34 +111,35 @@ export class TldrawWsService { this.sendAwarenessMessage(buff, doc); }; - public async getYDoc(docName: string) { - const wsSharedDocDo = await map.setIfUndefined(this.docs, docName, async () => { - const doc = await this.tldrawBoardRepo.getYDocFromMdb(docName); + public async getDocument(docName: string) { + const existingDoc = this.docs.get(docName); + if (existingDoc) { + return existingDoc; + } - this.registerAwarenessUpdateHandler(doc); - this.registerUpdateHandler(doc); - this.subscribeToRedisChannels(doc); - this.registerDatabaseUpdateHandler(doc); + const doc = await this.tldrawBoardRepo.getDocumentFromDb(docName); - this.docs.set(docName, doc); - this.metricsService.incrementNumberOfBoardsOnServerCounter(); - return doc; - }); + this.registerAwarenessUpdateHandler(doc); + this.registerUpdateHandler(doc); + this.subscribeToRedisChannels(doc); + this.registerDatabaseUpdateHandler(doc); - return wsSharedDocDo; + this.docs.set(docName, doc); + this.metricsService.incrementNumberOfBoardsOnServerCounter(); + return doc; } public async createDbIndex(): Promise { await this.tldrawBoardRepo.createDbIndex(); } - public messageHandler(conn: WebSocket, doc: WsSharedDocDo, message: Uint8Array): void { + public messageHandler(ws: WebSocket, doc: WsSharedDocDo, message: Uint8Array): void { const encoder = encoding.createEncoder(); const decoder = decoding.createDecoder(message); const messageType = decoding.readVarUint(decoder); switch (messageType) { case WSMessageType.SYNC: - this.handleSyncMessage(doc, encoder, decoder, conn); + this.handleSyncMessage(doc, encoder, decoder, ws); break; case WSMessageType.AWARENESS: { this.handleAwarenessMessage(doc, decoder); @@ -154,41 +154,44 @@ export class TldrawWsService { doc: WsSharedDocDo, encoder: encoding.Encoder, decoder: decoding.Decoder, - conn: WebSocket + ws: WebSocket ): void { encoding.writeVarUint(encoder, WSMessageType.SYNC); - readSyncMessage(decoder, encoder, doc, conn); + readSyncMessage(decoder, encoder, doc, ws); // If the `encoder` only contains the type of reply message and no // message, there is no need to send the message. When `encoder` only // contains the type of reply, its length is 1. if (encoding.length(encoder) > 1) { - this.send(doc, conn, encoding.toUint8Array(encoder)); + this.send(doc, ws, encoding.toUint8Array(encoder)); } } private handleAwarenessMessage(doc: WsSharedDocDo, decoder: decoding.Decoder) { const update = decoding.readVarUint8Array(decoder); - this.publishUpdateToRedis(doc, update, 'awareness'); + this.publishUpdateToRedis(doc, update, UpdateType.AWARENESS); } - public redisMessageHandler = (channel: Buffer, update: Buffer, doc: WsSharedDocDo): void => { + public redisMessageHandler = (channel: Buffer, update: Buffer): void => { const channelId = channel.toString(); - - if (channelId === doc.name) { - applyUpdate(doc, update, UpdateOrigin.REDIS); + const docName = channel.toString().split('-')[0]; + const doc = this.docs.get(docName); + if (!doc) { + return; } - if (channelId === doc.awarenessChannel) { + if (channelId.includes(UpdateType.AWARENESS)) { applyAwarenessUpdate(doc.awareness, update, UpdateOrigin.REDIS); + } else { + applyUpdate(doc, update, UpdateOrigin.REDIS); } }; - public async setupWSConnection(ws: WebSocket, docName: string) { + public async setupWsConnection(ws: WebSocket, docName: string) { ws.binaryType = 'arraybuffer'; // get doc, initialize if it does not exist yet - const doc = await this.getYDoc(docName); + const doc = await this.getDocument(docName); doc.connections.set(ws, new Set()); ws.on('error', (err) => { @@ -207,6 +210,7 @@ export class TldrawWsService { this.sendInitialState(ws, doc); // check if connection is still alive + const pingTimeout = this.configService.get('TLDRAW_PING_TIMEOUT'); let pongReceived = true; const pingInterval = setInterval(() => { if (pongReceived && doc.connections.has(ws)) { @@ -215,14 +219,14 @@ export class TldrawWsService { return; } - this.closeConn(doc, ws).catch((err) => { + this.closeConnection(doc, ws).catch((err) => { this.logger.warning(new CloseConnectionLoggable('pingInterval', err)); }); clearInterval(pingInterval); - }, this.pingTimeout); + }, pingTimeout); ws.on('close', () => { - this.closeConn(doc, ws).catch((err) => { + this.closeConnection(doc, ws).catch((err) => { this.logger.warning(new CloseConnectionLoggable('websocket close', err)); }); clearInterval(pingInterval); @@ -252,24 +256,25 @@ export class TldrawWsService { this.metricsService.incrementNumberOfUsersOnServerCounter(); } - private async storeStateAndDestroyYDocIfPersisted(doc: WsSharedDocDo) { - if (doc.connections.size === 0) { - // if persisted, we store state and destroy yDoc - try { - const usedAssets = this.syncDocumentAssetsWithShapes(doc); + private async finalizeIfNoConnections(doc: WsSharedDocDo) { + if (doc.connections.size > 0) { + return; + } - await this.tldrawBoardRepo.compressDocument(doc.name); - this.unsubscribeFromRedisChannels(doc); + try { + const usedAssets = this.syncDocumentAssetsWithShapes(doc); + await this.tldrawBoardRepo.compressDocument(doc.name); + this.unsubscribeFromRedisChannels(doc); + if (this.configService.get('TLDRAW_ASSETS_SYNC_ENABLED')) { void this.filesStorageTldrawAdapterService.deleteUnusedFilesForDocument(doc.name, usedAssets).catch((err) => { this.logger.warning(new FileStorageErrorLoggable(doc.name, err)); }); - doc.destroy(); - } catch (err) { - this.logger.warning(new WsSharedDocErrorLoggable(doc.name, 'Error while flushing doc', err)); - throw err; } - + } catch (err) { + this.logger.warning(new WsSharedDocErrorLoggable(doc.name, 'Error while finalizing document', err)); + } finally { + doc.destroy(); this.docs.delete(doc.name); this.metricsService.decrementNumberOfBoardsOnServerCounter(); } @@ -283,35 +288,35 @@ export class TldrawWsService { const usedShapesAsAssets: TldrawShape[] = []; const usedAssets: TldrawAsset[] = []; - shapes.forEach((shape) => { + for (const [, shape] of shapes) { if (shape.assetId) { usedShapesAsAssets.push(shape); } - }); + } doc.transact(() => { - assets.forEach((asset) => { - const foundAsset = usedShapesAsAssets.find((shape) => shape.assetId === asset.id); + for (const [, asset] of assets) { + const foundAsset = usedShapesAsAssets.some((shape) => shape.assetId === asset.id); if (!foundAsset) { assets.delete(asset.id); } else { usedAssets.push(asset); } - }); + } }); return usedAssets; } - private propagateUpdate(update: Uint8Array, doc: WsSharedDocDo): void { + private sendUpdateToConnectedClients(update: Uint8Array, doc: WsSharedDocDo): void { const encoder = encoding.createEncoder(); encoding.writeVarUint(encoder, WSMessageType.SYNC); writeUpdate(encoder, update); const message = encoding.toUint8Array(encoder); - doc.connections.forEach((_, conn) => { + for (const [conn] of doc.connections) { this.send(doc, conn, message); - }); + } } private prepareAwarenessMessage(changedClients: number[], doc: WsSharedDocDo): Uint8Array { @@ -322,29 +327,31 @@ export class TldrawWsService { return message; } - private sendAwarenessMessage(buff: Uint8Array, doc: WsSharedDocDo): void { - doc.connections.forEach((_, c) => { - this.send(doc, c, buff); - }); + private sendAwarenessMessage(message: Uint8Array, doc: WsSharedDocDo): void { + for (const [conn] of doc.connections) { + this.send(doc, conn, message); + } } private manageClientsConnections( connectionsUpdate: AwarenessConnectionsUpdate, - wsConnection: WebSocket | null, + ws: WebSocket | null, doc: WsSharedDocDo ): number[] { const changedClients = connectionsUpdate.added.concat(connectionsUpdate.updated, connectionsUpdate.removed); - if (wsConnection !== null) { - const connControlledIDs = doc.connections.get(wsConnection); + if (ws !== null) { + const connControlledIDs = doc.connections.get(ws); if (connControlledIDs !== undefined) { - connectionsUpdate.added.forEach((clientID) => { + for (const clientID of connectionsUpdate.added) { connControlledIDs.add(clientID); - }); - connectionsUpdate.removed.forEach((clientID) => { + } + + for (const clientID of connectionsUpdate.removed) { connControlledIDs.delete(clientID); - }); + } } } + return changedClients; } @@ -363,45 +370,22 @@ export class TldrawWsService { } private subscribeToRedisChannels(doc: WsSharedDocDo) { - this.sub - .subscribe(doc.name, doc.awarenessChannel, (err) => { - if (err) { - this.logger.warning(new WsSharedDocErrorLoggable(doc.name, 'Error while subscribing to Redis channels', err)); - } - }) - .catch((err) => { - this.logger.warning(new WsSharedDocErrorLoggable(doc.name, 'Error while subscribing to Redis channels', err)); - }); - this.sub.on('messageBuffer', (channel, message) => this.redisMessageHandler(channel, message, doc)); + this.sub.subscribe(doc.name, doc.awarenessChannel).catch((err) => { + this.logger.warning(new WsSharedDocErrorLoggable(doc.name, 'Error while subscribing to Redis channels', err)); + }); } private unsubscribeFromRedisChannels(doc: WsSharedDocDo) { - this.sub - .unsubscribe(doc.name, doc.awarenessChannel, (err) => { - if (err) { - this.logger.warning( - new WsSharedDocErrorLoggable(doc.name, 'Error while unsubscribing from Redis channels', err) - ); - } - }) - .catch((err) => { - this.logger.warning( - new WsSharedDocErrorLoggable(doc.name, 'Error while unsubscribing from Redis channels', err) - ); - }); + this.sub.unsubscribe(doc.name, doc.awarenessChannel).catch((err) => { + this.logger.warning(new WsSharedDocErrorLoggable(doc.name, 'Error while unsubscribing from Redis channels', err)); + }); } - private publishUpdateToRedis(doc: WsSharedDocDo, update: Uint8Array, type: 'awareness' | 'document') { - const channel = type === 'awareness' ? doc.awarenessChannel : doc.name; - this.pub - .publish(channel, Buffer.from(update), (err) => { - if (err) { - this.logger.warning(new RedisPublishErrorLoggable('awareness', err)); - } - }) - .catch((err) => { - this.logger.warning(new RedisPublishErrorLoggable('awareness', err)); - }); + private publishUpdateToRedis(doc: WsSharedDocDo, update: Uint8Array, type: UpdateType) { + const channel = type === UpdateType.AWARENESS ? doc.awarenessChannel : doc.name; + this.pub.publish(channel, Buffer.from(update)).catch((err) => { + this.logger.warning(new RedisPublishErrorLoggable(type, err)); + }); } private sendInitialState(ws: WebSocket, doc: WsSharedDocDo): void { @@ -411,8 +395,8 @@ export class TldrawWsService { this.send(doc, ws, encoding.toUint8Array(encoder)); } - private isClosedOrClosing(connection: WebSocket): boolean { - return connection.readyState === WebSocket.CLOSING || connection.readyState === WebSocket.CLOSED; + private isClosedOrClosing(ws: WebSocket): boolean { + return ws.readyState === WebSocket.CLOSING || ws.readyState === WebSocket.CLOSED; } private forceToArray(connections: Set | undefined): number[] { diff --git a/apps/server/src/modules/tldraw/testing/testConfig.ts b/apps/server/src/modules/tldraw/testing/testConfig.ts index e3557bfbaf5..82768116188 100644 --- a/apps/server/src/modules/tldraw/testing/testConfig.ts +++ b/apps/server/src/modules/tldraw/testing/testConfig.ts @@ -7,6 +7,6 @@ export const tldrawTestConfig = () => { } conf.TLDRAW_DB_COMPRESS_THRESHOLD = 2; conf.TLDRAW_PING_TIMEOUT = 0; - conf.TLDRAW_MAX_DOCUMENT_SIZE = 3; + conf.TLDRAW_MAX_DOCUMENT_SIZE = 1; return conf; }; diff --git a/apps/server/src/modules/tldraw/types/index.ts b/apps/server/src/modules/tldraw/types/index.ts index ed5a4a4b6b5..ed1bf3d3226 100644 --- a/apps/server/src/modules/tldraw/types/index.ts +++ b/apps/server/src/modules/tldraw/types/index.ts @@ -3,5 +3,5 @@ export * from './connection-enum'; export * from './y-transaction-type'; export * from './ws-close-enum'; export * from './awareness-connections-update-type'; -export * from './redis-connection-type.enum'; -export * from './update-origin-enum'; +export * from './redis-connection-type-enum'; +export * from './update-enums'; diff --git a/apps/server/src/modules/tldraw/types/redis-connection-type.enum.ts b/apps/server/src/modules/tldraw/types/redis-connection-type-enum.ts similarity index 100% rename from apps/server/src/modules/tldraw/types/redis-connection-type.enum.ts rename to apps/server/src/modules/tldraw/types/redis-connection-type-enum.ts diff --git a/apps/server/src/modules/tldraw/types/update-enums.ts b/apps/server/src/modules/tldraw/types/update-enums.ts new file mode 100644 index 00000000000..826bfe7039c --- /dev/null +++ b/apps/server/src/modules/tldraw/types/update-enums.ts @@ -0,0 +1,8 @@ +export enum UpdateOrigin { + REDIS = 'redis', +} + +export enum UpdateType { + AWARENESS = 'awareness', + DOCUMENT = 'document', +} diff --git a/apps/server/src/modules/tldraw/types/update-origin-enum.ts b/apps/server/src/modules/tldraw/types/update-origin-enum.ts deleted file mode 100644 index 9b2cc0aa505..00000000000 --- a/apps/server/src/modules/tldraw/types/update-origin-enum.ts +++ /dev/null @@ -1,3 +0,0 @@ -export enum UpdateOrigin { - REDIS = 'redis', -} diff --git a/apps/server/src/modules/tldraw/types/ws-close-enum.ts b/apps/server/src/modules/tldraw/types/ws-close-enum.ts index 97b40f4bf57..4c0d7f2f10d 100644 --- a/apps/server/src/modules/tldraw/types/ws-close-enum.ts +++ b/apps/server/src/modules/tldraw/types/ws-close-enum.ts @@ -1,12 +1,13 @@ -export enum WsCloseCodeEnum { - WS_CLIENT_BAD_REQUEST_CODE = 4400, - WS_CLIENT_UNAUTHORISED_CONNECTION_CODE = 4401, - WS_CLIENT_NOT_FOUND_CODE = 4404, - WS_CLIENT_FAILED_CONNECTION_CODE = 4500, +export enum WsCloseCode { + BAD_REQUEST = 4400, + UNAUTHORIZED = 4401, + NOT_FOUND = 4404, + INTERNAL_SERVER_ERROR = 4500, } -export enum WsCloseMessageEnum { - WS_CLIENT_BAD_REQUEST_MESSAGE = 'Document name is mandatory in url or Tldraw Tool is turned off.', - WS_CLIENT_UNAUTHORISED_CONNECTION_MESSAGE = "Unauthorised connection - you don't have permission to this drawing.", - WS_CLIENT_NOT_FOUND_MESSAGE = 'Drawing not found.', - WS_CLIENT_FAILED_CONNECTION_MESSAGE = 'Unable to establish websocket connection. Try again later.', +export enum WsCloseMessage { + FEATURE_DISABLED = 'Tldraw feature is disabled.', + BAD_REQUEST = 'Room name param not found in url.', + UNAUTHORIZED = "You don't have permission to this drawing.", + NOT_FOUND = 'Drawing not found.', + INTERNAL_SERVER_ERROR = 'Unable to establish websocket connection.', } diff --git a/config/default.schema.json b/config/default.schema.json index 12633cc568c..20e77316332 100644 --- a/config/default.schema.json +++ b/config/default.schema.json @@ -1521,6 +1521,10 @@ "type": "boolean", "description": "Enables uploading assets to tldraw board" }, + "ASSETS_SYNC_ENABLED": { + "type": "boolean", + "description": "Enables synchronization of tldraw board assets with file storage" + }, "ASSETS_MAX_SIZE": { "type": "integer", "description": "Maximum asset size in bytes" @@ -1538,6 +1542,7 @@ "DB_COMPRESS_THRESHOLD": 400, "MAX_DOCUMENT_SIZE": 15000000, "ASSETS_ENABLED": true, + "ASSETS_SYNC_ENABLED": false, "ASSETS_MAX_SIZE": 10485760, "ASSETS_ALLOWED_MIME_TYPES_LIST": "image/png,image/jpeg,image/gif,image/svg+xml" } diff --git a/config/test.json b/config/test.json index 99c4c310e8e..9ee70da1159 100644 --- a/config/test.json +++ b/config/test.json @@ -73,6 +73,7 @@ "DB_COMPRESS_THRESHOLD": 400, "MAX_DOCUMENT_SIZE": 15000000, "ASSETS_ENABLED": true, + "ASSETS_SYNC_ENABLED": true, "ASSETS_MAX_SIZE": 25000000, "ASSETS_ALLOWED_MIME_TYPES_LIST": "" },