Skip to content

Commit

Permalink
buffered compute is now working and passing tests
Browse files Browse the repository at this point in the history
  • Loading branch information
maxwellflitton committed Dec 26, 2024
1 parent 9922332 commit cdec7c0
Show file tree
Hide file tree
Showing 5 changed files with 286 additions and 37 deletions.
166 changes: 166 additions & 0 deletions modules/c-wrapper/src/api/execution/buffered_compute.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
//! This module contains the buffered_compute function that is called from the C API to compute the model.
use crate::state::STATE;
use std::ffi::{c_float, CStr, CString, c_int, c_char};
use surrealml_core::execution::compute::ModelComputation;
use crate::utils::Vecf32Return;
use std::collections::HashMap;


/// Computes the model with the given data.
///
/// # Arguments
/// * `file_id_ptr` - A pointer to the unique identifier for the loaded model.
/// * `data_ptr` - A pointer to the data to compute.
/// * `length` - The length of the data.
/// * `strings` - A pointer to an array of strings to use as keys for the data.
/// * `string_count` - The number of strings in the array.
///
/// # Returns
/// A Vecf32Return object containing the outcome of the computation.
#[no_mangle]
pub extern "C" fn buffered_compute(
file_id_ptr: *const c_char,
data_ptr: *const c_float,
data_length: usize,
strings: *const *const c_char,
string_count: c_int
) -> Vecf32Return {
if file_id_ptr.is_null() {
return Vecf32Return {
data: std::ptr::null_mut(),
length: 0,
capacity: 0,
is_error: 1,
error_message: CString::new("File id is null").unwrap().into_raw()
}
}
if data_ptr.is_null() {
return Vecf32Return {
data: std::ptr::null_mut(),
length: 0,
capacity: 0,
is_error: 1,
error_message: CString::new("Data is null").unwrap().into_raw()
}
}

let file_id = match unsafe { CStr::from_ptr(file_id_ptr) }.to_str() {
Ok(file_id) => file_id.to_owned(),
Err(error) => return Vecf32Return {
data: std::ptr::null_mut(),
length: 0,
capacity: 0,
is_error: 1,
error_message: CString::new(format!("Error getting file id: {}", error)).unwrap().into_raw()
}
};

if strings.is_null() {
return Vecf32Return {
data: std::ptr::null_mut(),
length: 0,
capacity: 0,
is_error: 1,
error_message: CString::new("string pointer is null").unwrap().into_raw()
}
}

// extract the list of strings from the C array
let string_count = string_count as usize;
let c_strings = unsafe { std::slice::from_raw_parts(strings, string_count) };
let rust_strings: Vec<String> = c_strings
.iter()
.map(|&s| {
if s.is_null() {
String::new()
} else {
unsafe { CStr::from_ptr(s).to_string_lossy().into_owned() }
}
})
.collect();

for i in rust_strings.iter() {
if i.is_empty() {
return Vecf32Return {
data: std::ptr::null_mut(),
length: 0,
capacity: 0,
is_error: 1,
error_message: CString::new("null string passed in as key").unwrap().into_raw()
}
}
}

let data_slice = unsafe { std::slice::from_raw_parts(data_ptr, data_length) };

if rust_strings.len() != data_slice.len() {
return Vecf32Return {
data: std::ptr::null_mut(),
length: 0,
capacity: 0,
is_error: 1,
error_message: CString::new("String count does not match data length").unwrap().into_raw()
}
}

// stitch the strings and data together
let mut input_map = HashMap::new();
for (i, key) in rust_strings.iter().enumerate() {
input_map.insert(key.clone(), data_slice[i]);
}

let mut state = match STATE.lock() {
Ok(state) => state,
Err(error) => {
return Vecf32Return {
data: std::ptr::null_mut(),
length: 0,
capacity: 0,
is_error: 1,
error_message: CString::new(format!("Error getting state: {}", error)).unwrap().into_raw()
}
}
};
let mut file = match state.get_mut(&file_id) {
Some(file) => file,
None => {
{
return Vecf32Return {
data: std::ptr::null_mut(),
length: 0,
capacity: 0,
is_error: 1,
error_message: CString::new(format!("File not found for id: {}, here is the state: {:?}", file_id, state.keys())).unwrap().into_raw()
}
}
}
};

let compute_unit = ModelComputation {
surml_file: &mut file
};
match compute_unit.buffered_compute(&mut input_map) {
Ok(mut output) => {
let output_len = output.len();
let output_capacity = output.capacity();
let output_ptr = output.as_mut_ptr();
std::mem::forget(output);
Vecf32Return {
data: output_ptr,
length: output_len,
capacity: output_capacity,
is_error: 0,
error_message: std::ptr::null_mut()
}
},
Err(error) => {
Vecf32Return {
data: std::ptr::null_mut(),
length: 0,
capacity: 0,
is_error: 1,
error_message: CString::new(format!("Error computing model: {}", error)).unwrap().into_raw()
}
}
}
}
38 changes: 2 additions & 36 deletions modules/c-wrapper/src/api/execution/raw_compute.rs
Original file line number Diff line number Diff line change
@@ -1,42 +1,8 @@
//! This module contains the raw_compute function that is called from the C API to compute the model.
use crate::state::STATE;
use std::ffi::{c_float, CStr, CString, c_int, c_char};
use std::ffi::{c_float, CStr, CString, c_char};
use surrealml_core::execution::compute::ModelComputation;


/// Holds the data around the outcome of the raw_compute function.
///
/// # Fields
/// * `data` - The data returned from the computation.
/// * `length` - The length of the data.
/// * `capacity` - The capacity of the data.
/// * `is_error` - A flag indicating if an error occurred (1 for error, 0 for success).
/// * `error_message` - An error message if the computation failed.
#[repr(C)]
pub struct Vecf32Return {
pub data: *mut f32,
pub length: usize,
pub capacity: usize, // Optional if you want to include capacity for clarity
pub is_error: c_int,
pub error_message: *mut c_char
}


