Skip to content

Commit

Permalink
c-wrapper working
Browse files Browse the repository at this point in the history
  • Loading branch information
maxwellflitton committed Dec 24, 2024
1 parent e1eb109 commit d33d9af
Show file tree
Hide file tree
Showing 19 changed files with 953 additions and 0 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ surrealml.egg-info/
.vscode/
./modules/utils/target/
modules/core/target/
modules/c-wrapper/target/
./modules/onnx_driver/target/
modules/onnx_driver/target/
surrealdb_build/
Expand All @@ -27,3 +28,4 @@ surrealml/rust_surrealml.cpython-310-darwin.so
./modules/pipelines/runners/integrated_training_runner/run_env/
modules/pipelines/runners/integrated_training_runner/run_env/
modules/pipelines/data_access/target/
*.dylib
12 changes: 12 additions & 0 deletions modules/c-wrapper/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
[package]
name = "c-wrapper"
version = "0.1.0"
edition = "2021"

[dependencies]
surrealml-core = { path = "../core" }
uuid = { version = "1.4.1", features = ["v4"] }
ndarray = "0.16.1"

[lib]
crate-type = ["cdylib"]
Empty file added modules/c-wrapper/README.md
Empty file.
45 changes: 45 additions & 0 deletions modules/c-wrapper/scripts/prep_tests.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
#!/usr/bin/env bash

# navigate to directory
SCRIPTPATH="$( cd "$(dirname "$0")" ; pwd -P )"
cd $SCRIPTPATH

cd ..

cargo build

# Get the operating system
OS=$(uname)

# Set the library name and extension based on the OS
case "$OS" in
"Linux")
LIB_NAME="libc_wrapper.so"
;;
"Darwin")
LIB_NAME="libc_wrapper.dylib"
;;
"CYGWIN"*|"MINGW"*)
LIB_NAME="libc_wrapper.dll"
;;
*)
echo "Unsupported operating system: $OS"
exit 1
;;
esac

# Source directory (where Cargo outputs the compiled library)
SOURCE_DIR="target/debug"

# Destination directory (tests directory)
DEST_DIR="tests"


# Copy the library to the tests directory
if [ -f "$SOURCE_DIR/$LIB_NAME" ]; then
cp "$SOURCE_DIR/$LIB_NAME" "$DEST_DIR/"
echo "Copied $LIB_NAME to $DEST_DIR"
else
echo "Library not found: $SOURCE_DIR/$LIB_NAME"
exit 1
fi
1 change: 1 addition & 0 deletions modules/c-wrapper/src/api/execution/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pub mod raw_compute;
120 changes: 120 additions & 0 deletions modules/c-wrapper/src/api/execution/raw_compute.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
use crate::state::STATE;
use std::ffi::{c_float, CStr, CString, c_int, c_char};
use surrealml_core::execution::compute::ModelComputation;


#[repr(C)]
pub struct Vecf32Return {
pub data: *mut f32,
pub length: usize,
pub capacity: usize, // Optional if you want to include capacity for clarity
pub is_error: c_int,
pub error_message: *mut c_char
}


#[no_mangle]
pub extern "C" fn free_vecf32_return(vecf32_return: Vecf32Return) {
// Free the data if it is not null
if !vecf32_return.data.is_null() {
unsafe { drop(Vec::from_raw_parts(vecf32_return.data, vecf32_return.length, vecf32_return.capacity)) };
}
// Free the error message if it is not null
if !vecf32_return.error_message.is_null() {
unsafe { drop(CString::from_raw(vecf32_return.error_message)) };
}
}



#[no_mangle]
pub extern "C" fn raw_compute(file_id_ptr: *const c_char, data_ptr: *const c_float, length: usize) -> Vecf32Return {

if file_id_ptr.is_null() {
return Vecf32Return {
data: std::ptr::null_mut(),
length: 0,
capacity: 0,
is_error: 1,
error_message: CString::new("File id is null").unwrap().into_raw()
}
}
if data_ptr.is_null() {
return Vecf32Return {
data: std::ptr::null_mut(),
length: 0,
capacity: 0,
is_error: 1,
error_message: CString::new("Data is null").unwrap().into_raw()
}
}

let file_id = match unsafe { CStr::from_ptr(file_id_ptr) }.to_str() {
Ok(file_id) => file_id.to_owned(),
Err(error) => return Vecf32Return {
data: std::ptr::null_mut(),
length: 0,
capacity: 0,
is_error: 1,
error_message: CString::new(format!("Error getting file id: {}", error)).unwrap().into_raw()
}
};

let mut state = match STATE.lock() {
Ok(state) => state,
Err(error) => {
return Vecf32Return {
data: std::ptr::null_mut(),
length: 0,
capacity: 0,
is_error: 1,
error_message: CString::new(format!("Error getting state: {}", error)).unwrap().into_raw()
}
}
};
let mut file = match state.get_mut(&file_id) {
Some(file) => file,
None => {
{
return Vecf32Return {
data: std::ptr::null_mut(),
length: 0,
capacity: 0,
is_error: 1,
error_message: CString::new(format!("File not found for id: {}, here is the state: {:?}", file_id, state.keys())).unwrap().into_raw()
}
}
}
};

let slice = unsafe { std::slice::from_raw_parts(data_ptr, length) };
let tensor = ndarray::arr1(slice).into_dyn();
let compute_unit = ModelComputation {
surml_file: &mut file
};

// perform the computation
let mut outcome = match compute_unit.raw_compute(tensor, None) {
Ok(outcome) => outcome,
Err(error) => {
return Vecf32Return {
data: std::ptr::null_mut(),
length: 0,
capacity: 0,
is_error: 1,
error_message: CString::new(format!("Error computing model: {}", error.message)).unwrap().into_raw()
}
}
};
let outcome_ptr = outcome.as_mut_ptr();
let outcome_len = outcome.len();
let outcome_capacity = outcome.capacity();
std::mem::forget(outcome);
Vecf32Return {
data: outcome_ptr,
length: outcome_len,
capacity: outcome_capacity,
is_error: 0,
error_message: std::ptr::null_mut()
}
}
2 changes: 2 additions & 0 deletions modules/c-wrapper/src/api/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
pub mod execution;
pub mod storage;
31 changes: 31 additions & 0 deletions modules/c-wrapper/src/api/storage/load_cached_raw_model.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
use crate::state::{STATE, generate_unique_id};
use surrealml_core::storage::surml_file::SurMlFile;
use std::fs::File;
use std::io::Read;
use std::os::raw::c_char;
use std::ffi::CString;
use std::ffi::CStr;
use crate::utils::StringReturn;
use crate::process_string_for_string_return;


