diff --git a/src/defines.ts b/src/defines.ts index 955d04f..02eec84 100644 --- a/src/defines.ts +++ b/src/defines.ts @@ -76,10 +76,20 @@ export type StatementType = export type ExecutionType = 'LISTING' | 'MODIFICATION' | 'INFORMATION' | 'ANON_BLOCK' | 'UNKNOWN'; +export interface ParamTypes { + positional?: boolean; + numbered?: ('?' | ':' | '$')[]; + named?: (':' | '@' | '$')[]; + quoted?: (':' | '@' | '$')[]; + // regex for identifying that it is a param + custom?: string[]; +} + export interface IdentifyOptions { strict?: boolean; dialect?: Dialect; identifyTables?: boolean; + paramTypes?: ParamTypes; } export interface IdentifyResult { diff --git a/src/index.ts b/src/index.ts index c57dafd..f600339 100644 --- a/src/index.ts +++ b/src/index.ts @@ -1,4 +1,4 @@ -import { parse, EXECUTION_TYPES } from './parser'; +import { parse, EXECUTION_TYPES, defaultParamTypesFor } from './parser'; import { DIALECTS } from './defines'; import type { ExecutionType, IdentifyOptions, IdentifyResult, StatementType } from './defines'; @@ -21,7 +21,11 @@ export function identify(query: string, options: IdentifyOptions = {}): Identify throw new Error(`Unknown dialect. Allowed values: ${DIALECTS.join(', ')}`); } - const result = parse(query, isStrict, dialect, options.identifyTables); + // Default parameter types for each dialect + const paramTypes = options.paramTypes || defaultParamTypesFor(dialect); + + const result = parse(query, isStrict, dialect, options.identifyTables, paramTypes); + const sort = dialect === 'psql' && !options.paramTypes; return result.body.map((statement) => { const result: IdentifyResult = { @@ -31,7 +35,7 @@ export function identify(query: string, options: IdentifyOptions = {}): Identify type: statement.type, executionType: statement.executionType, // we want to sort the postgres params: $1 $2 $3, regardless of the order they appear - parameters: dialect === 'psql' ? statement.parameters.sort() : statement.parameters, + parameters: sort ? statement.parameters.sort() : statement.parameters, tables: statement.tables || [], }; return result; diff --git a/src/parser.ts b/src/parser.ts index 3aec638..06150d7 100644 --- a/src/parser.ts +++ b/src/parser.ts @@ -9,6 +9,7 @@ import type { Step, ParseResult, ConcreteStatement, + ParamTypes, } from './defines'; interface StatementParser { @@ -144,6 +145,7 @@ export function parse( isStrict = true, dialect: Dialect = 'generic', identifyTables = false, + paramTypes?: ParamTypes, ): ParseResult { const topLevelState = initState({ input }); const topLevelStatement: ParseResult = { @@ -174,7 +176,7 @@ export function parse( while (prevState.position < topLevelState.end) { const tokenState = initState({ prevState }); - const token = scanToken(tokenState, dialect); + const token = scanToken(tokenState, dialect, paramTypes); const nextToken = nextNonWhitespaceToken(tokenState, dialect); if (!statementParser) { @@ -1013,3 +1015,32 @@ function stateMachineStatementParser( }, }; } + +export function defaultParamTypesFor(dialect: Dialect): ParamTypes { + switch (dialect) { + case 'psql': + return { + numbered: ['$'], + }; + case 'mssql': + return { + named: [':'], + }; + case 'bigquery': + return { + positional: true, + named: ['@'], + quoted: ['@'], + }; + case 'sqlite': + return { + positional: true, + numbered: ['?'], + named: [':', '@'], + }; + default: + return { + positional: true, + }; + } +} diff --git a/src/tokenizer.ts b/src/tokenizer.ts index eccfbc7..1214b72 100644 --- a/src/tokenizer.ts +++ b/src/tokenizer.ts @@ -2,7 +2,7 @@ * Tokenizer */ -import type { Token, State, Dialect } from './defines'; +import type { Token, State, Dialect, ParamTypes } from './defines'; type Char = string | null; @@ -76,7 +76,11 @@ const ENDTOKENS: Record = { '[': ']', }; -export function scanToken(state: State, dialect: Dialect = 'generic'): Token { +export function scanToken( + state: State, + dialect: Dialect = 'generic', + paramTypes: ParamTypes = { positional: true }, +): Token { const ch = read(state); if (isWhitespace(ch)) { @@ -95,8 +99,8 @@ export function scanToken(state: State, dialect: Dialect = 'generic'): Token { return scanString(state, ENDTOKENS[ch]); } - if (isParameter(ch, state, dialect)) { - return scanParameter(state, dialect); + if (isParameter(ch, state, paramTypes)) { + return scanParameter(state, dialect, paramTypes); } if (isDollarQuotedString(state)) { @@ -253,42 +257,82 @@ function scanString(state: State, endToken: Char): Token { }; } -function scanParameter(state: State, dialect: Dialect): Token { - if (['mysql', 'generic', 'sqlite'].includes(dialect)) { - return { - type: 'parameter', - value: state.input.slice(state.start, state.position + 1), - start: state.start, - end: state.start, - }; - } +function getCustomParam(state: State, paramTypes: ParamTypes): string | null | undefined { + const matches = paramTypes?.custom + ?.map((regex) => { + const reg = new RegExp(`^(?:${regex})`, 'u'); + return reg.exec(state.input.slice(state.start)); + }) + .filter((value) => !!value)[0]; - if (dialect === 'psql') { - let nextChar: Char; + return matches ? matches[0] : null; +} - do { - nextChar = read(state); - } while (nextChar !== null && !isNaN(Number(nextChar)) && !isWhitespace(nextChar)); +function scanParameter(state: State, dialect: Dialect, paramTypes: ParamTypes): Token { + const curCh = state.input[state.start]; + const nextChar = peek(state); + let matched = false; + + if (paramTypes.numbered?.length && paramTypes.numbered.some((type) => type === curCh)) { + const endIndex = state.input + .slice(state.start + 1) + .split('') + .findIndex((val) => /^\W+/.test(val)); + const maybeNumbers = state.input.slice( + state.start + 1, + endIndex > 0 ? state.start + endIndex + 1 : state.end + 1, + ); + if (nextChar !== null && !isNaN(Number(nextChar)) && /^\d+$/.test(maybeNumbers)) { + let nextChar: Char = null; + do { + nextChar = read(state); + } while (nextChar !== null && !isNaN(Number(nextChar)) && !isWhitespace(nextChar)); + + if (nextChar !== null) unread(state); + matched = true; + } + } - if (nextChar !== null) unread(state); + if (!matched && paramTypes.named?.length && paramTypes.named.some((type) => type === curCh)) { + if (!isQuotedIdentifier(nextChar, dialect)) { + while (isAlphaNumeric(peek(state))) read(state); + matched = true; + } + } - const value = state.input.slice(state.start, state.position + 1); + if (!matched && paramTypes.quoted?.length && paramTypes.quoted.some((type) => type === curCh)) { + if (isQuotedIdentifier(nextChar, dialect)) { + const quoteChar = read(state) as string; + // end when we reach the end quote + while ( + (isAlphaNumeric(peek(state)) || peek(state) === ' ') && + peek(state) != ENDTOKENS[quoteChar] + ) { + read(state); + } - return { - type: 'parameter', - value, - start: state.start, - end: state.start + value.length - 1, - }; + // read the end quote + read(state); + + matched = true; + } } - if (dialect === 'mssql') { - while (isAlphaNumeric(peek(state))) read(state); + if (!matched && paramTypes.custom && paramTypes.custom.length) { + const custom = getCustomParam(state, paramTypes); + + if (custom) { + read(state, custom.length); + matched = true; + } + } + const value = state.input.slice(state.start, state.position + 1); - const value = state.input.slice(state.start, state.position + 1); + if (!matched && !paramTypes.positional && curCh !== '?') { + // not positional, panic return { - type: 'parameter', - value, + type: 'unknown', + value: value, start: state.start, end: state.start + value.length - 1, }; @@ -296,9 +340,9 @@ function scanParameter(state: State, dialect: Dialect): Token { return { type: 'parameter', - value: 'unknown', + value, start: state.start, - end: state.end, + end: state.start + value.length - 1, }; } @@ -413,18 +457,38 @@ function isString(ch: Char, dialect: Dialect): boolean { return stringStart.includes(ch); } -function isParameter(ch: Char, state: State, dialect: Dialect): boolean { - let pStart = '?'; // ansi standard - sqlite, mysql - if (dialect === 'psql') { - pStart = '$'; - const nextChar = peek(state); - if (nextChar === null || isNaN(Number(nextChar))) { - return false; +function isCustomParam(state: State, customParamType: NonNullable): boolean { + return customParamType.some((regex) => { + const reg = new RegExp(`^(?:${regex})`, 'uy'); + return reg.test(state.input.slice(state.start)); + }); +} + +function isParameter(ch: Char, state: State, paramTypes: ParamTypes): boolean { + if (!ch) { + return false; + } + const nextChar = peek(state); + if (paramTypes.positional && ch === '?') return true; + + if (paramTypes.numbered?.length && paramTypes.numbered.some((type) => ch === type)) { + if (nextChar !== null && !isNaN(Number(nextChar))) { + return true; } } - if (dialect === 'mssql') pStart = ':'; - return ch === pStart; + if ( + (paramTypes.named?.length && paramTypes.named.some((type) => type === ch)) || + (paramTypes.quoted?.length && paramTypes.quoted.some((type) => type === ch)) + ) { + return true; + } + + if (paramTypes.custom?.length && isCustomParam(state, paramTypes.custom)) { + return true; + } + + return false; } function isDollarQuotedString(state: State): boolean { diff --git a/test/index.spec.ts b/test/index.spec.ts index d7cd5e9..d2ef657 100644 --- a/test/index.spec.ts +++ b/test/index.spec.ts @@ -1,5 +1,6 @@ import { Dialect, getExecutionType, identify } from '../src/index'; import { expect } from 'chai'; +import { ParamTypes } from '../src/defines'; describe('identify', () => { it('should throw error for invalid dialect', () => { @@ -22,6 +23,49 @@ describe('identify', () => { ]); }); + it('should identify custom parameters', () => { + const paramTypes: ParamTypes = { + positional: true, + numbered: ['$'], + named: [':'], + quoted: [':'], + custom: ['\\{[a-zA-Z0-9_]+\\}'], + }; + const query = `SELECT * FROM foo WHERE bar = ? AND baz = $1 AND fizz = :fizzz AND buzz = :"buzz buzz" AND foo2 = {fooo}`; + + expect(identify(query, { dialect: 'psql', paramTypes })).to.eql([ + { + start: 0, + end: 104, + text: query, + type: 'SELECT', + executionType: 'LISTING', + parameters: ['?', '$1', ':fizzz', ':"buzz buzz"', '{fooo}'], + tables: [], + }, + ]); + }); + + it('custom params should override defaults for dialect', () => { + const paramTypes: ParamTypes = { + positional: true, + }; + + const query = 'SELECT * FROM foo WHERE bar = $1 AND bar = :named AND fizz = :`quoted`'; + + expect(identify(query, { dialect: 'psql', paramTypes })).to.eql([ + { + start: 0, + end: 69, + text: query, + type: 'SELECT', + executionType: 'LISTING', + parameters: [], + tables: [], + }, + ]); + }); + it('should identify tables in simple for basic cases', () => { expect( identify('SELECT * FROM foo JOIN bar ON foo.id = bar.id', { identifyTables: true }), diff --git a/test/parser/single-statements.spec.ts b/test/parser/single-statements.spec.ts index b0a15cc..4fa5b0b 100644 --- a/test/parser/single-statements.spec.ts +++ b/test/parser/single-statements.spec.ts @@ -1,7 +1,7 @@ import { expect } from 'chai'; import { aggregateUnknownTokens } from '../spec-helper'; -import { parse } from '../../src/parser'; +import { defaultParamTypesFor, parse } from '../../src/parser'; import { Token } from '../../src/defines'; describe('parser', () => { @@ -725,7 +725,13 @@ describe('parser', () => { }); it('should extract PSQL parameters', () => { - const actual = parse('select x from a where x = $1', true, 'psql'); + const actual = parse( + 'select x from a where x = $1', + true, + 'psql', + false, + defaultParamTypesFor('psql'), + ); actual.tokens = aggregateUnknownTokens(actual.tokens); const expected: Token[] = [ { @@ -752,7 +758,13 @@ describe('parser', () => { }); it('should extract multiple PSQL parameters', () => { - const actual = parse('select x from a where x = $1 and y = $2', true, 'psql'); + const actual = parse( + 'select x from a where x = $1 and y = $2', + true, + 'psql', + false, + defaultParamTypesFor('psql'), + ); actual.tokens = aggregateUnknownTokens(actual.tokens); const expected: Token[] = [ { @@ -791,7 +803,13 @@ describe('parser', () => { }); it('should extract mssql parameters', () => { - const actual = parse('select x from a where x = :foo', true, 'mssql'); + const actual = parse( + 'select x from a where x = :foo', + true, + 'mssql', + false, + defaultParamTypesFor('mssql'), + ); actual.tokens = aggregateUnknownTokens(actual.tokens); const expected: Token[] = [ { @@ -856,7 +874,13 @@ describe('parser', () => { }); it('should extract multiple mssql parameters', () => { - const actual = parse('select x from a where x = :foo and y = :bar', true, 'mssql'); + const actual = parse( + 'select x from a where x = :foo and y = :bar', + true, + 'mssql', + false, + defaultParamTypesFor('mssql'), + ); actual.tokens = aggregateUnknownTokens(actual.tokens); const expected: Token[] = [ { diff --git a/test/tokenizer/index.spec.ts b/test/tokenizer/index.spec.ts index 9d16f34..7195735 100644 --- a/test/tokenizer/index.spec.ts +++ b/test/tokenizer/index.spec.ts @@ -1,6 +1,7 @@ import { expect } from 'chai'; import { scanToken } from '../../src/tokenizer'; -import type { Dialect } from '../../src/defines'; +import type { Dialect, ParamTypes, Token } from '../../src/defines'; +import { defaultParamTypesFor } from '../../src/parser'; describe('scan', () => { const initState = (input: string) => ({ @@ -274,7 +275,11 @@ describe('scan', () => { ].forEach(([ch, dialect]) => { it(`scans just ${ch} as parameter for ${dialect}`, () => { const input = `${ch}`; - const actual = scanToken(initState(input), dialect as Dialect); + const actual = scanToken( + initState(input), + dialect as Dialect, + defaultParamTypesFor(dialect as Dialect), + ); const expected = { type: 'parameter', value: input, @@ -286,7 +291,7 @@ describe('scan', () => { }); it('does not scan just $ as parameter for psql', () => { const input = '$'; - const actual = scanToken(initState(input), 'psql'); + const actual = scanToken(initState(input), 'psql', defaultParamTypesFor('psql')); const expected = { type: 'unknown', value: input, @@ -300,11 +305,14 @@ describe('scan', () => { [ ['?', 'generic'], ['?', 'mysql'], - ['?', 'sqlite'], ].forEach(([ch, dialect]) => { it(`should only scan ${ch} from ${ch}1 for ${dialect}`, () => { const input = `${ch}1`; - const actual = scanToken(initState(input), dialect as Dialect); + const actual = scanToken( + initState(input), + dialect as Dialect, + defaultParamTypesFor(dialect as Dialect), + ); const expected = { type: 'parameter', value: ch, @@ -320,7 +328,11 @@ describe('scan', () => { ].forEach(([ch, dialect]) => { it(`should scan ${ch}1 for ${dialect}`, () => { const input = `${ch}1`; - const actual = scanToken(initState(input), dialect as Dialect); + const actual = scanToken( + initState(input), + dialect as Dialect, + defaultParamTypesFor(dialect as Dialect), + ); const expected = { type: 'parameter', value: input, @@ -333,7 +345,7 @@ describe('scan', () => { it('should not scan $a for psql', () => { const input = '$a'; - const actual = scanToken(initState(input), 'psql'); + const actual = scanToken(initState(input), 'psql', defaultParamTypesFor('psql')); const expected = { type: 'unknown', value: '$', @@ -344,7 +356,7 @@ describe('scan', () => { }); it('should not include trailing non-numbers for psql', () => { - const actual = scanToken(initState('$1,'), 'psql'); + const actual = scanToken(initState('$1,'), 'psql', defaultParamTypesFor('psql')); const expected = { type: 'parameter', value: '$1', @@ -355,9 +367,10 @@ describe('scan', () => { }); it('should not include trailing non-alphanumerics for mssql', () => { + const paramTypes = defaultParamTypesFor('mssql'); [ { - actual: scanToken(initState(':one,'), 'mssql'), + actual: scanToken(initState(':one,'), 'mssql', paramTypes), expected: { type: 'parameter', value: ':one', @@ -366,7 +379,7 @@ describe('scan', () => { }, }, { - actual: scanToken(initState(':two)'), 'mssql'), + actual: scanToken(initState(':two)'), 'mssql', paramTypes), expected: { type: 'parameter', value: ':two', @@ -376,6 +389,305 @@ describe('scan', () => { }, ].forEach(({ actual, expected }) => expect(actual).to.eql(expected)); }); + + describe('custom parameters', () => { + describe('positional parameters', () => { + const paramTypes = { + positional: true, + }; + + const expected = [ + { + type: 'parameter', + value: '?', + start: 0, + end: 0, + }, + ]; + + ( + ['mssql', 'psql', 'oracle', 'bigquery', 'sqlite', 'mysql', 'generic'] as Array + ).forEach((dialect) => { + [ + { + actual: scanToken(initState('?'), dialect, paramTypes), + expected: expected[0], + }, + ].forEach(({ actual, expected }) => { + it(`should allow positional parameters for ${dialect}`, () => { + expect(actual).to.eql(expected); + }); + }); + }); + }); + + describe('numeric parameters', () => { + const paramTypes: ParamTypes = { + numbered: ['$', '?', ':'], + }; + + const expected = [ + { + type: 'parameter', + value: '$1', + start: 0, + end: 1, + }, + { + type: 'parameter', + value: '?1', + start: 0, + end: 1, + }, + { + type: 'parameter', + value: ':1', + start: 0, + end: 1, + }, + { + type: 'unknown', + value: '$', + start: 0, + end: 0, + }, + ]; + + ( + ['mssql', 'psql', 'oracle', 'bigquery', 'sqlite', 'mysql', 'generic'] as Array + ).forEach((dialect) => { + [ + { + actual: scanToken(initState('$1'), dialect, paramTypes), + expected: expected[0], + description: '$ numeric', + }, + { + actual: scanToken(initState('?1'), dialect, paramTypes), + expected: expected[1], + description: '? numeric', + }, + { + actual: scanToken(initState(':1'), dialect, paramTypes), + expected: expected[2], + description: ': numeric', + }, + { + actual: scanToken(initState('$123hello'), dialect, paramTypes), // won't recognize + expected: expected[3], + description: 'numeric trailing alpha', + }, + ].forEach(({ actual, expected, description }) => { + it(`should allow numeric parameters for ${dialect} - ${description}`, () => { + expect(actual).to.eql(expected); + }); + }); + }); + }); + + describe('named parameters', () => { + const paramTypes: ParamTypes = { + named: ['$', '@', ':'], + }; + + const expected = [ + { + type: 'parameter', + value: '$namedParam', + start: 0, + end: 10, + }, + { + type: 'parameter', + value: '@namedParam', + start: 0, + end: 10, + }, + { + type: 'parameter', + value: ':namedParam', + start: 0, + end: 10, + }, + { + type: 'parameter', + value: '$123hello', // allow starting with a number + start: 0, + end: 8, + }, + ]; + + ( + ['mssql', 'psql', 'oracle', 'bigquery', 'sqlite', 'mysql', 'generic'] as Array + ).forEach((dialect) => { + [ + { + actual: scanToken(initState('$namedParam'), dialect, paramTypes), + expected: expected[0], + description: '$ named', + }, + { + actual: scanToken(initState('@namedParam'), dialect, paramTypes), + expected: expected[1], + description: '@ named', + }, + { + actual: scanToken(initState(':namedParam'), dialect, paramTypes), + expected: expected[2], + description: ': named', + }, + { + actual: scanToken(initState('$123hello'), dialect, paramTypes), + expected: expected[3], + description: 'named starting with numbers', + }, + ].forEach(({ actual, expected, description }) => { + it(`should allow named parameters for ${dialect} - ${description}`, () => { + expect(actual).to.eql(expected); + }); + }); + }); + }); + + describe('quoted parameters', () => { + const paramTypes: ParamTypes = { + quoted: ['$', '@', ':'], + }; + + const expected = [ + { + type: 'parameter', + value: '$', + start: 0, + end: 14, + }, + { + type: 'parameter', + value: '@', + start: 0, + end: 14, + }, + { + type: 'parameter', + value: ':', + start: 0, + end: 14, + }, + ]; + + ( + [ + { dialect: 'mssql', quotes: ['""', '[]'] }, + { dialect: 'psql', quotes: ['""', '``'] }, + { dialect: 'oracle', quotes: ['""', '``'] }, + { dialect: 'bigquery', quotes: ['""', '``'] }, + { dialect: 'sqlite', quotes: ['""', '``'] }, + { dialect: 'mysql', quotes: ['""', '``'] }, + { dialect: 'generic', quotes: ['""', '``'] }, + ] as Array<{ dialect: Dialect; quotes: Array }> + ).forEach(({ dialect, quotes }) => { + const dialectExpected = Array.prototype.concat.apply( + [], + expected.map((exp) => { + return quotes.map((quote) => { + return { + expected: { + ...exp, + value: `${exp.value}${quote[0]}quoted param${quote[1]}`, + }, + description: `${exp.value} quoted with ${quote[0]}`, + }; + }); + }), + ); + dialectExpected + .map(({ expected, description }) => ({ + actual: scanToken(initState((expected as Token).value), dialect, paramTypes), + expected: expected as Token, + description: description as string, + })) + .forEach(({ actual, expected, description }) => { + it(`should allow quoted parameters for ${dialect} - ${description}`, () => { + expect(actual).to.eql(expected); + }); + }); + }); + }); + + describe('custom parameters', () => { + const paramTypes: ParamTypes = { + custom: ['\\{[a-zA-Z0-9_]+\\}'], + }; + + const expected = [ + { + type: 'parameter', + value: '{namedParam}', + start: 0, + end: 11, + }, + ]; + + ( + ['mssql', 'psql', 'oracle', 'bigquery', 'sqlite', 'mysql', 'generic'] as Array + ).forEach((dialect) => { + it(`should allow custom parameters for ${dialect}`, () => { + expect(scanToken(initState('{namedParam}'), dialect, paramTypes)).to.eql(expected[0]); + }); + }); + }); + + describe('should not have collision between param types', () => { + const paramTypes: ParamTypes = { + positional: true, + numbered: [':'], + named: [':'], + quoted: [':'], + custom: ['\\{[a-zA-Z0-9_]+\\}'], + }; + + const type = ['positional', 'numeric', 'named', 'quoted', 'custom']; + + const expected = [ + { + type: 'parameter', + value: '?', + start: 0, + end: 0, + }, + { + type: 'parameter', + value: ':123', + start: 0, + end: 3, + }, + { + type: 'parameter', + value: ':123hello', + start: 0, + end: 8, + }, + { + type: 'parameter', + value: ':"named param"', + start: 0, + end: 13, + }, + { + type: 'parameter', + value: '{namedParam}', + start: 0, + end: 11, + }, + ]; + + expected.forEach((expected, index) => { + it(`parameter types don't collide, finds ${type[index]}`, () => { + expect(scanToken(initState(expected.value), 'mssql', paramTypes)).to.eql(expected); + }); + }); + }); + }); }); }); });