Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
chanwutk committed Feb 9, 2024
1 parent 375f60e commit 12dc08f
Show file tree
Hide file tree
Showing 11 changed files with 65 additions and 64 deletions.
4 changes: 3 additions & 1 deletion .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,6 @@ ignore =
# multiple '#' leading a line comment is ok
E266,
# module-level import not on top of file is ok
E402
E402,
# multiple statements on one line (def)
E704
7 changes: 3 additions & 4 deletions spatialyze/predicate.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,7 @@ def __call__(self, node: "PredicateNode"):
return self._is_detection_only

def visit_TableAttrNode(self, node: TableAttrNode) -> Any:
if isinstance(node.table, ObjectTableNode) and node.name == "itemHeadings":
if isinstance(node.table, ObjectTableNode) and node.name == "itemHeading":
self._is_detection_only = False
return super().visit_TableAttrNode(node)

Expand Down Expand Up @@ -603,9 +603,8 @@ def resolve_camera_attr(attr: str, num: "int | None" = None):
"itemId": False,
"cameraId": False,
"objectType": False,
"translations": True,
"translations": True,
"itemHeadings": True,
"translation": True,
"itemHeading": True,
"bbox": True,
}

Expand Down
56 changes: 28 additions & 28 deletions tests/interface/test_predicate.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,28 +12,28 @@


@pytest.mark.parametrize("fn, sql", [
(o, "valueAtTimestamp(t0.translations,c0.timestamp)"),
(o.trans, "valueAtTimestamp(t0.translations,c0.timestamp)"),
(o.trans, "valueAtTimestamp(t0.translations,c0.timestamp)"),
(o, "t0.translation"),
(o.trans, "t0.translation"),
(o.trans, "t0.translation"),
(o.id, "t0.itemId"),
(o.type, "t0.objectType"),
(o.heading, "valueAtTimestamp(t0.itemHeadings,c0.timestamp)"),
(o.heading, "t0.itemHeading"),
(c, "c0.cameraTranslation"),
(c.ego, "c0.egoTranslation"),
(c.heading, "c0.cameraHeading"),
(c.egoheading, "c0.egoHeading"),
(o.trans + c, "(valueAtTimestamp(t0.translations,c0.timestamp)+c0.cameraTranslation)"),
(o.trans - c, "(valueAtTimestamp(t0.translations,c0.timestamp)-c0.cameraTranslation)"),
(o.trans * c, "(valueAtTimestamp(t0.translations,c0.timestamp)*c0.cameraTranslation)"),
(o.trans / c, "(valueAtTimestamp(t0.translations,c0.timestamp)/c0.cameraTranslation)"),
(o.trans == c, "(valueAtTimestamp(t0.translations,c0.timestamp)=c0.cameraTranslation)"),
(o.trans < c, "(valueAtTimestamp(t0.translations,c0.timestamp)<c0.cameraTranslation)"),
(o.trans > c, "(valueAtTimestamp(t0.translations,c0.timestamp)>c0.cameraTranslation)"),
(o.trans <= c, "(valueAtTimestamp(t0.translations,c0.timestamp)<=c0.cameraTranslation)"),
(o.trans >= c, "(valueAtTimestamp(t0.translations,c0.timestamp)>=c0.cameraTranslation)"),
(o.trans != c, "(valueAtTimestamp(t0.translations,c0.timestamp)<>c0.cameraTranslation)"),
(o.trans + c, "(t0.translation+c0.cameraTranslation)"),
(o.trans - c, "(t0.translation-c0.cameraTranslation)"),
(o.trans * c, "(t0.translation*c0.cameraTranslation)"),
(o.trans / c, "(t0.translation/c0.cameraTranslation)"),
(o.trans == c, "(t0.translation=c0.cameraTranslation)"),
(o.trans < c, "(t0.translation<c0.cameraTranslation)"),
(o.trans > c, "(t0.translation>c0.cameraTranslation)"),
(o.trans <= c, "(t0.translation<=c0.cameraTranslation)"),
(o.trans >= c, "(t0.translation>=c0.cameraTranslation)"),
(o.trans != c, "(t0.translation<>c0.cameraTranslation)"),
(1 + c, "(1+c0.cameraTranslation)"),
(1 - c, "(1-c0.cameraTranslation)"),
Expand All @@ -52,12 +52,12 @@
(c0.cam, "c0.cameraTranslation"),
(cast(c0.heading, 'real'), "(c0.cameraHeading)::real"),
(-o.trans, "(-valueAtTimestamp(t0.translations,c0.timestamp))"),
(~o.trans, "(NOT valueAtTimestamp(t0.translations,c0.timestamp))"),
(o.trans & o & o, "(valueAtTimestamp(t0.translations,c0.timestamp) AND valueAtTimestamp(t0.translations,c0.timestamp) AND valueAtTimestamp(t0.translations,c0.timestamp))"),
(o.trans | o | o, "(valueAtTimestamp(t0.translations,c0.timestamp) OR valueAtTimestamp(t0.translations,c0.timestamp) OR valueAtTimestamp(t0.translations,c0.timestamp))"),
(-o.trans, "(-t0.translation)"),
(~o.trans, "(NOT t0.translation)"),
(o.trans & o & o, "(t0.translation AND t0.translation AND t0.translation)"),
(o.trans | o | o, "(t0.translation OR t0.translation OR t0.translation)"),
(c.time, "c0.timestamp"),
(arr(o.trans, o), "ARRAY[valueAtTimestamp(t0.translations,c0.timestamp),valueAtTimestamp(t0.translations,c0.timestamp)]"),
(arr(o.trans, o), "ARRAY[t0.translation,t0.translation]"),
(o.bbox, "objectBBox(t0.itemId,c0.timestamp)"),
])
def test_simple_ops(fn, sql):
Expand All @@ -77,7 +77,7 @@ def test_unnormalized_node_exception(fn, msg):


