Skip to content

Commit

Permalink
unit testing for C dylib is now working
Browse files Browse the repository at this point in the history
  • Loading branch information
maxwellflitton committed Dec 25, 2024
1 parent d33d9af commit ac1c133
Show file tree
Hide file tree
Showing 30 changed files with 539 additions and 103 deletions.
20 changes: 20 additions & 0 deletions modules/c-wrapper/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@

# C Wrapper

This workspace is a C wrapper for the `surrealml-core` library. This enables us to no longer need `PyO3` and we can also use this library in other languages.

## Testing

To test this C wrapper we first need to build the C lib and position it in the correct location for the Python tests to load the library. We can perform this setup with the following command:

```bash
sh ./scripts/prep_tests.sh
```

This will build the C lib in debug mode and place it in the correct location for the Python tests to load the library. We can then run the tests with the following command:

```bash
sh ./scripts/run_tests.sh
```

If you setup pycharm to put your Python tests through a debugger, you need to open pycharm in the root of this workspace and set the `tests` directory as the sources root. This will allow you to point and click on specific tests and run them through a debugger.
2 changes: 1 addition & 1 deletion modules/c-wrapper/scripts/prep_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ esac
SOURCE_DIR="target/debug"

# Destination directory (tests directory)
DEST_DIR="tests"
DEST_DIR="tests/test_utils"


# Copy the library to the tests directory
Expand Down
11 changes: 11 additions & 0 deletions modules/c-wrapper/scripts/run_tests.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#!/usr/bin/env bash

# navigate to directory
SCRIPTPATH="$( cd "$(dirname "$0")" ; pwd -P )"
cd $SCRIPTPATH

cd ..

cd tests

python3 -m unittest discover .
1 change: 1 addition & 0 deletions modules/c-wrapper/src/api/execution/mod.rs
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
//! The C API for executing ML models.
pub mod raw_compute;
25 changes: 23 additions & 2 deletions modules/c-wrapper/src/api/execution/raw_compute.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,17 @@
//! This module contains the raw_compute function that is called from the C API to compute the model.
use crate::state::STATE;
use std::ffi::{c_float, CStr, CString, c_int, c_char};
use surrealml_core::execution::compute::ModelComputation;


