@@ -74,10 +74,33 @@ def __init__(
74
74
output_single_stem = None ,
75
75
invert_using_spec = False ,
76
76
sample_rate = 44100 ,
77
- mdx_params = {"hop_length" : 1024 , "segment_size" : 256 , "overlap" : 0.25 , "batch_size" : 1 , "enable_denoise" : False },
78
- vr_params = {"batch_size" : 16 , "window_size" : 512 , "aggression" : 5 , "enable_tta" : False , "enable_post_process" : False , "post_process_threshold" : 0.2 , "high_end_process" : False },
79
- demucs_params = {"segment_size" : "Default" , "shifts" : 2 , "overlap" : 0.25 , "segments_enabled" : True },
80
- mdxc_params = {"segment_size" : 256 , "batch_size" : 1 , "overlap" : 8 },
77
+ mdx_params = {
78
+ "hop_length" : 1024 ,
79
+ "segment_size" : 256 ,
80
+ "overlap" : 0.25 ,
81
+ "batch_size" : 1 ,
82
+ "enable_denoise" : False ,
83
+ },
84
+ vr_params = {
85
+ "batch_size" : 16 ,
86
+ "window_size" : 512 ,
87
+ "aggression" : 5 ,
88
+ "enable_tta" : False ,
89
+ "enable_post_process" : False ,
90
+ "post_process_threshold" : 0.2 ,
91
+ "high_end_process" : False ,
92
+ },
93
+ demucs_params = {
94
+ "segment_size" : "Default" ,
95
+ "shifts" : 2 ,
96
+ "overlap" : 0.25 ,
97
+ "segments_enabled" : True ,
98
+ },
99
+ mdxc_params = {
100
+ "segment_size" : 256 ,
101
+ "batch_size" : 1 ,
102
+ "overlap" : 8 ,
103
+ },
81
104
):
82
105
self .logger = logging .getLogger (__name__ )
83
106
self .logger .setLevel (log_level )
@@ -143,7 +166,12 @@ def __init__(
143
166
144
167
# These are parameters which users may want to configure so we expose them to the top-level Separator class,
145
168
# even though they are specific to a single model architecture
146
- self .arch_specific_params = {"MDX" : mdx_params , "VR" : vr_params , "Demucs" : demucs_params , "MDXC" : mdxc_params }
169
+ self .arch_specific_params = {
170
+ "MDX" : mdx_params ,
171
+ "VR" : vr_params ,
172
+ "Demucs" : demucs_params ,
173
+ "MDXC" : mdxc_params ,
174
+ }
147
175
148
176
self .torch_device = None
149
177
self .torch_device_cpu = None
@@ -385,9 +413,16 @@ def list_supported_model_files(self):
385
413
# Return object with list of model names, which are the keys in vr_download_list, mdx_download_list, demucs_download_list, mdx23_download_list, mdx23c_download_list, grouped by type: VR, MDX, Demucs, MDX23, MDX23C
386
414
model_files_grouped_by_type = {
387
415
"VR" : model_downloads_list ["vr_download_list" ],
388
- "MDX" : {** model_downloads_list ["mdx_download_list" ], ** model_downloads_list ["mdx_download_vip_list" ]},
416
+ "MDX" : {
417
+ ** model_downloads_list ["mdx_download_list" ],
418
+ ** model_downloads_list ["mdx_download_vip_list" ],
419
+ },
389
420
"Demucs" : filtered_demucs_v4 ,
390
- "MDXC" : {** model_downloads_list ["mdx23c_download_list" ], ** model_downloads_list ["mdx23c_download_vip_list" ], ** model_downloads_list ["roformer_download_list" ]},
421
+ "MDXC" : {
422
+ ** model_downloads_list ["mdx23c_download_list" ],
423
+ ** model_downloads_list ["mdx23c_download_vip_list" ],
424
+ ** model_downloads_list ["roformer_download_list" ],
425
+ },
391
426
}
392
427
return model_files_grouped_by_type
393
428
@@ -664,7 +699,12 @@ def load_model(self, model_filename="model_mel_band_roformer_ep_3005_sdr_11.4360
664
699
}
665
700
666
701
# Instantiate the appropriate separator class depending on the model type
667
- separator_classes = {"MDX" : "mdx_separator.MDXSeparator" , "VR" : "vr_separator.VRSeparator" , "Demucs" : "demucs_separator.DemucsSeparator" , "MDXC" : "mdxc_separator.MDXCSeparator" }
702
+ separator_classes = {
703
+ "MDX" : "mdx_separator.MDXSeparator" ,
704
+ "VR" : "vr_separator.VRSeparator" ,
705
+ "Demucs" : "demucs_separator.DemucsSeparator" ,
706
+ "MDXC" : "mdxc_separator.MDXCSeparator" ,
707
+ }
668
708
669
709
if model_type not in self .arch_specific_params or model_type not in separator_classes :
670
710
raise ValueError (f"Model type not supported (yet): { model_type } " )
0 commit comments