Skip to content

Commit

Permalink
all local functions are now working with the raw C bindings, just nee…
Browse files Browse the repository at this point in the history
…d to test the upload to server left
  • Loading branch information
maxwellflitton committed Dec 30, 2024
1 parent 8adeb23 commit 306df6a
Show file tree
Hide file tree
Showing 11 changed files with 251 additions and 46 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ SurrealML is a feature that allows you to store trained machine learning models
4. Python Environment Setup: A Python environment with necessary libraries installed, including SurrealML, PyTorch or SKLearn (depending on your model preference).
5. SurrealDB Installation: Ensure you have SurrealDB installed and running on your machine or server

## New Clients

We are removing the `PyO3` bindings and just using raw C bindings for the `surrealml-core` library. This will simplfy builds and also enable clients in other languges to use the `surrealml-core` library. The `c-wrapper` module can be found in the `modules/c-wrapper` directory. The new clients can be found in the `clients` directory.

## Installation

To install SurrealML, make sure you have Python installed. Then, install the `SurrealML` library and either `PyTorch` or
Expand Down
4 changes: 4 additions & 0 deletions clients/python/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@

# SurrealML Python Client

The SurrealML Python client using the Rust `surrealml` library without any `PyO3` bindings.
6 changes: 0 additions & 6 deletions clients/python/scripts/build_c_lib.py

This file was deleted.

37 changes: 34 additions & 3 deletions clients/python/surrealml/loader.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
"""
The loader for the dynamic C lib written in Rust.
"""
import ctypes
from pathlib import Path
import platform
from pathlib import Path

from surrealml.c_structs import EmptyReturn, StringReturn, Vecf32Return, FileInfo, VecU8Return


class Singleton(type):

"""
Ensures that the loader only loads once throughout the program's lifetime
"""
_instances = {}

def __call__(cls, *args, **kwargs):
Expand Down Expand Up @@ -45,6 +51,12 @@ def load_library(lib_name: str = "libc_wrapper") -> ctypes.CDLL:
class LibLoader(metaclass=Singleton):

