diff --git a/package-lock.json b/package-lock.json index a612bb3..d30d101 100644 --- a/package-lock.json +++ b/package-lock.json @@ -373,9 +373,9 @@ } }, "node_modules/@electric-sql/pglite": { - "version": "0.2.6", - "resolved": "https://registry.npmjs.org/@electric-sql/pglite/-/pglite-0.2.6.tgz", - "integrity": "sha512-tyWWxj1Z1Pd4BqBZL1ER2SXaCn5s9N0bxTQCAkGaaWe8r9EEe1bNs20RAG3/+ZeBJtDrk8y5xjocactL+4aIXg==", + "version": "0.2.7", + "resolved": "https://registry.npmjs.org/@electric-sql/pglite/-/pglite-0.2.7.tgz", + "integrity": "sha512-8Il//XHTAtZ8VeQF+6P1UjsIoaAJyO4LwOMoXhSFaHpmkwKs63cUhHHNzLzUmcZvP/ZTmlT3+xTiWfU/EyoxwQ==", "license": "Apache-2.0" }, "node_modules/@esbuild/aix-ppc64": { @@ -4992,7 +4992,7 @@ "license": "MIT", "devDependencies": { "@biomejs/biome": "^1.8.3", - "@electric-sql/pglite": "^0.2.6", + "@electric-sql/pglite": "^0.2.7", "@nodeweb/knex": "^3.1.0-alpha.13", "@nodeweb/pg": "^8.12.0-alpha.5", "@std/bytes": "npm:@jsr/std__bytes@^1.0.2", diff --git a/packages/pg-gateway/package.json b/packages/pg-gateway/package.json index 38d6f9b..07cd2cc 100644 --- a/packages/pg-gateway/package.json +++ b/packages/pg-gateway/package.json @@ -42,7 +42,7 @@ }, "devDependencies": { "@biomejs/biome": "^1.8.3", - "@electric-sql/pglite": "^0.2.6", + "@electric-sql/pglite": "^0.2.7", "@nodeweb/knex": "^3.1.0-alpha.13", "@nodeweb/pg": "^8.12.0-alpha.5", "@std/bytes": "npm:@jsr/std__bytes@^1.0.2", diff --git a/packages/pg-gateway/test/node/certs.ts b/packages/pg-gateway/test/node/certs.ts index eb9c237..3ef096b 100644 --- a/packages/pg-gateway/test/node/certs.ts +++ b/packages/pg-gateway/test/node/certs.ts @@ -145,3 +145,24 @@ export async function signCert(caCert: ArrayBuffer, caKey: ArrayBuffer, csr: Arr return certBytes; } + +export async function generateAllCertificates() { + const { caKey, caCert } = await generateCA('My Root CA'); + + const { key: serverKey, csr: serverCsr } = await generateCSR('localhost'); + const serverCert = await signCert(caCert, caKey, serverCsr); + + const { key: clientKey, csr: clientCsr } = await generateCSR('postgres'); + const clientCert = await signCert(caCert, caKey, clientCsr); + + const encoder = new TextEncoder(); + + return { + caKey: encoder.encode(toPEM(caKey, 'PRIVATE KEY')), + caCert: encoder.encode(toPEM(caCert, 'CERTIFICATE')), + serverKey: encoder.encode(toPEM(serverKey, 'PRIVATE KEY')), + serverCert: encoder.encode(toPEM(serverCert, 'CERTIFICATE')), + clientKey: encoder.encode(toPEM(clientKey, 'PRIVATE KEY')), + clientCert: encoder.encode(toPEM(clientCert, 'CERTIFICATE')), + }; +} diff --git a/packages/pg-gateway/test/node/errors.test.ts b/packages/pg-gateway/test/node/errors.test.ts new file mode 100644 index 0000000..cb85611 --- /dev/null +++ b/packages/pg-gateway/test/node/errors.test.ts @@ -0,0 +1,110 @@ +import { BackendError } from 'pg-gateway'; +import { describe, expect, it, vi } from 'vitest'; +import { createPostgresClient, createPostgresServer, getPort } from '../util'; +import { generateAllCertificates } from './certs'; + +describe('errors', () => { + it('sends backend error thrown in onTlsUpgrade to the client', async () => { + const { caCert, serverKey, serverCert } = await generateAllCertificates(); + await using server = await createPostgresServer({ + tls: { + cert: serverCert, + key: serverKey, + }, + async onTlsUpgrade() { + throw BackendError.create({ + message: 'onTlsUpgrade failed', + code: 'P0000', + severity: 'FATAL', + }); + }, + }); + const promise = createPostgresClient({ + port: getPort(server), + ssl: { + ca: Buffer.from(caCert), + }, + }); + await expect(promise).rejects.toThrow('onTlsUpgrade failed'); + }); + + it('sends backend error thrown in onAuthenticated to the client', async () => { + await using server = await createPostgresServer({ + async onAuthenticated() { + throw BackendError.create({ + message: 'onAuthenticated failed', + code: 'P0000', + severity: 'FATAL', + }); + }, + }); + const promise = createPostgresClient({ + port: getPort(server), + }); + await expect(promise).rejects.toThrow('onAuthenticated failed'); + }); + + it('sends backend error thrown in onStartup to the client', async () => { + await using server = await createPostgresServer({ + async onStartup() { + throw BackendError.create({ + message: 'onStartup failed', + code: 'P0000', + severity: 'FATAL', + }); + }, + }); + const promise = createPostgresClient({ + port: getPort(server), + }); + await expect(promise).rejects.toThrow('onStartup failed'); + }); + + it('sends backend error thrown in onMessage to the client', async () => { + await using server = await createPostgresServer({ + async onMessage() { + throw BackendError.create({ + message: 'onMessage failed', + code: 'P0000', + severity: 'FATAL', + }); + }, + }); + const promise = createPostgresClient({ + port: getPort(server), + }); + await expect(promise).rejects.toThrow('onMessage failed'); + }); + + const mockOutput = () => { + const output = { + stderr: '', + [Symbol.dispose]() { + consoleErrorMock.mockRestore(); + }, + }; + const consoleErrorMock = vi.spyOn(console, 'error').mockImplementation((...args) => { + output.stderr += args.join(' '); + }); + return output; + }; + + it('does not send non backend errors to the client', async () => { + using output = mockOutput(); + await using server = await createPostgresServer({ + async onMessage() { + throw Error('wat?'); + }, + }); + const promise = createPostgresClient({ + port: getPort(server), + }); + try { + await promise; + } catch (error) { + expect(error.message).not.toContain('wat?'); + expect(output.stderr).toContain('wat?'); + expect(error.message).toContain('Connection terminated unexpectedly'); + } + }); +}); diff --git a/packages/pg-gateway/test/node/tls.test.ts b/packages/pg-gateway/test/node/tls.test.ts index 57ed9c8..bd9e11b 100644 --- a/packages/pg-gateway/test/node/tls.test.ts +++ b/packages/pg-gateway/test/node/tls.test.ts @@ -1,64 +1,16 @@ -import { createServer, type Server } from 'node:net'; -import type { ClientConfig } from 'pg'; -import { type PostgresConnectionOptions, createDuplexPair } from 'pg-gateway'; -import { fromDuplexStream, fromNodeSocket } from 'pg-gateway/node'; +import { createDuplexPair } from 'pg-gateway'; +import { fromDuplexStream } from 'pg-gateway/node'; import { describe, expect, it } from 'vitest'; -import { DisposablePgClient, socketFromDuplexStream } from '../util'; -import { generateCA, generateCSR, signCert, toPEM } from './certs'; -import { once } from 'node:events'; - -async function generateAllCertificates() { - const { caKey, caCert } = await generateCA('My Root CA'); - - const { key: serverKey, csr: serverCsr } = await generateCSR('localhost'); - const serverCert = await signCert(caCert, caKey, serverCsr); - - const { key: clientKey, csr: clientCsr } = await generateCSR('postgres'); - const clientCert = await signCert(caCert, caKey, clientCsr); - - const encoder = new TextEncoder(); - - return { - caKey: encoder.encode(toPEM(caKey, 'PRIVATE KEY')), - caCert: encoder.encode(toPEM(caCert, 'CERTIFICATE')), - serverKey: encoder.encode(toPEM(serverKey, 'PRIVATE KEY')), - serverCert: encoder.encode(toPEM(serverCert, 'CERTIFICATE')), - clientKey: encoder.encode(toPEM(clientKey, 'PRIVATE KEY')), - clientCert: encoder.encode(toPEM(clientCert, 'CERTIFICATE')), - }; -} +import { + createPostgresClient, + createPostgresServer, + getPort, + socketFromDuplexStream, +} from '../util'; +import { generateAllCertificates } from './certs'; const { caCert, serverKey, serverCert, clientKey, clientCert } = await generateAllCertificates(); -async function createPostgresServer(options?: PostgresConnectionOptions) { - const server = createServer((socket) => fromNodeSocket(socket, options)); - - // Listen on a random free port - server.listen(0); - await once(server, 'listening'); - return server; -} - -function getPort(server: Server) { - const address = server.address(); - - if (typeof address !== 'object') { - throw new Error(`Invalid server address '${address}'`); - } - - if (!address) { - throw new Error('Server has no address'); - } - - return address.port; -} - -async function connectPg(config: string | ClientConfig) { - const client = new DisposablePgClient(config); - await client.connect(); - return client; -} - describe('tls', () => { it('basic tls over tcp', async () => { await using server = await createPostgresServer({ @@ -71,7 +23,7 @@ describe('tls', () => { }, }); - await using client = await connectPg({ + await using client = await createPostgresClient({ port: getPort(server), ssl: { ca: Buffer.from(caCert), @@ -93,7 +45,7 @@ describe('tls', () => { }, }); - await using client = await connectPg({ + await using client = await createPostgresClient({ host: 'localhost', port: getPort(server), ssl: { @@ -116,7 +68,7 @@ describe('tls', () => { }, }); - await using client = await connectPg({ + await using client = await createPostgresClient({ host: '127.0.0.1', port: getPort(server), ssl: { @@ -137,7 +89,7 @@ describe('tls', () => { }, }); - await using client = await connectPg({ + await using client = await createPostgresClient({ port: getPort(server), user: 'postgres', ssl: { @@ -160,7 +112,7 @@ describe('tls', () => { }, }); - const promise = connectPg({ + const promise = createPostgresClient({ port: getPort(server), user: 'bob', ssl: { @@ -185,7 +137,7 @@ describe('tls', () => { }, }); - const promise = connectPg({ + const promise = createPostgresClient({ port: getPort(server), ssl: { ca: Buffer.from(serverCert), @@ -209,7 +161,7 @@ describe('tls', () => { }, }); - await using client = await connectPg({ + await using client = await createPostgresClient({ stream: socketFromDuplexStream(clientDuplex), ssl: { ca: Buffer.from(caCert), diff --git a/packages/pg-gateway/test/util.ts b/packages/pg-gateway/test/util.ts index 3136645..a9f8958 100644 --- a/packages/pg-gateway/test/util.ts +++ b/packages/pg-gateway/test/util.ts @@ -1,6 +1,8 @@ -import EventEmitter from 'node:events'; -import type { DuplexStream } from 'pg-gateway'; -import { Client } from 'pg'; +import EventEmitter, { once } from 'node:events'; +import { Server, createServer } from 'node:net'; +import { Client, type ClientConfig } from 'pg'; +import type { DuplexStream, PostgresConnectionOptions } from 'pg-gateway'; +import { fromNodeSocket } from 'pg-gateway/node'; /** * Creates a passthrough socket object that can be passed @@ -128,3 +130,45 @@ export class DisposablePgClient extends Client { await this.end(); } } + +export async function createPostgresClient(config: string | ClientConfig) { + const client = new DisposablePgClient(config); + await client.connect(); + return client; +} + +export class DisposableServer extends Server { + async [Symbol.asyncDispose]() { + await new Promise((resolve, reject) => { + this.close((err) => { + if (err) { + reject(err); + } else { + resolve(undefined); + } + }); + }); + } +} + +export async function createPostgresServer(options?: PostgresConnectionOptions) { + const server = new DisposableServer((socket) => fromNodeSocket(socket, options)); + // Listen on a random free port + server.listen(0); + await once(server, 'listening'); + return server; +} + +export function getPort(server: Server) { + const address = server.address(); + + if (typeof address !== 'object') { + throw new Error(`Invalid server address '${address}'`); + } + + if (!address) { + throw new Error('Server has no address'); + } + + return address.port; +}