From 5691efd5ee8ac315175fb5017cbc89fc3816d74d Mon Sep 17 00:00:00 2001 From: Julien Goux Date: Fri, 13 Sep 2024 11:04:26 +0200 Subject: [PATCH 1/2] error story --- packages/pg-gateway/src/auth/cert.ts | 14 +- packages/pg-gateway/src/auth/md5.ts | 8 +- packages/pg-gateway/src/auth/password.ts | 6 +- .../pg-gateway/src/auth/sasl/scram-sha-256.ts | 12 +- packages/pg-gateway/src/backend-error.ts | 205 ++++++++++-------- packages/pg-gateway/src/connection.ts | 99 +++++---- 6 files changed, 196 insertions(+), 148 deletions(-) diff --git a/packages/pg-gateway/src/auth/cert.ts b/packages/pg-gateway/src/auth/cert.ts index cb9413c..94e3528 100644 --- a/packages/pg-gateway/src/auth/cert.ts +++ b/packages/pg-gateway/src/auth/cert.ts @@ -1,4 +1,4 @@ -import { createBackendErrorMessage } from '../backend-error.js'; +import { BackendError } from '../backend-error.js'; import type { BufferReader } from '../buffer-reader.js'; import type { BufferWriter } from '../buffer-writer.js'; import type { ConnectionState } from '../connection.types'; @@ -45,21 +45,21 @@ export class CertAuthFlow extends BaseAuthFlow { async *handleClientMessage(message: BufferSource) { if (!this.connectionState.tlsInfo) { - yield createBackendErrorMessage({ + yield BackendError.create({ severity: 'FATAL', code: '08000', message: `ssl connection required when auth mode is 'certificate'`, - }); + }).flush(); yield closeSignal; return; } if (!this.connectionState.tlsInfo.clientCertificate) { - yield createBackendErrorMessage({ + yield BackendError.create({ severity: 'FATAL', code: '08000', message: 'client certificate required', - }); + }).flush(); yield closeSignal; return; } @@ -73,11 +73,11 @@ export class CertAuthFlow extends BaseAuthFlow { ); if (!isValid) { - yield createBackendErrorMessage({ + yield BackendError.create({ severity: 'FATAL', code: '08000', message: 'client certificate is invalid', - }); + }).flush(); yield closeSignal; return; } diff --git a/packages/pg-gateway/src/auth/md5.ts b/packages/pg-gateway/src/auth/md5.ts index fdb7568..fe3f558 100644 --- a/packages/pg-gateway/src/auth/md5.ts +++ b/packages/pg-gateway/src/auth/md5.ts @@ -1,13 +1,13 @@ import { concat } from '@std/bytes/concat'; import { crypto } from '@std/crypto'; import { encodeHex } from '@std/encoding/hex'; -import { createBackendErrorMessage } from '../backend-error.js'; +import { BackendError } from '../backend-error.js'; import type { BufferReader } from '../buffer-reader.js'; import type { BufferWriter } from '../buffer-writer.js'; import type { ConnectionState } from '../connection.types'; import { BackendMessageCode } from '../message-codes'; -import { BaseAuthFlow } from './base-auth-flow'; import { closeSignal } from '../signals.js'; +import { BaseAuthFlow } from './base-auth-flow'; export type Md5AuthOptions = { method: 'md5'; @@ -76,11 +76,11 @@ export class Md5AuthFlow extends BaseAuthFlow { ); if (!isValid) { - yield createBackendErrorMessage({ + yield BackendError.create({ severity: 'FATAL', code: '28P01', message: `password authentication failed for user "${this.username}"`, - }); + }).flush(); yield closeSignal; return; } diff --git a/packages/pg-gateway/src/auth/password.ts b/packages/pg-gateway/src/auth/password.ts index 2714a95..0344ab6 100644 --- a/packages/pg-gateway/src/auth/password.ts +++ b/packages/pg-gateway/src/auth/password.ts @@ -1,4 +1,4 @@ -import { createBackendErrorMessage } from '../backend-error.js'; +import { BackendError } from '../backend-error.js'; import type { BufferReader } from '../buffer-reader.js'; import type { BufferWriter } from '../buffer-writer.js'; import type { ConnectionState } from '../connection.types'; @@ -72,11 +72,11 @@ export class PasswordAuthFlow extends BaseAuthFlow { ); if (!isValid) { - yield createBackendErrorMessage({ + yield BackendError.create({ severity: 'FATAL', code: '28P01', message: `password authentication failed for user "${this.username}"`, - }); + }).flush(); yield closeSignal; return; } diff --git a/packages/pg-gateway/src/auth/sasl/scram-sha-256.ts b/packages/pg-gateway/src/auth/sasl/scram-sha-256.ts index 46c19b9..04e186e 100644 --- a/packages/pg-gateway/src/auth/sasl/scram-sha-256.ts +++ b/packages/pg-gateway/src/auth/sasl/scram-sha-256.ts @@ -1,12 +1,12 @@ import { decodeBase64, encodeBase64 } from '@std/encoding/base64'; -import { createBackendErrorMessage } from '../../backend-error.js'; +import { BackendError } from '../../backend-error.js'; import type { BufferReader } from '../../buffer-reader.js'; import type { BufferWriter } from '../../buffer-writer.js'; import type { ConnectionState } from '../../connection.types'; import { createHashKey, createHmacKey, pbkdf2, timingSafeEqual } from '../../crypto.js'; +import { closeSignal } from '../../signals.js'; import type { AuthFlow } from '../base-auth-flow'; import { SaslMechanism } from './sasl-mechanism'; -import { closeSignal } from '../../signals.js'; export type ScramSha256Data = { salt: string; @@ -163,11 +163,11 @@ export class ScramSha256AuthFlow extends SaslMechanism implements AuthFlow { const saslMechanism = this.reader.cstring(); if (saslMechanism !== 'SCRAM-SHA-256') { - yield createBackendErrorMessage({ + yield BackendError.create({ severity: 'FATAL', code: '28000', message: 'Unsupported SASL authentication mechanism', - }); + }).flush(); yield closeSignal; return; } @@ -207,11 +207,11 @@ export class ScramSha256AuthFlow extends SaslMechanism implements AuthFlow { this.step = ScramSha256Step.Completed; yield this.createAuthenticationSASLFinal(serverFinalMessage); } catch (error) { - yield createBackendErrorMessage({ + yield BackendError.create({ severity: 'FATAL', code: '28000', message: (error as Error).message, - }); + }).flush(); yield closeSignal; return; } diff --git a/packages/pg-gateway/src/backend-error.ts b/packages/pg-gateway/src/backend-error.ts index 83469ff..2a2207e 100644 --- a/packages/pg-gateway/src/backend-error.ts +++ b/packages/pg-gateway/src/backend-error.ts @@ -1,7 +1,7 @@ import { BufferWriter } from './buffer-writer.js'; import { BackendMessageCode } from './message-codes.js'; -export interface BackendError { +interface BackendErrorParams { severity: 'ERROR' | 'FATAL' | 'PANIC'; code: string; message: string; @@ -22,99 +22,134 @@ export interface BackendError { } /** - * Creates a backend error message + * Represents a backend error message * * @see https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-ERRORRESPONSE * * For error fields, @see https://www.postgresql.org/docs/current/protocol-error-fields.html#PROTOCOL-ERROR-FIELDS */ -export function createBackendErrorMessage(error: BackendError) { - const writer = new BufferWriter(); - - writer.addString('S'); - writer.addCString(error.severity); - - writer.addString('V'); - writer.addCString(error.severity); - - writer.addString('C'); - writer.addCString(error.code); - - writer.addString('M'); - writer.addCString(error.message); - - if (error.detail !== undefined) { - writer.addString('D'); - writer.addCString(error.detail); - } - - if (error.hint !== undefined) { - writer.addString('H'); - writer.addCString(error.hint); - } - - if (error.position !== undefined) { - writer.addString('P'); - writer.addCString(error.position); - } - - if (error.internalPosition !== undefined) { - writer.addString('p'); - writer.addCString(error.internalPosition); - } - - if (error.internalQuery !== undefined) { - writer.addString('q'); - writer.addCString(error.internalQuery); - } - - if (error.where !== undefined) { - writer.addString('W'); - writer.addCString(error.where); - } - - if (error.schema !== undefined) { - writer.addString('s'); - writer.addCString(error.schema); - } - - if (error.table !== undefined) { - writer.addString('t'); - writer.addCString(error.table); - } - - if (error.column !== undefined) { - writer.addString('c'); - writer.addCString(error.column); - } - - if (error.dataType !== undefined) { - writer.addString('d'); - writer.addCString(error.dataType); - } - - if (error.constraint !== undefined) { - writer.addString('n'); - writer.addCString(error.constraint); - } +export class BackendError { + severity!: 'ERROR' | 'FATAL' | 'PANIC'; + code!: string; + message!: string; + detail?: string; + hint?: string; + position?: string; + internalPosition?: string; + internalQuery?: string; + where?: string; + schema?: string; + table?: string; + column?: string; + dataType?: string; + constraint?: string; + file?: string; + line?: string; + routine?: string; - if (error.file !== undefined) { - writer.addString('F'); - writer.addCString(error.file); + constructor(params: BackendErrorParams) { + Object.assign(this, params); } - if (error.line !== undefined) { - writer.addString('L'); - writer.addCString(error.line); + /** + * Creates a backend error message + * + * @see https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-ERRORRESPONSE + * + * For error fields, @see https://www.postgresql.org/docs/current/protocol-error-fields.html#PROTOCOL-ERROR-FIELDS + */ + static create(params: BackendErrorParams) { + return new BackendError(params); } - if (error.routine !== undefined) { - writer.addString('R'); - writer.addCString(error.routine); + flush() { + const writer = new BufferWriter(); + + writer.addString('S'); + writer.addCString(this.severity); + + writer.addString('V'); + writer.addCString(this.severity); + + writer.addString('C'); + writer.addCString(this.code); + + writer.addString('M'); + writer.addCString(this.message); + + if (this.detail !== undefined) { + writer.addString('D'); + writer.addCString(this.detail); + } + + if (this.hint !== undefined) { + writer.addString('H'); + writer.addCString(this.hint); + } + + if (this.position !== undefined) { + writer.addString('P'); + writer.addCString(this.position); + } + + if (this.internalPosition !== undefined) { + writer.addString('p'); + writer.addCString(this.internalPosition); + } + + if (this.internalQuery !== undefined) { + writer.addString('q'); + writer.addCString(this.internalQuery); + } + + if (this.where !== undefined) { + writer.addString('W'); + writer.addCString(this.where); + } + + if (this.schema !== undefined) { + writer.addString('s'); + writer.addCString(this.schema); + } + + if (this.table !== undefined) { + writer.addString('t'); + writer.addCString(this.table); + } + + if (this.column !== undefined) { + writer.addString('c'); + writer.addCString(this.column); + } + + if (this.dataType !== undefined) { + writer.addString('d'); + writer.addCString(this.dataType); + } + + if (this.constraint !== undefined) { + writer.addString('n'); + writer.addCString(this.constraint); + } + + if (this.file !== undefined) { + writer.addString('F'); + writer.addCString(this.file); + } + + if (this.line !== undefined) { + writer.addString('L'); + writer.addCString(this.line); + } + + if (this.routine !== undefined) { + writer.addString('R'); + writer.addCString(this.routine); + } + + // Add null byte to the end + writer.addCString(''); + + return writer.flush(BackendMessageCode.ErrorMessage); } - - // Add null byte to the end - writer.addCString(''); - - return writer.flush(BackendMessageCode.ErrorMessage); } diff --git a/packages/pg-gateway/src/connection.ts b/packages/pg-gateway/src/connection.ts index bc4e217..7443bfb 100644 --- a/packages/pg-gateway/src/connection.ts +++ b/packages/pg-gateway/src/connection.ts @@ -1,6 +1,6 @@ import type { AuthFlow } from './auth/base-auth-flow.js'; import { type AuthOptions, createAuthFlow } from './auth/index.js'; -import { createBackendErrorMessage } from './backend-error.js'; +import { BackendError } from './backend-error.js'; import { BufferReader } from './buffer-reader.js'; import { BufferWriter } from './buffer-writer.js'; import { @@ -12,7 +12,7 @@ import { import { AsyncIterableWithMetadata } from './iterable-util.js'; import { MessageBuffer } from './message-buffer.js'; import { BackendMessageCode, FrontendMessageCode } from './message-codes.js'; -import { tlsUpgradeSignal, closeSignal, type ConnectionSignal } from './signals.js'; +import { type ConnectionSignal, closeSignal, tlsUpgradeSignal } from './signals.js'; import { type DuplexStream, toAsyncIterator } from './streams.js'; export type TlsOptions = { @@ -137,11 +137,13 @@ export default class PostgresConnection { hasStarted = false; isAuthenticated = false; detached = false; - writer = new BufferWriter(); - reader = new BufferReader(); + bufferWriter = new BufferWriter(); + bufferReader = new BufferReader(); clientParams?: ClientParameters; tlsInfo?: TlsInfo; messageBuffer = new MessageBuffer(); + // reference to the stream writer when processing data + streamWriter?: WritableStreamDefaultWriter; constructor( public duplex: DuplexStream, @@ -199,6 +201,7 @@ export default class PostgresConnection { this.messageBuffer = new MessageBuffer(); await this.options.onTlsUpgrade?.(this.state); + if (this.detached) { return; } @@ -212,8 +215,18 @@ export default class PostgresConnection { return; } } catch (err) { - await this.duplex.writable.abort(); - console.error(err); + if (err instanceof BackendError) { + const writer = this.duplex.writable.getWriter(); + await writer.write(err.flush()); + writer.releaseLock(); + await this.duplex.writable.close(); + } else { + // ignore ABORT_ERR errors which are common, like a user closing its terminal while running a psql session + if (!(err instanceof Error && 'code' in err && err.code === 'ABORT_ERR')) { + console.error(err); + } + await this.duplex.writable.abort(); + } } } /** @@ -234,7 +247,7 @@ export default class PostgresConnection { } async processData(duplex: DuplexStream): Promise { - const writer = duplex.writable.getWriter(); + this.streamWriter = duplex.writable.getWriter(); try { for await (const data of toAsyncIterator(duplex.readable, { preventCancel: true })) { this.messageBuffer.mergeBuffer(data); @@ -249,19 +262,19 @@ export default class PostgresConnection { if (responseMessage === closeSignal) { return closeSignal; } - await writer.write(responseMessage); + await this.streamWriter.write(responseMessage); } } } } finally { - writer.releaseLock(); + this.streamWriter.releaseLock(); } } async *handleClientMessage( message: Uint8Array, ): AsyncGenerator { - this.reader.setBuffer(message); + this.bufferReader.setBuffer(message); const messageResponse = await this.options.onMessage?.(message, this.state); @@ -307,11 +320,11 @@ export default class PostgresConnection { } else if (this.isStartupMessage(message)) { // Guard against SSL connection not being established when `tls` is enabled if (this.options.tls && !this.tlsInfo) { - yield createBackendErrorMessage({ + yield BackendError.create({ severity: 'FATAL', code: '08P01', message: 'SSL connection is required', - }); + }).flush(); yield closeSignal; return; } @@ -341,14 +354,14 @@ export default class PostgresConnection { async *handleSslRequest() { if (!this.options.tls || !this.adapters.upgradeTls) { - this.writer.addString('N'); - yield this.writer.flush(); + this.bufferWriter.addString('N'); + yield this.bufferWriter.flush(); return; } // Otherwise respond with 'S' to indicate it is supported - this.writer.addString('S'); - yield this.writer.flush(); + this.bufferWriter.addString('S'); + yield this.bufferWriter.flush(); // From now on the frontend will communicate via TLS, so upgrade the connection yield tlsUpgradeSignal; @@ -359,21 +372,21 @@ export default class PostgresConnection { // user is required if (!parameters.user) { - yield createBackendErrorMessage({ + yield BackendError.create({ severity: 'FATAL', code: '08000', message: 'user is required', - }); + }).flush(); yield closeSignal; return; } if (majorVersion !== 3 || minorVersion !== 0) { - yield createBackendErrorMessage({ + yield BackendError.create({ severity: 'FATAL', code: '08000', message: `Unsupported protocol version ${majorVersion.toString()}.${minorVersion.toString()}`, - }); + }).flush(); yield closeSignal; return; } @@ -397,8 +410,8 @@ export default class PostgresConnection { } this.authFlow = createAuthFlow({ - reader: this.reader, - writer: this.writer, + reader: this.bufferReader, + writer: this.bufferWriter, username: this.clientParams.user, auth: this.options.auth, connectionState: this.state, @@ -422,7 +435,7 @@ export default class PostgresConnection { } async *handleAuthenticationMessage(message: BufferSource) { - const code = this.reader.byte(); + const code = this.bufferReader.byte(); if (code !== FrontendMessageCode.Password) { throw new Error(`Unexpected authentication message code: ${code}`); @@ -438,18 +451,18 @@ export default class PostgresConnection { } private async *handleRegularMessage(message: BufferSource) { - const code = this.reader.byte(); + const code = this.bufferReader.byte(); switch (code) { case FrontendMessageCode.Terminate: yield closeSignal; return; default: - yield createBackendErrorMessage({ + yield BackendError.create({ severity: 'ERROR', code: '123', message: 'Message code not yet implemented', - }); + }).flush(); yield this.createReadyForQuery(); } } @@ -518,15 +531,15 @@ export default class PostgresConnection { * @see https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-STARTUPMESSAGE */ readStartupMessage() { - const length = this.reader.int32(); - const majorVersion = this.reader.int16(); - const minorVersion = this.reader.int16(); + const length = this.bufferReader.int32(); + const majorVersion = this.bufferReader.int16(); + const minorVersion = this.bufferReader.int16(); const parameters: Record = {}; // biome-ignore lint/suspicious/noAssignInExpressions: - for (let key: string; (key = this.reader.cstring()) !== ''; ) { - parameters[key] = this.reader.cstring(); + for (let key: string; (key = this.bufferReader.cstring()) !== ''; ) { + parameters[key] = this.bufferReader.cstring(); } return { @@ -542,7 +555,7 @@ export default class PostgresConnection { * @see https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-QUERY */ readQuery() { - const query = this.reader.cstring(); + const query = this.bufferReader.cstring(); return { query, @@ -555,8 +568,8 @@ export default class PostgresConnection { * @see https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-AUTHENTICATIONOK */ createAuthenticationOk() { - this.writer.addInt32(0); - return this.writer.flush(BackendMessageCode.AuthenticationResponse); + this.bufferWriter.addInt32(0); + return this.bufferWriter.flush(BackendMessageCode.AuthenticationResponse); } /** @@ -567,9 +580,9 @@ export default class PostgresConnection { * @see https://www.postgresql.org/docs/current/protocol-flow.html#PROTOCOL-ASYNC */ createParameterStatus(name: string, value: string) { - this.writer.addCString(name); - this.writer.addCString(value); - return this.writer.flush(BackendMessageCode.ParameterStatus); + this.bufferWriter.addCString(name); + this.bufferWriter.addCString(value); + return this.bufferWriter.flush(BackendMessageCode.ParameterStatus); } /** @@ -580,28 +593,28 @@ export default class PostgresConnection { createReadyForQuery(transactionStatus: 'idle' | 'transaction' | 'error' = 'idle') { switch (transactionStatus) { case 'idle': - this.writer.addString('I'); + this.bufferWriter.addString('I'); break; case 'transaction': - this.writer.addString('T'); + this.bufferWriter.addString('T'); break; case 'error': - this.writer.addString('E'); + this.bufferWriter.addString('E'); break; default: throw new Error(`Unknown transaction status '${transactionStatus}'`); } - return this.writer.flush(BackendMessageCode.ReadyForQuery); + return this.bufferWriter.flush(BackendMessageCode.ReadyForQuery); } createAuthenticationFailedError() { - return createBackendErrorMessage({ + return BackendError.create({ severity: 'FATAL', code: '28P01', message: this.clientParams?.user ? `password authentication failed for user "${this.clientParams.user}"` : 'password authentication failed', - }); + }).flush(); } } From 4f7fc2421e30e1d495e4e88b9fc714520c547a23 Mon Sep 17 00:00:00 2001 From: Julien Goux Date: Fri, 13 Sep 2024 12:10:35 +0200 Subject: [PATCH 2/2] tests --- package-lock.json | 8 +- packages/pg-gateway/package.json | 2 +- packages/pg-gateway/test/node/certs.ts | 21 ++++ packages/pg-gateway/test/node/errors.test.ts | 110 +++++++++++++++++++ packages/pg-gateway/test/node/tls.test.ts | 80 +++----------- packages/pg-gateway/test/util.ts | 50 ++++++++- 6 files changed, 199 insertions(+), 72 deletions(-) create mode 100644 packages/pg-gateway/test/node/errors.test.ts diff --git a/package-lock.json b/package-lock.json index a612bb3..d30d101 100644 --- a/package-lock.json +++ b/package-lock.json @@ -373,9 +373,9 @@ } }, "node_modules/@electric-sql/pglite": { - "version": "0.2.6", - "resolved": "https://registry.npmjs.org/@electric-sql/pglite/-/pglite-0.2.6.tgz", - "integrity": "sha512-tyWWxj1Z1Pd4BqBZL1ER2SXaCn5s9N0bxTQCAkGaaWe8r9EEe1bNs20RAG3/+ZeBJtDrk8y5xjocactL+4aIXg==", + "version": "0.2.7", + "resolved": "https://registry.npmjs.org/@electric-sql/pglite/-/pglite-0.2.7.tgz", + "integrity": "sha512-8Il//XHTAtZ8VeQF+6P1UjsIoaAJyO4LwOMoXhSFaHpmkwKs63cUhHHNzLzUmcZvP/ZTmlT3+xTiWfU/EyoxwQ==", "license": "Apache-2.0" }, "node_modules/@esbuild/aix-ppc64": { @@ -4992,7 +4992,7 @@ "license": "MIT", "devDependencies": { "@biomejs/biome": "^1.8.3", - "@electric-sql/pglite": "^0.2.6", + "@electric-sql/pglite": "^0.2.7", "@nodeweb/knex": "^3.1.0-alpha.13", "@nodeweb/pg": "^8.12.0-alpha.5", "@std/bytes": "npm:@jsr/std__bytes@^1.0.2", diff --git a/packages/pg-gateway/package.json b/packages/pg-gateway/package.json index 38d6f9b..07cd2cc 100644 --- a/packages/pg-gateway/package.json +++ b/packages/pg-gateway/package.json @@ -42,7 +42,7 @@ }, "devDependencies": { "@biomejs/biome": "^1.8.3", - "@electric-sql/pglite": "^0.2.6", + "@electric-sql/pglite": "^0.2.7", "@nodeweb/knex": "^3.1.0-alpha.13", "@nodeweb/pg": "^8.12.0-alpha.5", "@std/bytes": "npm:@jsr/std__bytes@^1.0.2", diff --git a/packages/pg-gateway/test/node/certs.ts b/packages/pg-gateway/test/node/certs.ts index eb9c237..3ef096b 100644 --- a/packages/pg-gateway/test/node/certs.ts +++ b/packages/pg-gateway/test/node/certs.ts @@ -145,3 +145,24 @@ export async function signCert(caCert: ArrayBuffer, caKey: ArrayBuffer, csr: Arr return certBytes; } + +export async function generateAllCertificates() { + const { caKey, caCert } = await generateCA('My Root CA'); + + const { key: serverKey, csr: serverCsr } = await generateCSR('localhost'); + const serverCert = await signCert(caCert, caKey, serverCsr); + + const { key: clientKey, csr: clientCsr } = await generateCSR('postgres'); + const clientCert = await signCert(caCert, caKey, clientCsr); + + const encoder = new TextEncoder(); + + return { + caKey: encoder.encode(toPEM(caKey, 'PRIVATE KEY')), + caCert: encoder.encode(toPEM(caCert, 'CERTIFICATE')), + serverKey: encoder.encode(toPEM(serverKey, 'PRIVATE KEY')), + serverCert: encoder.encode(toPEM(serverCert, 'CERTIFICATE')), + clientKey: encoder.encode(toPEM(clientKey, 'PRIVATE KEY')), + clientCert: encoder.encode(toPEM(clientCert, 'CERTIFICATE')), + }; +} diff --git a/packages/pg-gateway/test/node/errors.test.ts b/packages/pg-gateway/test/node/errors.test.ts new file mode 100644 index 0000000..cb85611 --- /dev/null +++ b/packages/pg-gateway/test/node/errors.test.ts @@ -0,0 +1,110 @@ +import { BackendError } from 'pg-gateway'; +import { describe, expect, it, vi } from 'vitest'; +import { createPostgresClient, createPostgresServer, getPort } from '../util'; +import { generateAllCertificates } from './certs'; + +describe('errors', () => { + it('sends backend error thrown in onTlsUpgrade to the client', async () => { + const { caCert, serverKey, serverCert } = await generateAllCertificates(); + await using server = await createPostgresServer({ + tls: { + cert: serverCert, + key: serverKey, + }, + async onTlsUpgrade() { + throw BackendError.create({ + message: 'onTlsUpgrade failed', + code: 'P0000', + severity: 'FATAL', + }); + }, + }); + const promise = createPostgresClient({ + port: getPort(server), + ssl: { + ca: Buffer.from(caCert), + }, + }); + await expect(promise).rejects.toThrow('onTlsUpgrade failed'); + }); + + it('sends backend error thrown in onAuthenticated to the client', async () => { + await using server = await createPostgresServer({ + async onAuthenticated() { + throw BackendError.create({ + message: 'onAuthenticated failed', + code: 'P0000', + severity: 'FATAL', + }); + }, + }); + const promise = createPostgresClient({ + port: getPort(server), + }); + await expect(promise).rejects.toThrow('onAuthenticated failed'); + }); + + it('sends backend error thrown in onStartup to the client', async () => { + await using server = await createPostgresServer({ + async onStartup() { + throw BackendError.create({ + message: 'onStartup failed', + code: 'P0000', + severity: 'FATAL', + }); + }, + }); + const promise = createPostgresClient({ + port: getPort(server), + }); + await expect(promise).rejects.toThrow('onStartup failed'); + }); + + it('sends backend error thrown in onMessage to the client', async () => { + await using server = await createPostgresServer({ + async onMessage() { + throw BackendError.create({ + message: 'onMessage failed', + code: 'P0000', + severity: 'FATAL', + }); + }, + }); + const promise = createPostgresClient({ + port: getPort(server), + }); + await expect(promise).rejects.toThrow('onMessage failed'); + }); + + const mockOutput = () => { + const output = { + stderr: '', + [Symbol.dispose]() { + consoleErrorMock.mockRestore(); + }, + }; + const consoleErrorMock = vi.spyOn(console, 'error').mockImplementation((...args) => { + output.stderr += args.join(' '); + }); + return output; + }; + + it('does not send non backend errors to the client', async () => { + using output = mockOutput(); + await using server = await createPostgresServer({ + async onMessage() { + throw Error('wat?'); + }, + }); + const promise = createPostgresClient({ + port: getPort(server), + }); + try { + await promise; + } catch (error) { + expect(error.message).not.toContain('wat?'); + expect(output.stderr).toContain('wat?'); + expect(error.message).toContain('Connection terminated unexpectedly'); + } + }); +}); diff --git a/packages/pg-gateway/test/node/tls.test.ts b/packages/pg-gateway/test/node/tls.test.ts index 57ed9c8..bd9e11b 100644 --- a/packages/pg-gateway/test/node/tls.test.ts +++ b/packages/pg-gateway/test/node/tls.test.ts @@ -1,64 +1,16 @@ -import { createServer, type Server } from 'node:net'; -import type { ClientConfig } from 'pg'; -import { type PostgresConnectionOptions, createDuplexPair } from 'pg-gateway'; -import { fromDuplexStream, fromNodeSocket } from 'pg-gateway/node'; +import { createDuplexPair } from 'pg-gateway'; +import { fromDuplexStream } from 'pg-gateway/node'; import { describe, expect, it } from 'vitest'; -import { DisposablePgClient, socketFromDuplexStream } from '../util'; -import { generateCA, generateCSR, signCert, toPEM } from './certs'; -import { once } from 'node:events'; - -async function generateAllCertificates() { - const { caKey, caCert } = await generateCA('My Root CA'); - - const { key: serverKey, csr: serverCsr } = await generateCSR('localhost'); - const serverCert = await signCert(caCert, caKey, serverCsr); - - const { key: clientKey, csr: clientCsr } = await generateCSR('postgres'); - const clientCert = await signCert(caCert, caKey, clientCsr); - - const encoder = new TextEncoder(); - - return { - caKey: encoder.encode(toPEM(caKey, 'PRIVATE KEY')), - caCert: encoder.encode(toPEM(caCert, 'CERTIFICATE')), - serverKey: encoder.encode(toPEM(serverKey, 'PRIVATE KEY')), - serverCert: encoder.encode(toPEM(serverCert, 'CERTIFICATE')), - clientKey: encoder.encode(toPEM(clientKey, 'PRIVATE KEY')), - clientCert: encoder.encode(toPEM(clientCert, 'CERTIFICATE')), - }; -} +import { + createPostgresClient, + createPostgresServer, + getPort, + socketFromDuplexStream, +} from '../util'; +import { generateAllCertificates } from './certs'; const { caCert, serverKey, serverCert, clientKey, clientCert } = await generateAllCertificates(); -async function createPostgresServer(options?: PostgresConnectionOptions) { - const server = createServer((socket) => fromNodeSocket(socket, options)); - - // Listen on a random free port - server.listen(0); - await once(server, 'listening'); - return server; -} - -function getPort(server: Server) { - const address = server.address(); - - if (typeof address !== 'object') { - throw new Error(`Invalid server address '${address}'`); - } - - if (!address) { - throw new Error('Server has no address'); - } - - return address.port; -} - -async function connectPg(config: string | ClientConfig) { - const client = new DisposablePgClient(config); - await client.connect(); - return client; -} - describe('tls', () => { it('basic tls over tcp', async () => { await using server = await createPostgresServer({ @@ -71,7 +23,7 @@ describe('tls', () => { }, }); - await using client = await connectPg({ + await using client = await createPostgresClient({ port: getPort(server), ssl: { ca: Buffer.from(caCert), @@ -93,7 +45,7 @@ describe('tls', () => { }, }); - await using client = await connectPg({ + await using client = await createPostgresClient({ host: 'localhost', port: getPort(server), ssl: { @@ -116,7 +68,7 @@ describe('tls', () => { }, }); - await using client = await connectPg({ + await using client = await createPostgresClient({ host: '127.0.0.1', port: getPort(server), ssl: { @@ -137,7 +89,7 @@ describe('tls', () => { }, }); - await using client = await connectPg({ + await using client = await createPostgresClient({ port: getPort(server), user: 'postgres', ssl: { @@ -160,7 +112,7 @@ describe('tls', () => { }, }); - const promise = connectPg({ + const promise = createPostgresClient({ port: getPort(server), user: 'bob', ssl: { @@ -185,7 +137,7 @@ describe('tls', () => { }, }); - const promise = connectPg({ + const promise = createPostgresClient({ port: getPort(server), ssl: { ca: Buffer.from(serverCert), @@ -209,7 +161,7 @@ describe('tls', () => { }, }); - await using client = await connectPg({ + await using client = await createPostgresClient({ stream: socketFromDuplexStream(clientDuplex), ssl: { ca: Buffer.from(caCert), diff --git a/packages/pg-gateway/test/util.ts b/packages/pg-gateway/test/util.ts index 3136645..a9f8958 100644 --- a/packages/pg-gateway/test/util.ts +++ b/packages/pg-gateway/test/util.ts @@ -1,6 +1,8 @@ -import EventEmitter from 'node:events'; -import type { DuplexStream } from 'pg-gateway'; -import { Client } from 'pg'; +import EventEmitter, { once } from 'node:events'; +import { Server, createServer } from 'node:net'; +import { Client, type ClientConfig } from 'pg'; +import type { DuplexStream, PostgresConnectionOptions } from 'pg-gateway'; +import { fromNodeSocket } from 'pg-gateway/node'; /** * Creates a passthrough socket object that can be passed @@ -128,3 +130,45 @@ export class DisposablePgClient extends Client { await this.end(); } } + +export async function createPostgresClient(config: string | ClientConfig) { + const client = new DisposablePgClient(config); + await client.connect(); + return client; +} + +export class DisposableServer extends Server { + async [Symbol.asyncDispose]() { + await new Promise((resolve, reject) => { + this.close((err) => { + if (err) { + reject(err); + } else { + resolve(undefined); + } + }); + }); + } +} + +export async function createPostgresServer(options?: PostgresConnectionOptions) { + const server = new DisposableServer((socket) => fromNodeSocket(socket, options)); + // Listen on a random free port + server.listen(0); + await once(server, 'listening'); + return server; +} + +export function getPort(server: Server) { + const address = server.address(); + + if (typeof address !== 'object') { + throw new Error(`Invalid server address '${address}'`); + } + + if (!address) { + throw new Error('Server has no address'); + } + + return address.port; +}