-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdata_loader.py
More file actions
264 lines (244 loc) · 9.8 KB
/
data_loader.py
File metadata and controls
264 lines (244 loc) · 9.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
import torch
import ipdb
from torch.utils import data
import json
import pickle
import os
import numpy as np
import soundfile as sf
#from .wsj0_mix import wsj0_license
EPS = 1e-8
DATASET = 'WAVESPLIT'
# WaveSplit tasks
enh_single = {'mixture': 'mix_single',
'sources': ['s1'],
'infos': ['noise'],
'default_nsrc': 1}
enh_both = {'mixture': 'mix_both',
'sources': ['mix_clean'],
'infos': ['noise'],
'default_nsrc': 1}
sep_clean = {'mixture': 'mix_clean',
'sources': ['s1', 's2'],
'infos': [],
'default_nsrc': 2}
sep_noisy = {'mixture': 'mix_both',
'sources': ['s1', 's2'],
'infos': ['noise'],
'default_nsrc': 2}
WAVESPLIT_TASKS = {'enhance_single': enh_single,
'enhance_both': enh_both,
'sep_clean': sep_clean,
'sep_noisy': sep_noisy}
# Aliases.
WAVESPLIT_TASKS['enh_single'] = WAVESPLIT_TASKS['enhance_single']
WAVESPLIT_TASKS['enh_both'] = WAVESPLIT_TASKS['enhance_both']
def get_task(task):
if task is 'mix_clean':
return 0
elif task is 'mix_single':
return 1
elif task is 'mix_both':
return 2
elif task is 's1':
return 3
elif task is 's2':
return 4
elif task is 'noise':
return 5
else:
print("Wrong Task!")
exit()
###
def normalize_tensor_wav(wav_tensor, eps=1e-8, std=None):
mean = wav_tensor.mean(-1, keepdim=True)
if std is None:
std = wav_tensor.std(-1, keepdim=True)
return (wav_tensor - mean) / (std + eps)
###
class WaveSplitDataset(data.Dataset):
""" Dataset class for WHAM source separation and speech enhancement tasks.
Args:
json_dir (str): The path to the directory containing the json files.
task (str): One of ``'enh_single'``, ``'enh_both'``, ``'sep_clean'`` or
``'sep_noisy'``.
* ``'enh_single'`` for single speaker speech enhancement.
* ``'enh_both'`` for multi speaker speech enhancement.
* ``'sep_clean'`` for two-speaker clean source separation.
* ``'sep_noisy'`` for two-speaker noisy source separation.
sample_rate (int, optional): The sampling rate of the wav files.
segment (float, optional): Length of the segments used for training,
in seconds. If None, use full utterances (e.g. for test).
nondefault_nsrc (int, optional): Number of sources in the training
targets.
If None, defaults to one for enhancement tasks and two for
separation tasks.
normalize_audio (bool): If True then both sources and the mixture are
normalized with the standard deviation of the mixture.
"""
dataset_name = 'WAVESPLIT'
def __init__(self, json_dir, task, spk_dict, sample_rate=8000, segment=4.0,
nondefault_nsrc=None, normalize_audio=False):
super(WaveSplitDataset, self).__init__()
if task not in WAVESPLIT_TASKS.keys():
raise ValueError('Unexpected task {}, expected one of '
'{}'.format(task, WAVESPLIT_TASKS.keys()))
# Task setting
self.json_dir = json_dir
self.task = task
self.task_dict = WAVESPLIT_TASKS[task]
self.sample_rate = sample_rate
self.normalize_audio = normalize_audio
self.seg_len = None if segment is None else int(segment * sample_rate)
if not nondefault_nsrc:
self.n_src = self.task_dict['default_nsrc']
else:
assert nondefault_nsrc >= self.task_dict['default_nsrc']
self.n_src = nondefault_nsrc
self.like_test = self.seg_len is None
# Load json files
mix_json = os.path.join(json_dir,'data.json')
"""
mix_json = os.path.join(json_dir, self.task_dict['mixture'] + '.json')
sources_json = [os.path.join(json_dir, source + '.json') for
source in self.task_dict['sources']]
"""
"""
label part
mix_label_json = os.path.join(json_dir, self.task_dict['mixture'] + '_label.json')
with open(mix_label_json, 'r') as f:
mix_label_infos = json.load(f)
"""
with open(mix_json, 'r') as f:
mix_infos = json.loads(f.read())
#ipdb.set_trace()
# sources_infos = []
# for src_json in sources_json:
# with open(src_json, 'r') as f:
# sources_infos.append(json.load(f))
# Filter out short utterances only when segment is specified
orig_len = len(mix_infos)
drop_utt, drop_len = 0, 0
if not self.like_test:
for i in range(len(mix_infos) - 1, -1, -1): # Go backward
if mix_infos[i][get_task(self.task_dict['mixture'])]['Sample'] < self.seg_len:
drop_utt += 1
drop_len += mix_infos[i][get_task(self.task_dict['mixture'])]['Sample']
del mix_infos[i]
"""
del mix_label_infos[i]
"""
# for src_inf in sources_infos:
# del src_inf[i]
print("Drop {} utts({:.2f} h) from {} (shorter than {} samples)".format(
drop_utt, drop_len/sample_rate/36000, orig_len, self.seg_len))
self.mix = mix_infos
# Handle the case n_src > default_nsrc
# while len(sources_infos) < self.n_src:
# sources_infos.append([None for _ in range(len(self.mix))])
# self.sources = sources_infos
"""
self.labels = mix_label_infos
"""
self.labels = None
f = open(spk_dict,'rb')
self.spk_dict = pickle.load(f)
def __add__(self, wham):
if self.n_src != wham.n_src:
raise ValueError('Only datasets having the same number of sources'
'can be added together. Received '
'{} and {}'.format(self.n_src, wham.n_src))
if self.seg_len != wham.seg_len:
self.seg_len = min(self.seg_len, wham.seg_len)
print('Segment length mismatched between the two Dataset'
'passed one the smallest to the sum.')
self.mix = self.mix + wham.mix
self.sources = [a + b for a, b in zip(self.sources, wham.sources)]
def __len__(self):
return len(self.mix)
def __getitem__(self, idx):
""" Gets a mixture/sources pair.
Returns:
mixture, vstack([source_arrays])
"""
# Random start
if self.mix[idx][get_task(self.task_dict['mixture'])]['Sample'] == self.seg_len or self.like_test:
rand_start = 0
else:
rand_start = np.random.randint(0, self.mix[idx][get_task(self.task_dict['mixture'])]['Sample'] - self.seg_len)
if self.like_test:
stop = None
else:
stop = rand_start + self.seg_len
# Load mixture
x, _ = sf.read('../'+ self.mix[idx][get_task(self.task_dict['mixture'])]['Src'], start=rand_start,
stop=stop, dtype='float32')
seg_len = torch.as_tensor([len(x)])
# Load sources
source_arrays = []
for SPK in ['s1','s2']:
if self.mix[idx][get_task(SPK)] is None:
# Target is filled with zeros if n_src > default_nsrc
s = np.zeros((seg_len, ))
else:
s, _ = sf.read('../'+ self.mix[idx][get_task(SPK)]['Src'], start=rand_start,
stop=stop, dtype='float32')
source_arrays.append(s)
# for src in self.sources:
# if src[idx] is None:
# # Target is filled with zeros if n_src > default_nsrc
# s = np.zeros((seg_len, ))
# else:
# s, _ = sf.read(src[idx][0], start=rand_start,
# stop=stop, dtype='float32')
# source_arrays.append(s)
# Load labels
_labels = np.zeros((2,self.seg_len))
if self.labels is None:
_labels = np.ones((2,self.seg_len))
else:
lab = np.load(self.labels[idx])
_labels[0] = lab[0][rand_start:stop]
_labels[1] = lab[1][rand_start:stop]
# Transfer to Speaker ID
_labels[0] = _labels[0]*self.spk_dict[self.mix[idx][3]['spkID']]
_labels[1] = _labels[1]*self.spk_dict[self.mix[idx][4]['spkID']]
sources = torch.from_numpy(np.vstack(source_arrays)).float()
mixture = torch.from_numpy(x).float()
labels = torch.from_numpy(_labels).long()
if self.normalize_audio:
m_std = mixture.std(-1, keepdim=True)
mixture = normalize_tensor_wav(mixture, eps=EPS, std=m_std)
sources = normalize_tensor_wav(sources, eps=EPS, std=m_std)
return mixture, sources, labels
"""
def get_infos(self):
Get dataset infos (for publishing models).
Returns:
dict, dataset infos with keys `dataset`, `task` and `licences`.
infos = dict()
infos['dataset'] = self.dataset_name
infos['task'] = self.task
if self.task == 'sep_clean':
data_license = [wsj0_license]
else:
data_license = [wsj0_license, wham_noise_license]
infos['licenses'] = data_license
return infos
"""
def get_speakerID(src_path):
"""
Input: Source
Output: ID
"""
key = src_path.split("/")[-1][0:3]
return self.spk_dict[key]
wham_noise_license = dict(
title='The WSJ0 Hipster Ambient Mixtures dataset',
title_link='http://wham.whisper.ai/',
author='Whisper.ai',
author_link='https://whisper.ai/',
license='CC BY-NC 4.0',
license_link='https://creativecommons.org/licenses/by-nc/4.0/',
non_commercial=True,
)