@@ -46,20 +46,11 @@ class CommonSeparator:
46
46
BV_VOCAL_STEM_I = "with_backing_vocals"
47
47
LEAD_VOCAL_STEM_LABEL = "Lead Vocals"
48
48
BV_VOCAL_STEM_LABEL = "Backing Vocals"
49
+ NO_STEM = "No "
49
50
50
- NON_ACCOM_STEMS = (
51
- VOCAL_STEM ,
52
- OTHER_STEM ,
53
- BASS_STEM ,
54
- DRUM_STEM ,
55
- GUITAR_STEM ,
56
- PIANO_STEM ,
57
- SYNTH_STEM ,
58
- STRINGS_STEM ,
59
- WOODWINDS_STEM ,
60
- BRASS_STEM ,
61
- WIND_INST_STEM ,
62
- )
51
+ STEM_PAIR_MAPPER = {VOCAL_STEM : INST_STEM , INST_STEM : VOCAL_STEM , LEAD_VOCAL_STEM : BV_VOCAL_STEM , BV_VOCAL_STEM : LEAD_VOCAL_STEM , PRIMARY_STEM : SECONDARY_STEM }
52
+
53
+ NON_ACCOM_STEMS = (VOCAL_STEM , OTHER_STEM , BASS_STEM , DRUM_STEM , GUITAR_STEM , PIANO_STEM , SYNTH_STEM , STRINGS_STEM , WOODWINDS_STEM , BRASS_STEM , WIND_INST_STEM )
63
54
64
55
def __init__ (self , config ):
65
56
@@ -91,7 +82,7 @@ def __init__(self, config):
91
82
92
83
# Model specific properties
93
84
self .primary_stem_name = self .model_data .get ("primary_stem" , "Vocals" )
94
- self .secondary_stem_name = "Vocals" if self .primary_stem_name == "Instrumental" else "Instrumental"
85
+ self .secondary_stem_name = self .secondary_stem ( self . primary_stem_name )
95
86
self .is_karaoke = self .model_data .get ("is_karaoke" , False )
96
87
self .is_bv_model = self .model_data .get ("is_bv_model" , False )
97
88
self .bv_model_rebalance = self .model_data .get ("is_bv_model_rebalanced" , 0 )
@@ -117,6 +108,17 @@ def __init__(self, config):
117
108
118
109
self .cached_sources_map = {}
119
110
111
+ def secondary_stem (self , primary_stem : str ):
112
+ """Determines secondary stem name based on the primary stem name."""
113
+ primary_stem = primary_stem if primary_stem else self .NO_STEM
114
+
115
+ if primary_stem in self .STEM_PAIR_MAPPER :
116
+ secondary_stem = self .STEM_PAIR_MAPPER [primary_stem ]
117
+ else :
118
+ secondary_stem = primary_stem .replace (self .NO_STEM , "" ) if self .NO_STEM in primary_stem else f"{ self .NO_STEM } { primary_stem } "
119
+
120
+ return secondary_stem
121
+
120
122
def separate (self , audio_file_path ):
121
123
"""
122
124
Placeholder method for separating audio sources. Should be overridden by subclasses.
0 commit comments