Skip to content

Commit 6ae982c

Browse files
authored
minor changes (#107)
* Update cli.py * Create code_formatter.yml * Update calculate-model-hashes.py * Update test_cli.py * Update separator.py * Update common_separator.py * Delete .github/workflows/code_formatter.yml
1 parent d1fcf5b commit 6ae982c

File tree

5 files changed

+105
-16
lines changed

5 files changed

+105
-16
lines changed

audio_separator/separator/common_separator.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,19 @@ class CommonSeparator:
4747
LEAD_VOCAL_STEM_LABEL = "Lead Vocals"
4848
BV_VOCAL_STEM_LABEL = "Backing Vocals"
4949

50-
NON_ACCOM_STEMS = (VOCAL_STEM, OTHER_STEM, BASS_STEM, DRUM_STEM, GUITAR_STEM, PIANO_STEM, SYNTH_STEM, STRINGS_STEM, WOODWINDS_STEM, BRASS_STEM, WIND_INST_STEM)
50+
NON_ACCOM_STEMS = (
51+
VOCAL_STEM,
52+
OTHER_STEM,
53+
BASS_STEM,
54+
DRUM_STEM,
55+
GUITAR_STEM,
56+
PIANO_STEM,
57+
SYNTH_STEM,
58+
STRINGS_STEM,
59+
WOODWINDS_STEM,
60+
BRASS_STEM,
61+
WIND_INST_STEM,
62+
)
5163

5264
def __init__(self, config):
5365

audio_separator/separator/separator.py

+48-8
Original file line numberDiff line numberDiff line change
@@ -74,10 +74,33 @@ def __init__(
7474
output_single_stem=None,
7575
invert_using_spec=False,
7676
sample_rate=44100,
77-
mdx_params={"hop_length": 1024, "segment_size": 256, "overlap": 0.25, "batch_size": 1, "enable_denoise": False},
78-
vr_params={"batch_size": 16, "window_size": 512, "aggression": 5, "enable_tta": False, "enable_post_process": False, "post_process_threshold": 0.2, "high_end_process": False},
79-
demucs_params={"segment_size": "Default", "shifts": 2, "overlap": 0.25, "segments_enabled": True},
80-
mdxc_params={"segment_size": 256, "batch_size": 1, "overlap": 8},
77+
mdx_params={
78+
"hop_length": 1024,
79+
"segment_size": 256,
80+
"overlap": 0.25,
81+
"batch_size": 1,
82+
"enable_denoise": False,
83+
},
84+
vr_params={
85+
"batch_size": 16,
86+
"window_size": 512,
87+
"aggression": 5,
88+
"enable_tta": False,
89+
"enable_post_process": False,
90+
"post_process_threshold": 0.2,
91+
"high_end_process": False,
92+
},
93+
demucs_params={
94+
"segment_size": "Default",
95+
"shifts": 2,
96+
"overlap": 0.25,
97+
"segments_enabled": True,
98+
},
99+
mdxc_params={
100+
"segment_size": 256,
101+
"batch_size": 1,
102+
"overlap": 8,
103+
},
81104
):
82105
self.logger = logging.getLogger(__name__)
83106
self.logger.setLevel(log_level)
@@ -143,7 +166,12 @@ def __init__(
143166

144167
# These are parameters which users may want to configure so we expose them to the top-level Separator class,
145168
# even though they are specific to a single model architecture
146-
self.arch_specific_params = {"MDX": mdx_params, "VR": vr_params, "Demucs": demucs_params, "MDXC": mdxc_params}
169+
self.arch_specific_params = {
170+
"MDX": mdx_params,
171+
"VR": vr_params,
172+
"Demucs": demucs_params,
173+
"MDXC": mdxc_params,
174+
}
147175

148176
self.torch_device = None
149177
self.torch_device_cpu = None
@@ -385,9 +413,16 @@ def list_supported_model_files(self):
385413
# Return object with list of model names, which are the keys in vr_download_list, mdx_download_list, demucs_download_list, mdx23_download_list, mdx23c_download_list, grouped by type: VR, MDX, Demucs, MDX23, MDX23C
386414
model_files_grouped_by_type = {
387415
"VR": model_downloads_list["vr_download_list"],
388-
"MDX": {**model_downloads_list["mdx_download_list"], **model_downloads_list["mdx_download_vip_list"]},
416+
"MDX": {
417+
**model_downloads_list["mdx_download_list"],
418+
**model_downloads_list["mdx_download_vip_list"],
419+
},
389420
"Demucs": filtered_demucs_v4,
390-
"MDXC": {**model_downloads_list["mdx23c_download_list"], **model_downloads_list["mdx23c_download_vip_list"], **model_downloads_list["roformer_download_list"]},
421+
"MDXC": {
422+
**model_downloads_list["mdx23c_download_list"],
423+
**model_downloads_list["mdx23c_download_vip_list"],
424+
**model_downloads_list["roformer_download_list"],
425+
},
391426
}
392427
return model_files_grouped_by_type
393428

@@ -664,7 +699,12 @@ def load_model(self, model_filename="model_mel_band_roformer_ep_3005_sdr_11.4360
664699
}
665700

666701
# Instantiate the appropriate separator class depending on the model type
667-
separator_classes = {"MDX": "mdx_separator.MDXSeparator", "VR": "vr_separator.VRSeparator", "Demucs": "demucs_separator.DemucsSeparator", "MDXC": "mdxc_separator.MDXCSeparator"}
702+
separator_classes = {
703+
"MDX": "mdx_separator.MDXSeparator",
704+
"VR": "vr_separator.VRSeparator",
705+
"Demucs": "demucs_separator.DemucsSeparator",
706+
"MDXC": "mdxc_separator.MDXCSeparator",
707+
}
668708

669709
if model_type not in self.arch_specific_params or model_type not in separator_classes:
670710
raise ValueError(f"Model type not supported (yet): {model_type}")

audio_separator/utils/cli.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,12 @@ def main():
174174
"post_process_threshold": args.vr_post_process_threshold,
175175
"high_end_process": args.vr_high_end_process,
176176
},
177-
demucs_params={"segment_size": args.demucs_segment_size, "shifts": args.demucs_shifts, "overlap": args.demucs_overlap, "segments_enabled": args.demucs_segments_enabled},
177+
demucs_params={
178+
"segment_size": args.demucs_segment_size,
179+
"shifts": args.demucs_shifts,
180+
"overlap": args.demucs_overlap,
181+
"segments_enabled": args.demucs_segments_enabled,
182+
},
178183
mdxc_params={
179184
"segment_size": args.mdxc_segment_size,
180185
"batch_size": args.mdxc_batch_size,

tests/unit/test_cli.py

+29-4
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,35 @@ def common_expected_args():
2020
"output_single_stem": None,
2121
"invert_using_spec": False,
2222
"sample_rate": 44100,
23-
"mdx_params": {"hop_length": 1024, "segment_size": 256, "overlap": 0.25, "batch_size": 1, "enable_denoise": False},
24-
"vr_params": {"batch_size": 4, "window_size": 512, "aggression": 5, "enable_tta": False, "enable_post_process": False, "post_process_threshold": 0.2, "high_end_process": False},
25-
"demucs_params": {"segment_size": "Default", "shifts": 2, "overlap": 0.25, "segments_enabled": True},
26-
"mdxc_params": {"segment_size": 256, "batch_size": 1, "overlap": 8, "override_model_segment_size": False, "pitch_shift": 0},
23+
"mdx_params": {
24+
"hop_length": 1024,
25+
"segment_size": 256,
26+
"overlap": 0.25,
27+
"batch_size": 1,
28+
"enable_denoise": False,
29+
},
30+
"vr_params": {
31+
"batch_size": 4,
32+
"window_size": 512,
33+
"aggression": 5,
34+
"enable_tta": False,
35+
"enable_post_process": False,
36+
"post_process_threshold": 0.2,
37+
"high_end_process": False,
38+
},
39+
"demucs_params": {
40+
"segment_size": "Default",
41+
"shifts": 2,
42+
"overlap": 0.25,
43+
"segments_enabled": True,
44+
},
45+
"mdxc_params": {
46+
"segment_size": 256,
47+
"batch_size": 1,
48+
"overlap": 8,
49+
"override_model_segment_size": False,
50+
"pitch_shift": 0,
51+
},
2752
}
2853

2954

tools/calculate-model-hashes.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -80,12 +80,19 @@ def iterate_and_hash(directory):
8080
vr_model_data = load_json_data(VR_MODEL_DATA_LOCAL_PATH)
8181
mdx_model_data = load_json_data(MDX_MODEL_DATA_LOCAL_PATH)
8282

83-
combined_model_params = {**vr_model_data, **mdx_model_data}
83+
combined_model_params = {
84+
**vr_model_data,
85+
**mdx_model_data,
86+
}
8487

8588
model_info_list = []
8689
for file, file_path in sorted(model_files):
8790
file_hash = get_model_hash(file_path)
88-
model_info = {"file": file, "hash": file_hash, "params": combined_model_params.get(file_hash, "Parameters not found")}
91+
model_info = {
92+
"file": file,
93+
"hash": file_hash,
94+
"params": combined_model_params.get(file_hash, "Parameters not found"),
95+
}
8996
model_info_list.append(model_info)
9097

9198
print(f"Writing model info list to {OUTPUT_PATH}")

0 commit comments

Comments
 (0)