@pytest.mark.parametrize("fn, msg", [
(AtTimeNode(o.trans), "AtTimeNode is illegal prior NormalizeDefaultValue: AtTimeNode(attr=TableAttrNode(name='translations', table=ObjectTableNode[0], shorten=True))"),
(AtTimeNode(o.trans), "AtTimeNode is illegal prior NormalizeDefaultValue: AtTimeNode(attr=TableAttrNode(name='translation', table=ObjectTableNode[0], shorten=True))"),
])
def test_normalize_exception(fn, msg):
with pytest.raises(Exception) as e_info:
Expand All @@ -100,8 +100,8 @@ def test_predicate_node_exception(args, kwargs, msg):


@pytest.mark.parametrize("fn, sql", [
((o.trans + c) - c.cam + o.type * c.ego / o, "(((valueAtTimestamp(t0.translations,c0.timestamp)+c0.cameraTranslation)-c0.cameraTranslation)+((t0.objectType*c0.egoTranslation)/valueAtTimestamp(t0.translations,c0.timestamp)))"),
((o.trans == c) & ((o < c.cam) | (o == c.ego)), "((valueAtTimestamp(t0.translations,c0.timestamp)=c0.cameraTranslation) AND ((valueAtTimestamp(t0.translations,c0.timestamp)<c0.cameraTranslation) OR (valueAtTimestamp(t0.translations,c0.timestamp)=c0.egoTranslation)))"),
((o.trans + c) - c.cam + o.type * c.ego / o, "(((t0.translation+c0.cameraTranslation)-c0.cameraTranslation)+((t0.objectType*c0.egoTranslation)/t0.translation))"),
((o.trans == c) & ((o < c.cam) | (o == c.ego)), "((t0.translation=c0.cameraTranslation) AND ((t0.translation<c0.cameraTranslation) OR (t0.translation=c0.egoTranslation)))"),
])
def test_nested(fn, sql):
assert gen(normalize(fn)) == sql
Expand All @@ -127,18 +127,18 @@ def test_find_all_tables(fn, tables, camera):


