Skip to content

Commit

Permalink
upgrading to beta V 2 of ort with optional compilation of GPU
Browse files Browse the repository at this point in the history
  • Loading branch information
maxwellflitton committed Dec 24, 2024
1 parent e0cf6c4 commit 3f29dd5
Show file tree
Hide file tree
Showing 5 changed files with 156 additions and 123 deletions.
8 changes: 5 additions & 3 deletions modules/core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,18 @@ sklearn-tests = []
onnx-tests = []
torch-tests = []
tensorflow-tests = []
gpu = []

[dependencies]
regex = "1.9.3"
ort = { version = "1.16.2", features = ["load-dynamic"], default-features = false }
ndarray = "0.15.6"
# ort = { version = "1.16.2", features = ["load-dynamic"], default-features = false }
ort = { version = "2.0.0-rc.9", features = [ "cuda", "ndarray" ]}
ndarray = "0.16.1"
once_cell = "1.18.0"
bytes = "1.5.0"
futures-util = "0.3.28"
futures-core = "0.3.28"
thiserror = "1.0.57"
thiserror = "2.0.9"
serde = { version = "1.0.197", features = ["derive"] }
axum = { version = "0.7.4", optional = true }
actix-web = { version = "4.5.1", optional = true }
Expand Down
202 changes: 103 additions & 99 deletions modules/core/build.rs
Original file line number Diff line number Diff line change
@@ -1,109 +1,113 @@
use std::env;
use std::fs;
use std::path::Path;
// use std::env;
// use std::fs;
// use std::path::Path;

