Skip to content

Commit

Permalink
Various Fixes (#15004)
Browse files Browse the repository at this point in the history
* Don't track shared memory in frame tracker

* Don't track any instance

* Don't assign sub label to objects when multiple cars are overlapping

* Formatting

* Fix assignment
  • Loading branch information
NickM-27 authored Nov 15, 2024
1 parent 4eea541 commit 7fdf42a
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 23 deletions.
17 changes: 13 additions & 4 deletions frigate/track/tracked_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging
from collections import defaultdict
from statistics import median
from typing import Optional

import cv2
import numpy as np
Expand Down Expand Up @@ -423,10 +424,11 @@ def get_tracking_data(self) -> dict[str, any]:
"box": self.box,
}

def find_best_object(self, objects: list[dict[str, any]]) -> str:
def find_best_object(self, objects: list[dict[str, any]]) -> Optional[str]:
"""Find the best attribute for each object and return its ID."""
best_object_area = None
best_object_id = None
best_object_label = None

for obj in objects:
if not box_inside(obj["box"], self.box):
Expand All @@ -440,8 +442,15 @@ def find_best_object(self, objects: list[dict[str, any]]) -> str:
if best_object_area is None:
best_object_area = object_area
best_object_id = obj["id"]
elif object_area < best_object_area:
best_object_area = object_area
best_object_id = obj["id"]
best_object_label = obj["label"]
else:
if best_object_label == "car" and obj["label"] == "car":
# if multiple cars are overlapping with the same label then the label will not be assigned
return None
elif object_area < best_object_area:
# if a car and person are overlapping then assign the label to the smaller object (which should be the person)
best_object_area = object_area
best_object_id = obj["id"]
best_object_label = obj["label"]

return best_object_id
64 changes: 45 additions & 19 deletions frigate/util/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
import datetime
import logging
import subprocess as sp
import threading
from abc import ABC, abstractmethod
from multiprocessing import shared_memory
from multiprocessing import resource_tracker as _mprt
from multiprocessing import shared_memory as _mpshm
from string import printable
from typing import AnyStr, Optional

Expand Down Expand Up @@ -731,32 +733,56 @@ def delete(self, name):
pass


class DictFrameManager(FrameManager):
def __init__(self):
self.frames = {}
class SharedMemory(_mpshm.SharedMemory):
# https://github.com/python/cpython/issues/82300#issuecomment-2169035092

def create(self, name, size) -> AnyStr:
mem = bytearray(size)
self.frames[name] = mem
return mem
__lock = threading.Lock()

def get(self, name, shape):
mem = self.frames[name]
return np.ndarray(shape, dtype=np.uint8, buffer=mem)
def __init__(
self,
name: Optional[str] = None,
create: bool = False,
size: int = 0,
*,
track: bool = True,
) -> None:
self._track = track

def close(self, name):
pass
# if tracking, normal init will suffice
if track:
return super().__init__(name=name, create=create, size=size)

def delete(self, name):
del self.frames[name]
# lock so that other threads don't attempt to use the
# register function during this time
with self.__lock:
# temporarily disable registration during initialization
orig_register = _mprt.register
_mprt.register = self.__tmp_register

# initialize; ensure original register function is
# re-instated
try:
super().__init__(name=name, create=create, size=size)
finally:
_mprt.register = orig_register

@staticmethod
def __tmp_register(*args, **kwargs) -> None:
return

def unlink(self) -> None:
if _mpshm._USE_POSIX and self._name:
_mpshm._posixshmem.shm_unlink(self._name)
if self._track:
_mprt.unregister(self._name, "shared_memory")


class SharedMemoryFrameManager(FrameManager):
def __init__(self):
self.shm_store: dict[str, shared_memory.SharedMemory] = {}
self.shm_store: dict[str, SharedMemory] = {}

def create(self, name: str, size) -> AnyStr:
shm = shared_memory.SharedMemory(name=name, create=True, size=size)
shm = SharedMemory(name=name, create=True, size=size, track=False)
self.shm_store[name] = shm
return shm.buf

Expand All @@ -765,7 +791,7 @@ def get(self, name: str, shape) -> Optional[np.ndarray]:
if name in self.shm_store:
shm = self.shm_store[name]
else:
shm = shared_memory.SharedMemory(name=name)
shm = SharedMemory(name=name, track=False)
self.shm_store[name] = shm
return np.ndarray(shape, dtype=np.uint8, buffer=shm.buf)
except FileNotFoundError:
Expand All @@ -788,7 +814,7 @@ def delete(self, name: str):
del self.shm_store[name]
else:
try:
shm = shared_memory.SharedMemory(name=name)
shm = SharedMemory(name=name, track=False)
shm.close()
shm.unlink()
except FileNotFoundError:
Expand Down

0 comments on commit 7fdf42a

Please sign in to comment.