Skip to content

Commit

Permalink
update dependency versions
Browse files Browse the repository at this point in the history
  • Loading branch information
themattinthehatt committed Jun 13, 2024
1 parent cb8f449 commit d1e2800
Show file tree
Hide file tree
Showing 6 changed files with 18 additions and 18 deletions.
4 changes: 2 additions & 2 deletions daart/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def on_epoch_end(self, data_generator, model, trainer, **kwargs):
batch[np.sum(batch, axis=1) == 0, 0] = 1
# turn into a one-hot vector
batch = np.argmax(batch, axis=1)
pseudo_labels_data.append(batch.astype(np.int))
pseudo_labels_data.append(batch.astype(int))
pseudo_labels.append(pseudo_labels_data)

# total_new_pseudos = \
Expand Down Expand Up @@ -210,7 +210,7 @@ def on_epoch_end(self, data_generator, model, trainer, **kwargs):
new_batch[np.sum(new_batch, axis=1) == 0, 0] = 1
# turn into a one-hot vector
new_batch = np.argmax(new_batch, axis=1)
pseudo_labels_data.append(new_batch.astype(np.int))
pseudo_labels_data.append(new_batch.astype(int))
pseudo_labels.append(pseudo_labels_data)

# update the data generator with the new psuedo-labels
Expand Down
8 changes: 4 additions & 4 deletions daart/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ def load_data(self, sequence_length: int, input_type: str) -> None:
# if no path given, assume same size as markers and set all to background
if 'markers' in self.data.keys():
data_curr = np.zeros(
(len(self.data['markers']) * sequence_length,), dtype=np.int)
(len(self.data['markers']) * sequence_length,), dtype=int)
else:
raise FileNotFoundError(
'Could not load "labels_strong" from None file without markers')
Expand Down Expand Up @@ -636,12 +636,12 @@ def count_class_examples(self) -> np.array:

assert 'labels_strong' in self.signals[0], 'Cannot count examples without hand labels'

totals = np.zeros(len(self.label_names), dtype=np.int)
totals = np.zeros(len(self.label_names), dtype=int)
for dataset in self.datasets:
pad = dataset.sequence_pad
for b, batch in enumerate(dataset.data['labels_strong']):
# log number of examples for batch
counts = np.bincount(batch[pad:-pad].astype('int'))
counts = np.bincount(batch[pad:-pad].astype(int))
if len(counts) == len(totals):
totals += counts
else:
Expand Down Expand Up @@ -874,7 +874,7 @@ def load_label_csv(filepath: str) -> tuple:
"""
labels = np.genfromtxt(
filepath, delimiter=',', dtype=np.int, encoding=None, skip_header=1)[:, 1:]
filepath, delimiter=',', dtype=int, encoding=None, skip_header=1)[:, 1:]
label_names = list(
np.genfromtxt(filepath, delimiter=',', dtype=None, encoding=None, max_rows=1)[1:])
return labels, label_names
Expand Down
14 changes: 7 additions & 7 deletions daart/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def update_metrics(
self,
dtype: str,
loss_dict: dict,
dataset: Union[int, np.int64, list, None] = None
dataset: Union[int, int, list, None] = None
) -> None:
"""Update metrics for a specific dtype/dataset.
Expand Down Expand Up @@ -123,12 +123,12 @@ def update_metrics(
def create_metric_row(
self,
dtype: str,
epoch: Union[int, np.int64],
batch: Union[int, np.int64],
dataset: Union[int, np.int64],
trial: Union[int, np.int64, None],
best_epoch: Optional[Union[int, np.int64]] = None,
by_dataset: bool = False
epoch: Union[int, int],
batch: Union[int, int],
dataset: Union[int, int],
trial: Union[int, int, None],
best_epoch: Optional[Union[int, int]] = None,
by_dataset: bool = False,
) -> dict:
"""Export metrics and other data (e.g. epoch) for logging train progress.
Expand Down
2 changes: 1 addition & 1 deletion docs/source/installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ Then, create a conda environment:

.. code-block:: console
conda create --name daart python=3.6
conda create --name daart python=3.10
Active the new environment:

Expand Down
2 changes: 1 addition & 1 deletion examples/fit_models.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@
"source": [
"# load hand labels\n",
"from numpy import genfromtxt\n",
"labels = genfromtxt(hand_labels_file, delimiter=',', dtype=np.int, encoding=None)\n",
"labels = genfromtxt(hand_labels_file, delimiter=',', dtype=int, encoding=None)\n",
"labels = labels[1:, 1:] # get rid of headers, etc.\n",
"states = np.argmax(labels, axis=1)\n",
"\n",
Expand Down
6 changes: 3 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from distutils.core import setup


VERSION = '1.0.2'
VERSION = '1.1.0'

# add the README.md file to the long_description
with open('README.md', 'r') as fh:
Expand All @@ -12,7 +12,7 @@
'jupyter',
'matplotlib',
'numpy',
'opencv-python',
'opencv-python-headless',
'pandas',
'pytest',
'pyyaml',
Expand All @@ -21,7 +21,7 @@
'seaborn',
'tables',
'test-tube',
'torch==1.8.0',
'torch',
'tqdm',
'typeguard',
]
Expand Down

0 comments on commit d1e2800

Please sign in to comment.