/// works out where the `onnxruntime` library is in the build target and copies the library to the root
/// of the crate so the core library can find it and load it into the binary using `include_bytes!()`.
///
/// # Notes
/// This is a workaround for the fact that `onnxruntime` doesn't support `cargo` yet. This build step
/// is reliant on the `ort` crate downloading and building the `onnxruntime` library. This is
/// why the following dependency is required in `Cargo.toml`:
/// ```toml
/// [build-dependencies]
/// ort = { version = "1.16.2", default-features = true }
/// ```
/// Here we can see that the `default-features` is set to `true`. This is because the `ort` crate will download
/// the correct package and build it for the target platform by default. In the main part of our dependencies
/// we have the following:
/// ```toml
/// [dependencies]
/// ort = { version = "1.16.2", features = ["load-dynamic"], default-features = false }
/// ```
/// Here we can see that the `default-features` is set to `false`. This is because we don't want the `ort` crate
/// to download and build the `onnxruntime` library again. Instead we want to use the one that was built in the
/// build step. We also set the `load-dynamic` feature to `true` so that the `ort` crate will load the `onnxruntime`
/// library dynamically at runtime. This is because we don't want to statically link the `onnxruntime`. Our `onnxruntime`
/// is embedded into the binary using `include_bytes!()` and we want to load it dynamically at runtime. This means that
/// we do not need to move the `onnxruntime` library around with the binary, and there is no complicated setup required
/// or linking.
fn unpack_onnx() -> std::io::Result<()> {
let out_dir = env::var("OUT_DIR").expect("OUT_DIR not set");
let out_path = Path::new(&out_dir);
let build_dir = out_path
.ancestors() // This gives an iterator over all ancestors of the path
.nth(3) // 'nth(3)' gets the fourth ancestor (counting from 0), which should be the debug directory
.expect("Failed to find debug directory");
// /// works out where the `onnxruntime` library is in the build target and copies the library to the root
// /// of the crate so the core library can find it and load it into the binary using `include_bytes!()`.
// ///
// /// # Notes
// /// This is a workaround for the fact that `onnxruntime` doesn't support `cargo` yet. This build step
// /// is reliant on the `ort` crate downloading and building the `onnxruntime` library. This is
// /// why the following dependency is required in `Cargo.toml`:
// /// ```toml
// /// [build-dependencies]
// /// ort = { version = "1.16.2", default-features = true }
// /// ```
// /// Here we can see that the `default-features` is set to `true`. This is because the `ort` crate will download
// /// the correct package and build it for the target platform by default. In the main part of our dependencies
// /// we have the following:
// /// ```toml
// /// [dependencies]
// /// ort = { version = "1.16.2", features = ["load-dynamic"], default-features = false }
// /// ```
// /// Here we can see that the `default-features` is set to `false`. This is because we don't want the `ort` crate
// /// to download and build the `onnxruntime` library again. Instead we want to use the one that was built in the
// /// build step. We also set the `load-dynamic` feature to `true` so that the `ort` crate will load the `onnxruntime`
// /// library dynamically at runtime. This is because we don't want to statically link the `onnxruntime`. Our `onnxruntime`
// /// is embedded into the binary using `include_bytes!()` and we want to load it dynamically at runtime. This means that
// /// we do not need to move the `onnxruntime` library around with the binary, and there is no complicated setup required
// /// or linking.
// fn unpack_onnx() -> std::io::Result<()> {
// let out_dir = env::var("OUT_DIR").expect("OUT_DIR not set");
// let out_path = Path::new(&out_dir);
// let build_dir = out_path
// .ancestors() // This gives an iterator over all ancestors of the path
// .nth(3) // 'nth(3)' gets the fourth ancestor (counting from 0), which should be the debug directory
// .expect("Failed to find debug directory");

match std::env::var("ONNXRUNTIME_LIB_PATH") {
Ok(onnx_path) => {
println!("Surrealml Core Debug: ONNXRUNTIME_LIB_PATH set at: {}", onnx_path);
println!("cargo:rustc-cfg=onnx_runtime_env_var_set");
}
Err(_) => {
println!("Surrealml Core Debug: ONNXRUNTIME_LIB_PATH not set");
let target_lib = match env::var("CARGO_CFG_TARGET_OS").unwrap() {
ref s if s.contains("linux") => "libonnxruntime.so",
ref s if s.contains("macos") => "libonnxruntime.dylib",
ref s if s.contains("windows") => "onnxruntime.dll",
// ref s if s.contains("android") => "android", => not building for android
_ => panic!("Unsupported target os"),
};
// match std::env::var("ONNXRUNTIME_LIB_PATH") {
// Ok(onnx_path) => {
// println!("Surrealml Core Debug: ONNXRUNTIME_LIB_PATH set at: {}", onnx_path);
// println!("cargo:rustc-cfg=onnx_runtime_env_var_set");
// }
// Err(_) => {
// println!("Surrealml Core Debug: ONNXRUNTIME_LIB_PATH not set");
// let target_lib = match env::var("CARGO_CFG_TARGET_OS").unwrap() {
// ref s if s.contains("linux") => "libonnxruntime.so",
// ref s if s.contains("macos") => "libonnxruntime.dylib",
// ref s if s.contains("windows") => "onnxruntime.dll",
// // ref s if s.contains("android") => "android", => not building for android
// _ => panic!("Unsupported target os"),
// };

let lib_path = build_dir.join(target_lib);
let lib_path = lib_path.to_str().unwrap();
println!("Surrealml Core Debug: lib_path={}", lib_path);
// let lib_path = build_dir.join(target_lib);
// let lib_path = lib_path.to_str().unwrap();
// println!("Surrealml Core Debug: lib_path={}", lib_path);

// Check if the path exists
if fs::metadata(lib_path).is_ok() {
println!("Surrealml Core Debug: lib_path exists");
} else {
println!("Surrealml Core Debug: lib_path does not exist");
// Extract the directory path
if let Some(parent) = std::path::Path::new(lib_path).parent() {
// Print the contents of the directory
match fs::read_dir(parent) {
Ok(entries) => {
println!("Surrealml Core Debug: content of directory {}", parent.display());
for entry in entries {
if let Ok(entry) = entry {
println!("{}", entry.path().display());
}
}
}
Err(e) => {
println!("Surrealml Core Debug: Failed to read directory {}: {}", parent.display(), e);
}
}
} else {
println!("Surrealml Core Debug: Could not determine the parent directory of the path.");
}
}
// // Check if the path exists
// if fs::metadata(lib_path).is_ok() {
// println!("Surrealml Core Debug: lib_path exists");
// } else {
// println!("Surrealml Core Debug: lib_path does not exist");
// // Extract the directory path
// if let Some(parent) = std::path::Path::new(lib_path).parent() {
// // Print the contents of the directory
// match fs::read_dir(parent) {
// Ok(entries) => {
// println!("Surrealml Core Debug: content of directory {}", parent.display());
// for entry in entries {
// if let Ok(entry) = entry {
// println!("{}", entry.path().display());
// }
// }
// }
// Err(e) => {
// println!("Surrealml Core Debug: Failed to read directory {}: {}", parent.display(), e);
// }
// }
// } else {
// println!("Surrealml Core Debug: Could not determine the parent directory of the path.");
// }
// }

// put it next to the file of the embedding
let destination = Path::new(target_lib);
fs::copy(lib_path, destination)?;
println!("Surrealml Core Debug: onnx lib copied from {} to {}", lib_path, destination.display());
}
}
Ok(())
}
// // put it next to the file of the embedding
// let destination = Path::new(target_lib);
// fs::copy(lib_path, destination)?;
// println!("Surrealml Core Debug: onnx lib copied from {} to {}", lib_path, destination.display());
// }
// }
// Ok(())
// }

