-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
raw compute with C dylib now working for surrealml python client
- Loading branch information
1 parent
cdec7c0
commit 8adeb23
Showing
31 changed files
with
1,316 additions
and
3 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
[build-system] | ||
requires = ["setuptools", "wheel", "build"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
onnxruntime==1.17.3 | ||
numpy==1.26.3 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
from pathlib import Path | ||
|
||
current_dir = Path(__file__).parent.joinpath("..").joinpath("surrealml").joinpath("test.py") | ||
|
||
with open(current_dir, "w") as f: | ||
f.write("argh yeah") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
|
||
python -m build |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
import os | ||
import platform | ||
import shutil | ||
import subprocess | ||
from pathlib import Path | ||
|
||
from setuptools import setup | ||
|
||
|
||
def get_c_lib_name() -> str: | ||
system = platform.system() | ||
if system == "Linux": | ||
return "libc_wrapper.so" | ||
elif system == "Darwin": # macOS | ||
return "libc_wrapper.dylib" | ||
elif system == "Windows": | ||
return "libc_wrapper.dll" | ||
raise ValueError(f"Unsupported system: {system}") | ||
|
||
# define the paths to the C wrapper and root | ||
DIR_PATH = Path(__file__).parent | ||
ROOT_PATH = DIR_PATH.joinpath("..").joinpath("..") | ||
C_PATH = ROOT_PATH.joinpath("modules").joinpath("c-wrapper") | ||
BINARY_PATH = C_PATH.joinpath("target").joinpath("release").joinpath(get_c_lib_name()) | ||
BINARY_DIST = DIR_PATH.joinpath("surrealml").joinpath(get_c_lib_name()) | ||
|
||
build_flag = False | ||
|
||
# build the C lib and copy it over to the python lib | ||
if BINARY_DIST.exists() is False: | ||
subprocess.Popen("cargo build --release", cwd=str(C_PATH), shell=True).wait() | ||
shutil.copyfile(BINARY_PATH, BINARY_DIST) | ||
build_flag = True | ||
|
||
setup( | ||
name="surrealml", | ||
version="0.1.0", | ||
description="A machine learning package for interfacing with various frameworks.", | ||
author="Maxwell Flitton", | ||
author_email="[email protected]", | ||
url="https://github.com/surrealdb/surrealml", | ||
license="MIT", | ||
classifiers=[ | ||
"Programming Language :: Python :: 3", | ||
"License :: OSI Approved :: MIT License", | ||
"Operating System :: OS Independent", | ||
], | ||
python_requires=">=3.6", | ||
install_requires=[ | ||
"numpy==1.26.3", | ||
], | ||
extras_require={ | ||
"sklearn": [ | ||
"skl2onnx==1.16.0", | ||
"scikit-learn==1.4.0", | ||
], | ||
"torch": [ | ||
"torch==2.1.2", | ||
], | ||
"tensorflow": [ | ||
"tf2onnx==1.16.1", | ||
"tensorflow==2.16.1", | ||
], | ||
}, | ||
packages=[ | ||
"surrealml", | ||
"surrealml.engine", | ||
"surrealml.model_templates", | ||
"surrealml.model_templates.datasets", | ||
"surrealml.model_templates.sklearn", | ||
"surrealml.model_templates.torch", | ||
"surrealml.model_templates.tensorflow", | ||
], | ||
package_data={ | ||
"surrealml": ["libc_wrapper.so", "libc_wrapper.dylib", "libc_wrapper.dll"] | ||
}, | ||
include_package_data=True, | ||
zip_safe=False, | ||
) | ||
|
||
# cleanup after install | ||
if build_flag is True: | ||
os.remove(BINARY_DIST) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from surrealml.surml_file import SurMlFile | ||
from surrealml.engine import Engine |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
""" | ||
Defines all the C structs that are returned from the C lib. | ||
""" | ||
from ctypes import Structure, c_char_p, c_int, c_size_t, POINTER, c_float, c_byte | ||
|
||
|
||
class StringReturn(Structure): | ||
""" | ||
A return type that just returns a string | ||
Fields: | ||
string: the string that is being returned (only present if successful) | ||
is_error: 1 if error, 0 if not | ||
error_message: the error message (only present if error) | ||
""" | ||
_fields_ = [ | ||
("string", c_char_p), # Corresponds to *mut c_char | ||
("is_error", c_int), # Corresponds to c_int | ||
("error_message", c_char_p) # Corresponds to *mut c_char | ||
] | ||
|
||
class EmptyReturn(Structure): | ||
""" | ||
A return type that just returns nothing | ||
Fields: | ||
is_error: 1 if error, 0 if not | ||
error_message: the error message (only present if error) | ||
""" | ||
_fields_ = [ | ||
("is_error", c_int), # Corresponds to c_int | ||
("error_message", c_char_p) # Corresponds to *mut c_char | ||
] | ||
|
||
|
||
class FileInfo(Structure): | ||
""" | ||
A return type when loading the meta of a surml file. | ||
Fields: | ||
file_id: a unique identifier for the file in the state of the C lib | ||
name: a name of the model | ||
description: a description of the model | ||
error_message: the error message (only present if error) | ||
is_error: 1 if error, 0 if not | ||
""" | ||
_fields_ = [ | ||
("file_id", c_char_p), # Corresponds to *mut c_char | ||
("name", c_char_p), # Corresponds to *mut c_char | ||
("description", c_char_p), # Corresponds to *mut c_char | ||
("version", c_char_p), # Corresponds to *mut c_char | ||
("error_message", c_char_p), # Corresponds to *mut c_char | ||
("is_error", c_int) # Corresponds to c_int | ||
] | ||
|
||
|
||
class Vecf32Return(Structure): | ||
""" | ||
A return type when loading the meta of a surml vector. | ||
Fields: | ||
data: the result of the ML execution | ||
length: the length of the vector | ||
capacity: the capacity of the vector | ||
is_error: 1 if error, 0 if not | ||
error_message: the error message (only present if error) | ||
""" | ||
_fields_ = [ | ||
("data", POINTER(c_float)), # Pointer to f32 array | ||
("length", c_size_t), # Length of the array | ||
("capacity", c_size_t), # Capacity of the array | ||
("is_error", c_int), # Indicates if it's an error | ||
("error_message", c_char_p), # Optional error message | ||
] | ||
|
||
|
||
class VecU8Return(Structure): | ||
""" | ||
A return type returning bytes. | ||
Fields: | ||
data: bytes | ||
length: the length of the vector | ||
capacity: the capacity of the vector | ||
is_error: 1 if error, 0 if not | ||
error_message: the error message (only present if error) | ||
""" | ||
_fields_ = [ | ||
("data", POINTER(c_byte)), # Pointer to bytes | ||
("length", c_size_t), # Length of the array | ||
("capacity", c_size_t), # Capacity of the array | ||
("is_error", c_int), # Indicates if it's an error | ||
("error_message", c_char_p), | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
from enum import Enum | ||
|
||
from surrealml.engine.sklearn import SklearnOnnxAdapter | ||
from surrealml.engine.torch import TorchOnnxAdapter | ||
from surrealml.engine.tensorflow import TensorflowOnnxAdapter | ||
from surrealml.engine.onnx import OnnxAdapter | ||
|
||
|
||
class Engine(Enum): | ||
""" | ||
The Engine enum is used to specify the engine to use for a given model. | ||
Attributes: | ||
PYTORCH: The PyTorch engine which will be PyTorch and ONNX. | ||
NATIVE: The native engine which will be native rust and linfa. | ||
SKLEARN: The sklearn engine which will be sklearn and ONNX | ||
TENSOFRLOW: The TensorFlow engine which will be TensorFlow and ONNX | ||
ONNX: The ONNX engine which bypasses the conversion to ONNX. | ||
""" | ||
PYTORCH = "pytorch" | ||
NATIVE = "native" | ||
SKLEARN = "sklearn" | ||
TENSORFLOW = "tensorflow" | ||
ONNX = "onnx" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
""" | ||
This file defines the adapter for the ONNX file format. This adapter does not convert anything as the input | ||
model is already in the ONNX format. It simply saves the model to a file. However, I have added this adapter | ||
to keep the same structure as the other adapters for different engines (maxwell flitton). | ||
""" | ||
from surrealml.engine.utils import create_file_cache_path | ||
|
||
|
||
class OnnxAdapter: | ||
|
||
@staticmethod | ||
def save_model_to_onnx(model, inputs) -> str: | ||
""" | ||
Saves a model to an onnx file. | ||
:param model: the raw ONNX model to directly save | ||
:param inputs: the inputs to the model needed to trace the model | ||
:return: the path to the cache created with a unique id to prevent collisions. | ||
""" | ||
file_path = create_file_cache_path() | ||
|
||
with open(file_path, "wb") as f: | ||
f.write(model.SerializeToString()) | ||
|
||
return file_path | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
""" | ||
This file defines the adapter that converts an sklearn model to an onnx model and saves the onnx model to a file. | ||
""" | ||
try: | ||
import skl2onnx | ||
except ImportError: | ||
skl2onnx = None | ||
|
||
from surrealml.engine.utils import create_file_cache_path | ||
|
||
|
||
class SklearnOnnxAdapter: | ||
""" | ||
Converts and saves sklearn models to onnx format. | ||
""" | ||
|
||
@staticmethod | ||
def check_dependency() -> None: | ||
""" | ||
Checks if the sklearn dependency is installed raising an error if not. | ||
Please call this function when performing any sklearn related operations. | ||
""" | ||
if skl2onnx is None: | ||
raise ImportError("sklearn feature needs to be installed to use sklearn features") | ||
|
||
@staticmethod | ||
def save_model_to_onnx(model, inputs) -> str: | ||
""" | ||
Saves a sklearn model to an onnx file. | ||
:param model: the sklearn model to convert. | ||
:param inputs: the inputs to the model needed to trace the model | ||
:return: the path to the cache created with a unique id to prevent collisions. | ||
""" | ||
SklearnOnnxAdapter.check_dependency() | ||
file_path = create_file_cache_path() | ||
onnx = skl2onnx.to_onnx(model, inputs) | ||
|
||
with open(file_path, "wb") as f: | ||
f.write(onnx.SerializeToString()) | ||
|
||
return file_path |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
import os | ||
import shutil | ||
try: | ||
import tf2onnx | ||
import tensorflow as tf | ||
except ImportError: | ||
tf2onnx = None | ||
tf = None | ||
|
||
from surrealml.engine.utils import TensorflowCache | ||
|
||
|
||
class TensorflowOnnxAdapter: | ||
|
||
@staticmethod | ||
def check_dependency() -> None: | ||
""" | ||
Checks if the tensorflow dependency is installed raising an error if not. | ||
Please call this function when performing any tensorflow related operations. | ||
""" | ||
if tf2onnx is None or tf is None: | ||
raise ImportError("tensorflow feature needs to be installed to use tensorflow features") | ||
|
||
@staticmethod | ||
def save_model_to_onnx(model, inputs) -> str: | ||
""" | ||
Saves a tensorflow model to an onnx file. | ||
:param model: the tensorflow model to convert. | ||
:param inputs: the inputs to the model needed to trace the model | ||
:return: the path to the cache created with a unique id to prevent collisions. | ||
""" | ||
TensorflowOnnxAdapter.check_dependency() | ||
cache = TensorflowCache() | ||
|
||
model_file_path = cache.new_cache_path | ||
onnx_file_path = cache.new_cache_path | ||
|
||
tf.saved_model.save(model, model_file_path) | ||
|
||
os.system( | ||
f"python -m tf2onnx.convert --saved-model {model_file_path} --output {onnx_file_path}" | ||
) | ||
shutil.rmtree(model_file_path) | ||
return onnx_file_path |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
try: | ||
import torch | ||
except ImportError: | ||
torch = None | ||
|
||
from surrealml.engine.utils import create_file_cache_path | ||
|
||
|
||
class TorchOnnxAdapter: | ||
|
||
@staticmethod | ||
def check_dependency() -> None: | ||
""" | ||
Checks if the sklearn dependency is installed raising an error if not. | ||
Please call this function when performing any sklearn related operations. | ||
""" | ||
if torch is None: | ||
raise ImportError("torch feature needs to be installed to use torch features") | ||
|
||
@staticmethod | ||
def save_model_to_onnx(model, inputs) -> str: | ||
""" | ||
Saves a torch model to an onnx file. | ||
:param model: the torch model to convert. | ||
:param inputs: the inputs to the model needed to trace the model | ||
:return: the path to the cache created with a unique id to prevent collisions. | ||
""" | ||
# the dynamic import it to prevent the torch dependency from being required for the whole package. | ||
file_path = create_file_cache_path() | ||
# below is to satisfy type checkers | ||
if torch is not None: | ||
traced_script_module = torch.jit.trace(model, inputs) | ||
torch.onnx.export(traced_script_module, inputs, file_path) | ||
return file_path |
Oops, something went wrong.