/// Frees the memory allocated for the Vecf32Return.
///
/// # Arguments
/// * `vecf32_return` - The Vecf32Return to free.
#[no_mangle]
pub extern "C" fn free_vecf32_return(vecf32_return: Vecf32Return) {
// Free the data if it is not null
if !vecf32_return.data.is_null() {
unsafe { drop(Vec::from_raw_parts(vecf32_return.data, vecf32_return.length, vecf32_return.capacity)) };
}
// Free the error message if it is not null
if !vecf32_return.error_message.is_null() {
unsafe { drop(CString::from_raw(vecf32_return.error_message)) };
}
}
use crate::utils::Vecf32Return;


/// Computes the model with the given data.
Expand Down
35 changes: 35 additions & 0 deletions modules/c-wrapper/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -307,3 +307,38 @@ pub extern "C" fn free_vec_u8(vec_u8: VecU8Return) {
unsafe { drop(Vec::from_raw_parts(vec_u8.data, vec_u8.length, vec_u8.capacity)) };
}
}


/// Holds the data around the outcome of the raw_compute function.
///
/// # Fields
/// * `data` - The data returned from the computation.
/// * `length` - The length of the data.
/// * `capacity` - The capacity of the data.
/// * `is_error` - A flag indicating if an error occurred (1 for error, 0 for success).
/// * `error_message` - An error message if the computation failed.
#[repr(C)]
pub struct Vecf32Return {
pub data: *mut f32,
pub length: usize,
pub capacity: usize, // Optional if you want to include capacity for clarity
pub is_error: c_int,
pub error_message: *mut c_char
}


/// Frees the memory allocated for the Vecf32Return.
///
/// # Arguments
/// * `vecf32_return` - The Vecf32Return to free.
#[no_mangle]
pub extern "C" fn free_vecf32_return(vecf32_return: Vecf32Return) {
// Free the data if it is not null
if !vecf32_return.data.is_null() {
unsafe { drop(Vec::from_raw_parts(vecf32_return.data, vecf32_return.length, vecf32_return.capacity)) };
}
// Free the error message if it is not null
if !vecf32_return.error_message.is_null() {
unsafe { drop(CString::from_raw(vecf32_return.error_message)) };
}
}
80 changes: 80 additions & 0 deletions modules/c-wrapper/tests/api/execution/test_buffered_compute.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import ctypes
from unittest import TestCase, main

from test_utils.c_lib_loader import load_library
from test_utils.return_structs import FileInfo, Vecf32Return
from test_utils.routes import TEST_SURML_PATH


class TestExecution(TestCase):

def setUp(self) -> None:
self.lib = load_library()

# Define the Rust function signatures
self.lib.load_model.argtypes = [ctypes.c_char_p]
self.lib.load_model.restype = FileInfo

self.lib.free_file_info.argtypes = [FileInfo]

self.lib.buffered_compute.argtypes = [
ctypes.c_char_p, # file_id_ptr -> *const c_char
ctypes.POINTER(ctypes.c_float), # data_ptr -> *const c_float
ctypes.c_size_t, # data_length -> usize
ctypes.POINTER(ctypes.c_char_p), # strings -> *const *const c_char
ctypes.c_int # string_count -> c_int
]
self.lib.buffered_compute.restype = Vecf32Return

self.lib.free_vecf32_return.argtypes = [Vecf32Return]

def test_buffered_compute(self):
# Load a test model
c_string = str(TEST_SURML_PATH).encode('utf-8')
file_info = self.lib.load_model(c_string)

if file_info.error_message:
self.fail(f"Failed to load model: {file_info.error_message.decode('utf-8')}")

input_data = {
"squarefoot": 500.0,
"num_floors": 2.0
}

string_buffer = []
data_buffer = []
for key, value in input_data.items():
string_buffer.append(key.encode('utf-8'))
data_buffer.append(value)

# Prepare input data as a ctypes array
array_type = ctypes.c_float * len(data_buffer) # Create an array type of the appropriate size
input_data = array_type(*data_buffer) # Instantiate the array with the list elements

# prepare the input strings
string_array = (ctypes.c_char_p * len(string_buffer))(*string_buffer)
string_count = len(string_buffer)

# Call the raw_compute function
result = self.lib.buffered_compute(
file_info.file_id,
input_data,
len(input_data),
string_array,
string_count
)

if result.is_error:
self.fail(f"Error in buffered_compute: {result.error_message.decode('utf-8')}")

# Extract and verify the computation result
outcome = [result.data[i] for i in range(result.length)]
self.assertEqual(362.9851989746094, outcome[0])

# Free allocated memory
self.lib.free_vecf32_return(result)
self.lib.free_file_info(file_info)


if __name__ == '__main__':
main()
4 changes: 3 additions & 1 deletion modules/c-wrapper/tests/api/execution/test_raw_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ def test_raw_compute(self):
self.fail(f"Failed to load model: {file_info.error_message.decode('utf-8')}")

# Prepare input data as a ctypes array
input_data = (ctypes.c_float * 2)(1.0, 4.0)
data_buffer = [1.0, 4.0]
array_type = ctypes.c_float * len(data_buffer) # Create an array type of the appropriate size
input_data = array_type(*data_buffer) # Instantiate the array with the list elements

# Call the raw_compute function
result = self.lib.raw_compute(file_info.file_id, input_data, len(input_data))
Expand Down

0 comments on commit cdec7c0

Please sign in to comment.