Skip to content

Commit

Permalink
Merge branch 'main' into feat/cross-db-params
Browse files Browse the repository at this point in the history
  • Loading branch information
azmy60 committed Feb 27, 2024
2 parents f0eec03 + c17b973 commit 2f8c611
Show file tree
Hide file tree
Showing 10 changed files with 192 additions and 33 deletions.
5 changes: 2 additions & 3 deletions package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "sql-query-identifier",
"version": "2.6.0",
"version": "2.7.0",
"description": "A SQL query identifier",
"license": "MIT",
"main": "lib/index.js",
Expand Down Expand Up @@ -34,12 +34,11 @@
"prettier": "^2.3.2",
"terser-webpack-plugin": "^5.1.1",
"ts-loader": "^8.0.17",
"ts-node": "^10.9.1",
"ts-node": "^10.9.2",
"typescript": "^4.1.5",
"webpack": "^5.11.1",
"webpack-cli": "^4.3.1"
},
"dependencies": {},
"engines": {
"node": ">= 10.13"
}
Expand Down
5 changes: 5 additions & 0 deletions src/defines.ts
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ export type ExecutionType = 'LISTING' | 'MODIFICATION' | 'INFORMATION' | 'ANON_B
export interface IdentifyOptions {
strict?: boolean;
dialect?: Dialect;
identifyTables?: boolean;
enableCrossDBParameters?: boolean;
}

Expand All @@ -89,6 +90,7 @@ export interface IdentifyResult {
type: StatementType;
executionType: ExecutionType;
parameters: string[];
tables: string[];
}