def __init__(self, lib_name: str = "libc_wrapper") -> None:
"""
The constructor for the LibLoader class.
args:
lib_name (str): The base name of the library without extension (e.g., "libc_wrapper").
"""
self.lib = load_library(lib_name=lib_name)
functions = [
self.lib.add_name,
Expand All @@ -60,15 +72,34 @@ def __init__(self, lib_name: str = "libc_wrapper") -> None:
i.restype = EmptyReturn
self.lib.load_model.restype = FileInfo
self.lib.load_model.argtypes = [ctypes.c_char_p]
self.lib.free_file_info.argtypes = [FileInfo]
self.lib.load_cached_raw_model.restype = StringReturn
self.lib.load_cached_raw_model.argtypes = [ctypes.c_char_p]
self.lib.to_bytes.argtypes = [ctypes.c_char_p]
self.lib.to_bytes.restype = VecU8Return
self.lib.save_model.restype = EmptyReturn
self.lib.save_model.argtypes = [ctypes.c_char_p, ctypes.c_char_p]
self.lib.upload_model.argtypes = [
ctypes.c_char_p,
ctypes.c_char_p,
ctypes.c_size_t,
ctypes.c_char_p,
ctypes.c_char_p,
ctypes.c_char_p,
ctypes.c_char_p,
]
self.lib.upload_model.restype = EmptyReturn

# define the compute functions
self.lib.raw_compute.argtypes = [ctypes.c_char_p, ctypes.POINTER(ctypes.c_float), ctypes.c_size_t]
self.lib.raw_compute.restype = Vecf32Return
self.lib.buffered_compute.argtypes = [
ctypes.c_char_p, # file_id_ptr -> *const c_char
ctypes.POINTER(ctypes.c_float), # data_ptr -> *const c_float
ctypes.c_size_t, # data_length -> usize
ctypes.POINTER(ctypes.c_char_p), # strings -> *const *const c_char
ctypes.c_int # string_count -> c_int
]
self.lib.buffered_compute.restype = Vecf32Return

# Define free alloc functions
self.lib.free_string_return.argtypes = [StringReturn]
Expand Down
132 changes: 97 additions & 35 deletions clients/python/surrealml/rust_adapter.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
from typing import Optional
"""
The adapter to interact with the Rust module compiled to a C dynamic library
"""
import ctypes
import warnings
import platform
import warnings
from pathlib import Path
from surrealml.c_structs import EmptyReturn, StringReturn, Vecf32Return, FileInfo
from surrealml.loader import LibLoader
from typing import List, Tuple
from typing import Optional

from surrealml.c_structs import EmptyReturn, StringReturn, Vecf32Return, FileInfo, VecU8Return
from surrealml.engine import Engine
from surrealml.loader import LibLoader


def load_library(lib_name: str = "libc_wrapper") -> ctypes.CDLL:
Expand Down Expand Up @@ -173,8 +176,13 @@ def add_author(self, author: str) -> None:
:param author: the author of the model.
:return: None
"""
# add_author(self.file_id, author)
pass
outcome: EmptyReturn = self.loader.lib.add_author(
self.file_id.encode("utf-8"),
author.encode("utf-8"),
)
if outcome.is_error == 1:
raise RuntimeError(outcome.error_message.decode("utf-8"))
self.loader.lib.free_empty_return(outcome)

def save(self, path: str, name: Optional[str]) -> None:
"""
Expand All @@ -185,25 +193,54 @@ def save(self, path: str, name: Optional[str]) -> None:
:return: None
"""
pass
# add_engine(self.file_id, self.engine.value)
# add_origin(self.file_id, "local")
# if name is not None:
# add_name(self.file_id, name)
# else:
# warnings.warn(
# "You are saving a model without a name, you will not be able to upload this model to the database"
# )
# save_model(path, self.file_id)

def to_bytes(self):
outcome: EmptyReturn = self.loader.lib.add_engine(
self.file_id.encode("utf-8"),
self.engine.value.encode("utf-8"),
)
if outcome.is_error == 1:
raise RuntimeError(outcome.error_message.decode("utf-8"))
self.loader.lib.free_empty_return(outcome)
outcome: EmptyReturn = self.loader.lib.add_origin(
self.file_id.encode("utf-8"),
"local".encode("utf-8"),
)
if outcome.is_error == 1:
raise RuntimeError(outcome.error_message.decode("utf-8"))
self.loader.lib.free_empty_return(outcome)
if name is not None:
outcome: EmptyReturn = self.loader.lib.add_name(
self.file_id.encode("utf-8"),
name.encode("utf-8"),
)
if outcome.is_error == 1:
raise RuntimeError(outcome.error_message.decode("utf-8"))
self.loader.lib.free_empty_return(outcome)
else:
warnings.warn(
"You are saving a model without a name, you will not be able to upload this model to the database"
)
outcome: EmptyReturn = self.loader.lib.save_model(
path.encode("utf-8"),
self.file_id.encode("utf-8")
)
if outcome.is_error == 1:
raise RuntimeError(outcome.error_message.decode("utf-8"))
self.loader.lib.free_empty_return(outcome)

def to_bytes(self) -> bytes:
"""
Converts the model to bytes.
:return: the model as bytes.
"""
pass
# return to_bytes(self.file_id)
outcome: VecU8Return = self.loader.lib.to_bytes(
self.file_id.encode("utf-8"),
)
if outcome.is_error == 1:
raise RuntimeError(outcome.error_message.decode("utf-8"))
byte_vec = outcome.data
self.loader.lib.free_vec_u8(outcome)
return byte_vec

@staticmethod
def load(path) -> Tuple[str, str, str, str]:
Expand Down Expand Up @@ -251,16 +288,19 @@ def upload(
:return: None
"""
pass
# upload_model(
# path,
# url,
# chunk_size,
# namespace,
# database,
# username,
# password
# )
loader = LibLoader()
outcome = loader.lib.upload_model(
path.encode("utf-8"),
url.encode("utf-8"),
chunk_size,
namespace.encode("utf-8"),
database.encode("utf-8"),
username.encode("utf-8"),
password.encode("utf-8"),
)
if outcome.is_error == 1:
raise RuntimeError(outcome.error_message.decode("utf-8"))
loader.lib.free_file_info(outcome)

def raw_compute(self, input_vector, dims=None) -> List[float]:
"""
Expand All @@ -283,14 +323,36 @@ def raw_compute(self, input_vector, dims=None) -> List[float]:
self.loader.lib.free_vecf32_return(outcome)
return package

# return raw_compute(self.file_id, input_vector, dims)

def buffered_compute(self, value_map):
def buffered_compute(self, value_map: dict) -> List[float]:
"""
Calculates an output from the model given a value map.
:param value_map: a dictionary of inputs to the model with the column names as keys and floats as values.
:return: the output of the model.
"""
pass
# return buffered_compute(self.file_id, value_map)
string_buffer = []
data_buffer = []
for key, value in value_map.items():
string_buffer.append(key.encode('utf-8'))
data_buffer.append(value)

# Prepare input data as a ctypes array
array_type = ctypes.c_float * len(data_buffer) # Create an array type of the appropriate size
input_data = array_type(*data_buffer) # Instantiate the array with the list elements

# prepare the input strings
string_array = (ctypes.c_char_p * len(string_buffer))(*string_buffer)
string_count = len(string_buffer)

outcome = self.loader.lib.buffered_compute(
self.file_id.encode("utf-8"),
input_data,
len(input_data),
string_array,
string_count
)
if outcome.is_error == 1:
raise RuntimeError(outcome.error_message.decode("utf-8"))
return_data = [outcome.data[i] for i in range(outcome.length)]
self.loader.lib.free_vecf32_return(outcome)
return return_data
5 changes: 5 additions & 0 deletions modules/c-wrapper/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,10 @@ surrealml-core = { path = "../core" }
uuid = { version = "1.4.1", features = ["v4"] }
ndarray = "0.16.1"

# for the uploading the model to the server
tokio = { version = "1.42.0", features = ["full"] }
hyper = { version = "0.14.27", features = ["full"] }
base64 = "0.13"

[lib]
crate-type = ["cdylib"]
2 changes: 0 additions & 2 deletions modules/c-wrapper/src/api/execution/buffered_compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ pub extern "C" fn buffered_compute(
}
})
.collect();

for i in rust_strings.iter() {
if i.is_empty() {
return Vecf32Return {
Expand Down Expand Up @@ -135,7 +134,6 @@ pub extern "C" fn buffered_compute(
}
}
};

let compute_unit = ModelComputation {
surml_file: &mut file
};
Expand Down
1 change: 1 addition & 0 deletions modules/c-wrapper/src/api/execution/raw_compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ pub extern "C" fn raw_compute(file_id_ptr: *const c_char, data_ptr: *const c_flo
}
}
};

let mut file = match state.get_mut(&file_id) {
Some(file) => file,
None => {
Expand Down
1 change: 1 addition & 0 deletions modules/c-wrapper/src/api/storage/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ pub mod save_model;
pub mod load_cached_raw_model;
pub mod to_bytes;
pub mod meta;
pub mod upload_model;
86 changes: 86 additions & 0 deletions modules/c-wrapper/src/api/storage/upload_model.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
// Standard library imports
use std::ffi::{CStr, CString};
use std::os::raw::c_char;

// External crate imports
use base64::encode;
use hyper::{
Body, Client, Method, Request, Uri,
header::{AUTHORIZATION, CONTENT_TYPE, HeaderValue},
};
use surrealml_core::storage::stream_adapter::StreamAdapter;

// Local module imports
use crate::utils::EmptyReturn;
use crate::{empty_return_safe_eject, process_string_for_empty_return};


/// Uploads a model to a remote server.
///
/// # Arguments
/// * `file_path_ptr` - The path to the file to upload.
/// * `url_ptr` - The URL to upload the file to.
/// * `chunk_size` - The size of the chunks to upload the file in.
/// * `ns_ptr` - The namespace to upload the file to.
/// * `db_ptr` - The database to upload the file to.
/// * `username_ptr` - The username to use for authentication.
/// * `password_ptr` - The password to use for authentication.
///
/// # Returns
/// An empty return object indicating success or failure.
#[no_mangle]
pub extern "C" fn upload_model(
file_path_ptr: *const c_char,
url_ptr: *const c_char,
chunk_size: usize,
ns_ptr: *const c_char,
db_ptr: *const c_char,
username_ptr: *const c_char,
password_ptr: *const c_char
) -> EmptyReturn {
// process the inputs
let file_path = process_string_for_empty_return!(file_path_ptr, "file path");
let url = process_string_for_empty_return!(url_ptr, "url");
let ns = process_string_for_empty_return!(ns_ptr, "namespace");
let db = process_string_for_empty_return!(db_ptr, "database");
let username = match username_ptr.is_null() {
true => None,
false => Some(process_string_for_empty_return!(username_ptr, "username"))
};
let password = match password_ptr.is_null() {
true => None,
false => Some(process_string_for_empty_return!(password_ptr, "password"))
};

let client = Client::new();

let uri: Uri = empty_return_safe_eject!(url.parse());
let generator = empty_return_safe_eject!(StreamAdapter::new(chunk_size, file_path));
let body = Body::wrap_stream(generator);

let part_req = Request::builder()
.method(Method::POST)
.uri(uri)
.header(CONTENT_TYPE, "application/octet-stream")
.header("surreal-ns", empty_return_safe_eject!(HeaderValue::from_str(&ns)))
.header("surreal-db", empty_return_safe_eject!(HeaderValue::from_str(&db)));

let req;
if username.is_none() == false && password.is_none() == false {
// unwraps are safe because we have already checked that the values are not None
let encoded_credentials = encode(format!("{}:{}", username.unwrap(), password.unwrap()));
req = empty_return_safe_eject!(part_req.header(AUTHORIZATION, format!("Basic {}", encoded_credentials))
.body(body));
}
else {
req = empty_return_safe_eject!(part_req.body(body));
}

let tokio_runtime = empty_return_safe_eject!(tokio::runtime::Builder::new_current_thread().enable_io()
.enable_time()
.build());
tokio_runtime.block_on( async move {
let _response = client.request(req).await.unwrap();
});
EmptyReturn::success()
}
Loading

0 comments on commit 306df6a

Please sign in to comment.