Skip to content

Commit

Permalink
Nested type tabulator fix (#884)
Browse files Browse the repository at this point in the history
* Fix tabulator for nested report types

* Add more tests

* Fix test
  • Loading branch information
hugohills-regnosys authored Dec 13, 2024
1 parent f1f8f37 commit 6bd6e67
Show file tree
Hide file tree
Showing 3 changed files with 156 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ class TabulatorGenerator {
} else {
rawAttrType
}
if (attrType instanceof RDataType) {
needsTabulator(attrType, visited)
if (attrType instanceof RDataType && needsTabulator(attrType as RDataType, visited)) {
true
} else {
ruleMap.containsKey(attr)
}
Expand Down Expand Up @@ -198,12 +198,7 @@ class TabulatorGenerator {
def generateTabulatorForReportData(IFileSystemAccess2 fsa, RDataType type, Optional<RosettaExternalRuleSource> ruleSource) {
val context = getReportTabulatorContext(type, ruleSource)
if (context.needsTabulator(type)) {
val tabulatorClass = type.EObject.toTabulatorJavaClass(ruleSource)
val topScope = new JavaScope(tabulatorClass.packageName)

val classBody = type.tabulatorClassBody(context, topScope, tabulatorClass)
val content = buildClass(tabulatorClass.packageName, classBody, topScope)
fsa.generateFile(tabulatorClass.canonicalName.withForwardSlashes + ".java", content)
recursivelyGenerateTabulators(fsa, type, context, newHashSet)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,11 @@ class FunctionGeneratorHelper {
}

def createFunc(Map<String, Class<?>> classes, String funcName) {
injector.getInstance(classes.get(rootPackage.functions + '.' + funcName)) as RosettaFunction
createFunc(classes, rootPackage.functions.toString, funcName)
}

def createFunc(Map<String, Class<?>> classes, String namespace, String funcName) {
injector.getInstance(classes.get(namespace + '.' + funcName)) as RosettaFunction
}

def <T> invokeFunc(RosettaFunction func, Class<T> resultClass, Object... inputs) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@ package com.regnosys.rosetta.generator.java.rule

import com.google.inject.Guice
import com.google.inject.Injector
import com.regnosys.rosetta.generator.java.function.FunctionGeneratorHelper
import com.regnosys.rosetta.tests.RosettaInjectorProvider
import com.regnosys.rosetta.tests.util.CodeGeneratorTestHelper
import com.regnosys.rosetta.tests.util.ModelHelper
import com.regnosys.rosetta.validation.RosettaIssueCodes
import com.rosetta.model.lib.RosettaModelObject
import java.util.Map
import javax.inject.Inject
import org.eclipse.xtext.testing.InjectWith
import org.eclipse.xtext.testing.extensions.InjectionExtension
import org.eclipse.xtext.testing.validation.ValidationTestHelper
Expand All @@ -17,7 +20,6 @@ import org.junit.jupiter.api.^extension.ExtendWith
import static com.regnosys.rosetta.rosetta.expression.ExpressionPackage.Literals.*
import static org.hamcrest.MatcherAssert.*
import static org.junit.jupiter.api.Assertions.*
import javax.inject.Inject

@InjectWith(RosettaInjectorProvider)
@ExtendWith(InjectionExtension)
Expand All @@ -26,7 +28,8 @@ class RosettaRuleGeneratorTest {
@Inject extension CodeGeneratorTestHelper
@Inject extension ModelHelper
@Inject extension ValidationTestHelper

@Inject extension FunctionGeneratorHelper

static final CharSequence REPORT_TYPES = '''
namespace com.rosetta.test.model
Expand Down Expand Up @@ -631,6 +634,149 @@ class RosettaRuleGeneratorTest {
code.compileToClasses
}

@Test
def void parseSimpleReportForTypeWithExternalRuleReferencesWithEmptyAs() {
val model = #[
REPORT_TYPES,
'''
namespace com.rosetta.test.model
eligibility rule FooRule from Bar:
filter bar1 exists
reporting rule BarToBazReport from Bar:
extract BazReport {
qux: QuxReport {
attr: item -> bar1
}
}
as "Label 1"
reporting rule EmptyWithAs from Bar:
empty
as "Label 2"
''',
'''
namespace com.rosetta.test.model
body Authority TEST_REG
corpus TEST_REG MiFIR
report TEST_REG MiFIR in T+1
from Bar
when FooRule
with type BarReport
with source RuleSource
type BarReport:
baz BazReport (1..1)
type BazReport:
qux QuxReport (1..1)
type QuxReport:
attr string (1..1)
rule source RuleSource {
BarReport:
+ baz
[ruleReference BarToBazReport]
QuxReport:
+ attr
[ruleReference EmptyWithAs]
}
''']
val code = model.generateCode
//println(code)
val reportJava = code.get("com.rosetta.test.model.reports.TEST_REGMiFIRReportFunction")
try {
assertThat(reportJava, CoreMatchers.notNullValue())
val expected = '''
package com.rosetta.test.model.reports;
import com.google.inject.ImplementedBy;
import com.rosetta.model.lib.annotations.RosettaReport;
import com.rosetta.model.lib.functions.ModelObjectValidator;
import com.rosetta.model.lib.reports.ReportFunction;
import com.rosetta.test.model.Bar;
import com.rosetta.test.model.BarReport;
import com.rosetta.test.model.BarReport.BarReportBuilder;
import java.util.Optional;
import javax.inject.Inject;
@RosettaReport(namespace="com.rosetta.test.model", body="TEST_REG", corpusList={"MiFIR"})
@ImplementedBy(TEST_REGMiFIRReportFunction.TEST_REGMiFIRReportFunctionDefault.class)
public abstract class TEST_REGMiFIRReportFunction implements ReportFunction<Bar, BarReport> {
@Inject protected ModelObjectValidator objectValidator;
// RosettaFunction dependencies
//
@Inject protected BarToBazReportRule barToBazReportRule;
/**
* @param input
* @return output
*/
@Override
public BarReport evaluate(Bar input) {
BarReport.BarReportBuilder outputBuilder = doEvaluate(input);
final BarReport output;
if (outputBuilder == null) {
output = null;
} else {
output = outputBuilder.build();
objectValidator.validate(BarReport.class, output);
}
return output;
}
protected abstract BarReport.BarReportBuilder doEvaluate(Bar input);
public static class TEST_REGMiFIRReportFunctionDefault extends TEST_REGMiFIRReportFunction {
@Override
protected BarReport.BarReportBuilder doEvaluate(Bar input) {
BarReport.BarReportBuilder output = BarReport.builder();
return assignOutput(output, input);
}
protected BarReport.BarReportBuilder assignOutput(BarReport.BarReportBuilder output, Bar input) {
output
.setBaz(barToBazReportRule.evaluate(input));
return Optional.ofNullable(output)
.map(o -> o.prune())
.orElse(null);
}
}
}
'''
assertEquals(expected, reportJava)

} finally {
}
val classes = code.compileToClasses

val test = classes.createFunc("com.rosetta.test.model.reports", "TEST_REGMiFIRReportFunction")

val input = classes.createInstanceUsingBuilder("Bar", #{"bar1" -> "bar1Value"})

val output = test.invokeFunc(RosettaModelObject, input)

// expected output
val expectedQuxReport = classes.createInstanceUsingBuilder("QuxReport", #{"attr" -> "bar1Value"})
val expectedBazReport = classes.createInstanceUsingBuilder("BazReport", #{"qux" -> expectedQuxReport})
val expectedBarReport = classes.createInstanceUsingBuilder("BarReport", #{"baz" -> expectedBazReport})

assertEquals(expectedBarReport, output)
}

@Test
def void parseSimpleReportWithEmptyType() {
val model = '''
Expand Down

0 comments on commit 6bd6e67

Please sign in to comment.