Skip to content

Commit 5f58a48

Browse files
committed
Added tentative Apple Silicon GPU support
1 parent 0cd3563 commit 5f58a48

File tree

4 files changed

+122
-82
lines changed

4 files changed

+122
-82
lines changed

audio_separator/separator.py

+22
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def __init__(
2323
model_file_dir="/tmp/audio-separator-models/",
2424
output_dir=None,
2525
use_cuda=False,
26+
use_coreml=False,
2627
output_format="WAV",
2728
output_subtype=None,
2829
normalization_enabled=True,
@@ -50,6 +51,7 @@ def __init__(
5051
self.model_file_dir = model_file_dir
5152
self.output_dir = output_dir
5253
self.use_cuda = use_cuda
54+
self.use_coreml = use_coreml
5355

5456
# Create the model directory if it does not exist
5557
os.makedirs(self.model_file_dir, exist_ok=True)
@@ -116,6 +118,26 @@ def __init__(
116118
else:
117119
raise Exception("CUDA requested but not available with current Torch installation. Do you have an Nvidia GPU?")
118120

121+
elif self.use_coreml:
122+
self.logger.debug("Apple Silicon CoreML requested, checking Torch version")
123+
self.logger.debug(f"Torch version: {str(torch.__version__)}")
124+
125+
mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
126+
self.logger.debug(f"Is Apple Silicon CoreML MPS available? {str(mps_available)}")
127+
128+
if mps_available:
129+
self.logger.debug("Running in Apple Silicon MPS GPU mode")
130+
131+
# TODO: Change this to use MPS once FFTs are supported, see https://github.com/pytorch/pytorch/issues/78044
132+
# self.device = torch.device("mps")
133+
134+
self.device = torch.device("cpu")
135+
self.run_type = ["CoreMLExecutionProvider"]
136+
else:
137+
raise Exception(
138+
"Apple Silicon CoreML / MPS requested but not available with current Torch installation. Do you have an Apple Silicon GPU?"
139+
)
140+
119141
else:
120142
self.logger.debug("Running in CPU mode")
121143
self.device = torch.device("cpu")

audio_separator/utils/cli.py

+7
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,12 @@ def main():
5252
help="Optional: use Nvidia GPU with CUDA for separation (default: %(default)s). Example: --use_cuda=true",
5353
)
5454

55+
parser.add_argument(
56+
"--use_coreml",
57+
action="store_true",
58+
help="Optional: use Apple Silicon GPU with CoreML for separation (default: %(default)s). Example: --use_coreml=true",
59+
)
60+
5561
parser.add_argument(
5662
"--output_format",
5763
default="FLAC",
@@ -97,6 +103,7 @@ def main():
97103
model_file_dir=args.model_file_dir,
98104
output_dir=args.output_dir,
99105
use_cuda=args.use_cuda,
106+
use_coreml=args.use_coreml,
100107
output_format=args.output_format,
101108
denoise_enabled=args.denoise,
102109
normalization_enabled=args.normalize,

0 commit comments

Comments
 (0)