@pytest.mark.parametrize("fn, mapping, sql", [
(o.trans & o1.trans & c.ego, {0:1, 1:2}, '(valueAtTimestamp(t1.translations,c0.timestamp) AND valueAtTimestamp(t2.translations,c0.timestamp) AND c0.egoTranslation)'),
((o.trans + c) - c.ego + o.type * c.heading / o1.heading, {0:1, 1:0}, '(((valueAtTimestamp(t1.translations,c0.timestamp)+c0.cameraTranslation)-c0.egoTranslation)+((t1.objectType*c0.cameraHeading)/valueAtTimestamp(t0.itemHeadings,c0.timestamp)))'),
(o.trans & o1.trans & c.ego, {0:1, 1:2}, '(t1.translation AND t2.translation AND c0.egoTranslation)'),
((o.trans + c) - c.ego + o.type * c.heading / o1.heading, {0:1, 1:0}, '(((t1.translation+c0.cameraTranslation)-c0.egoTranslation)+((t1.objectType*c0.cameraHeading)/t0.itemHeading))'),
])
def test_map_tables(fn, mapping, sql):
assert gen(MapTablesTransformer(mapping)(normalize(fn))) == sql


@pytest.mark.parametrize("fn, sql", [
(arr(o.trans, c, 3), "ARRAY[valueAtTimestamp(t0.translations,c0.timestamp),c0.cameraTranslation,3]"),
(arr(o.trans, [c, 3]), "ARRAY[valueAtTimestamp(t0.translations,c0.timestamp),[c0.cameraTranslation,3]]"),
# (c in o.trans[1:2], "(trans IN valueAtTimestamp(t0.translations,c0.timestamp)[1:2])"),
# (c[o.trans[3]] in o.trans[1:2], "(trans[valueAtTimestamp(t0.translations,c0.timestamp)[3]] IN valueAtTimestamp(t0.translations,c0.timestamp)[1:2])"),
(arr(o.trans, c, 3), "ARRAY[t0.translation,c0.cameraTranslation,3]"),
(arr(o.trans, [c, 3]), "ARRAY[t0.translation,[c0.cameraTranslation,3]]"),
# (c in o.trans[1:2], "(trans IN t0.translation[1:2])"),
# (c[o.trans[3]] in o.trans[1:2], "(trans[t0.translation[3]] IN t0.translation[1:2])"),
])
def test_array(fn, sql):
assert gen(normalize(fn)) == sql
Expand Down
6 changes: 3 additions & 3 deletions tests/interface/utils/F/test_ahead.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@

