@@ -23,6 +23,7 @@ def __init__(
23
23
model_file_dir = "/tmp/audio-separator-models/" ,
24
24
output_dir = None ,
25
25
use_cuda = False ,
26
+ use_coreml = False ,
26
27
output_format = "WAV" ,
27
28
output_subtype = None ,
28
29
normalization_enabled = True ,
@@ -50,6 +51,7 @@ def __init__(
50
51
self .model_file_dir = model_file_dir
51
52
self .output_dir = output_dir
52
53
self .use_cuda = use_cuda
54
+ self .use_coreml = use_coreml
53
55
54
56
# Create the model directory if it does not exist
55
57
os .makedirs (self .model_file_dir , exist_ok = True )
@@ -116,6 +118,26 @@ def __init__(
116
118
else :
117
119
raise Exception ("CUDA requested but not available with current Torch installation. Do you have an Nvidia GPU?" )
118
120
121
+ elif self .use_coreml :
122
+ self .logger .debug ("Apple Silicon CoreML requested, checking Torch version" )
123
+ self .logger .debug (f"Torch version: { str (torch .__version__ )} " )
124
+
125
+ mps_available = hasattr (torch .backends , "mps" ) and torch .backends .mps .is_available ()
126
+ self .logger .debug (f"Is Apple Silicon CoreML MPS available? { str (mps_available )} " )
127
+
128
+ if mps_available :
129
+ self .logger .debug ("Running in Apple Silicon MPS GPU mode" )
130
+
131
+ # TODO: Change this to use MPS once FFTs are supported, see https://github.com/pytorch/pytorch/issues/78044
132
+ # self.device = torch.device("mps")
133
+
134
+ self .device = torch .device ("cpu" )
135
+ self .run_type = ["CoreMLExecutionProvider" ]
136
+ else :
137
+ raise Exception (
138
+ "Apple Silicon CoreML / MPS requested but not available with current Torch installation. Do you have an Apple Silicon GPU?"
139
+ )
140
+
119
141
else :
120
142
self .logger .debug ("Running in CPU mode" )
121
143
self .device = torch .device ("cpu" )
0 commit comments