-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
buffered compute is now working and passing tests
- Loading branch information
1 parent
9922332
commit cdec7c0
Showing
5 changed files
with
286 additions
and
37 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,166 @@ | ||
//! This module contains the buffered_compute function that is called from the C API to compute the model. | ||
use crate::state::STATE; | ||
use std::ffi::{c_float, CStr, CString, c_int, c_char}; | ||
use surrealml_core::execution::compute::ModelComputation; | ||
use crate::utils::Vecf32Return; | ||
use std::collections::HashMap; | ||
|
||
|
||
/// Computes the model with the given data. | ||
/// | ||
/// # Arguments | ||
/// * `file_id_ptr` - A pointer to the unique identifier for the loaded model. | ||
/// * `data_ptr` - A pointer to the data to compute. | ||
/// * `length` - The length of the data. | ||
/// * `strings` - A pointer to an array of strings to use as keys for the data. | ||
/// * `string_count` - The number of strings in the array. | ||
/// | ||
/// # Returns | ||
/// A Vecf32Return object containing the outcome of the computation. | ||
#[no_mangle] | ||
pub extern "C" fn buffered_compute( | ||
file_id_ptr: *const c_char, | ||
data_ptr: *const c_float, | ||
data_length: usize, | ||
strings: *const *const c_char, | ||
string_count: c_int | ||
) -> Vecf32Return { | ||
if file_id_ptr.is_null() { | ||
return Vecf32Return { | ||
data: std::ptr::null_mut(), | ||
length: 0, | ||
capacity: 0, | ||
is_error: 1, | ||
error_message: CString::new("File id is null").unwrap().into_raw() | ||
} | ||
} | ||
if data_ptr.is_null() { | ||
return Vecf32Return { | ||
data: std::ptr::null_mut(), | ||
length: 0, | ||
capacity: 0, | ||
is_error: 1, | ||
error_message: CString::new("Data is null").unwrap().into_raw() | ||
} | ||
} | ||
|
||
let file_id = match unsafe { CStr::from_ptr(file_id_ptr) }.to_str() { | ||
Ok(file_id) => file_id.to_owned(), | ||
Err(error) => return Vecf32Return { | ||
data: std::ptr::null_mut(), | ||
length: 0, | ||
capacity: 0, | ||
is_error: 1, | ||
error_message: CString::new(format!("Error getting file id: {}", error)).unwrap().into_raw() | ||
} | ||
}; | ||
|
||
if strings.is_null() { | ||
return Vecf32Return { | ||
data: std::ptr::null_mut(), | ||
length: 0, | ||
capacity: 0, | ||
is_error: 1, | ||
error_message: CString::new("string pointer is null").unwrap().into_raw() | ||
} | ||
} | ||
|
||
// extract the list of strings from the C array | ||
let string_count = string_count as usize; | ||
let c_strings = unsafe { std::slice::from_raw_parts(strings, string_count) }; | ||
let rust_strings: Vec<String> = c_strings | ||
.iter() | ||
.map(|&s| { | ||
if s.is_null() { | ||
String::new() | ||
} else { | ||
unsafe { CStr::from_ptr(s).to_string_lossy().into_owned() } | ||
} | ||
}) | ||
.collect(); | ||
|
||
for i in rust_strings.iter() { | ||
if i.is_empty() { | ||
return Vecf32Return { | ||
data: std::ptr::null_mut(), | ||
length: 0, | ||
capacity: 0, | ||
is_error: 1, | ||
error_message: CString::new("null string passed in as key").unwrap().into_raw() | ||
} | ||
} | ||
} | ||
|
||
let data_slice = unsafe { std::slice::from_raw_parts(data_ptr, data_length) }; | ||
|
||
if rust_strings.len() != data_slice.len() { | ||
return Vecf32Return { | ||
data: std::ptr::null_mut(), | ||
length: 0, | ||
capacity: 0, | ||
is_error: 1, | ||
error_message: CString::new("String count does not match data length").unwrap().into_raw() | ||
} | ||
} | ||
|
||
// stitch the strings and data together | ||
let mut input_map = HashMap::new(); | ||
for (i, key) in rust_strings.iter().enumerate() { | ||
input_map.insert(key.clone(), data_slice[i]); | ||
} | ||
|
||
let mut state = match STATE.lock() { | ||
Ok(state) => state, | ||
Err(error) => { | ||
return Vecf32Return { | ||
data: std::ptr::null_mut(), | ||
length: 0, | ||
capacity: 0, | ||
is_error: 1, | ||
error_message: CString::new(format!("Error getting state: {}", error)).unwrap().into_raw() | ||
} | ||
} | ||
}; | ||
let mut file = match state.get_mut(&file_id) { | ||
Some(file) => file, | ||
None => { | ||
{ | ||
return Vecf32Return { | ||
data: std::ptr::null_mut(), | ||
length: 0, | ||
capacity: 0, | ||
is_error: 1, | ||
error_message: CString::new(format!("File not found for id: {}, here is the state: {:?}", file_id, state.keys())).unwrap().into_raw() | ||
} | ||
} | ||
} | ||
}; | ||
|
||
let compute_unit = ModelComputation { | ||
surml_file: &mut file | ||
}; | ||
match compute_unit.buffered_compute(&mut input_map) { | ||
Ok(mut output) => { | ||
let output_len = output.len(); | ||
let output_capacity = output.capacity(); | ||
let output_ptr = output.as_mut_ptr(); | ||
std::mem::forget(output); | ||
Vecf32Return { | ||
data: output_ptr, | ||
length: output_len, | ||
capacity: output_capacity, | ||
is_error: 0, | ||
error_message: std::ptr::null_mut() | ||
} | ||
}, | ||
Err(error) => { | ||
Vecf32Return { | ||
data: std::ptr::null_mut(), | ||
length: 0, | ||
capacity: 0, | ||
is_error: 1, | ||
error_message: CString::new(format!("Error computing model: {}", error)).unwrap().into_raw() | ||
} | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
80 changes: 80 additions & 0 deletions
80
modules/c-wrapper/tests/api/execution/test_buffered_compute.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
import ctypes | ||
from unittest import TestCase, main | ||
|
||
from test_utils.c_lib_loader import load_library | ||
from test_utils.return_structs import FileInfo, Vecf32Return | ||
from test_utils.routes import TEST_SURML_PATH | ||
|
||
|
||
class TestExecution(TestCase): | ||
|
||
def setUp(self) -> None: | ||
self.lib = load_library() | ||
|
||
# Define the Rust function signatures | ||
self.lib.load_model.argtypes = [ctypes.c_char_p] | ||
self.lib.load_model.restype = FileInfo | ||
|
||
self.lib.free_file_info.argtypes = [FileInfo] | ||
|
||
self.lib.buffered_compute.argtypes = [ | ||
ctypes.c_char_p, # file_id_ptr -> *const c_char | ||
ctypes.POINTER(ctypes.c_float), # data_ptr -> *const c_float | ||
ctypes.c_size_t, # data_length -> usize | ||
ctypes.POINTER(ctypes.c_char_p), # strings -> *const *const c_char | ||
ctypes.c_int # string_count -> c_int | ||
] | ||
self.lib.buffered_compute.restype = Vecf32Return | ||
|
||
self.lib.free_vecf32_return.argtypes = [Vecf32Return] | ||
|
||
def test_buffered_compute(self): | ||
# Load a test model | ||
c_string = str(TEST_SURML_PATH).encode('utf-8') | ||
file_info = self.lib.load_model(c_string) | ||
|
||
if file_info.error_message: | ||
self.fail(f"Failed to load model: {file_info.error_message.decode('utf-8')}") | ||
|
||
input_data = { | ||
"squarefoot": 500.0, | ||
"num_floors": 2.0 | ||
} | ||
|
||
string_buffer = [] | ||
data_buffer = [] | ||
for key, value in input_data.items(): | ||
string_buffer.append(key.encode('utf-8')) | ||
data_buffer.append(value) | ||
|
||
# Prepare input data as a ctypes array | ||
array_type = ctypes.c_float * len(data_buffer) # Create an array type of the appropriate size | ||
input_data = array_type(*data_buffer) # Instantiate the array with the list elements | ||
|
||
# prepare the input strings | ||
string_array = (ctypes.c_char_p * len(string_buffer))(*string_buffer) | ||
string_count = len(string_buffer) | ||
|
||
# Call the raw_compute function | ||
result = self.lib.buffered_compute( | ||
file_info.file_id, | ||
input_data, | ||
len(input_data), | ||
string_array, | ||
string_count | ||
) | ||
|
||
if result.is_error: | ||
self.fail(f"Error in buffered_compute: {result.error_message.decode('utf-8')}") | ||
|
||
# Extract and verify the computation result | ||
outcome = [result.data[i] for i in range(result.length)] | ||
self.assertEqual(362.9851989746094, outcome[0]) | ||
|
||
# Free allocated memory | ||
self.lib.free_vecf32_return(result) | ||
self.lib.free_file_info(file_info) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters