Skip to content

Commit

Permalink
update ablation
Browse files Browse the repository at this point in the history
  • Loading branch information
chanwutk committed Aug 11, 2023
1 parent 756086f commit c9251de
Show file tree
Hide file tree
Showing 10 changed files with 51 additions and 66 deletions.
2 changes: 1 addition & 1 deletion .env
Original file line number Diff line number Diff line change
@@ -1 +1 @@
mamba activate spatialyze
mamba activate spatialyze-ablation
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: spatialyze
name: spatialyze-ablation

channels:
- conda-forge
Expand Down
24 changes: 11 additions & 13 deletions playground/run-ablation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@
"from spatialyze.video_processor.stages.tracking_3d.from_tracking_2d_and_road import FromTracking2DAndRoad\n",
"from spatialyze.video_processor.stages.tracking_3d.from_tracking_2d_and_depth import FromTracking2DAndDepth\n",
"from spatialyze.video_processor.stages.tracking_3d.tracking_3d import Tracking3DResult, Tracking3D\n",
"from spatialyze.video_processor.stages.tracking_3d.from_tracking_2d_and_detection_3d import FromTracking2DAndDetection3D as FromT2DAndD3D\n",
"\n",
"from spatialyze.video_processor.stages.segment_trajectory import SegmentTrajectory\n",
"from spatialyze.video_processor.stages.segment_trajectory.construct_segment_trajectory import SegmentPoint\n",
Expand Down Expand Up @@ -362,7 +363,7 @@
" # names = set(sampled_scenes)\n",
" filtered_videos = [\n",
" n for n in videos\n",
" if n[6:10] in names # and 'FRONT' in n # and n.endswith('FRONT')\n",
" if n[6:10] in names and 'FRONT' in n # and n.endswith('FRONT')\n",
" ]\n",
" N = len(filtered_videos)\n",
" print('# of filtered videos:', N)\n",
Expand Down Expand Up @@ -420,14 +421,14 @@
"\n",
" # Ingest Trackings\n",
" ego_meta = frames.interpolated_frames\n",
" sortmeta = FromTracking2DAndRoad.get(output)\n",
" sortmeta = Tracking3D.get(output)\n",
" assert sortmeta is not None\n",
" segment_trajectories = FromTracking3D.get(output)\n",
" tracks = get_tracks(sortmeta, ego_meta, segment_trajectories)\n",
" for obj_id, track in tracks.items():\n",
" trajectory = format_trajectory(name, obj_id, track)\n",
" if trajectory:\n",
" insert_trajectory(database, *trajectory)\n",
" insert_trajectory(database, *trajectory[0])\n",
"\n",
" # Ingest Camera\n",
" accs: 'ACameraConfig' = []\n",
Expand Down Expand Up @@ -569,9 +570,6 @@
"\n",
" # Object Filter\n",
" if object_filter:\n",
" # if isinstance(object_filter, bool):\n",
" # object_filter = ['car', 'truck']\n",
" # TODO: filter objects based on predicate\n",
" pipeline.add_filter(ObjectTypeFilter(predicate=predicate))\n",
"\n",
" # 3D Detection\n",
Expand All @@ -592,10 +590,11 @@
" cache=ss_cache,\n",
" ))\n",
"\n",
" if geo_depth:\n",
" pipeline.add_filter(FromTracking2DAndRoad())\n",
" else:\n",
" pipeline.add_filter(FromTracking2DAndDepth())\n",
" pipeline.add_filter(FromT2DAndD3D())\n",
" # if geo_depth:\n",
" # pipeline.add_filter(FromTracking2DAndRoad())\n",
" # else:\n",
" # pipeline.add_filter(FromTracking2DAndDepth())\n",
"\n",
" # Segment Trajectory\n",
" # pipeline.add_filter(FromTracking3D())\n",
Expand Down Expand Up @@ -902,9 +901,8 @@
},
"outputs": [],
"source": [
"tests = ['de', 'optde', 'noopt', 'inview', 'objectfilter', 'geo', 'opt']\n",
"# tests = ['de', 'optde']\n",
"# tests = ['de']\n",
"tests = ['optde', 'de', 'noopt', 'inview', 'objectfilter', 'geo', 'opt']\n",
"tests = ['de', 'noopt', 'inview', 'objectfilter']\n",
"# random.shuffle(tests)\n",
"\n",
"for _test in tests:\n",
Expand Down
10 changes: 10 additions & 0 deletions playground/run-all-local
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#!/bin/bash

tmux new-session -d -s run-test -n run-test-local || echo 'hi'
tmux send-keys -t run-test-local 'docker container start mobilitydb' Enter
tmux send-keys -t run-test-local 'cd ~/Documents/spatialyze' Enter
tmux send-keys -t run-test-local 'rm -rf ./outputs/run/*' Enter
tmux send-keys -t run-test-local 'rm -rf ./run-ablation.py' Enter
tmux send-keys -t run-test-local 'python ./spatialyze/utils/ingest_road.py "./data/scenic/road-network/boston-seaport"' Enter
tmux send-keys -t run-test-local 'jupyter nbconvert --to python ./playground/run-ablation.ipynb && mv playground/run-ablation.py .' Enter
tmux send-keys -t run-test-local 'python run-ablation.py' Enter
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def _run(self, payload: "Payload") -> "tuple[bitarray | None, dict[str, list] |
t.detection_id,
t.object_id,
point_from_camera,
point,
tuple(point.tolist()),
t.bbox_left,
t.bbox_top,
t.bbox_w,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from .tracking_3d import Metadatum, Tracking3D, Tracking3DResult


class FromTracking2DAndDetection2D(Tracking3D):
class FromTracking2DAndDetection3D(Tracking3D):
def _run(self, payload: "Payload") -> "tuple[bitarray | None, dict[str, list] | None]":
metadata: "list[Metadatum]" = []
trajectories: "dict[int, list[Tracking3DResult]]" = {}
Expand Down Expand Up @@ -34,19 +34,18 @@ def _run(self, payload: "Payload") -> "tuple[bitarray | None, dict[str, list] |
detection_map = {
did: (det, p, pfc)
for det, p, pfc, did
in zip(dets, points, points_from_camera, dids)
in zip(dets, points.tolist(), points_from_camera.tolist(), dids)
}
trackings3d: "dict[int, Tracking3DResult]" = {}
for object_id, t in tracking.items():
did = t.detection_id
det, p, pfc = detection_map[did]

xfc, yxc, zxc = pfc.tolist()
trackings3d[object_id] = Tracking3DResult(
t.frame_idx,
t.detection_id,
t.object_id,
(xfc, yxc, zxc),
pfc,
p,
t.bbox_left,
t.bbox_top,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,21 +63,16 @@ def _run(self, payload: "Payload"):
points = rotated_directions * ts + translation[:, np.newaxis]
points_from_camera = rotate(points - translation[:, np.newaxis], rotation.inverse)

for t, oid, point, point_from_camera in zip(_ts, oids, points.T, points_from_camera.T):
for t, oid, point, point_from_camera in zip(_ts, oids, points.T.tolist(), points_from_camera.T.tolist()):
assert point_from_camera.shape == (3,)
assert isinstance(oid, int) or oid.is_integer()
oid = int(oid)
point_from_camera = (
point_from_camera[0],
point_from_camera[1],
point_from_camera[2],
)
trackings3d[oid] = Tracking3DResult(
t.frame_idx,
t.detection_id,
oid,
point_from_camera,
point,
tuple(point_from_camera),
tuple(point),
t.bbox_left,
t.bbox_top,
t.bbox_w,
Expand Down
5 changes: 1 addition & 4 deletions spatialyze/video_processor/stages/tracking_3d/tracking_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,6 @@
from dataclasses import dataclass
from typing import Any, Dict

import numpy as np
import numpy.typing as npt

from ...types import DetectionId, Float3
from ..stage import Stage

Expand All @@ -15,7 +12,7 @@ class Tracking3DResult:
detection_id: DetectionId
object_id: float
point_from_camera: "Float3"
point: "npt.NDArray[np.floating]"
point: "Float3"
bbox_left: float
bbox_top: float
bbox_w: float
Expand Down
51 changes: 19 additions & 32 deletions spatialyze/world.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,10 @@
)
from .video_processor.stages.detection_estimation import DetectionEstimation
from .video_processor.stages.in_view.in_view import InView
from .video_processor.stages.stage import Stage
from .video_processor.stages.tracking_2d.strongsort import StrongSORT
from .video_processor.stages.tracking_3d.from_tracking_2d_and_depth import (
FromTracking2DAndDepth,
)
from .video_processor.stages.tracking_3d.from_tracking_2d_and_road import (
FromTracking2DAndRoad,
from .video_processor.stages.tracking_3d.from_tracking_2d_and_detection_3d import (
FromTracking2DAndDetection3D,
)
from .video_processor.stages.tracking_3d.tracking_3d import Metadatum as T3DMetadatum
from .video_processor.stages.tracking_3d.tracking_3d import Tracking3D
Expand Down Expand Up @@ -107,35 +105,24 @@ def _execute(world: "World", optimization=True):
# for gc in world._geogConstructs:
# gc.ingest(database)
# analyze predicates to generate pipeline
objtypes_filter = ObjectTypeFilter(predicate=world.predicates)
steps: "list[Stage]" = []
if optimization:
steps.append(InView(distance=50, predicate=world.predicates))
steps.append(DecodeFrame())
steps.append(YoloDetection())
if optimization:
pipeline = Pipeline(
[
DecodeFrame(),
InView(distance=50, predicate=world.predicates),
YoloDetection(),
objtypes_filter,
FromDetection2DAndRoad(),
*(
[DetectionEstimation()]
if all(t in ["car", "truck"] for t in objtypes_filter.types)
else []
),
StrongSORT(),
FromTracking2DAndRoad(),
]
)
objtypes_filter = ObjectTypeFilter(predicate=world.predicates)
steps.append(objtypes_filter)
steps.append(FromDetection2DAndRoad())
if all(t in ["car", "truck"] for t in objtypes_filter.types):
steps.append(DetectionEstimation())
else:
pipeline = Pipeline(
[
DecodeFrame(),
YoloDetection(),
DepthEstimation(),
FromDetection2DAndDepth(),
StrongSORT(),
FromTracking2DAndDepth(),
]
)
steps.append(DepthEstimation())
steps.append(FromDetection2DAndDepth())
steps.append(StrongSORT())
steps.append(FromTracking2DAndDetection3D())

pipeline = Pipeline(steps)

qresults: "dict[str, list[tuple]]" = {}
vresults: "dict[str, list[T3DMetadatum]]" = {}
Expand Down
3 changes: 1 addition & 2 deletions tests/workflow/test_optimized_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from spatialyze.video_processor.stages.tracking_3d.tracking_3d import Tracking3DResult
from spatialyze.world import World, _execute
from spatialyze.video_processor.cache import disable_cache
from spatialyze.video_processor.metadata_json_encoder import MetadataJSONEncoder


OUTPUT_DIR = './data/pipeline/test-results'
Expand Down Expand Up @@ -75,7 +74,7 @@ def test_optimized_workflow():
assert tuple(p.detection_id) == tuple(g.detection_id), (p.detection_id, g.detection_id)
assert p.object_id == g.object_id, (p.object_id, g.object_id)
assert np.allclose(np.array(p.point_from_camera), np.array(g.point_from_camera)), (p.point_from_camera, g.point_from_camera)
assert np.allclose(np.array(p.point.tolist()), np.array(g.point)), (p.point, g.point)
assert np.allclose(np.array(p.point), np.array(g.point)), (p.point, g.point)
assert p.bbox_left == g.bbox_left, (p.bbox_left, g.bbox_left)
assert p.bbox_top == g.bbox_top, (p.bbox_top, g.bbox_top)
assert p.bbox_w == g.bbox_w, (p.bbox_w, g.bbox_w)
Expand Down

0 comments on commit c9251de

Please sign in to comment.