Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Config option to allow custom parameter syntax #78

Merged
merged 16 commits into from
Apr 30, 2024
Merged
10 changes: 10 additions & 0 deletions src/defines.ts
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,20 @@ export type StatementType =

export type ExecutionType = 'LISTING' | 'MODIFICATION' | 'INFORMATION' | 'ANON_BLOCK' | 'UNKNOWN';

export interface ParamTypes {
positional?: boolean;
numbered?: ('?' | ':' | '$')[];
named?: (':' | '@' | '$')[];
quoted?: (':' | '@' | '$')[];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Confirming that there is not any error checking if someone passes a value that's not in this array of values? I'm fine merging this as is, and then making my own PR doing this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct. There is no error checking besides typescript type checking

// regex for identifying that it is a param
custom?: string[];
}

export interface IdentifyOptions {
strict?: boolean;
dialect?: Dialect;
identifyTables?: boolean;
paramTypes?: ParamTypes;
}

export interface IdentifyResult {
Expand Down
10 changes: 7 additions & 3 deletions src/index.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { parse, EXECUTION_TYPES } from './parser';
import { parse, EXECUTION_TYPES, defaultParamTypesFor } from './parser';
import { DIALECTS } from './defines';
import type { ExecutionType, IdentifyOptions, IdentifyResult, StatementType } from './defines';

Expand All @@ -21,7 +21,11 @@ export function identify(query: string, options: IdentifyOptions = {}): Identify
throw new Error(`Unknown dialect. Allowed values: ${DIALECTS.join(', ')}`);
}

const result = parse(query, isStrict, dialect, options.identifyTables);
// Default parameter types for each dialect
const paramTypes = options.paramTypes || defaultParamTypesFor(dialect);

const result = parse(query, isStrict, dialect, options.identifyTables, paramTypes);
const sort = dialect === 'psql' && !options.paramTypes;

