Skip to content

Commit

Permalink
First needs to work with nested loops (#226)
Browse files Browse the repository at this point in the history
* A first loop test

* Fix up type errors

* Fix up how we mark fills and how we fetch prior context for variables

* Remove dead line of code

* Fix up test to get tracks from the method
  • Loading branch information
gordonwatts authored Jul 1, 2024
1 parent 7a7f9e4 commit b352b78
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 14 deletions.
34 changes: 23 additions & 11 deletions func_adl_xAOD/common/ast_to_cpp_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,9 @@ def visit_call_Aggregate_initial(self, node: ast.Call, args: List[ast.AST]):
else:
self._gc.set_scope(sv.scope())
call = ast.Call(
func=agg_lambda, args=[accumulator.as_ast(), seq.sequence_value().as_ast()]
func=agg_lambda,
args=[accumulator.as_ast(), seq.sequence_value().as_ast()], # type: ignore
keywords=[],
)
update_lambda = cast(crep.cpp_value, self.get_rep(call))

Expand Down Expand Up @@ -1159,13 +1161,15 @@ def call_ResultTTree(self, node: ast.Call, args: List[ast.AST]):

# What we have is a sequence of the data values we want to fill. The iterator at play
# here is the scope we want to use to run our Fill() calls to the TTree.
scope_fill = v_rep_not_norm.iterator_value().scope()
iterator_scope = v_rep_not_norm.iterator_value().scope()

# Clean the data up so it is uniform and the next bit can proceed smoothly.
# If we don't have a tuple of data to log, turn it into a tuple.
seq_values = v_rep_not_norm.sequence_value()
if not isinstance(seq_values, crep.cpp_tuple):
seq_values = crep.cpp_tuple((v_rep_not_norm.sequence_value(),), scope_fill)
seq_values = crep.cpp_tuple(
(v_rep_not_norm.sequence_value(),), iterator_scope
)

# Make sure the number of items is the same as the number of columns specified.
if len(seq_values.values()) != len(column_names):
Expand Down Expand Up @@ -1207,7 +1211,7 @@ def call_ResultTTree(self, node: ast.Call, args: List[ast.AST]):
# Make sure that it happens at the proper scope, where what we are after is defined!
s_orig = self._gc.current_scope()
for e_rep, e_name in zip(seq_values.values(), var_names):
scope_fill = self.code_fill_ttree(e_rep, e_name[1], scope_fill)
scope_fill = self.code_fill_ttree(e_rep, e_name[1], iterator_scope)

# The fill statement. This should happen at the scope where the tuple was defined.
# The scope where this should be done is a bit tricky (note the update above):
Expand Down Expand Up @@ -1237,7 +1241,9 @@ def call_Select(self, node: ast.Call, args: List[ast.arg]):

# Simulate this as a "call"
c = ast.Call(
func=lambda_unwrap(selection), args=[seq.sequence_value().as_ast()]
func=lambda_unwrap(selection),
args=[seq.sequence_value().as_ast()], # type: ignore
keywords=[],
)
new_sequence_value = cast(crep.cpp_value, self.get_rep(c))

Expand Down Expand Up @@ -1267,7 +1273,9 @@ def call_SelectMany(self, node: ast.AST, args: List[ast.AST]):
# We need to "call" the source with the function. So build up a new
# call, and then visit it.
c = ast.Call(
func=lambda_unwrap(selection), args=[seq.sequence_value().as_ast()]
func=lambda_unwrap(selection),
args=[seq.sequence_value().as_ast()], # type: ignore
keywords=[],
)

# Get the collection, and then generate the loop over it.
Expand All @@ -1291,7 +1299,11 @@ def call_Where(self, node: ast.AST, args: List[ast.AST]):

# Simulate the filtering call - we want the resulting value to test.
filter = lambda_unwrap(filter)
c = ast.Call(func=filter, args=[seq.sequence_value().as_ast()])
c = ast.Call(
func=filter,
args=[seq.sequence_value().as_ast()], # type: ignore
keywords=[],
)
rep = self.get_rep(c)

# Create an if statement
Expand Down Expand Up @@ -1359,12 +1371,13 @@ def call_Range(self, node: ast.Call, args: List[ast.AST]):
)

c = ast.Call(
func=FunctionAST("std::iota", ["numeric"], "void"),
func=FunctionAST("std::iota", ["numeric"], "void"), # type: ignore
args=[
vector_value_begin.as_ast(),
vector_value_end.as_ast(),
begin_value.as_ast(),
],
], # type: ignore
keywords=[],
)

self._gc.add_statement(statement.arbitrary_statement(self.get_rep(c).as_cpp())) # type: ignore
Expand All @@ -1381,7 +1394,6 @@ def call_First(self, node: ast.AST, args: List[ast.AST]) -> Any:
source = args[0]

# Make sure we are in a loop.
cs = self._gc.current_scope()
seq = self.as_sequence(source)

# The First terminal works by protecting the code with a if (first_time) {} block.
Expand Down Expand Up @@ -1429,4 +1441,4 @@ def call_First(self, node: ast.AST, args: List[ast.AST]) -> Any:
else sv.copy_with_new_scope(self._gc.current_scope())
)

crep.set_rep(node, first_value, cs)
crep.set_rep(node, first_value, self._gc.current_scope())
45 changes: 42 additions & 3 deletions tests/atlas/xaod/test_first_last.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Code to do the testing starts here.
from tests.utils.locators import find_line_with, find_open_blocks # type: ignore
import func_adl_xAOD.common.cpp_types as ctyp
from tests.utils.locators import find_line_numbers_with, find_line_with, find_next_closing_bracket, find_open_blocks # type: ignore
from tests.utils.general import get_lines_of_code, print_lines # type: ignore
from tests.atlas.xaod.utils import atlas_xaod_dataset # type: ignore
import re
Expand All @@ -11,7 +12,7 @@ def test_first_jet_in_event():
).value()


def test_first_after_selectmany():
def test_first_after_SelectMany():
r = (
atlas_xaod_dataset()
.Select(
Expand Down Expand Up @@ -111,4 +112,42 @@ def test_First_with_dict():
assert l_pt_r is not None
assert l_eta_r is not None

assert l_pt_r[1] == l_eta_r[1]

def test_First_with_inner_loop():
"Check we can loop over tracks"
ctyp.add_method_type_info(
"xAOD::Jet",
"JetTracks",
ctyp.collection(
ctyp.terminal("xAOD::Track", p_depth=1),
"std::vector<xAOD::Track>",
p_depth=1,
),
)

r = (
atlas_xaod_dataset()
.Select(lambda e: e.Jets("Anti").First())
.Select(lambda j: j.JetTracks("fork"))
.Select(
lambda tracks: {
"pt": tracks.Select(lambda t: t.pt()),
"eta": tracks.Select(lambda t: t.eta()),
}
)
.value()
)

lines = get_lines_of_code(r)
print_lines(lines)

# Make sure the eta capture is inside the is first.
first_lines = find_line_numbers_with("if (is_first", lines)
assert len(first_lines) == 2
assert lines[first_lines[0] + 1].strip() == "{"
lines_post_if = lines[first_lines[0] + 2 :] # noqa
is_first_closing = find_next_closing_bracket(lines_post_if)

eta_line = find_line_numbers_with("->pt()", lines_post_if)
assert len(eta_line) == 1
assert is_first_closing > eta_line[0]

0 comments on commit b352b78

Please sign in to comment.