fn main() -> std::io::Result<()> {
if std::env::var("DOCS_RS").is_ok() {
// we are not going to be anything here for docs.rs, because we are merely building the docs. When we are just building
// the docs, the onnx environment variable will not look for the `onnxruntime` library, so we don't need to unpack it.
return Ok(());
}
// fn main() -> std::io::Result<()> {
// if std::env::var("DOCS_RS").is_ok() {
// // we are not going to be anything here for docs.rs, because we are merely building the docs. When we are just building
// // the docs, the onnx environment variable will not look for the `onnxruntime` library, so we don't need to unpack it.
// return Ok(());
// }

if env::var("ORT_STRATEGY").as_deref() == Ok("system") {
// If the ORT crate is built with the `system` strategy, then the crate will take care of statically linking the library.
// No need to do anything here.
println!("cargo:rustc-cfg=onnx_statically_linked");
// if env::var("ORT_STRATEGY").as_deref() == Ok("system") {
// // If the ORT crate is built with the `system` strategy, then the crate will take care of statically linking the library.
// // No need to do anything here.
// println!("cargo:rustc-cfg=onnx_statically_linked");

return Ok(());
}
// return Ok(());
// }

unpack_onnx()?;
Ok(())
// unpack_onnx()?;
// Ok(())
// }

fn main() {

}
45 changes: 25 additions & 20 deletions modules/core/src/execution/compute.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
//! Defines the operations around performing computations on a loaded model.
use crate::storage::surml_file::SurMlFile;
use std::collections::HashMap;
use ndarray::{ArrayD, CowArray};
use ort::{SessionBuilder, Value, session::Input};
use ndarray::ArrayD;
use ort::value::ValueType;
use ort::session::Session;

use super::onnx_environment::ENVIRONMENT;
use crate::safe_eject;
use crate::errors::error::{SurrealError, SurrealErrorStatus};
use crate::execution::session::get_session;


/// A wrapper for the loaded machine learning model so we can perform computations on the loaded model.
Expand Down Expand Up @@ -39,15 +40,21 @@ impl <'a>ModelComputation<'a> {
///
/// # Returns
/// A vector of dimensions for the input tensor to be reshaped into from the loaded model.
fn process_input_dims(input_dims: &Input) -> Vec<usize> {
let mut buffer = Vec::new();
for dim in input_dims.dimensions() {
match dim {
Some(dim) => buffer.push(dim as usize),
None => buffer.push(1)
fn process_input_dims(session_ref: &Session) -> Vec<usize> {
let some_dims = match &session_ref.inputs[0].input_type {
ValueType::Tensor{ ty: _, dimensions: new_dims, dimension_symbols: _ } => Some(new_dims),
_ => None
};
let mut dims_cache = Vec::new();
for dim in some_dims.unwrap() {
if dim < &0 {
dims_cache.push((dim * -1) as usize);
}
else {
dims_cache.push(*dim as usize);
}
}
buffer
dims_cache
}

/// Creates a Vector that can be used manipulated with other operations such as normalisation from a hashmap of keys and values.
Expand Down Expand Up @@ -79,26 +86,24 @@ impl <'a>ModelComputation<'a> {
/// # Returns
/// The computed output tensor from the loaded model.
pub fn raw_compute(&self, tensor: ArrayD<f32>, _dims: Option<(i32, i32)>) -> Result<Vec<f32>, SurrealError> {
let session = safe_eject!(SessionBuilder::new(&ENVIRONMENT), SurrealErrorStatus::Unknown);
let session = safe_eject!(session.with_model_from_memory(&self.surml_file.model), SurrealErrorStatus::Unknown);
let unwrapped_dims = ModelComputation::process_input_dims(&session.inputs[0]);
let tensor = safe_eject!(tensor.into_shape(unwrapped_dims), SurrealErrorStatus::Unknown);

let x = CowArray::from(tensor).into_dyn();
let input_values = safe_eject!(Value::from_array(session.allocator(), &x), SurrealErrorStatus::Unknown);
let outputs = safe_eject!(session.run(vec![input_values]), SurrealErrorStatus::Unknown);
let session = get_session(self.surml_file.model.clone())?;
let dims_cache = ModelComputation::process_input_dims(&session);
let tensor = tensor.into_shape_with_order(dims_cache).unwrap();
let tensor = ort::value::Tensor::from_array(tensor).unwrap();
let x = ort::inputs![tensor].unwrap();
let outputs = safe_eject!(session.run(x), SurrealErrorStatus::Unknown);

let mut buffer: Vec<f32> = Vec::new();

// extract the output tensor converting the values to f32 if they are i64
match outputs[0].try_extract::<f32>() {
match outputs[0].try_extract_tensor::<f32>() {
Ok(y) => {
for i in y.view().clone().into_iter() {
buffer.push(*i);
}
},
Err(_) => {
for i in safe_eject!(outputs[0].try_extract::<i64>(), SurrealErrorStatus::Unknown).view().clone().into_iter() {
for i in safe_eject!(outputs[0].try_extract_tensor::<i64>(), SurrealErrorStatus::Unknown).view().clone().into_iter() {
buffer.push(*i as f32);
}
}
Expand Down
3 changes: 2 additions & 1 deletion modules/core/src/execution/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
//! Defines operations around performing computations on a loaded model.
pub mod compute;
pub mod onnx_environment;
// pub mod onnx_environment;
pub mod session;
21 changes: 21 additions & 0 deletions modules/core/src/execution/session.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
use ort::session::Session;
use crate::errors::error::{SurrealError, SurrealErrorStatus};
use crate::safe_eject;


pub fn get_session(model_bytes: Vec<u8>) -> Result<Session, SurrealError> {
let builder = safe_eject!(Session::builder(), SurrealErrorStatus::Unknown);

#[cfg(feature = "gpu")]
{
let cuda = CUDAExecutionProvider::default();
if let Err(e) = cuda.register(&builder) {
eprintln!("Failed to register CUDA: {:?}. Falling back to CPU.", e);
} else {
println!("CUDA registered successfully");
}
}
let session: Session = safe_eject!(builder
.commit_from_memory(&model_bytes), SurrealErrorStatus::Unknown);
Ok(session)
}

0 comments on commit 3f29dd5

Please sign in to comment.