export interface Statement {
Expand All @@ -102,6 +104,8 @@ export interface Statement {
algorithm?: number;
sqlSecurity?: number;
parameters: string[];
tables: string[];
isCte?: boolean;
}

export interface ConcreteStatement extends Statement {
Expand All @@ -125,6 +129,7 @@ export interface Token {
| 'semicolon'
| 'keyword'
| 'parameter'
| 'table'
| 'unknown';
value: string;
start: number;
Expand Down
13 changes: 8 additions & 5 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,18 @@ export type {
export function identify(query: string, options: IdentifyOptions = {}): IdentifyResult[] {
const isStrict = typeof options.strict === 'undefined' ? true : options.strict === true;
const dialect = typeof options.dialect === 'undefined' ? 'generic' : options.dialect;
const enableCrossDBParameters =
typeof options.enableCrossDBParameters === 'undefined'
? false
: options.enableCrossDBParameters;

if (!DIALECTS.includes(dialect)) {
throw new Error(`Unknown dialect. Allowed values: ${DIALECTS.join(', ')}`);
}

const result = parse(query, isStrict, dialect, enableCrossDBParameters);
const result = parse(
query,
isStrict,
dialect,
options.identifyTables,
options.enableCrossDBParameters,
);

return result.body.map((statement) => {
const result: IdentifyResult = {
Expand All @@ -36,6 +38,7 @@ export function identify(query: string, options: IdentifyOptions = {}): Identify
executionType: statement.executionType,
// we want to sort the postgres params: $1 $2 $3, regardless of the order they appear
parameters: dialect === 'psql' ? statement.parameters.sort() : statement.parameters,
tables: statement.tables || [],
};
return result;
});
Expand Down
69 changes: 45 additions & 24 deletions src/parser.ts
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,10 @@ const statementsWithEnds = [
'UNKNOWN',
];

// keywords that come directly before a table name.
// v1 - keeping it very simple.
const PRE_TABLE_KEYWORDS = /^from$|^join$|^into$/i;

const blockOpeners: Record<Dialect, string[]> = {
generic: ['BEGIN', 'CASE'],
psql: ['BEGIN', 'CASE', 'LOOP', 'IF'],
Expand All @@ -111,13 +115,15 @@ const blockOpeners: Record<Dialect, string[]> = {
interface ParseOptions {
isStrict: boolean;
dialect: Dialect;
identifyTables: boolean;
}

function createInitialStatement(): Statement {
return {
start: -1,
end: 0,
parameters: [],
tables: [],
};
}

Expand All @@ -141,6 +147,7 @@ export function parse(
input: string,
isStrict = true,
dialect: Dialect = 'generic',
identifyTables = false,
enableCrossDBParameters = false,
): ParseResult {
const topLevelState = initState({ input });
Expand Down Expand Up @@ -180,7 +187,6 @@ export function parse(
if (!cteState.isCte && ignoreOutsideBlankTokens.includes(token.type)) {
topLevelStatement.tokens.push(token);
prevState = tokenState;
continue;
} else if (
!cteState.isCte &&
token.type === 'keyword' &&
Expand All @@ -190,7 +196,7 @@ export function parse(
topLevelStatement.tokens.push(token);
cteState.state = tokenState;
prevState = tokenState;
continue;

// If we're scanning in a CTE, handle someone putting a semicolon anywhere (after 'with',
// after semicolon, etc.) along it to "early terminate".
} else if (cteState.isCte && token.type === 'semicolon') {
Expand All @@ -202,12 +208,12 @@ export function parse(
type: 'UNKNOWN',
executionType: 'UNKNOWN',
parameters: [],
tables: [],
});
cteState.isCte = false;
cteState.asSeen = false;
cteState.statementEnd = false;
cteState.parens = 0;
continue;
} else if (cteState.isCte && !cteState.statementEnd) {
if (cteState.asSeen) {
if (token.value === '(') {
Expand All @@ -224,14 +230,13 @@ export function parse(

topLevelStatement.tokens.push(token);
prevState = tokenState;
continue;
} else if (cteState.isCte && cteState.statementEnd && token.value === ',') {
cteState.asSeen = false;
cteState.statementEnd = false;

topLevelStatement.tokens.push(token);
prevState = tokenState;
continue;

// Ignore blank tokens after the end of the CTE till start of statement
} else if (
cteState.isCte &&
Expand All @@ -240,28 +245,32 @@ export function parse(
) {
topLevelStatement.tokens.push(token);
prevState = tokenState;
continue;
} else {
statementParser = createStatementParserByToken(token, nextToken, {
isStrict,
dialect,
identifyTables,
});
if (cteState.isCte) {
statementParser.getStatement().start = cteState.state.start;
statementParser.getStatement().isCte = true;
cteState.isCte = false;
cteState.asSeen = false;
cteState.statementEnd = false;
}
}
} else {
statementParser.addToken(token, nextToken);
topLevelStatement.tokens.push(token);
prevState = tokenState;

statementParser = createStatementParserByToken(token, nextToken, { isStrict, dialect });
if (cteState.isCte) {
statementParser.getStatement().start = cteState.state.start;
cteState.isCte = false;
cteState.asSeen = false;
cteState.statementEnd = false;
const statement = statementParser.getStatement();
if (statement.endStatement) {
statement.end = token.end;
topLevelStatement.body.push(statement as ConcreteStatement);
statementParser = null;
}
}

statementParser.addToken(token, nextToken);
topLevelStatement.tokens.push(token);
prevState = tokenState;

const statement = statementParser.getStatement();
if (statement.endStatement) {
statement.end = token.end;
topLevelStatement.body.push(statement as ConcreteStatement);
statementParser = null;
}
}

// last statement without ending key
Expand Down Expand Up @@ -717,7 +726,7 @@ function createUnknownStatementParser(options: ParseOptions) {
function stateMachineStatementParser(
statement: Statement,
steps: Step[],
{ isStrict, dialect }: ParseOptions,
{ isStrict, dialect, identifyTables }: ParseOptions,
): StatementParser {
let currentStepIndex = 0;
let prevToken: Token | undefined;
Expand Down Expand Up @@ -817,6 +826,18 @@ function stateMachineStatementParser(
}
}

if (
identifyTables &&
PRE_TABLE_KEYWORDS.exec(token.value) &&
!statement.isCte &&
statement.type?.match(/SELECT|INSERT/)
) {
const tableValue = nextToken.value;
if (!statement.tables.includes(tableValue)) {
statement.tables.push(tableValue);
}
}

if (
token.type === 'parameter' &&
(token.value === '?' || !statement.parameters.includes(token.value))
Expand Down
4 changes: 4 additions & 0 deletions test/identifier/inner-statements.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ describe('identifier', () => {
type: 'INSERT',
executionType: 'MODIFICATION',
parameters: [],
tables: [],
},
];

Expand All @@ -34,6 +35,7 @@ describe('identifier', () => {
type: 'INSERT',
executionType: 'MODIFICATION',
parameters: [],
tables: [],
},
];

Expand All @@ -54,6 +56,7 @@ describe('identifier', () => {
type: 'INSERT',
executionType: 'MODIFICATION',
parameters: [],
tables: [],
},
];

Expand All @@ -75,6 +78,7 @@ describe('identifier', () => {
type: 'INSERT',
executionType: 'MODIFICATION',
parameters: [],
tables: [],
},
];

Expand Down
Loading

0 comments on commit 2f8c611

Please sign in to comment.