-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
e1eb109
commit d33d9af
Showing
19 changed files
with
953 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
[package] | ||
name = "c-wrapper" | ||
version = "0.1.0" | ||
edition = "2021" | ||
|
||
[dependencies] | ||
surrealml-core = { path = "../core" } | ||
uuid = { version = "1.4.1", features = ["v4"] } | ||
ndarray = "0.16.1" | ||
|
||
[lib] | ||
crate-type = ["cdylib"] |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
#!/usr/bin/env bash | ||
|
||
# navigate to directory | ||
SCRIPTPATH="$( cd "$(dirname "$0")" ; pwd -P )" | ||
cd $SCRIPTPATH | ||
|
||
cd .. | ||
|
||
cargo build | ||
|
||
# Get the operating system | ||
OS=$(uname) | ||
|
||
# Set the library name and extension based on the OS | ||
case "$OS" in | ||
"Linux") | ||
LIB_NAME="libc_wrapper.so" | ||
;; | ||
"Darwin") | ||
LIB_NAME="libc_wrapper.dylib" | ||
;; | ||
"CYGWIN"*|"MINGW"*) | ||
LIB_NAME="libc_wrapper.dll" | ||
;; | ||
*) | ||
echo "Unsupported operating system: $OS" | ||
exit 1 | ||
;; | ||
esac | ||
|
||
# Source directory (where Cargo outputs the compiled library) | ||
SOURCE_DIR="target/debug" | ||
|
||
# Destination directory (tests directory) | ||
DEST_DIR="tests" | ||
|
||
|
||
# Copy the library to the tests directory | ||
if [ -f "$SOURCE_DIR/$LIB_NAME" ]; then | ||
cp "$SOURCE_DIR/$LIB_NAME" "$DEST_DIR/" | ||
echo "Copied $LIB_NAME to $DEST_DIR" | ||
else | ||
echo "Library not found: $SOURCE_DIR/$LIB_NAME" | ||
exit 1 | ||
fi |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
pub mod raw_compute; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
use crate::state::STATE; | ||
use std::ffi::{c_float, CStr, CString, c_int, c_char}; | ||
use surrealml_core::execution::compute::ModelComputation; | ||
|
||
|
||
#[repr(C)] | ||
pub struct Vecf32Return { | ||
pub data: *mut f32, | ||
pub length: usize, | ||
pub capacity: usize, // Optional if you want to include capacity for clarity | ||
pub is_error: c_int, | ||
pub error_message: *mut c_char | ||
} | ||
|
||
|
||
#[no_mangle] | ||
pub extern "C" fn free_vecf32_return(vecf32_return: Vecf32Return) { | ||
// Free the data if it is not null | ||
if !vecf32_return.data.is_null() { | ||
unsafe { drop(Vec::from_raw_parts(vecf32_return.data, vecf32_return.length, vecf32_return.capacity)) }; | ||
} | ||
// Free the error message if it is not null | ||
if !vecf32_return.error_message.is_null() { | ||
unsafe { drop(CString::from_raw(vecf32_return.error_message)) }; | ||
} | ||
} | ||
|
||
|
||
|
||
#[no_mangle] | ||
pub extern "C" fn raw_compute(file_id_ptr: *const c_char, data_ptr: *const c_float, length: usize) -> Vecf32Return { | ||
|
||
if file_id_ptr.is_null() { | ||
return Vecf32Return { | ||
data: std::ptr::null_mut(), | ||
length: 0, | ||
capacity: 0, | ||
is_error: 1, | ||
error_message: CString::new("File id is null").unwrap().into_raw() | ||
} | ||
} | ||
if data_ptr.is_null() { | ||
return Vecf32Return { | ||
data: std::ptr::null_mut(), | ||
length: 0, | ||
capacity: 0, | ||
is_error: 1, | ||
error_message: CString::new("Data is null").unwrap().into_raw() | ||
} | ||
} | ||
|
||
let file_id = match unsafe { CStr::from_ptr(file_id_ptr) }.to_str() { | ||
Ok(file_id) => file_id.to_owned(), | ||
Err(error) => return Vecf32Return { | ||
data: std::ptr::null_mut(), | ||
length: 0, | ||
capacity: 0, | ||
is_error: 1, | ||
error_message: CString::new(format!("Error getting file id: {}", error)).unwrap().into_raw() | ||
} | ||
}; | ||
|
||
let mut state = match STATE.lock() { | ||
Ok(state) => state, | ||
Err(error) => { | ||
return Vecf32Return { | ||
data: std::ptr::null_mut(), | ||
length: 0, | ||
capacity: 0, | ||
is_error: 1, | ||
error_message: CString::new(format!("Error getting state: {}", error)).unwrap().into_raw() | ||
} | ||
} | ||
}; | ||
let mut file = match state.get_mut(&file_id) { | ||
Some(file) => file, | ||
None => { | ||
{ | ||
return Vecf32Return { | ||
data: std::ptr::null_mut(), | ||
length: 0, | ||
capacity: 0, | ||
is_error: 1, | ||
error_message: CString::new(format!("File not found for id: {}, here is the state: {:?}", file_id, state.keys())).unwrap().into_raw() | ||
} | ||
} | ||
} | ||
}; | ||
|
||
let slice = unsafe { std::slice::from_raw_parts(data_ptr, length) }; | ||
let tensor = ndarray::arr1(slice).into_dyn(); | ||
let compute_unit = ModelComputation { | ||
surml_file: &mut file | ||
}; | ||
|
||
// perform the computation | ||
let mut outcome = match compute_unit.raw_compute(tensor, None) { | ||
Ok(outcome) => outcome, | ||
Err(error) => { | ||
return Vecf32Return { | ||
data: std::ptr::null_mut(), | ||
length: 0, | ||
capacity: 0, | ||
is_error: 1, | ||
error_message: CString::new(format!("Error computing model: {}", error.message)).unwrap().into_raw() | ||
} | ||
} | ||
}; | ||
let outcome_ptr = outcome.as_mut_ptr(); | ||
let outcome_len = outcome.len(); | ||
let outcome_capacity = outcome.capacity(); | ||
std::mem::forget(outcome); | ||
Vecf32Return { | ||
data: outcome_ptr, | ||
length: outcome_len, | ||
capacity: outcome_capacity, | ||
is_error: 0, | ||
error_message: std::ptr::null_mut() | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
pub mod execution; | ||
pub mod storage; |
31 changes: 31 additions & 0 deletions
31
modules/c-wrapper/src/api/storage/load_cached_raw_model.rs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
use crate::state::{STATE, generate_unique_id}; | ||
use surrealml_core::storage::surml_file::SurMlFile; | ||
use std::fs::File; | ||
use std::io::Read; | ||
use std::os::raw::c_char; | ||
use std::ffi::CString; | ||
use std::ffi::CStr; | ||
use crate::utils::StringReturn; | ||
use crate::process_string_for_string_return; | ||
|
||
|
||
/// Loads a PyTorch C model from a file wrapping it in a SurMlFile struct | ||
/// which is stored in memory and referenced by a unique ID. | ||
/// | ||
/// # Arguments | ||
/// * `file_path` - The path to the file to load. | ||
/// | ||
/// # Returns | ||
/// A unique identifier for the loaded model. | ||
#[no_mangle] | ||
pub extern "C" fn load_cached_raw_model(file_path_ptr: *const c_char) -> StringReturn { | ||
let file_path_str = process_string_for_string_return!(file_path_ptr, "file path"); | ||
let file_id = generate_unique_id(); | ||
let mut model = File::open(file_path_str).unwrap(); | ||
let mut data = vec![]; | ||
model.read_to_end(&mut data).unwrap(); | ||
let file = SurMlFile::fresh(data); | ||
let mut python_state = STATE.lock().unwrap(); | ||
python_state.insert(file_id.clone(), file); | ||
StringReturn::success(file_id) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
use crate::state::{STATE, generate_unique_id}; | ||
use surrealml_core::storage::surml_file::SurMlFile; | ||
use std::ffi::CStr; | ||
use std::ffi::CString; | ||
use std::os::raw::c_char; | ||
|
||
|
||
#[repr(C)] | ||
pub struct FileInfo { | ||
file_id: *mut c_char, | ||
name: *mut c_char, | ||
description: *mut c_char, | ||
version: *mut c_char, | ||
error_message: *mut c_char | ||
} | ||
|
||
#[no_mangle] | ||
pub extern "C" fn free_file_info(info: FileInfo) { | ||
// Free all allocated strings if they are not null | ||
if !info.file_id.is_null() { | ||
unsafe { drop(CString::from_raw(info.file_id)) }; | ||
} | ||
if !info.name.is_null() { | ||
unsafe { drop(CString::from_raw(info.name)) }; | ||
} | ||
if !info.description.is_null() { | ||
unsafe { drop(CString::from_raw(info.description)) }; | ||
} | ||
if !info.version.is_null() { | ||
unsafe { drop(CString::from_raw(info.version)) }; | ||
} | ||
if !info.error_message.is_null() { | ||
unsafe { drop(CString::from_raw(info.error_message)) }; | ||
} | ||
} | ||
|
||
/// Loads a model from a file and returns a unique identifier for the loaded model. | ||
/// | ||
/// # Arguments | ||
/// * `file_path_ptr` - A pointer to the file path of the model to load. | ||
/// | ||
/// # Returns | ||
/// Meta data around the model and a unique identifier for the loaded model. | ||
#[no_mangle] | ||
pub extern "C" fn load_model(file_path_ptr: *const c_char) -> FileInfo { | ||
|
||
// checking that the file path pointer is not null | ||
if file_path_ptr.is_null() { | ||
return FileInfo { | ||
file_id: std::ptr::null_mut(), | ||
name: std::ptr::null_mut(), | ||
description: std::ptr::null_mut(), | ||
version: std::ptr::null_mut(), | ||
error_message: CString::new("Received a null pointer for file path").unwrap().into_raw() | ||
}; | ||
} | ||
|
||
// Convert the raw C string to a Rust string | ||
let c_str = unsafe { CStr::from_ptr(file_path_ptr) }; | ||
|
||
// convert the CStr into a &str | ||
let file_path = match c_str.to_str() { | ||
Ok(rust_str) => rust_str, | ||
Err(_) => { | ||
return FileInfo { | ||
file_id: std::ptr::null_mut(), | ||
name: std::ptr::null_mut(), | ||
description: std::ptr::null_mut(), | ||
version: std::ptr::null_mut(), | ||
error_message: CString::new("Invalid UTF-8 string received for file path").unwrap().into_raw() | ||
}; | ||
} | ||
}; | ||
|
||
let file = match SurMlFile::from_file(&file_path) { | ||
Ok(file) => file, | ||
Err(e) => { | ||
return FileInfo { | ||
file_id: std::ptr::null_mut(), | ||
name: std::ptr::null_mut(), | ||
description: std::ptr::null_mut(), | ||
version: std::ptr::null_mut(), | ||
error_message: CString::new(e.to_string()).unwrap().into_raw() | ||
}; | ||
} | ||
}; | ||
|
||
// get the meta data from the file | ||
let name = file.header.name.to_string(); | ||
let description = file.header.description.to_string(); | ||
let version = file.header.version.to_string(); | ||
|
||
// insert the file into the state | ||
let file_id = generate_unique_id(); | ||
let mut state = STATE.lock().unwrap(); | ||
state.insert(file_id.clone(), file); | ||
|
||
// return the meta data | ||
let file_id = CString::new(file_id).unwrap(); | ||
let name = CString::new(name).unwrap(); | ||
let description = CString::new(description).unwrap(); | ||
let version = CString::new(version).unwrap(); | ||
|
||
FileInfo { | ||
file_id: file_id.into_raw(), | ||
name: name.into_raw(), | ||
description: description.into_raw(), | ||
version: version.into_raw(), | ||
error_message: std::ptr::null_mut() | ||
} | ||
} |
Oops, something went wrong.