1
1
import soundfile as sf
2
- import torch ,pdb ,time ,argparse ,os ,warnings ,sys ,librosa
2
+ import torch , pdb , time , argparse , os , warnings , sys , librosa
3
3
import numpy as np
4
4
import onnxruntime as ort
5
5
from scipy .io .wavfile import write
8
8
import torch .nn as nn
9
9
10
10
dim_c = 4
11
- class Conv_TDF_net_trim ():
12
- def __init__ (self , device , model_name , target_name ,
13
- L , dim_f , dim_t , n_fft , hop = 1024 ):
11
+
12
+
13
+ class Conv_TDF_net_trim :
14
+ def __init__ (
15
+ self , device , model_name , target_name , L , dim_f , dim_t , n_fft , hop = 1024
16
+ ):
14
17
super (Conv_TDF_net_trim , self ).__init__ ()
15
18
16
19
self .dim_f = dim_f
17
- self .dim_t = 2 ** dim_t
20
+ self .dim_t = 2 ** dim_t
18
21
self .n_fft = n_fft
19
22
self .hop = hop
20
23
self .n_bins = self .n_fft // 2 + 1
21
24
self .chunk_size = hop * (self .dim_t - 1 )
22
- self .window = torch .hann_window (window_length = self .n_fft , periodic = True ).to (device )
25
+ self .window = torch .hann_window (window_length = self .n_fft , periodic = True ).to (
26
+ device
27
+ )
23
28
self .target_name = target_name
24
- self .blender = ' blender' in model_name
29
+ self .blender = " blender" in model_name
25
30
26
- out_c = dim_c * 4 if target_name == '*' else dim_c
27
- self .freq_pad = torch .zeros ([1 , out_c , self .n_bins - self .dim_f , self .dim_t ]).to (device )
31
+ out_c = dim_c * 4 if target_name == "*" else dim_c
32
+ self .freq_pad = torch .zeros (
33
+ [1 , out_c , self .n_bins - self .dim_f , self .dim_t ]
34
+ ).to (device )
28
35
29
36
self .n = L // 2
30
37
31
38
def stft (self , x ):
32
39
x = x .reshape ([- 1 , self .chunk_size ])
33
- x = torch .stft (x , n_fft = self .n_fft , hop_length = self .hop , window = self .window , center = True , return_complex = True )
40
+ x = torch .stft (
41
+ x ,
42
+ n_fft = self .n_fft ,
43
+ hop_length = self .hop ,
44
+ window = self .window ,
45
+ center = True ,
46
+ return_complex = True ,
47
+ )
34
48
x = torch .view_as_real (x )
35
49
x = x .permute ([0 , 3 , 1 , 2 ])
36
- x = x .reshape ([- 1 , 2 , 2 , self .n_bins , self .dim_t ]).reshape ([- 1 , dim_c , self .n_bins , self .dim_t ])
37
- return x [:, :, :self .dim_f ]
50
+ x = x .reshape ([- 1 , 2 , 2 , self .n_bins , self .dim_t ]).reshape (
51
+ [- 1 , dim_c , self .n_bins , self .dim_t ]
52
+ )
53
+ return x [:, :, : self .dim_f ]
38
54
39
55
def istft (self , x , freq_pad = None ):
40
- freq_pad = self .freq_pad .repeat ([x .shape [0 ], 1 , 1 , 1 ]) if freq_pad is None else freq_pad
56
+ freq_pad = (
57
+ self .freq_pad .repeat ([x .shape [0 ], 1 , 1 , 1 ])
58
+ if freq_pad is None
59
+ else freq_pad
60
+ )
41
61
x = torch .cat ([x , freq_pad ], - 2 )
42
- c = 4 * 2 if self .target_name == '*' else 2
43
- x = x .reshape ([- 1 , c , 2 , self .n_bins , self .dim_t ]).reshape ([- 1 , 2 , self .n_bins , self .dim_t ])
62
+ c = 4 * 2 if self .target_name == "*" else 2
63
+ x = x .reshape ([- 1 , c , 2 , self .n_bins , self .dim_t ]).reshape (
64
+ [- 1 , 2 , self .n_bins , self .dim_t ]
65
+ )
44
66
x = x .permute ([0 , 2 , 3 , 1 ])
45
67
x = x .contiguous ()
46
68
x = torch .view_as_complex (x )
47
- x = torch .istft (x , n_fft = self .n_fft , hop_length = self .hop , window = self .window , center = True )
69
+ x = torch .istft (
70
+ x , n_fft = self .n_fft , hop_length = self .hop , window = self .window , center = True
71
+ )
48
72
return x .reshape ([- 1 , c , self .chunk_size ])
73
+
74
+
49
75
def get_models (device , dim_f , dim_t , n_fft ):
50
76
return Conv_TDF_net_trim (
51
77
device = device ,
52
- model_name = 'Conv-TDF' , target_name = 'vocals' ,
78
+ model_name = "Conv-TDF" ,
79
+ target_name = "vocals" ,
53
80
L = 11 ,
54
- dim_f = dim_f , dim_t = dim_t ,
55
- n_fft = n_fft
81
+ dim_f = dim_f ,
82
+ dim_t = dim_t ,
83
+ n_fft = n_fft ,
56
84
)
57
85
86
+
58
87
warnings .filterwarnings ("ignore" )
59
- cpu = torch .device ('cpu' )
60
- device = torch .device ('cuda:0' if torch .cuda .is_available () else 'cpu' )
88
+ cpu = torch .device ("cpu" )
89
+ device = torch .device ("cuda:0" if torch .cuda .is_available () else "cpu" )
90
+
61
91
62
92
class Predictor :
63
- def __init__ (self ,args ):
64
- self .args = args
65
- self .model_ = get_models (device = cpu , dim_f = args .dim_f , dim_t = args .dim_t , n_fft = args .n_fft )
66
- self .model = ort .InferenceSession (os .path .join (args .onnx ,self .model_ .target_name + '.onnx' ), providers = ['CUDAExecutionProvider' , 'CPUExecutionProvider' ])
67
- print ('onnx load done' )
93
+ def __init__ (self , args ):
94
+ self .args = args
95
+ self .model_ = get_models (
96
+ device = cpu , dim_f = args .dim_f , dim_t = args .dim_t , n_fft = args .n_fft
97
+ )
98
+ self .model = ort .InferenceSession (
99
+ os .path .join (args .onnx , self .model_ .target_name + ".onnx" ),
100
+ providers = ["CUDAExecutionProvider" , "CPUExecutionProvider" ],
101
+ )
102
+ print ("onnx load done" )
103
+
68
104
def demix (self , mix ):
69
105
samples = mix .shape [- 1 ]
70
106
margin = self .args .margin
71
- chunk_size = self .args .chunks * 44100
72
- assert not margin == 0 , ' margin cannot be zero!'
107
+ chunk_size = self .args .chunks * 44100
108
+ assert not margin == 0 , " margin cannot be zero!"
73
109
if margin > chunk_size :
74
110
margin = chunk_size
75
111
76
112
segmented_mix = {}
77
-
113
+
78
114
if self .args .chunks == 0 or samples < chunk_size :
79
115
chunk_size = samples
80
-
116
+
81
117
counter = - 1
82
118
for skip in range (0 , samples , chunk_size ):
83
- counter += 1
84
-
119
+ counter += 1
120
+
85
121
s_margin = 0 if counter == 0 else margin
86
- end = min (skip + chunk_size + margin , samples )
122
+ end = min (skip + chunk_size + margin , samples )
87
123
88
- start = skip - s_margin
124
+ start = skip - s_margin
89
125
90
- segmented_mix [skip ] = mix [:,start :end ].copy ()
126
+ segmented_mix [skip ] = mix [:, start :end ].copy ()
91
127
if end == samples :
92
128
break
93
129
94
130
sources = self .demix_base (segmented_mix , margin_size = margin )
95
- '''
131
+ """
96
132
mix:(2,big_sample)
97
133
segmented_mix:offset->(2,small_sample)
98
134
sources:(1,2,big_sample)
99
- '''
135
+ """
100
136
return sources
137
+
101
138
def demix_base (self , mixes , margin_size ):
102
139
chunked_sources = []
103
140
progress_bar = tqdm (total = len (mixes ))
@@ -106,84 +143,102 @@ def demix_base(self, mixes, margin_size):
106
143
cmix = mixes [mix ]
107
144
sources = []
108
145
n_sample = cmix .shape [1 ]
109
- model = self .model_
110
- trim = model .n_fft // 2
111
- gen_size = model .chunk_size - 2 * trim
112
- pad = gen_size - n_sample % gen_size
113
- mix_p = np .concatenate ((np .zeros ((2 ,trim )), cmix , np .zeros ((2 ,pad )), np .zeros ((2 ,trim ))), 1 )
146
+ model = self .model_
147
+ trim = model .n_fft // 2
148
+ gen_size = model .chunk_size - 2 * trim
149
+ pad = gen_size - n_sample % gen_size
150
+ mix_p = np .concatenate (
151
+ (np .zeros ((2 , trim )), cmix , np .zeros ((2 , pad )), np .zeros ((2 , trim ))), 1
152
+ )
114
153
mix_waves = []
115
154
i = 0
116
155
while i < n_sample + pad :
117
- waves = np .array (mix_p [:, i : i + model .chunk_size ])
156
+ waves = np .array (mix_p [:, i : i + model .chunk_size ])
118
157
mix_waves .append (waves )
119
158
i += gen_size
120
159
mix_waves = torch .tensor (mix_waves , dtype = torch .float32 ).to (cpu )
121
160
with torch .no_grad ():
122
161
_ort = self .model
123
162
spek = model .stft (mix_waves )
124
163
if self .args .denoise :
125
- spec_pred = - _ort .run (None , {'input' : - spek .cpu ().numpy ()})[0 ]* 0.5 + _ort .run (None , {'input' : spek .cpu ().numpy ()})[0 ]* 0.5
164
+ spec_pred = (
165
+ - _ort .run (None , {"input" : - spek .cpu ().numpy ()})[0 ] * 0.5
166
+ + _ort .run (None , {"input" : spek .cpu ().numpy ()})[0 ] * 0.5
167
+ )
126
168
tar_waves = model .istft (torch .tensor (spec_pred ))
127
169
else :
128
- tar_waves = model .istft (torch .tensor (_ort .run (None , {'input' : spek .cpu ().numpy ()})[0 ]))
129
- tar_signal = tar_waves [:,:,trim :- trim ].transpose (0 ,1 ).reshape (2 , - 1 ).numpy ()[:, :- pad ]
170
+ tar_waves = model .istft (
171
+ torch .tensor (_ort .run (None , {"input" : spek .cpu ().numpy ()})[0 ])
172
+ )
173
+ tar_signal = (
174
+ tar_waves [:, :, trim :- trim ]
175
+ .transpose (0 , 1 )
176
+ .reshape (2 , - 1 )
177
+ .numpy ()[:, :- pad ]
178
+ )
130
179
131
180
start = 0 if mix == 0 else margin_size
132
181
end = None if mix == list (mixes .keys ())[::- 1 ][0 ] else - margin_size
133
182
if margin_size == 0 :
134
183
end = None
135
- sources .append (tar_signal [:,start :end ])
184
+ sources .append (tar_signal [:, start :end ])
136
185
137
186
progress_bar .update (1 )
138
-
187
+
139
188
chunked_sources .append (sources )
140
189
_sources = np .concatenate (chunked_sources , axis = - 1 )
141
190
# del self.model
142
191
progress_bar .close ()
143
192
return _sources
144
- def prediction (self , m ,vocal_root ,others_root ,format ):
145
- os .makedirs (vocal_root ,exist_ok = True )
146
- os .makedirs (others_root ,exist_ok = True )
193
+
194
+ def prediction (self , m , vocal_root , others_root , format ):
195
+ os .makedirs (vocal_root , exist_ok = True )
196
+ os .makedirs (others_root , exist_ok = True )
147
197
basename = os .path .basename (m )
148
198
mix , rate = librosa .load (m , mono = False , sr = 44100 )
149
199
if mix .ndim == 1 :
150
- mix = np .asfortranarray ([mix ,mix ])
200
+ mix = np .asfortranarray ([mix , mix ])
151
201
mix = mix .T
152
202
sources = self .demix (mix .T )
153
- opt = sources [0 ].T
154
- sf .write ("%s/%s_main_vocal.%s" % (vocal_root ,basename ,format ), mix - opt , rate )
155
- sf .write ("%s/%s_others.%s" % (others_root ,basename ,format ), opt , rate )
156
-
157
- class MDXNetDereverb ():
158
- def __init__ (self ,chunks ):
159
- self .onnx = "uvr5_weights/onnx_dereverb_By_FoxJoy"
160
- self .shifts = 10 #'Predict with randomised equivariant stabilisation'
161
- self .mixing = "min_mag" #['default','min_mag','max_mag']
162
- self .chunks = chunks
163
- self .margin = 44100
164
- self .dim_t = 9
165
- self .dim_f = 3072
166
- self .n_fft = 6144
167
- self .denoise = True
168
- self .pred = Predictor (self )
169
-
170
- def _path_audio_ (self ,input ,vocal_root ,others_root ,format ):
171
- self .pred .prediction (input ,vocal_root ,others_root ,format )
172
-
173
- if __name__ == '__main__' :
174
- dereverb = MDXNetDereverb (15 )
203
+ opt = sources [0 ].T
204
+ sf .write (
205
+ "%s/%s_main_vocal.%s" % (vocal_root , basename , format ), mix - opt , rate
206
+ )
207
+ sf .write ("%s/%s_others.%s" % (others_root , basename , format ), opt , rate )
208
+
209
+
210
+ class MDXNetDereverb :
211
+ def __init__ (self , chunks ):
212
+ self .onnx = "uvr5_weights/onnx_dereverb_By_FoxJoy"
213
+ self .shifts = 10 #'Predict with randomised equivariant stabilisation'
214
+ self .mixing = "min_mag" # ['default','min_mag','max_mag']
215
+ self .chunks = chunks
216
+ self .margin = 44100
217
+ self .dim_t = 9
218
+ self .dim_f = 3072
219
+ self .n_fft = 6144
220
+ self .denoise = True
221
+ self .pred = Predictor (self )
222
+
223
+ def _path_audio_ (self , input , vocal_root , others_root , format ):
224
+ self .pred .prediction (input , vocal_root , others_root , format )
225
+
226
+
227
+ if __name__ == "__main__" :
228
+ dereverb = MDXNetDereverb (15 )
175
229
from time import time as ttime
176
- t0 = ttime ()
230
+
231
+ t0 = ttime ()
177
232
dereverb ._path_audio_ (
178
233
"雪雪伴奏对消HP5.wav" ,
179
234
"vocal" ,
180
235
"others" ,
181
236
)
182
- t1 = ttime ()
183
- print (t1 - t0 )
237
+ t1 = ttime ()
238
+ print (t1 - t0 )
184
239
185
240
186
- '''
241
+ """
187
242
188
243
runtime\python.exe MDXNet.py
189
244
@@ -195,4 +250,4 @@ def _path_audio_(self,input,vocal_root,others_root,format):
195
250
half15:0.7G->6.6G,22.69s
196
251
fp32-15:0.7G->6.6G,20.85s
197
252
198
- '''
253
+ """
0 commit comments