/// Holds the data around the outcome of the raw_compute function.
///
/// # Fields
/// * `data` - The data returned from the computation.
/// * `length` - The length of the data.
/// * `capacity` - The capacity of the data.
/// * `is_error` - A flag indicating if an error occurred (1 for error, 0 for success).
/// * `error_message` - An error message if the computation failed.
#[repr(C)]
pub struct Vecf32Return {
pub data: *mut f32,
Expand All @@ -13,6 +22,10 @@ pub struct Vecf32Return {
}


/// Frees the memory allocated for the Vecf32Return.
///
/// # Arguments
/// * `vecf32_return` - The Vecf32Return to free.
#[no_mangle]
pub extern "C" fn free_vecf32_return(vecf32_return: Vecf32Return) {
// Free the data if it is not null
Expand All @@ -26,7 +39,15 @@ pub extern "C" fn free_vecf32_return(vecf32_return: Vecf32Return) {
}



/// Computes the model with the given data.
///
/// # Arguments
/// * `file_id_ptr` - A pointer to the unique identifier for the loaded model.
/// * `data_ptr` - A pointer to the data to compute.
/// * `length` - The length of the data.
///
/// # Returns
/// A Vecf32Return object containing the outcome of the computation.
#[no_mangle]
pub extern "C" fn raw_compute(file_id_ptr: *const c_char, data_ptr: *const c_float, length: usize) -> Vecf32Return {

Expand Down Expand Up @@ -117,4 +138,4 @@ pub extern "C" fn raw_compute(file_id_ptr: *const c_char, data_ptr: *const c_flo
is_error: 0,
error_message: std::ptr::null_mut()
}
}
}
1 change: 1 addition & 0 deletions modules/c-wrapper/src/api/mod.rs
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
//! C API for interacting with the SurML file storage and executing models.
pub mod execution;
pub mod storage;
24 changes: 15 additions & 9 deletions modules/c-wrapper/src/api/storage/load_cached_raw_model.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,21 @@
use crate::state::{STATE, generate_unique_id};
use surrealml_core::storage::surml_file::SurMlFile;
//! Defines the C interface for loading an ONNX model from a file and storing it in memory.
// Standard library imports
use std::ffi::{CStr, CString};
use std::fs::File;
use std::io::Read;
use std::os::raw::c_char;
use std::ffi::CString;
use std::ffi::CStr;

// External crate imports
use surrealml_core::storage::surml_file::SurMlFile;

// Local module imports
use crate::state::{generate_unique_id, STATE};
use crate::utils::StringReturn;
use crate::process_string_for_string_return;
use crate::{process_string_for_string_return, string_return_safe_eject};



/// Loads a PyTorch C model from a file wrapping it in a SurMlFile struct
/// Loads a ONNX model from a file wrapping it in a SurMlFile struct
/// which is stored in memory and referenced by a unique ID.
///
/// # Arguments
Expand All @@ -21,11 +27,11 @@ use crate::process_string_for_string_return;
pub extern "C" fn load_cached_raw_model(file_path_ptr: *const c_char) -> StringReturn {
let file_path_str = process_string_for_string_return!(file_path_ptr, "file path");
let file_id = generate_unique_id();
let mut model = File::open(file_path_str).unwrap();
let mut model = string_return_safe_eject!(File::open(file_path_str));
let mut data = vec![];
model.read_to_end(&mut data).unwrap();
string_return_safe_eject!(model.read_to_end(&mut data));
let file = SurMlFile::fresh(data);
let mut python_state = STATE.lock().unwrap();
python_state.insert(file_id.clone(), file);
StringReturn::success(file_id)
}
}
50 changes: 37 additions & 13 deletions modules/c-wrapper/src/api/storage/load_model.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,39 @@
use crate::state::{STATE, generate_unique_id};
//! Defines the C interface for loading a surml file and getting the meta data around the model.
// Standard library imports
use std::ffi::{CStr, CString};
use std::os::raw::{c_char, c_int};

// External crate imports
use surrealml_core::storage::surml_file::SurMlFile;
use std::ffi::CStr;
use std::ffi::CString;
use std::os::raw::c_char;

// Local module imports
use crate::state::{generate_unique_id, STATE};


/// Holds the data around the outcome of the load_model function.
///
/// # Fields
/// * `file_id` - The unique identifier for the loaded model.
/// * `name` - The name of the model.
/// * `description` - The description of the model.
/// * `version` - The version of the model.
/// * `error_message` - An error message if the loading failed.
/// * `is_error` - A flag indicating if an error occurred (1 for error, 0 for success).
#[repr(C)]
pub struct FileInfo {
file_id: *mut c_char,
name: *mut c_char,
description: *mut c_char,
version: *mut c_char,
error_message: *mut c_char
pub file_id: *mut c_char,
pub name: *mut c_char,
pub description: *mut c_char,
pub version: *mut c_char,
pub error_message: *mut c_char,
pub is_error: c_int,
}


/// Frees the memory allocated for the file info.
///
/// # Arguments
/// * `info` - The file info to free.
#[no_mangle]
pub extern "C" fn free_file_info(info: FileInfo) {
// Free all allocated strings if they are not null
Expand Down Expand Up @@ -51,7 +71,8 @@ pub extern "C" fn load_model(file_path_ptr: *const c_char) -> FileInfo {
name: std::ptr::null_mut(),
description: std::ptr::null_mut(),
version: std::ptr::null_mut(),
error_message: CString::new("Received a null pointer for file path").unwrap().into_raw()
error_message: CString::new("Received a null pointer for file path").unwrap().into_raw(),
is_error: 1
};
}

