diff --git a/source/optims/constantFolding.ts b/source/optims/constantFolding.ts index 8e90243..7064a70 100644 --- a/source/optims/constantFolding.ts +++ b/source/optims/constantFolding.ts @@ -17,10 +17,20 @@ type RefMaps = { childs: RefMap } -export type PredicateOnRule = (rule: [RuleName, RuleNode]) => boolean +export type PredicateOnRule = (rule: RuleNode) => boolean +/** + * Parameters for the constant folding optimization pass. + * + * @field toAvoid A predicate that returns true if the rule should be avoided to be folded, + * if not present, all rules will be folded. + * @field toKeep A predicate that returns true if the rule should be kept AFTER being folded, + * if not present, all folded rules will be kept. + * @field isFoldedAttr The attribute name to use to mark a rule as folded, default to 'optimized'. + */ export type FoldingParams = { - // The attribute name to use to mark a rule as folded, default to 'optimized'. + toAvoid?: PredicateOnRule + toKeep?: PredicateOnRule isFoldedAttr?: string } @@ -58,7 +68,6 @@ function addMapEntry(map: RefMap, key: RuleName, values: Set) { function initFoldingCtx( engine: Engine, - toKeep?: PredicateOnRule, foldingParams?: FoldingParams, ): FoldingCtx { const refs: RefMaps = { @@ -138,9 +147,9 @@ function initFoldingCtx( engine, parsedRules, refs, - toKeep, unfoldableRules, params: { + ...foldingParams, isFoldedAttr: foldingParams?.isFoldedAttr ?? 'optimized', }, } @@ -161,6 +170,7 @@ function isFoldable(ctx: FoldingCtx, rule: RuleNode): boolean { return ( rule !== undefined && + !ctx.params.toAvoid?.(rule) && !unfoldableAttr.find((attr) => attr in rule.rawNode) && !ctx.unfoldableRules.has(rule.dottedName) && !childInContext @@ -184,20 +194,23 @@ function searchAndReplaceConstantValueInParentRefs( if (refs) { for (const parentName of refs) { const parentRule = ctx.parsedRules[parentName] - const newRule = traverseASTNode( - transformAST((node, _) => { - if (node.nodeKind === 'reference' && node.dottedName === ruleName) { - return constantNode - } - }), - parentRule, - ) as RuleNode - - if (newRule !== undefined) { - ctx.parsedRules[parentName] = newRule - ctx.parsedRules[parentName].rawNode[ctx.params.isFoldedAttr] = - 'partially' - removeInMap(ctx.refs.parents, ruleName, parentName) + + if (!ctx.params.toAvoid?.(parentRule)) { + const newRule = traverseASTNode( + transformAST((node, _) => { + if (node.nodeKind === 'reference' && node.dottedName === ruleName) { + return constantNode + } + }), + parentRule, + ) as RuleNode + + if (newRule !== undefined) { + ctx.parsedRules[parentName] = newRule + ctx.parsedRules[parentName].rawNode[ctx.params.isFoldedAttr] = + 'partially' + removeInMap(ctx.refs.parents, ruleName, parentName) + } } } } @@ -227,7 +240,7 @@ function tryToDeleteRule(ctx: FoldingCtx, dottedName: RuleName): boolean { const ruleNode = ctx.parsedRules[dottedName] if ( - (ctx.toKeep === undefined || !ctx.toKeep([dottedName, ruleNode])) && + (ctx.params.toKeep === undefined || !ctx.params.toKeep(ruleNode)) && isFoldable(ctx, ruleNode) ) { removeRuleFromRefs(ctx.refs.parents, dottedName) @@ -420,18 +433,15 @@ function copyFullParsedRules(engine: Engine): ParsedRules { * Applies a constant folding optimisation pass on parsed rules of [engine]. * * @param engine The engine instantiated with the rules to fold. - * @param toKeep A predicate that returns true if the rule should be kept, if not present, - * all folded rules will be kept. * @param params The folding parameters. * * @returns The parsed rules with constant folded rules. */ export function constantFolding( engine: Engine, - toKeep?: PredicateOnRule, params?: FoldingParams, ): ParsedRules { - let ctx = initFoldingCtx(engine, toKeep, params) + let ctx = initFoldingCtx(engine, params) let nbRules = Object.keys(ctx.parsedRules).length let nbRulesBefore = undefined @@ -448,14 +458,14 @@ export function constantFolding( nbRules = Object.keys(ctx.parsedRules).length } - if (toKeep) { + if (ctx.params.toKeep) { for (const ruleName in ctx.parsedRules) { const ruleNode = ctx.parsedRules[ruleName] const parents = ctx.refs.parents.get(ruleName) if ( isFoldable(ctx, ruleNode) && - !toKeep([ruleName, ruleNode]) && + !ctx.params.toKeep(ruleNode) && (!parents || parents?.size === 0) ) { delete ctx.parsedRules[ruleName] diff --git a/test/optims/constantFolding.test.ts b/test/optims/constantFolding.test.ts index d6fc4a4..f61da87 100644 --- a/test/optims/constantFolding.test.ts +++ b/test/optims/constantFolding.test.ts @@ -1,4 +1,4 @@ -import Engine from 'publicodes' +import Engine, { RuleNode } from 'publicodes' import { serializeParsedRules } from '../../source' import { RuleName, RawRules, disabledLogger } from '../../source/commons' import { constantFolding } from '../../source/optims/' @@ -7,10 +7,11 @@ import { callWithEngine } from '../utils.test' function constantFoldingWith(rawRules: any, targets?: RuleName[]): RawRules { const res = callWithEngine( (engine) => - constantFolding( - engine, - targets ? ([ruleName, _]) => targets.includes(ruleName) : undefined, - ), + constantFolding(engine, { + toKeep: targets + ? (rule: RuleNode) => targets.includes(rule.dottedName) + : undefined, + }), rawRules, ) return serializeParsedRules(res) @@ -39,7 +40,7 @@ describe('Constant folding [meta]', () => { const baseParsedRules = engine.getParsedRules() const serializedBaseParsedRules = serializeParsedRules(baseParsedRules) - constantFolding(engine, ([ruleName, _]) => ruleName === 'ruleA') + constantFolding(engine, { toKeep: (rule) => rule.dottedName === 'ruleA' }) const shouldNotBeModifiedRules = engine.getParsedRules() const serializedShouldNotBeModifiedRules = serializeParsedRules( @@ -51,6 +52,45 @@ describe('Constant folding [meta]', () => { serializedShouldNotBeModifiedRules, ) }) + + it('should not fold a rule specified in the [toAvoid] option', () => { + const rawRules = { + ruleA: { + titre: 'Rule A', + valeur: 'B . C * D', + }, + ruleB: { + valeur: 'ruleA . B . C * 3', + }, + 'ruleA . D': { + question: "What's the value of D?", + }, + 'ruleA . B . C': { + valeur: '10', + }, + } + const engine = new Engine(rawRules, { + logger: disabledLogger, + allowOrphanRules: true, + }) + const foldedRules = serializeParsedRules( + constantFolding(engine, { + toAvoid: (rule) => rule.dottedName === 'ruleB', + }), + ) + expect(foldedRules).toEqual({ + ...rawRules, + ruleA: { + optimized: 'partially', + titre: 'Rule A', + valeur: '10 * D', + }, + 'ruleA . B . C': { + optimized: 'fully', + valeur: 10, + }, + }) + }) }) describe('Constant folding [base]', () => {