return result.body.map((statement) => {
const result: IdentifyResult = {
Expand All @@ -31,7 +35,7 @@ export function identify(query: string, options: IdentifyOptions = {}): Identify
type: statement.type,
executionType: statement.executionType,
// we want to sort the postgres params: $1 $2 $3, regardless of the order they appear
parameters: dialect === 'psql' ? statement.parameters.sort() : statement.parameters,
parameters: sort ? statement.parameters.sort() : statement.parameters,
tables: statement.tables || [],
};
return result;
Expand Down
33 changes: 32 additions & 1 deletion src/parser.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import type {
Step,
ParseResult,
ConcreteStatement,
ParamTypes,
} from './defines';

interface StatementParser {
Expand Down Expand Up @@ -144,6 +145,7 @@ export function parse(
isStrict = true,
dialect: Dialect = 'generic',
identifyTables = false,
paramTypes?: ParamTypes,
): ParseResult {
const topLevelState = initState({ input });
const topLevelStatement: ParseResult = {
Expand Down Expand Up @@ -174,7 +176,7 @@ export function parse(

while (prevState.position < topLevelState.end) {
const tokenState = initState({ prevState });
const token = scanToken(tokenState, dialect);
const token = scanToken(tokenState, dialect, paramTypes);
const nextToken = nextNonWhitespaceToken(tokenState, dialect);

if (!statementParser) {
Expand Down Expand Up @@ -1013,3 +1015,32 @@ function stateMachineStatementParser(
},
};
}

export function defaultParamTypesFor(dialect: Dialect): ParamTypes {
switch (dialect) {
case 'psql':
return {
numbered: ['$'],
};
case 'mssql':
return {
named: [':'],
};
case 'bigquery':
return {
positional: true,
named: ['@'],
quoted: ['@'],
};
case 'sqlite':
return {
positional: true,
numbered: ['?'],
named: [':', '@'],
};
default:
return {
positional: true,
};
}
}
148 changes: 106 additions & 42 deletions src/tokenizer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
* Tokenizer
*/

import type { Token, State, Dialect } from './defines';
import type { Token, State, Dialect, ParamTypes } from './defines';

type Char = string | null;

Expand Down Expand Up @@ -76,7 +76,11 @@ const ENDTOKENS: Record<string, Char> = {
'[': ']',
};

export function scanToken(state: State, dialect: Dialect = 'generic'): Token {
export function scanToken(
state: State,
dialect: Dialect = 'generic',
paramTypes: ParamTypes = { positional: true },
): Token {
const ch = read(state);

if (isWhitespace(ch)) {
Expand All @@ -95,8 +99,8 @@ export function scanToken(state: State, dialect: Dialect = 'generic'): Token {
return scanString(state, ENDTOKENS[ch]);
}

if (isParameter(ch, state, dialect)) {
return scanParameter(state, dialect);
if (isParameter(ch, state, paramTypes)) {
return scanParameter(state, dialect, paramTypes);
}

if (isDollarQuotedString(state)) {
Expand Down Expand Up @@ -253,52 +257,92 @@ function scanString(state: State, endToken: Char): Token {
};
}

function scanParameter(state: State, dialect: Dialect): Token {
if (['mysql', 'generic', 'sqlite'].includes(dialect)) {
return {
type: 'parameter',
value: state.input.slice(state.start, state.position + 1),
start: state.start,
end: state.start,
};
}
function getCustomParam(state: State, paramTypes: ParamTypes): string | null | undefined {
const matches = paramTypes?.custom
?.map((regex) => {
const reg = new RegExp(`^(?:${regex})`, 'u');
return reg.exec(state.input.slice(state.start));
})
.filter((value) => !!value)[0];

if (dialect === 'psql') {
let nextChar: Char;
return matches ? matches[0] : null;
}

do {
nextChar = read(state);
} while (nextChar !== null && !isNaN(Number(nextChar)) && !isWhitespace(nextChar));
function scanParameter(state: State, dialect: Dialect, paramTypes: ParamTypes): Token {
const curCh = state.input[state.start];
const nextChar = peek(state);
let matched = false;

if (paramTypes.numbered?.length && paramTypes.numbered.some((type) => type === curCh)) {
const endIndex = state.input
.slice(state.start + 1)
.split('')
.findIndex((val) => /^\W+/.test(val));
const maybeNumbers = state.input.slice(
state.start + 1,
endIndex > 0 ? state.start + endIndex + 1 : state.end + 1,
);
if (nextChar !== null && !isNaN(Number(nextChar)) && /^\d+$/.test(maybeNumbers)) {
let nextChar: Char = null;
do {
nextChar = read(state);
} while (nextChar !== null && !isNaN(Number(nextChar)) && !isWhitespace(nextChar));

if (nextChar !== null) unread(state);
matched = true;
}
}

if (nextChar !== null) unread(state);
if (!matched && paramTypes.named?.length && paramTypes.named.some((type) => type === curCh)) {
if (!isQuotedIdentifier(nextChar, dialect)) {
while (isAlphaNumeric(peek(state))) read(state);
matched = true;
}
}

const value = state.input.slice(state.start, state.position + 1);
if (!matched && paramTypes.quoted?.length && paramTypes.quoted.some((type) => type === curCh)) {
if (isQuotedIdentifier(nextChar, dialect)) {
const quoteChar = read(state) as string;
// end when we reach the end quote
while (
(isAlphaNumeric(peek(state)) || peek(state) === ' ') &&
peek(state) != ENDTOKENS[quoteChar]
) {
read(state);
}

return {
type: 'parameter',
value,
start: state.start,
end: state.start + value.length - 1,
};
// read the end quote
read(state);

matched = true;
}
}

if (dialect === 'mssql') {
while (isAlphaNumeric(peek(state))) read(state);
if (!matched && paramTypes.custom && paramTypes.custom.length) {
const custom = getCustomParam(state, paramTypes);

if (custom) {
read(state, custom.length);
matched = true;
}
}
const value = state.input.slice(state.start, state.position + 1);

const value = state.input.slice(state.start, state.position + 1);
if (!matched && !paramTypes.positional && curCh !== '?') {
// not positional, panic
return {
type: 'parameter',
value,
type: 'unknown',
value: value,
start: state.start,
end: state.start + value.length - 1,
};
}

return {
type: 'parameter',
value: 'unknown',
value,
start: state.start,
end: state.end,
end: state.start + value.length - 1,
};
}

Expand Down Expand Up @@ -413,18 +457,38 @@ function isString(ch: Char, dialect: Dialect): boolean {
return stringStart.includes(ch);
}

function isParameter(ch: Char, state: State, dialect: Dialect): boolean {
let pStart = '?'; // ansi standard - sqlite, mysql
if (dialect === 'psql') {
pStart = '$';
const nextChar = peek(state);
if (nextChar === null || isNaN(Number(nextChar))) {
return false;
function isCustomParam(state: State, customParamType: NonNullable<ParamTypes['custom']>): boolean {
return customParamType.some((regex) => {
const reg = new RegExp(`^(?:${regex})`, 'uy');
return reg.test(state.input.slice(state.start));
});
}

function isParameter(ch: Char, state: State, paramTypes: ParamTypes): boolean {
if (!ch) {
return false;
}
const nextChar = peek(state);
if (paramTypes.positional && ch === '?') return true;

if (paramTypes.numbered?.length && paramTypes.numbered.some((type) => ch === type)) {
if (nextChar !== null && !isNaN(Number(nextChar))) {
return true;
}
}
if (dialect === 'mssql') pStart = ':';

return ch === pStart;
if (
(paramTypes.named?.length && paramTypes.named.some((type) => type === ch)) ||
(paramTypes.quoted?.length && paramTypes.quoted.some((type) => type === ch))
) {
return true;
}

if (paramTypes.custom?.length && isCustomParam(state, paramTypes.custom)) {
return true;
}

return false;
}

function isDollarQuotedString(state: State): boolean {
Expand Down
44 changes: 44 additions & 0 deletions test/index.spec.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { Dialect, getExecutionType, identify } from '../src/index';
import { expect } from 'chai';
import { ParamTypes } from '../src/defines';

describe('identify', () => {
it('should throw error for invalid dialect', () => {
Expand All @@ -22,6 +23,49 @@ describe('identify', () => {
]);
});

it('should identify custom parameters', () => {
not-night-but marked this conversation as resolved.
Show resolved Hide resolved
const paramTypes: ParamTypes = {
positional: true,
numbered: ['$'],
named: [':'],
quoted: [':'],
custom: ['\\{[a-zA-Z0-9_]+\\}'],
};
const query = `SELECT * FROM foo WHERE bar = ? AND baz = $1 AND fizz = :fizzz AND buzz = :"buzz buzz" AND foo2 = {fooo}`;

expect(identify(query, { dialect: 'psql', paramTypes })).to.eql([
{
start: 0,
end: 104,
text: query,
type: 'SELECT',
executionType: 'LISTING',
parameters: ['?', '$1', ':fizzz', ':"buzz buzz"', '{fooo}'],
tables: [],
},
]);
});

it('custom params should override defaults for dialect', () => {
const paramTypes: ParamTypes = {
positional: true,
};

const query = 'SELECT * FROM foo WHERE bar = $1 AND bar = :named AND fizz = :`quoted`';

expect(identify(query, { dialect: 'psql', paramTypes })).to.eql([
{
start: 0,
end: 69,
text: query,
type: 'SELECT',
executionType: 'LISTING',
parameters: [],
tables: [],
},
]);
});

it('should identify tables in simple for basic cases', () => {
expect(
identify('SELECT * FROM foo JOIN bar ON foo.id = bar.id', { identifyTables: true }),
Expand Down
Loading