-
Notifications
You must be signed in to change notification settings - Fork 0
/
generate_data.py
53 lines (45 loc) · 2.05 KB
/
generate_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
# ref code
# https://github.com/guacomolia/ptr_net/blob/master/generate_data.py
import numpy as np
import random
def generate_single_seq(length=30, min_len=5, max_len=10):
# https://medium.com/@devnag/pointer-networks-in-tensorflow-with-sample-code-14645063f264
""" Generates a sequence of numbers of random length and inserts a sub-sequence oh greater numbers at random place
Input:
length: total sequence length
min_len: minimum length of sequence
max_len: maximum length of sequence
Output: Sequence of numbers, index of the start of greater numbers subsequence"""
seq_before = [(random.randint(1, 5)) for _ in range(random.randint(min_len, max_len))]
seq_during = [(random.randint(6, 10)) for _ in range(random.randint(min_len, max_len))]
seq_after = [random.randint(1, 5) for _ in range(random.randint(min_len, max_len))]
seq = seq_before + seq_during + seq_after
seq = seq + ([0] * (length - len(seq)))
return (seq, len(seq_before), len(seq_before) + len(seq_during)-1)
def generate_set_seq(N):
"""
The `generate_set_seq` function generates a set of N sequences of fixed length for `Boundary tasks`.
It returns the data, starts and ends lists.
The data list contains all the sequences in string format.
The starts list contains all the starting indices for each sequence in integer format, and similarly for ends.
:param N: Generate n sequences
:return: A list of sequences, a list of starting indices and a list of ending indices
"""
data = []
starts = []
ends = []
for _ in range(N):
seq, ind_start, ind_end = generate_single_seq()
data.append(seq)
starts.append(ind_start)
ends.append(ind_end)
return data, starts, ends
def make_seq_data(n_samples, seq_len):
# Boundary tasks
data, labels = [], []
for _ in range(n_samples):
input = np.random.permutation(range(seq_len)).tolist()
target = sorted(range(len(input)), key=lambda k: input[k])
data.append(input)
labels.append(target)
return data, labels