Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update speed_check(), add type hints. #119

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 60 additions & 69 deletions reader/utils.py
Original file line number Diff line number Diff line change
@@ -1,87 +1,78 @@
"""Reader utilities."""
import itertools
import time
from typing import Optional

from tml.common.batch import DataclassBatch
from tml.ml_logging.torch_logging import logging

import pyarrow as pa
import torch

from tml.common.batch import DataclassBatch
from tml.ml_logging.torch_logging import logging

def roundrobin(*iterables):
"""Round robin through provided iterables, useful for simple load balancing.

Adapted from https://docs.python.org/3/library/itertools.html.

"""
num_active = len(iterables)
nexts = itertools.cycle(iter(it).__next__ for it in iterables)
while num_active:
try:
for _next in nexts:
result = _next()
yield result
except StopIteration:
# Remove the iterator we just exhausted from the cycle.
num_active -= 1
nexts = itertools.cycle(itertools.islice(nexts, num_active))
logging.warning(f"Iterable exhausted, {num_active} iterables left.")
except Exception as exc:
logging.warning(f"Iterable raised exception {exc}, ignoring.")
# continue
raise
def roundrobin(*iterables):
"""Round robin through provided iterables, useful for simple load balancing.
Adapted from https://docs.python.org/3/library/itertools.html.
"""
num_active = len(iterables)
nexts = itertools.cycle(iter(it).__next__ for it in iterables)
while num_active:
try:
for _next in nexts:
result = _next()
yield result
except StopIteration:
num_active -= 1
nexts = itertools.cycle(itertools.islice(nexts, num_active))
logging.warning(f"Iterable exhausted, {num_active} iterables left.")
except Exception as exc:
logging.warning(f"Iterable raised exception {exc}, ignoring.")
raise


def speed_check(data_loader, max_steps: int, frequency: int, peek: Optional[int]):
num_examples = 0
prev = time.perf_counter()
for idx, batch in enumerate(data_loader):
if idx > max_steps:
break
if peek and idx % peek == 0:
logging.info(f"Batch: {batch}")
num_examples += batch.batch_size
if idx % frequency == 0:
now = time.perf_counter()
elapsed = now - prev
logging.info(
f"step: {idx}, "
f"elapsed(s): {elapsed}, "
f"examples: {num_examples}, "
f"ex/s: {num_examples / elapsed}, "
)
prev = now
num_examples = 0
prev = time.perf_counter()
for idx, batch in enumerate(data_loader):
if idx > max_steps:
break
if peek and idx % peek == 0:
logging.info(f"Batch: {batch}")
if idx % frequency == 0 and idx > 0:
now = time.perf_counter()
elapsed = now - prev
examples_per_second = batch.batch_size / elapsed
logging.info(
f"step: {idx}, elapsed(s): {elapsed:.2f}, examples: {batch.batch_size}, "
f"ex/s: {examples_per_second:.2f}"
)
prev = now


def pa_to_torch(array: pa.array) -> torch.Tensor:
return torch.from_numpy(array.to_numpy())
return torch.from_numpy(array.to_numpy())


def create_default_pa_to_batch(schema) -> DataclassBatch:
""" """
_CustomBatch = DataclassBatch.from_schema("DefaultBatch", schema=schema)

def get_imputation_value(pa_type):
type_map = {
pa.float64(): pa.scalar(0, type=pa.float64()),
pa.int64(): pa.scalar(0, type=pa.int64()),
pa.string(): pa.scalar("", type=pa.string()),
}
if pa_type not in type_map:
raise Exception(f"Imputation for type {pa_type} not supported.")
return type_map[pa_type]

def _impute(array: pa.array) -> pa.array:
return array.fill_null(get_imputation_value(array.type))

def _column_to_tensor(record_batch: pa.RecordBatch):
tensors = {
col_name: pa_to_torch(_impute(record_batch.column(col_name)))
for col_name in record_batch.schema.names
}
return _CustomBatch(**tensors)

return _column_to_tensor
""" """
_CustomBatch = DataclassBatch.from_schema("DefaultBatch", schema=schema)

def get_imputation_value(pa_type):
type_map = {
pa.float64(): pa.scalar(0, type=pa.float64()),
pa.int64(): pa.scalar(0, type=pa.int64()),
pa.string(): pa.scalar("", type=pa.string()),
}
if pa_type not in type_map:
raise ValueError(f"Imputation for type {pa_type} not supported.")
return type_map[pa_type]

def _impute(array: pa.array) -> pa.array:
return array.fill_null(get_imputation_value(array.type))

def _column_to_tensor(record_batch: pa.RecordBatch):
tensors = {
col_name: pa_to_torch(_impute(record_batch.column(col_name)))
for col_name in record_batch.schema.names
}
return _CustomBatch(**tensors)

return _column_to_tensor