@@ -22,6 +22,8 @@ def __init__(
22
22
model_name = "UVR_MDXNET_KARA_2" ,
23
23
model_file_dir = "/tmp/audio-separator-models/" ,
24
24
output_dir = None ,
25
+ primary_stem_path = None ,
26
+ secondary_stem_path = None ,
25
27
use_cuda = False ,
26
28
use_coreml = False ,
27
29
output_format = "WAV" ,
@@ -41,7 +43,9 @@ def __init__(
41
43
self .log_formatter = logging .Formatter ("%(asctime)s - %(levelname)s - %(module)s - %(message)s" )
42
44
43
45
self .log_handler .setFormatter (self .log_formatter )
44
- self .logger .addHandler (self .log_handler )
46
+
47
+ if not self .logger .hasHandlers ():
48
+ self .logger .addHandler (self .log_handler )
45
49
46
50
self .logger .debug (
47
51
f"Separator instantiating with input file: { audio_file_path } , model_name: { model_name } , output_dir: { output_dir } , use_cuda: { use_cuda } , output_format: { output_format } "
@@ -52,6 +56,8 @@ def __init__(
52
56
self .output_dir = output_dir
53
57
self .use_cuda = use_cuda
54
58
self .use_coreml = use_coreml
59
+ self .primary_stem_path = primary_stem_path
60
+ self .secondary_stem_path = secondary_stem_path
55
61
56
62
# Create the model directory if it does not exist
57
63
os .makedirs (self .model_file_dir , exist_ok = True )
@@ -206,17 +212,21 @@ def separate(self):
206
212
207
213
if not self .output_single_stem or self .output_single_stem .lower () == self .primary_stem .lower ():
208
214
self .logger .info (f"Saving { self .primary_stem } stem..." )
209
- primary_stem_path = os .path .join (f"{ self .audio_file_base } _({ self .primary_stem } )_{ self .model_name } .{ self .output_format .lower ()} " )
210
- self .write_audio (primary_stem_path , self .primary_source , samplerate )
211
- output_files .append (primary_stem_path )
215
+ if not self .primary_stem_path :
216
+ self .primary_stem_path = os .path .join (
217
+ f"{ self .audio_file_base } _({ self .primary_stem } )_{ self .model_name } .{ self .output_format .lower ()} "
218
+ )
219
+ self .write_audio (self .primary_stem_path , self .primary_source , samplerate )
220
+ output_files .append (self .primary_stem_path )
212
221
213
222
if not self .output_single_stem or self .output_single_stem .lower () == self .secondary_stem .lower ():
214
223
self .logger .info (f"Saving { self .secondary_stem } stem..." )
215
- secondary_stem_path = os .path .join (
216
- f"{ self .audio_file_base } _({ self .secondary_stem } )_{ self .model_name } .{ self .output_format .lower ()} "
217
- )
218
- self .write_audio (secondary_stem_path , self .secondary_source , samplerate )
219
- output_files .append (secondary_stem_path )
224
+ if not self .secondary_stem_path :
225
+ self .secondary_stem_path = os .path .join (
226
+ f"{ self .audio_file_base } _({ self .secondary_stem } )_{ self .model_name } .{ self .output_format .lower ()} "
227
+ )
228
+ self .write_audio (self .secondary_stem_path , self .secondary_source , samplerate )
229
+ output_files .append (self .secondary_stem_path )
220
230
221
231
torch .cuda .empty_cache ()
222
232
return output_files
0 commit comments