1
1
""" This file contains the Separator class, to facilitate the separation of stems from audio. """
2
2
3
- from importlib import metadata
3
+ from importlib import metadata , resources
4
4
import os
5
5
import sys
6
6
import platform
@@ -74,33 +74,10 @@ def __init__(
74
74
output_single_stem = None ,
75
75
invert_using_spec = False ,
76
76
sample_rate = 44100 ,
77
- mdx_params = {
78
- "hop_length" : 1024 ,
79
- "segment_size" : 256 ,
80
- "overlap" : 0.25 ,
81
- "batch_size" : 1 ,
82
- "enable_denoise" : False ,
83
- },
84
- vr_params = {
85
- "batch_size" : 16 ,
86
- "window_size" : 512 ,
87
- "aggression" : 5 ,
88
- "enable_tta" : False ,
89
- "enable_post_process" : False ,
90
- "post_process_threshold" : 0.2 ,
91
- "high_end_process" : False ,
92
- },
93
- demucs_params = {
94
- "segment_size" : "Default" ,
95
- "shifts" : 2 ,
96
- "overlap" : 0.25 ,
97
- "segments_enabled" : True ,
98
- },
99
- mdxc_params = {
100
- "segment_size" : 256 ,
101
- "batch_size" : 1 ,
102
- "overlap" : 8 ,
103
- },
77
+ mdx_params = {"hop_length" : 1024 , "segment_size" : 256 , "overlap" : 0.25 , "batch_size" : 1 , "enable_denoise" : False },
78
+ vr_params = {"batch_size" : 16 , "window_size" : 512 , "aggression" : 5 , "enable_tta" : False , "enable_post_process" : False , "post_process_threshold" : 0.2 , "high_end_process" : False },
79
+ demucs_params = {"segment_size" : "Default" , "shifts" : 2 , "overlap" : 0.25 , "segments_enabled" : True },
80
+ mdxc_params = {"segment_size" : 256 , "batch_size" : 1 , "overlap" : 8 },
104
81
):
105
82
self .logger = logging .getLogger (__name__ )
106
83
self .logger .setLevel (log_level )
@@ -166,12 +143,7 @@ def __init__(
166
143
167
144
# These are parameters which users may want to configure so we expose them to the top-level Separator class,
168
145
# even though they are specific to a single model architecture
169
- self .arch_specific_params = {
170
- "MDX" : mdx_params ,
171
- "VR" : vr_params ,
172
- "Demucs" : demucs_params ,
173
- "MDXC" : mdxc_params ,
174
- }
146
+ self .arch_specific_params = {"MDX" : mdx_params , "VR" : vr_params , "Demucs" : demucs_params , "MDXC" : mdxc_params }
175
147
176
148
self .torch_device = None
177
149
self .torch_device_cpu = None
@@ -351,7 +323,7 @@ def list_supported_model_files(self):
351
323
self .download_file_if_not_exists ("https://raw.githubusercontent.com/TRvlvr/application_data/main/filelists/download_checks.json" , download_checks_path )
352
324
353
325
model_downloads_list = json .load (open (download_checks_path , encoding = "utf-8" ))
354
- self .logger .debug (f"Model download list loaded" )
326
+ self .logger .debug (f"UVR model download list loaded" )
355
327
356
328
# model_downloads_list JSON structure / example snippet:
357
329
# {
@@ -410,18 +382,21 @@ def list_supported_model_files(self):
410
382
# Only show Demucs v4 models as we've only implemented support for v4
411
383
filtered_demucs_v4 = {key : value for key , value in model_downloads_list ["demucs_download_list" ].items () if key .startswith ("Demucs v4" )}
412
384
385
+ # Load the JSON file using importlib.resources
386
+ with resources .open_text ("audio_separator" , "models.json" ) as f :
387
+ audio_separator_models_list = json .load (f )
388
+ self .logger .debug (f"Audio-Separator model list loaded" )
389
+
413
390
# Return object with list of model names, which are the keys in vr_download_list, mdx_download_list, demucs_download_list, mdx23_download_list, mdx23c_download_list, grouped by type: VR, MDX, Demucs, MDX23, MDX23C
414
391
model_files_grouped_by_type = {
415
- "VR" : model_downloads_list ["vr_download_list" ],
416
- "MDX" : {
417
- ** model_downloads_list ["mdx_download_list" ],
418
- ** model_downloads_list ["mdx_download_vip_list" ],
419
- },
392
+ "VR" : {** model_downloads_list ["vr_download_list" ], ** audio_separator_models_list ["vr_download_list" ]},
393
+ "MDX" : {** model_downloads_list ["mdx_download_list" ], ** model_downloads_list ["mdx_download_vip_list" ], ** audio_separator_models_list ["mdx_download_list" ]},
420
394
"Demucs" : filtered_demucs_v4 ,
421
395
"MDXC" : {
422
396
** model_downloads_list ["mdx23c_download_list" ],
423
397
** model_downloads_list ["mdx23c_download_vip_list" ],
424
398
** model_downloads_list ["roformer_download_list" ],
399
+ ** audio_separator_models_list ["roformer_download_list" ],
425
400
},
426
401
}
427
402
return model_files_grouped_by_type
@@ -444,6 +419,8 @@ def download_model_files(self, model_filename):
444
419
public_model_repo_url_prefix = "https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models"
445
420
vip_model_repo_url_prefix = "https://github.com/Anjok0109/ai_magic/releases/download/v5"
446
421
422
+ audio_separator_models_repo_url_prefix = "https://github.com/nomadkaraoke/python-audio-separator/releases/download/model-configs"
423
+
447
424
yaml_config_filename = None
448
425
449
426
self .logger .debug (f"Searching for model_filename { model_filename } in supported_model_files_grouped" )
@@ -457,7 +434,12 @@ def download_model_files(self, model_filename):
457
434
self .logger .debug (f"Single file model identified: { model_friendly_name } " )
458
435
self .model_friendly_name = model_friendly_name
459
436
460
- self .download_file_if_not_exists (f"{ model_repo_url_prefix } /{ model_filename } " , model_path )
437
+ try :
438
+ self .download_file_if_not_exists (f"{ model_repo_url_prefix } /{ model_filename } " , model_path )
439
+ except RuntimeError :
440
+ self .logger .debug ("Model not found in UVR repo, attempting to download from audio-separator models repo..." )
441
+ self .download_file_if_not_exists (f"{ audio_separator_models_repo_url_prefix } /{ model_filename } " , model_path )
442
+
461
443
self .print_uvr_vip_message ()
462
444
463
445
self .logger .debug (f"Returning path for single model file: { model_path } " )
@@ -488,8 +470,13 @@ def download_model_files(self, model_filename):
488
470
# Checkpoint models apparently use config_key as the model filename, but the value is a YAML config file name...
489
471
# Both need to be downloaded, but the model data YAML file actually comes from the application data repo...
490
472
elif config_key .endswith (".ckpt" ):
491
- download_url = f"{ model_repo_url_prefix } /{ config_key } "
492
- self .download_file_if_not_exists (download_url , os .path .join (self .model_file_dir , config_key ))
473
+ try :
474
+ download_url = f"{ model_repo_url_prefix } /{ config_key } "
475
+ self .download_file_if_not_exists (download_url , os .path .join (self .model_file_dir , config_key ))
476
+ except RuntimeError :
477
+ self .logger .debug ("Model not found in UVR repo, attempting to download from audio-separator models repo..." )
478
+ download_url = f"{ audio_separator_models_repo_url_prefix } /{ config_key } "
479
+ self .download_file_if_not_exists (download_url , os .path .join (self .model_file_dir , config_key ))
493
480
494
481
# In case the user specified the YAML filename as the model input instead of the model filename, correct that
495
482
if model_filename .endswith (".yaml" ):
@@ -503,11 +490,15 @@ def download_model_files(self, model_filename):
503
490
yaml_config_filename = config_value
504
491
yaml_config_filepath = os .path .join (self .model_file_dir , yaml_config_filename )
505
492
506
- # Repo for model data and configuration sources from UVR
507
- model_data_url_prefix = "https://raw.githubusercontent.com/TRvlvr/application_data/main"
508
- yaml_config_url = f"{ model_data_url_prefix } /mdx_model_data/mdx_c_configs/{ yaml_config_filename } "
509
-
510
- self .download_file_if_not_exists (f"{ yaml_config_url } " , yaml_config_filepath )
493
+ try :
494
+ # Repo for model data and configuration sources from UVR
495
+ model_data_url_prefix = "https://raw.githubusercontent.com/TRvlvr/application_data/main"
496
+ yaml_config_url = f"{ model_data_url_prefix } /mdx_model_data/mdx_c_configs/{ yaml_config_filename } "
497
+ self .download_file_if_not_exists (f"{ yaml_config_url } " , yaml_config_filepath )
498
+ except RuntimeError :
499
+ self .logger .debug ("Model YAML config file not found in UVR repo, attempting to download from audio-separator models repo..." )
500
+ yaml_config_url = f"{ audio_separator_models_repo_url_prefix } /{ yaml_config_filename } "
501
+ self .download_file_if_not_exists (f"{ yaml_config_url } " , yaml_config_filepath )
511
502
512
503
# MDX and VR models have config_value set to the model filename
513
504
else :
@@ -699,12 +690,7 @@ def load_model(self, model_filename="model_mel_band_roformer_ep_3005_sdr_11.4360
699
690
}
700
691
701
692
# Instantiate the appropriate separator class depending on the model type
702
- separator_classes = {
703
- "MDX" : "mdx_separator.MDXSeparator" ,
704
- "VR" : "vr_separator.VRSeparator" ,
705
- "Demucs" : "demucs_separator.DemucsSeparator" ,
706
- "MDXC" : "mdxc_separator.MDXCSeparator" ,
707
- }
693
+ separator_classes = {"MDX" : "mdx_separator.MDXSeparator" , "VR" : "vr_separator.VRSeparator" , "Demucs" : "demucs_separator.DemucsSeparator" , "MDXC" : "mdxc_separator.MDXCSeparator" }
708
694
709
695
if model_type not in self .arch_specific_params or model_type not in separator_classes :
710
696
raise ValueError (f"Model type not supported (yet): { model_type } " )
0 commit comments