Skip to content

Commit

Permalink
Config option to allow custom parameter syntax (#78)
Browse files Browse the repository at this point in the history
Co-authored-by: Matthew Peveler <mpeveler@timescale.com>
  • Loading branch information
not-night-but and MasterOdin authored Apr 30, 2024
1 parent c17b973 commit e15e335
Show file tree
Hide file tree
Showing 7 changed files with 550 additions and 61 deletions.
10 changes: 10 additions & 0 deletions src/defines.ts
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,20 @@ export type StatementType =

export type ExecutionType = 'LISTING' | 'MODIFICATION' | 'INFORMATION' | 'ANON_BLOCK' | 'UNKNOWN';

export interface ParamTypes {
positional?: boolean;
numbered?: ('?' | ':' | '$')[];
named?: (':' | '@' | '$')[];
quoted?: (':' | '@' | '$')[];
// regex for identifying that it is a param
custom?: string[];
}

export interface IdentifyOptions {
strict?: boolean;
dialect?: Dialect;
identifyTables?: boolean;
paramTypes?: ParamTypes;
}

export interface IdentifyResult {
Expand Down
10 changes: 7 additions & 3 deletions src/index.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { parse, EXECUTION_TYPES } from './parser';
import { parse, EXECUTION_TYPES, defaultParamTypesFor } from './parser';
import { DIALECTS } from './defines';
import type { ExecutionType, IdentifyOptions, IdentifyResult, StatementType } from './defines';

Expand All @@ -21,7 +21,11 @@ export function identify(query: string, options: IdentifyOptions = {}): Identify
throw new Error(`Unknown dialect. Allowed values: ${DIALECTS.join(', ')}`);
}

const result = parse(query, isStrict, dialect, options.identifyTables);
// Default parameter types for each dialect
const paramTypes = options.paramTypes || defaultParamTypesFor(dialect);

const result = parse(query, isStrict, dialect, options.identifyTables, paramTypes);
const sort = dialect === 'psql' && !options.paramTypes;

return result.body.map((statement) => {
const result: IdentifyResult = {
Expand All @@ -31,7 +35,7 @@ export function identify(query: string, options: IdentifyOptions = {}): Identify
type: statement.type,
executionType: statement.executionType,
// we want to sort the postgres params: $1 $2 $3, regardless of the order they appear
parameters: dialect === 'psql' ? statement.parameters.sort() : statement.parameters,
parameters: sort ? statement.parameters.sort() : statement.parameters,
tables: statement.tables || [],
};
return result;
Expand Down
33 changes: 32 additions & 1 deletion src/parser.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import type {
Step,
ParseResult,
ConcreteStatement,
ParamTypes,
} from './defines';

interface StatementParser {
Expand Down Expand Up @@ -144,6 +145,7 @@ export function parse(
isStrict = true,
dialect: Dialect = 'generic',
identifyTables = false,
paramTypes?: ParamTypes,
): ParseResult {
const topLevelState = initState({ input });
const topLevelStatement: ParseResult = {
Expand Down Expand Up @@ -174,7 +176,7 @@ export function parse(

while (prevState.position < topLevelState.end) {
const tokenState = initState({ prevState });
const token = scanToken(tokenState, dialect);
const token = scanToken(tokenState, dialect, paramTypes);
const nextToken = nextNonWhitespaceToken(tokenState, dialect);

if (!statementParser) {
Expand Down Expand Up @@ -1013,3 +1015,32 @@ function stateMachineStatementParser(
},
};
}

export function defaultParamTypesFor(dialect: Dialect): ParamTypes {
switch (dialect) {
case 'psql':
return {
numbered: ['$'],
};
case 'mssql':
return {
named: [':'],
};
case 'bigquery':
return {
positional: true,
named: ['@'],
quoted: ['@'],
};
case 'sqlite':
return {
positional: true,
numbered: ['?'],
named: [':', '@'],
};
default:
return {
positional: true,
};
}
}
148 changes: 106 additions & 42 deletions src/tokenizer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
* Tokenizer
*/

import type { Token, State, Dialect } from './defines';
import type { Token, State, Dialect, ParamTypes } from './defines';

type Char = string | null;

Expand Down Expand Up @@ -76,7 +76,11 @@ const ENDTOKENS: Record<string, Char> = {
'[': ']',
};

export function scanToken(state: State, dialect: Dialect = 'generic'): Token {
export function scanToken(
state: State,
dialect: Dialect = 'generic',
paramTypes: ParamTypes = { positional: true },
): Token {
const ch = read(state);

if (isWhitespace(ch)) {
Expand All @@ -95,8 +99,8 @@ export function scanToken(state: State, dialect: Dialect = 'generic'): Token {
return scanString(state, ENDTOKENS[ch]);
}

if (isParameter(ch, state, dialect)) {
return scanParameter(state, dialect);
if (isParameter(ch, state, paramTypes)) {
return scanParameter(state, dialect, paramTypes);
}

if (isDollarQuotedString(state)) {
Expand Down Expand Up @@ -253,52 +257,92 @@ function scanString(state: State, endToken: Char): Token {
};
}

function scanParameter(state: State, dialect: Dialect): Token {
if (['mysql', 'generic', 'sqlite'].includes(dialect)) {
return {
type: 'parameter',
value: state.input.slice(state.start, state.position + 1),
start: state.start,
end: state.start,
};
}
function getCustomParam(state: State, paramTypes: ParamTypes): string | null | undefined {
const matches = paramTypes?.custom
?.map((regex) => {
const reg = new RegExp(`^(?:${regex})`, 'u');
return reg.exec(state.input.slice(state.start));
})
.filter((value) => !!value)[0];

if (dialect === 'psql') {
let nextChar: Char;
return matches ? matches[0] : null;
}

do {
nextChar = read(state);
} while (nextChar !== null && !isNaN(Number(nextChar)) && !isWhitespace(nextChar));
function scanParameter(state: State, dialect: Dialect, paramTypes: ParamTypes): Token {
const curCh = state.input[state.start];
const nextChar = peek(state);
let matched = false;

if (paramTypes.numbered?.length && paramTypes.numbered.some((type) => type === curCh)) {
const endIndex = state.input
.slice(state.start + 1)
.split('')
.findIndex((val) => /^\W+/.test(val));
const maybeNumbers = state.input.slice(
state.start + 1,
endIndex > 0 ? state.start + endIndex + 1 : state.end + 1,
);
if (nextChar !== null && !isNaN(Number(nextChar)) && /^\d+$/.test(maybeNumbers)) {
let nextChar: Char = null;
do {
nextChar = read(state);
} while (nextChar !== null && !isNaN(Number(nextChar)) && !isWhitespace(nextChar));

if (nextChar !== null) unread(state);
matched = true;
}
}

if (nextChar !== null) unread(state);
if (!matched && paramTypes.named?.length && paramTypes.named.some((type) => type === curCh)) {
if (!isQuotedIdentifier(nextChar, dialect)) {
while (isAlphaNumeric(peek(state))) read(state);
matched = true;
}
}

const value = state.input.slice(state.start, state.position + 1);
if (!matched && paramTypes.quoted?.length && paramTypes.quoted.some((type) => type === curCh)) {
if (isQuotedIdentifier(nextChar, dialect)) {
const quoteChar = read(state) as string;
// end when we reach the end quote
while (
(isAlphaNumeric(peek(state)) || peek(state) === ' ') &&
peek(state) != ENDTOKENS[quoteChar]
) {
read(state);
}

return {
type: 'parameter',
value,
start: state.start,
end: state.start + value.length - 1,
};
// read the end quote
read(state);

matched = true;
}
}

if (dialect === 'mssql') {
while (isAlphaNumeric(peek(state))) read(state);
if (!matched && paramTypes.custom && paramTypes.custom.length) {
const custom = getCustomParam(state, paramTypes);

if (custom) {
read(state, custom.length);
matched = true;
}
}
const value = state.input.slice(state.start, state.position + 1);

const value = state.input.slice(state.start, state.position + 1);
if (!matched && !paramTypes.positional && curCh !== '?') {
// not positional, panic
return {
type: 'parameter',
value,
type: 'unknown',
value: value,
start: state.start,
end: state.start + value.length - 1,
};
}

return {
type: 'parameter',
value: 'unknown',
value,
start: state.start,
end: state.end,
end: state.start + value.length - 1,
};
}

Expand Down Expand Up @@ -413,18 +457,38 @@ function isString(ch: Char, dialect: Dialect): boolean {
return stringStart.includes(ch);
}

function isParameter(ch: Char, state: State, dialect: Dialect): boolean {
let pStart = '?'; // ansi standard - sqlite, mysql
if (dialect === 'psql') {
pStart = '$';
const nextChar = peek(state);
if (nextChar === null || isNaN(Number(nextChar))) {
return false;
function isCustomParam(state: State, customParamType: NonNullable<ParamTypes['custom']>): boolean {
return customParamType.some((regex) => {
const reg = new RegExp(`^(?:${regex})`, 'uy');
return reg.test(state.input.slice(state.start));
});
}

function isParameter(ch: Char, state: State, paramTypes: ParamTypes): boolean {
if (!ch) {
return false;
}
const nextChar = peek(state);
if (paramTypes.positional && ch === '?') return true;

if (paramTypes.numbered?.length && paramTypes.numbered.some((type) => ch === type)) {
if (nextChar !== null && !isNaN(Number(nextChar))) {
return true;
}
}
if (dialect === 'mssql') pStart = ':';

return ch === pStart;
if (
(paramTypes.named?.length && paramTypes.named.some((type) => type === ch)) ||
(paramTypes.quoted?.length && paramTypes.quoted.some((type) => type === ch))
) {
return true;
}

if (paramTypes.custom?.length && isCustomParam(state, paramTypes.custom)) {
return true;
}

return false;
}

function isDollarQuotedString(state: State): boolean {
Expand Down
44 changes: 44 additions & 0 deletions test/index.spec.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { Dialect, getExecutionType, identify } from '../src/index';
import { expect } from 'chai';
import { ParamTypes } from '../src/defines';

describe('identify', () => {
it('should throw error for invalid dialect', () => {
Expand All @@ -22,6 +23,49 @@ describe('identify', () => {
]);
});

it('should identify custom parameters', () => {
const paramTypes: ParamTypes = {
positional: true,
numbered: ['$'],
named: [':'],
quoted: [':'],
custom: ['\\{[a-zA-Z0-9_]+\\}'],
};
const query = `SELECT * FROM foo WHERE bar = ? AND baz = $1 AND fizz = :fizzz AND buzz = :"buzz buzz" AND foo2 = {fooo}`;

expect(identify(query, { dialect: 'psql', paramTypes })).to.eql([
{
start: 0,
end: 104,
text: query,
type: 'SELECT',
executionType: 'LISTING',
parameters: ['?', '$1', ':fizzz', ':"buzz buzz"', '{fooo}'],
tables: [],
},
]);
});

it('custom params should override defaults for dialect', () => {
const paramTypes: ParamTypes = {
positional: true,
};

const query = 'SELECT * FROM foo WHERE bar = $1 AND bar = :named AND fizz = :`quoted`';

expect(identify(query, { dialect: 'psql', paramTypes })).to.eql([
{
start: 0,
end: 69,
text: query,
type: 'SELECT',
executionType: 'LISTING',
parameters: [],
tables: [],
},
]);
});

it('should identify tables in simple for basic cases', () => {
expect(
identify('SELECT * FROM foo JOIN bar ON foo.id = bar.id', { identifyTables: true }),
Expand Down
Loading

0 comments on commit e15e335

Please sign in to comment.