Expand All @@ -67,7 +88,8 @@ pub extern "C" fn load_model(file_path_ptr: *const c_char) -> FileInfo {
name: std::ptr::null_mut(),
description: std::ptr::null_mut(),
version: std::ptr::null_mut(),
error_message: CString::new("Invalid UTF-8 string received for file path").unwrap().into_raw()
error_message: CString::new("Invalid UTF-8 string received for file path").unwrap().into_raw(),
is_error: 1
};
}
};
Expand All @@ -80,7 +102,8 @@ pub extern "C" fn load_model(file_path_ptr: *const c_char) -> FileInfo {
name: std::ptr::null_mut(),
description: std::ptr::null_mut(),
version: std::ptr::null_mut(),
error_message: CString::new(e.to_string()).unwrap().into_raw()
error_message: CString::new(e.to_string()).unwrap().into_raw(),
is_error: 1
};
}
};
Expand All @@ -106,6 +129,7 @@ pub extern "C" fn load_model(file_path_ptr: *const c_char) -> FileInfo {
name: name.into_raw(),
description: description.into_raw(),
version: version.into_raw(),
error_message: std::ptr::null_mut()
error_message: std::ptr::null_mut(),
is_error: 0
}
}
32 changes: 19 additions & 13 deletions modules/c-wrapper/src/api/storage/meta.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
use crate::state::STATE;
//! Defines the C API interface for interacting with the meta data of a SurML file.
// Standard library imports
use std::ffi::{CStr, CString};
use std::os::raw::c_char;

// External crate imports
use surrealml_core::storage::header::normalisers::wrapper::NormaliserType;
use crate::{process_string_for_empty_return, empty_return_safe_eject};

// Local module imports
use crate::state::STATE;
use crate::utils::EmptyReturn;
use std::os::raw::c_char;
use std::ffi::CString;
use std::ffi::CStr;
use crate::{empty_return_safe_eject, process_string_for_empty_return};



