Skip to content

JuanFMontesinos/torch_mir_eval

Repository files navigation

torch_mir_eval

PyPI Status Build Status Code Coverage Python Versions

Pytorch implementation of mir_eval.
Nvidia RTX 3090, ADM Threadripper 1920X for a single run

Pytorch 2.1.1
.bss_eval_sources test> permutation:         False      float32         CPU: 1.073      torch-CPU: 0.813
.bss_eval_sources test> Compute permutation: False      float32         CPU: 1.067      GPU: 0.592
.bss_eval_sources test> compute_permutation: True       float32         CPU: 5.224      torch-CPU: 3.737
.bss_eval_sources test> compute_permutation: True       float32         CPU: 5.417      GPU: 2.883
.bss_eval_sources test> Compute permutation: False      float64         CPU: 1.087      torch-CPU: 1.710
.bss_eval_sources test> Compute permutation: False      float64         CPU: 1.115      GPU: 0.940
.bss_eval_sources test> Compute permutation: False      float64         CPU: 5.390      torch-CPU: 8.451
.bss_eval_sources test> Compute permutation: False      float64         CPU: 5.558      GPU: 4.623
*Sources vary across tests  

Usage

PIP install:
pip install torch_mir_eval

The easy way (work as drop-in replacement or batches):

from torch_mir_eval import bss_eval_sources
N=4
S=44000
src = torch.rand(N,S).cuda()
est = torch.rand(N,S).cuda()
sdr,sir,sar,perm = bss_eval_sources(src,est,compute_permutation=True)
from torch_mir_eval import bss_eval_sources
B=3
N=4
S=44000
src = torch.rand(B,N,S).cuda()
est = torch.rand(B,N,S).cuda()
sdr,sir,sar,perm = bss_eval_sources(src,est,compute_permutation=True)

Just pass tensors instead of numpy arrays. Everything else is the same.
For the batched version we follow pytorch convention of batch first. Therefore the expected format is b, nsrc, samples

How to contribute

  • Implementing any other function from the original mir_eval
  • Addresing
    # TODO case in which only few of the Gs are singular but the rest are determined
    if all(torch.det(G) > 0.1):
    C = torch.solve(D.unsqueeze(-1), G).solution.reshape(b, nsrc, flen).permute(0, 2, 1)
    else:
    solutions = []

Changelog

  • Version 0.1: bss_eval_sources function implemented
  • Version 0.2: bss_eval_sources is now backpropagable. There algorithm now accepts batches.
  • Version 0.3: Incorporates new PyTorch's fft package for versions>1.7 and deprecates torch.rfft and
    torch.ifft following pytorch's roadmap.
  • Version 0.4: Support for PyTorch 1.9 onwards and the new linalg package which deprecates previous algebra solvers.
    • Partially Solves inconsistencies between GPU results and CPU results shown at #5

Current available functions

  • Separation:
    • mir_eval.separation.bss_eval_sources
    • mir_eval.batch_separation.bss_eval_sources