Best way compute dynamic slices of arrays where the slices are known beforehand? #23628
Unanswered
Mattias421
asked this question in
Q&A
Replies: 1 comment 3 replies
-
If you want those features to be static, you'll need to keep them static in your That said, it looks like you are scanning over these values, so by definition they cannot be static/constant from iteration to iteration. To make the semantics more clear and allow me to better answer your question, could you provide a minimal reproducible example of what exactly you're hoping to do? |
Beta Was this translation helpful? Give feedback.
3 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
I am trying to align sequences of text feature vectors with sequences of speech feature vectors.
My data is organised with
text_feature_array
which is an array of text feature vectors where each vector represents a character/token in the dataset. You can think of this as flattening out the entire dataset and then extracting feature vectors. The length of each text segment is recorded in a separate arraytext_lens
, which allows the text features for each sample to be extracted using cumulative sum andtext_lens
, e.g. the 0th text feature segment would betext_feature_array[0 : 0 + textlens[0]]
, and the 1th segment would betext_feature_array[0 + textlens[0]: 0 + textlens[0]] + textlens[1]]
and so on. The same applies tospeech_feature_array
, but the lengths are much larger than the text segments.I have a function
align
that, given a text and speech segment, will assign a duration/number to each text feature vector in the segment that assigns how many speech vectors correspond to it,align
returns a vector with length equal to the text segment and the sum ofalign
is equal to the length of the speech segment.I'm trying to make use of Jax's flexible concurrency to compute all the alignments as fast as possible, I have considered use
jax.lax.scan
like so (I've simplified the code for readability):My main issue is that
txt_idx, txt_len, sp_idx, sp_len
become abstract tracers meaning they cannot be used to slice the feature arrays. I have triedlax.dynamic_slice_in_dim
but the same problem persists.As this problem involves processing two large data structures to produce a new large data structure, parallel computing is highly desirable/possible, but I cannot quite figure out what angle to go about this.
Any thoughts?
Beta Was this translation helpful? Give feedback.
All reactions