/// Adds a name to the SurMlFile struct.
Expand All @@ -17,7 +23,7 @@ pub extern "C" fn add_name(file_id_ptr: *const c_char, model_name_ptr: *const c_
let file_id = process_string_for_empty_return!(file_id_ptr, "file id");
let model_name = process_string_for_empty_return!(model_name_ptr, "model name");
let mut state = STATE.lock().unwrap();
let wrapped_file = state.get_mut(&file_id).unwrap();
let wrapped_file = empty_return_safe_eject!(state.get_mut(&file_id), "Model not found", Option);
wrapped_file.header.add_name(model_name);
EmptyReturn::success()
}
Expand All @@ -33,7 +39,7 @@ pub extern "C" fn add_description(file_id_ptr: *const c_char, description_ptr: *
let file_id = process_string_for_empty_return!(file_id_ptr, "file id");
let description = process_string_for_empty_return!(description_ptr, "description");
let mut state = STATE.lock().unwrap();
let wrapped_file = state.get_mut(&file_id).unwrap();
let wrapped_file = empty_return_safe_eject!(state.get_mut(&file_id), "Model not found", Option);
wrapped_file.header.add_description(description);
EmptyReturn::success()
}
Expand All @@ -49,7 +55,7 @@ pub extern "C" fn add_version(file_id: *const c_char, version: *const c_char) ->
let file_id = process_string_for_empty_return!(file_id, "file id");
let version = process_string_for_empty_return!(version, "version");
let mut state = STATE.lock().unwrap();
let wrapped_file = state.get_mut(&file_id).unwrap();
let wrapped_file = empty_return_safe_eject!(state.get_mut(&file_id), "Model not found", Option);
let _ = wrapped_file.header.add_version(version);
EmptyReturn::success()
}
Expand All @@ -65,7 +71,7 @@ pub extern "C" fn add_column(file_id: *const c_char, column_name: *const c_char)
let file_id = process_string_for_empty_return!(file_id, "file id");
let column_name = process_string_for_empty_return!(column_name, "column name");
let mut state = STATE.lock().unwrap();
let wrapped_file = state.get_mut(&file_id).unwrap();
let wrapped_file = empty_return_safe_eject!(state.get_mut(&file_id), "Model not found", Option);
wrapped_file.header.add_column(column_name);
EmptyReturn::success()
}
Expand All @@ -81,7 +87,7 @@ pub extern "C" fn add_author(file_id: *const c_char, author: *const c_char) -> E
let file_id = process_string_for_empty_return!(file_id, "file id");
let author = process_string_for_empty_return!(author, "author");
let mut state = STATE.lock().unwrap();
let wrapped_file = state.get_mut(&file_id).unwrap();
let wrapped_file = empty_return_safe_eject!(state.get_mut(&file_id), "Model not found", Option);
wrapped_file.header.add_author(author);
EmptyReturn::success()
}
Expand All @@ -97,7 +103,7 @@ pub extern "C" fn add_origin(file_id: *const c_char, origin: *const c_char) -> E
let file_id = process_string_for_empty_return!(file_id, "file id");
let origin = process_string_for_empty_return!(origin, "origin");
let mut state = STATE.lock().unwrap();
let wrapped_file = state.get_mut(&file_id).unwrap();
let wrapped_file = empty_return_safe_eject!(state.get_mut(&file_id), "Model not found", Option);
let _ = wrapped_file.header.add_origin(origin);
EmptyReturn::success()
}
Expand All @@ -113,7 +119,7 @@ pub extern "C" fn add_engine(file_id: *const c_char, engine: *const c_char) -> E
let file_id = process_string_for_empty_return!(file_id, "file id");
let engine = process_string_for_empty_return!(engine, "engine");
let mut state = STATE.lock().unwrap();
let wrapped_file = state.get_mut(&file_id).unwrap();
let wrapped_file = empty_return_safe_eject!(state.get_mut(&file_id), "Model not found", Option);
wrapped_file.header.add_engine(engine);
EmptyReturn::success()
}
Expand Down Expand Up @@ -199,7 +205,7 @@ pub extern "C" fn add_normaliser(

let normaliser = NormaliserType::new(normaliser_label, one, two);
let mut state = STATE.lock().unwrap();
let file = state.get_mut(&file_id).unwrap();
let file = empty_return_safe_eject!(state.get_mut(&file_id), "Model not found", Option);
let _ = file.header.normalisers.add_normaliser(normaliser, column_name, &file.header.keys);
EmptyReturn::success()
}
1 change: 1 addition & 0 deletions modules/c-wrapper/src/api/storage/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
//! C Storage API
pub mod load_model;
pub mod save_model;
pub mod load_cached_raw_model;
Expand Down
14 changes: 9 additions & 5 deletions modules/c-wrapper/src/api/storage/save_model.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
//! Save a model to a file, deleting the file from the `STATE` in the process.
// Standard library imports
use std::ffi::{CStr, CString};
use std::os::raw::c_char;

// External crate imports
use surrealml_core::storage::surml_file::SurMlFile;

// Local module imports
use crate::state::STATE;
use std::ffi::CStr;
use std::ffi::CString;
use std::os::raw::c_char;
use crate::{process_string_for_empty_return, empty_return_safe_eject};
use crate::utils::EmptyReturn;
use crate::{empty_return_safe_eject, process_string_for_empty_return};


/// Saves a model to a file, deleting the file from the `PYTHON_STATE` in the process.
Expand All @@ -25,4 +29,4 @@ pub extern "C" fn save_model(file_path_ptr: *const c_char, file_id_ptr: *const c
empty_return_safe_eject!(file.write(&file_path_str));
state.remove(&file_id_str);
EmptyReturn::success()
}
}
13 changes: 9 additions & 4 deletions modules/c-wrapper/src/api/storage/to_bytes.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
use crate::state::STATE;
//! convert the entire SurML file to bytes
// Standard library imports
use std::ffi::{CStr, CString};
use std::os::raw::c_char;

// Local module imports
use crate::state::STATE;
use crate::utils::VecU8Return;
use crate::process_string_for_vec_u8_return;
use std::ffi::CString;
use std::ffi::CStr;

/// Converts the entire file to bytes.


/// Converts the entire SurML file to bytes.
///
/// # Arguments
/// * `file_id` - The unique identifier for the SurMlFile struct.
Expand Down
Loading

0 comments on commit ac1c133

Please sign in to comment.