Skip to content

Commit

Permalink
Add a Pip-based CI testing for inference code, with Colab dependencies.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 627135153
  • Loading branch information
sdenton4 authored and copybara-github committed May 11, 2024
1 parent c829814 commit 35ce265
Show file tree
Hide file tree
Showing 6 changed files with 1,462 additions and 1,353 deletions.
40 changes: 40 additions & 0 deletions .github/install_colab_deps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# coding=utf-8
# Copyright 2024 The Perch Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Installs Colab dependencies for CI testing."""

from typing import Sequence

from absl import app
import requests


REQS_FILE = 'https://raw.githubusercontent.com/googlecolab/backend-info/main/pip-freeze.txt'
COLAB_REQS_FILE = '/tmp/colab_reqs.txt'


def main(unused_argv: Sequence[str]) -> None:
got = requests.get(REQS_FILE)
requirements_str = str(got.content, 'utf8')
# Skip the file:// lines, which we do not have access to.
lines = [
ln + '\n' for ln in requirements_str.split('\n') if 'file://' not in ln
]
with open(COLAB_REQS_FILE, 'w') as f:
f.writelines(lines)


if __name__ == '__main__':
app.run(main)
41 changes: 41 additions & 0 deletions .github/workflows/ci_colab_no_jaxtrain.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
name: CI_colab_no_jaxtrain

on:
push:
branches: [ main ]
pull_request:
branches: [ main ]
# Allows to run this workflow manually from the Actions tab on GitHub.
workflow_dispatch:

jobs:
test-ubuntu:
name: "test on ${{ matrix.python-version }} on ${{ matrix.os }}"
runs-on: "${{ matrix.os }}"
strategy:
matrix:
python-version: ["3.10", "3.11"]
os: [ubuntu-latest]
steps:
- uses: actions/checkout@v3
- name: Set up Poetry
run: |
pipx install poetry
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
cache: 'poetry'
- name: Install Chirp and its dependencies via pip.
run: |
sudo apt-get update
sudo apt-get install libsndfile1 ffmpeg
pip install absl-py
pip install requests
pip install tensorflow-cpu
python3 .github/install_colab_deps.py
pip install -r /tmp/colab_reqs.txt
pip install git+https://github.com/google-research/perch.git
- name: Test with unittest
# TODO: Group together jaxtrain tests so they can be easily excluded.
run: poetry run python -m unittest discover -s chirp/inference/tests -p "*test.py"
31 changes: 21 additions & 10 deletions chirp/inference/tests/bootstrap_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"""Tests for project state handling."""

import os
import shutil
import tempfile

from chirp import audio_utils
Expand All @@ -33,20 +34,29 @@

class BootstrapTest(absltest.TestCase):

def setUp(self):
super().setUp()
# `self.create_tempdir()` raises an UnparsedFlagAccessError, which is why
# we use `tempdir` directly.
self.tempdir = tempfile.mkdtemp()

def tearDown(self):
super().tearDown()
shutil.rmtree(self.tempdir)

def make_wav_files(self, classes, filenames):
# Create a pile of files.
rng = np.random.default_rng(seed=42)
tmpdir = self.create_tempdir()
for subdir in classes:
subdir_path = os.path.join(tmpdir.full_path, subdir)
subdir_path = os.path.join(self.tempdir, subdir)
os.mkdir(subdir_path)
for filename in filenames:
with open(
os.path.join(subdir_path, f'{filename}_{subdir}.wav'), 'wb'
) as f:
noise = rng.normal(scale=0.2, size=16000)
wavfile.write(f, 16000, noise)
audio_glob = os.path.join(tmpdir.full_path, '*/*.wav')
audio_glob = os.path.join(self.tempdir, '*/*.wav')
return audio_glob

def write_placeholder_embeddings(self, audio_glob, source_infos, embed_dir):
Expand Down Expand Up @@ -129,15 +139,16 @@ def test_bootstrap_from_embeddings(self):
source_infos = embed_lib.create_source_infos([audio_glob], shard_len_s=5.0)
self.assertLen(source_infos, len(classes) * len(filenames))

embed_dir = self.create_tempdir()
labeled_dir = self.create_tempdir()
self.write_placeholder_embeddings(
audio_glob, source_infos, embed_dir.full_path
)
embed_dir = os.path.join(self.tempdir, 'embeddings')
labeled_dir = os.path.join(self.tempdir, 'labeled')
epath.Path(embed_dir).mkdir(parents=True, exist_ok=True)
epath.Path(labeled_dir).mkdir(parents=True, exist_ok=True)

self.write_placeholder_embeddings(audio_glob, source_infos, embed_dir)

bootstrap_config = bootstrap.BootstrapConfig.load_from_embedding_path(
embeddings_path=embed_dir.full_path,
annotated_path=labeled_dir.full_path,
embeddings_path=embed_dir,
annotated_path=labeled_dir,
)
print('config hash : ', bootstrap_config.embedding_config_hash())

Expand Down
2 changes: 1 addition & 1 deletion chirp/inference/tests/embed_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ def test_logits_output_head(self):
logits_model = _make_output_head_model(
'/tmp/logits_model', embedding_dim=128
)
base_outputs = base_model.embed(np.zeros(5 * 22050))
base_outputs = base_model.embed(np.zeros(5 * 22050, dtype=np.float32))
updated_outputs = logits_model.add_logits(base_outputs, keep_original=True)
self.assertSequenceEqual(
updated_outputs.logits['other_label'].shape,
Expand Down
Loading

0 comments on commit 35ce265

Please sign in to comment.