/// Loads a PyTorch C model from a file wrapping it in a SurMlFile struct
/// which is stored in memory and referenced by a unique ID.
///
/// # Arguments
/// * `file_path` - The path to the file to load.
///
/// # Returns
/// A unique identifier for the loaded model.
#[no_mangle]
pub extern "C" fn load_cached_raw_model(file_path_ptr: *const c_char) -> StringReturn {
let file_path_str = process_string_for_string_return!(file_path_ptr, "file path");
let file_id = generate_unique_id();
let mut model = File::open(file_path_str).unwrap();
let mut data = vec![];
model.read_to_end(&mut data).unwrap();
let file = SurMlFile::fresh(data);
let mut python_state = STATE.lock().unwrap();
python_state.insert(file_id.clone(), file);
StringReturn::success(file_id)
}
111 changes: 111 additions & 0 deletions modules/c-wrapper/src/api/storage/load_model.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
use crate::state::{STATE, generate_unique_id};
use surrealml_core::storage::surml_file::SurMlFile;
use std::ffi::CStr;
use std::ffi::CString;
use std::os::raw::c_char;


#[repr(C)]
pub struct FileInfo {
file_id: *mut c_char,
name: *mut c_char,
description: *mut c_char,
version: *mut c_char,
error_message: *mut c_char
}

#[no_mangle]
pub extern "C" fn free_file_info(info: FileInfo) {
// Free all allocated strings if they are not null
if !info.file_id.is_null() {
unsafe { drop(CString::from_raw(info.file_id)) };
}
if !info.name.is_null() {
unsafe { drop(CString::from_raw(info.name)) };
}
if !info.description.is_null() {
unsafe { drop(CString::from_raw(info.description)) };
}
if !info.version.is_null() {
unsafe { drop(CString::from_raw(info.version)) };
}
if !info.error_message.is_null() {
unsafe { drop(CString::from_raw(info.error_message)) };
}
}

/// Loads a model from a file and returns a unique identifier for the loaded model.
///
/// # Arguments
/// * `file_path_ptr` - A pointer to the file path of the model to load.
///
/// # Returns
/// Meta data around the model and a unique identifier for the loaded model.
#[no_mangle]
pub extern "C" fn load_model(file_path_ptr: *const c_char) -> FileInfo {

// checking that the file path pointer is not null
if file_path_ptr.is_null() {
return FileInfo {
file_id: std::ptr::null_mut(),
name: std::ptr::null_mut(),
description: std::ptr::null_mut(),
version: std::ptr::null_mut(),
error_message: CString::new("Received a null pointer for file path").unwrap().into_raw()
};
}

// Convert the raw C string to a Rust string
let c_str = unsafe { CStr::from_ptr(file_path_ptr) };

// convert the CStr into a &str
let file_path = match c_str.to_str() {
Ok(rust_str) => rust_str,
Err(_) => {
return FileInfo {
file_id: std::ptr::null_mut(),
name: std::ptr::null_mut(),
description: std::ptr::null_mut(),
version: std::ptr::null_mut(),
error_message: CString::new("Invalid UTF-8 string received for file path").unwrap().into_raw()
};
}
};

let file = match SurMlFile::from_file(&file_path) {
Ok(file) => file,
Err(e) => {
return FileInfo {
file_id: std::ptr::null_mut(),
name: std::ptr::null_mut(),
description: std::ptr::null_mut(),
version: std::ptr::null_mut(),
error_message: CString::new(e.to_string()).unwrap().into_raw()
};
}
};

// get the meta data from the file
let name = file.header.name.to_string();
let description = file.header.description.to_string();
let version = file.header.version.to_string();

// insert the file into the state
let file_id = generate_unique_id();
let mut state = STATE.lock().unwrap();
state.insert(file_id.clone(), file);

// return the meta data
let file_id = CString::new(file_id).unwrap();
let name = CString::new(name).unwrap();
let description = CString::new(description).unwrap();
let version = CString::new(version).unwrap();

FileInfo {
file_id: file_id.into_raw(),
name: name.into_raw(),
description: description.into_raw(),
version: version.into_raw(),
error_message: std::ptr::null_mut()
}
}
Loading

0 comments on commit d33d9af

Please sign in to comment.