Skip to content

Commit

Permalink
overlap added while creating task and gt
Browse files Browse the repository at this point in the history
  • Loading branch information
rohan220217 committed Nov 22, 2024
1 parent b81a78d commit 2c6ed64
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 24 deletions.
4 changes: 2 additions & 2 deletions cvat/apps/engine/frame_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,8 @@ def get_chunk_number(self, frame_number):

def _validate_chunk_number(self, chunk_number):
chunk_number_ = int(chunk_number)
if chunk_number_ < 0 or chunk_number_ >= math.ceil(self._db_data.size / self._db_data.chunk_size):
raise ValidationError('requested chunk does not exist')
# if chunk_number_ < 0 or chunk_number_ >= math.ceil(self._db_data.size / self._db_data.chunk_size):
# raise ValidationError('requested chunk does not exist')

return chunk_number_

Expand Down
27 changes: 14 additions & 13 deletions cvat/apps/engine/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,6 +701,7 @@ def create(self, validated_data):
size = task.data.size
valid_frame_ids = task.data.get_valid_frame_indices()
segment_size = task.segment_size
overlap = task.overlap

frame_selection_method = validated_data.pop("frame_selection_method", None)
if frame_selection_method == models.JobFrameSelectionMethod.RANDOM_UNIFORM:
Expand All @@ -712,26 +713,26 @@ def create(self, validated_data):
)

if task.data.original_chunk_type == DataChoice.AUDIO:
num_segments = size // segment_size
jobs_frame_list = []
for i in range(num_segments):
start = i * segment_size
end = (i+1) * segment_size - 1
array = [j for j in range(start,end+1)]
jobs_frame_list.append(array)
effective_increment = segment_size - overlap

# if there's a remainder, create the last array
if size % segment_size != 0:
start = num_segments * segment_size
end = size - 1
array = [j for j in range(start,end+1)]
# Create overlapping segments
jobs_frame_list = []
start = 0
while start < size:
end = min(start + segment_size - 1, size - 1) # last frame does not exceed the total size
array = [j for j in range(start, end + 1)]
jobs_frame_list.append(array)
start += effective_increment # Move to the next start position considering the overlap

#Random select from the list
# Randomly select from the list
import math, random

random_jobs_no = math.ceil(frame_count / segment_size)
selected_jobs_frames = random.sample(jobs_frame_list, random_jobs_no)

# Flatten and sort the selected frames
frames = sorted([item for sublist in selected_jobs_frames for item in sublist])

else:
seed = validated_data.pop("seed", None)

Expand Down
29 changes: 26 additions & 3 deletions cvat/apps/engine/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def _segments():
if segment_size == 0:
raise ValueError("Segment size cannot be zero.")

overlap = 0
overlap = db_task.overlap
segment_size = segment_step
# if db_task.overlap is not None:
# overlap = min(db_task.overlap, segment_size // 2)
Expand Down Expand Up @@ -1060,9 +1060,12 @@ def get_audio_duration(file_path):

db_task.audio_total_duration = None

num_frames_per_millisecond = 0
# calculate chunk size if it isn't specified
if MEDIA_TYPE == "audio":
segment_duration = db_task.segment_duration if db_task.segment_duration is not None else 600000
overlap_duration = 5*1000

db_task.audio_total_duration = get_audio_duration(details['source_path'][0])
# db_task.data.audio_total_duration = 720000 #get_audio_duration(details['source_path'][0])
total_audio_frames = extractor.get_total_frames()
Expand All @@ -1075,6 +1078,7 @@ def get_audio_duration(file_path):

num_frames_per_segment_duration = num_frames_per_millisecond*segment_duration
db_task.segment_size = int(round(num_frames_per_segment_duration))
db_task.overlap = int(round(num_frames_per_millisecond * overlap_duration)) # we want to hardcode overlap for audio

# num_segments = max(1, int(math.ceil(db_task.audio_total_duration / segment_duration)))

Expand Down Expand Up @@ -1206,9 +1210,23 @@ def get_audio_duration(file_path):
frame=frame, width=w, height=h)
for (path, frame), (w, h) in zip(chunk_paths, img_sizes)
])

if db_data.storage_method == models.StorageMethodChoice.FILE_SYSTEM or not settings.USE_CACHE:
def generate_chunks_with_overlap(extractor, chunk_size, overlap):
chunk = []
chunk_idx = 0
for frame in extractor:
chunk.append(frame)
if len(chunk) == chunk_size + overlap: # Full chunk including overlap
yield chunk_idx, chunk[:chunk_size] # Yield the main chunk
chunk_idx += 1
chunk = chunk[chunk_size - overlap:] # Retain the overlap portion for the next chunk
if chunk: # Yield remaining frames as the last chunk
yield chunk_idx, chunk

counter = itertools.count()
generator = itertools.groupby(extractor, lambda _: next(counter) // db_data.chunk_size)
# generator = itertools.groupby(extractor, lambda _: next(counter) // db_data.chunk_size)
generator = generate_chunks_with_overlap(extractor, chunk_size=db_data.chunk_size, overlap=db_task.overlap)
generator = ((idx, list(chunk_data)) for idx, chunk_data in generator)

def save_chunks(
Expand Down Expand Up @@ -1262,8 +1280,13 @@ def process_results(img_meta: list[tuple[str, int, tuple[int, int]]]):

futures = queue.Queue(maxsize=settings.CVAT_CONCURRENT_CHUNK_PROCESSING)
with concurrent.futures.ThreadPoolExecutor(max_workers=2*settings.CVAT_CONCURRENT_CHUNK_PROCESSING) as executor:
seen_frames = set() # To track unique frames
for chunk_idx, chunk_data in generator:
db_data.size += len(chunk_data)
unique_frames = [frame for frame in chunk_data if frame not in seen_frames]
seen_frames.update(unique_frames)
db_data.size += len(unique_frames)

# db_data.size += len(chunk_data)
if futures.full():
process_results(futures.get().result())
futures.put(executor.submit(save_chunks, executor, chunk_idx, chunk_data))
Expand Down
12 changes: 6 additions & 6 deletions cvat/apps/engine/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,12 +701,12 @@ def __call__(self, request, start: int, stop: int, db_data: Optional[Data]):

try:
if self.type == 'chunk':
start_chunk = frame_provider.get_chunk_number(start)
stop_chunk = frame_provider.get_chunk_number(stop)
# pylint: disable=superfluous-parens
if not (start_chunk <= self.number <= stop_chunk):
raise ValidationError('The chunk number should be in the ' +
f'[{start_chunk}, {stop_chunk}] range')
# start_chunk = frame_provider.get_chunk_number(start)
# stop_chunk = frame_provider.get_chunk_number(stop)
# # pylint: disable=superfluous-parens
# if not (start_chunk <= self.number <= stop_chunk):
# raise ValidationError('The chunk number should be in the ' +
# f'[{start_chunk}, {stop_chunk}] range')

# TODO: av.FFmpegError processing
if settings.USE_CACHE and db_data.storage_method == StorageMethodChoice.CACHE:
Expand Down

0 comments on commit 2c6ed64

Please sign in to comment.