@pytest.mark.parametrize("fn, sql", [
(ahead(o, c.ego),
"ahead(valueAtTimestamp(t0.translations,c0.timestamp),c0.egoTranslation,(c0.egoHeading)::real)"),
"ahead(t0.translation,c0.egoTranslation,(c0.egoHeading)::real)"),
(ahead(o, c.cam),
"ahead(valueAtTimestamp(t0.translations,c0.timestamp),c0.cameraTranslation,(c0.cameraHeading)::real)"),
"ahead(t0.translation,c0.cameraTranslation,(c0.cameraHeading)::real)"),
(ahead(o, o),
"ahead(valueAtTimestamp(t0.translations,c0.timestamp),valueAtTimestamp(t0.translations,c0.timestamp),(valueAtTimestamp(t0.itemHeadings,c0.timestamp))::real)"),
"ahead(t0.translation,t0.translation,(t0.itemHeading)::real)"),
])
def test_ahead(fn, sql):
assert gen(fn) == sql
10 changes: 5 additions & 5 deletions tests/interface/utils/F/test_contains.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,17 @@
@pytest.mark.parametrize("fn, sql", [
(contains('lane', o),
"(EXISTS(SELECT 1 FROM SegmentPolygon WHERE SegmentPolygon.__RoadType__lane__ AND\n"
" ST_Covers(SegmentPolygon.elementPolygon, valueAtTimestamp(t0.translations,c0.timestamp))\n"
" ST_Covers(SegmentPolygon.elementPolygon, t0.translation)\n"
"))"),
(contains(road_segment('lane'), o),
"(EXISTS(SELECT 1 FROM SegmentPolygon WHERE SegmentPolygon.__RoadType__lane__ AND\n"
" ST_Covers(SegmentPolygon.elementPolygon, valueAtTimestamp(t0.translations,c0.timestamp))\n"
" ST_Covers(SegmentPolygon.elementPolygon, t0.translation)\n"
"))"),
(contains('lane', [o, o1, o2]),
"(EXISTS(SELECT 1 FROM SegmentPolygon WHERE SegmentPolygon.__RoadType__lane__ AND\n"
" ST_Covers(SegmentPolygon.elementPolygon, valueAtTimestamp(t0.translations,c0.timestamp)) AND "
"ST_Covers(SegmentPolygon.elementPolygon, valueAtTimestamp(t1.translations,c0.timestamp)) AND "
"ST_Covers(SegmentPolygon.elementPolygon, valueAtTimestamp(t2.translations,c0.timestamp))\n"
" ST_Covers(SegmentPolygon.elementPolygon, t0.translation) AND "
"ST_Covers(SegmentPolygon.elementPolygon, t1.translation) AND "
"ST_Covers(SegmentPolygon.elementPolygon, t2.translation)\n"
"))"),
(contains('lane', [c, c.cam, c.ego]),
"(EXISTS(SELECT 1 FROM SegmentPolygon WHERE SegmentPolygon.__RoadType__lane__ AND\n"
Expand Down
6 changes: 3 additions & 3 deletions tests/interface/utils/F/test_convert_camera.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@

@pytest.mark.parametrize("fn, sql", [
(convert_camera(o, c.ego),
"ConvertCamera(valueAtTimestamp(t0.translations,c0.timestamp),c0.egoTranslation,c0.egoHeading)"),
"ConvertCamera(t0.translation,c0.egoTranslation,c0.egoHeading)"),
(convert_camera(o, c),
"ConvertCamera(valueAtTimestamp(t0.translations,c0.timestamp),c0.cameraTranslation,c0.cameraHeading)"),
"ConvertCamera(t0.translation,c0.cameraTranslation,c0.cameraHeading)"),
# (convert_camera(o, o),
# "ConvertCamera(valueAtTimestamp(t0.translations,c0.timestamp),valueAtTimestamp(t0.translations,timestamp),valueAtTimestamp(t0.itemHeadings,c0.timestamp))"),
# "ConvertCamera(t0.translation,t0.translations,timestamp),t0.itemHeading)"),
])
def test_convert_camera(fn, sql):
assert gen(fn) == sql
4 changes: 2 additions & 2 deletions tests/interface/utils/F/test_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@


@pytest.mark.parametrize("fn, sql", [
(distance(o, o), "ST_Distance(valueAtTimestamp(t0.translations,c0.timestamp),valueAtTimestamp(t0.translations,c0.timestamp))"),
(distance(o, c), "ST_Distance(valueAtTimestamp(t0.translations,c0.timestamp),c0.cameraTranslation)"),
(distance(o, o), "ST_Distance(t0.translation,t0.translation)"),
(distance(o, c), "ST_Distance(t0.translation,c0.cameraTranslation)"),
(distance(c.cam, c.ego), "ST_Distance(c0.cameraTranslation,c0.egoTranslation)"),
])
def test_distance(fn, sql):
Expand Down
24 changes: 12 additions & 12 deletions tests/interface/utils/F/test_heading_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,18 @@


@pytest.mark.parametrize("fn, sql", [
(heading_diff(o, o1), "(((((valueAtTimestamp(t0.itemHeadings,c0.timestamp)-valueAtTimestamp(t1.itemHeadings,c0.timestamp)))::numeric%360)+360)%360)"),
(heading_diff(o, c), "(((((valueAtTimestamp(t0.itemHeadings,c0.timestamp)-c0.cameraHeading))::numeric%360)+360)%360)"),
(heading_diff(o, c.cam), "(((((valueAtTimestamp(t0.itemHeadings,c0.timestamp)-c0.cameraHeading))::numeric%360)+360)%360)"),
(heading_diff(o, c.heading), "(((((valueAtTimestamp(t0.itemHeadings,c0.timestamp)-c0.cameraHeading))::numeric%360)+360)%360)"),
(heading_diff(o, c.ego), "(((((valueAtTimestamp(t0.itemHeadings,c0.timestamp)-c0.egoHeading))::numeric%360)+360)%360)"),
(heading_diff(o, c.egoheading), "(((((valueAtTimestamp(t0.itemHeadings,c0.timestamp)-c0.egoHeading))::numeric%360)+360)%360)"),
(heading_diff(o, o1, between=[40, 50]), "(((((((valueAtTimestamp(t0.itemHeadings,c0.timestamp)-valueAtTimestamp(t1.itemHeadings,c0.timestamp)))::numeric%360)+360)%360)>40) AND ((((((valueAtTimestamp(t0.itemHeadings,c0.timestamp)-valueAtTimestamp(t1.itemHeadings,c0.timestamp)))::numeric%360)+360)%360)<50))"),
(heading_diff(o, o1, between=[40+360, 50]), "(((((((valueAtTimestamp(t0.itemHeadings,c0.timestamp)-valueAtTimestamp(t1.itemHeadings,c0.timestamp)))::numeric%360)+360)%360)>40) AND ((((((valueAtTimestamp(t0.itemHeadings,c0.timestamp)-valueAtTimestamp(t1.itemHeadings,c0.timestamp)))::numeric%360)+360)%360)<50))"),
(heading_diff(o, o1, between=[40-360, 50]), "(((((((valueAtTimestamp(t0.itemHeadings,c0.timestamp)-valueAtTimestamp(t1.itemHeadings,c0.timestamp)))::numeric%360)+360)%360)>40) AND ((((((valueAtTimestamp(t0.itemHeadings,c0.timestamp)-valueAtTimestamp(t1.itemHeadings,c0.timestamp)))::numeric%360)+360)%360)<50))"),
(heading_diff(o, o1, between=[50, 40]), "(((((((valueAtTimestamp(t0.itemHeadings,c0.timestamp)-valueAtTimestamp(t1.itemHeadings,c0.timestamp)))::numeric%360)+360)%360)<50) OR ((((((valueAtTimestamp(t0.itemHeadings,c0.timestamp)-valueAtTimestamp(t1.itemHeadings,c0.timestamp)))::numeric%360)+360)%360)>40))"),
(heading_diff(o, o1, excluding=[40, 50]), "(((((((valueAtTimestamp(t0.itemHeadings,c0.timestamp)-valueAtTimestamp(t1.itemHeadings,c0.timestamp)))::numeric%360)+360)%360)<40) OR ((((((valueAtTimestamp(t0.itemHeadings,c0.timestamp)-valueAtTimestamp(t1.itemHeadings,c0.timestamp)))::numeric%360)+360)%360)>50))"),
(heading_diff(o, o1, excluding=[50, 40]), "(((((((valueAtTimestamp(t0.itemHeadings,c0.timestamp)-valueAtTimestamp(t1.itemHeadings,c0.timestamp)))::numeric%360)+360)%360)>50) AND ((((((valueAtTimestamp(t0.itemHeadings,c0.timestamp)-valueAtTimestamp(t1.itemHeadings,c0.timestamp)))::numeric%360)+360)%360)<40))"),
(heading_diff(o, o1), "(((((t0.itemHeading-t1.itemHeading))::numeric%360)+360)%360)"),
(heading_diff(o, c), "(((((t0.itemHeading-c0.cameraHeading))::numeric%360)+360)%360)"),
(heading_diff(o, c.cam), "(((((t0.itemHeading-c0.cameraHeading))::numeric%360)+360)%360)"),
(heading_diff(o, c.heading), "(((((t0.itemHeading-c0.cameraHeading))::numeric%360)+360)%360)"),
(heading_diff(o, c.ego), "(((((t0.itemHeading-c0.egoHeading))::numeric%360)+360)%360)"),
(heading_diff(o, c.egoheading), "(((((t0.itemHeading-c0.egoHeading))::numeric%360)+360)%360)"),
(heading_diff(o, o1, between=[40, 50]), "(((((((t0.itemHeading-t1.itemHeading))::numeric%360)+360)%360)>40) AND ((((((t0.itemHeading-t1.itemHeading))::numeric%360)+360)%360)<50))"),
(heading_diff(o, o1, between=[40+360, 50]), "(((((((t0.itemHeading-t1.itemHeading))::numeric%360)+360)%360)>40) AND ((((((t0.itemHeading-t1.itemHeading))::numeric%360)+360)%360)<50))"),
(heading_diff(o, o1, between=[40-360, 50]), "(((((((t0.itemHeading-t1.itemHeading))::numeric%360)+360)%360)>40) AND ((((((t0.itemHeading-t1.itemHeading))::numeric%360)+360)%360)<50))"),
(heading_diff(o, o1, between=[50, 40]), "(((((((t0.itemHeading-t1.itemHeading))::numeric%360)+360)%360)<50) OR ((((((t0.itemHeading-t1.itemHeading))::numeric%360)+360)%360)>40))"),
(heading_diff(o, o1, excluding=[40, 50]), "(((((((t0.itemHeading-t1.itemHeading))::numeric%360)+360)%360)<40) OR ((((((t0.itemHeading-t1.itemHeading))::numeric%360)+360)%360)>50))"),
(heading_diff(o, o1, excluding=[50, 40]), "(((((((t0.itemHeading-t1.itemHeading))::numeric%360)+360)%360)>50) AND ((((((t0.itemHeading-t1.itemHeading))::numeric%360)+360)%360)<40))"),
])
def test_heading_diff(fn, sql):
assert gen(fn) == sql
Expand Down
6 changes: 3 additions & 3 deletions tests/interface/utils/F/test_road_direction.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@

@pytest.mark.parametrize("fn, sql", [
(road_direction(o),
"roadDirection(valueAtTimestamp(t0.translations,c0.timestamp),(valueAtTimestamp(t0.itemHeadings,c0.timestamp))::real)"),
"roadDirection(t0.translation,(t0.itemHeading)::real)"),
(road_direction(o, c.ego),
"roadDirection(valueAtTimestamp(t0.translations,c0.timestamp),(c0.egoHeading)::real)"),
"roadDirection(t0.translation,(c0.egoHeading)::real)"),
(road_direction(c),
"roadDirection(c0.cameraTranslation,(c0.cameraHeading)::real)"),
(road_direction(c.ego, o),
"roadDirection(c0.egoTranslation,(valueAtTimestamp(t0.itemHeadings,c0.timestamp))::real)"),
"roadDirection(c0.egoTranslation,(t0.itemHeading)::real)"),
])
def test_road_direction(fn, sql):
assert gen(fn) == sql
Expand Down
2 changes: 1 addition & 1 deletion tests/interface/utils/F/test_same_region.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


@pytest.mark.parametrize("fn, sql", [
(same_region('intersection', o, c), "sameRegion('intersection',valueAtTimestamp(t0.translations,c0.timestamp),c0.cameraTranslation)"),
(same_region('intersection', o, c), "sameRegion('intersection',t0.translation,c0.cameraTranslation)"),
])
def test_same_retion(fn, sql):
assert gen(fn) == sql
Expand Down
4 changes: 2 additions & 2 deletions tests/interface/utils/F/test_view_angle.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@

@pytest.mark.parametrize("fn, sql", [
(view_angle(o, c),
"viewAngle(valueAtTimestamp(t0.translations,c0.timestamp),c0.cameraHeading,c0.cameraTranslation)"),
"viewAngle(t0.translation,c0.cameraHeading,c0.cameraTranslation)"),
(view_angle(o, o),
"viewAngle(valueAtTimestamp(t0.translations,c0.timestamp),valueAtTimestamp(t0.itemHeadings,c0.timestamp),valueAtTimestamp(t0.translations,c0.timestamp))")
"viewAngle(t0.translation,t0.itemHeading,t0.translation)")
])
def test_view_angle(fn, sql):
assert gen(fn) == sql

0 comments on commit 12dc08f

Please sign in to comment.