-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
upgrading to beta V 2 of ort with optional compilation of GPU
- Loading branch information
1 parent
e0cf6c4
commit 3f29dd5
Showing
5 changed files
with
156 additions
and
123 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,109 +1,113 @@ | ||
use std::env; | ||
use std::fs; | ||
use std::path::Path; | ||
// use std::env; | ||
// use std::fs; | ||
// use std::path::Path; | ||
|
||
/// works out where the `onnxruntime` library is in the build target and copies the library to the root | ||
/// of the crate so the core library can find it and load it into the binary using `include_bytes!()`. | ||
/// | ||
/// # Notes | ||
/// This is a workaround for the fact that `onnxruntime` doesn't support `cargo` yet. This build step | ||
/// is reliant on the `ort` crate downloading and building the `onnxruntime` library. This is | ||
/// why the following dependency is required in `Cargo.toml`: | ||
/// ```toml | ||
/// [build-dependencies] | ||
/// ort = { version = "1.16.2", default-features = true } | ||
/// ``` | ||
/// Here we can see that the `default-features` is set to `true`. This is because the `ort` crate will download | ||
/// the correct package and build it for the target platform by default. In the main part of our dependencies | ||
/// we have the following: | ||
/// ```toml | ||
/// [dependencies] | ||
/// ort = { version = "1.16.2", features = ["load-dynamic"], default-features = false } | ||
/// ``` | ||
/// Here we can see that the `default-features` is set to `false`. This is because we don't want the `ort` crate | ||
/// to download and build the `onnxruntime` library again. Instead we want to use the one that was built in the | ||
/// build step. We also set the `load-dynamic` feature to `true` so that the `ort` crate will load the `onnxruntime` | ||
/// library dynamically at runtime. This is because we don't want to statically link the `onnxruntime`. Our `onnxruntime` | ||
/// is embedded into the binary using `include_bytes!()` and we want to load it dynamically at runtime. This means that | ||
/// we do not need to move the `onnxruntime` library around with the binary, and there is no complicated setup required | ||
/// or linking. | ||
fn unpack_onnx() -> std::io::Result<()> { | ||
let out_dir = env::var("OUT_DIR").expect("OUT_DIR not set"); | ||
let out_path = Path::new(&out_dir); | ||
let build_dir = out_path | ||
.ancestors() // This gives an iterator over all ancestors of the path | ||
.nth(3) // 'nth(3)' gets the fourth ancestor (counting from 0), which should be the debug directory | ||
.expect("Failed to find debug directory"); | ||
// /// works out where the `onnxruntime` library is in the build target and copies the library to the root | ||
// /// of the crate so the core library can find it and load it into the binary using `include_bytes!()`. | ||
// /// | ||
// /// # Notes | ||
// /// This is a workaround for the fact that `onnxruntime` doesn't support `cargo` yet. This build step | ||
// /// is reliant on the `ort` crate downloading and building the `onnxruntime` library. This is | ||
// /// why the following dependency is required in `Cargo.toml`: | ||
// /// ```toml | ||
// /// [build-dependencies] | ||
// /// ort = { version = "1.16.2", default-features = true } | ||
// /// ``` | ||
// /// Here we can see that the `default-features` is set to `true`. This is because the `ort` crate will download | ||
// /// the correct package and build it for the target platform by default. In the main part of our dependencies | ||
// /// we have the following: | ||
// /// ```toml | ||
// /// [dependencies] | ||
// /// ort = { version = "1.16.2", features = ["load-dynamic"], default-features = false } | ||
// /// ``` | ||
// /// Here we can see that the `default-features` is set to `false`. This is because we don't want the `ort` crate | ||
// /// to download and build the `onnxruntime` library again. Instead we want to use the one that was built in the | ||
// /// build step. We also set the `load-dynamic` feature to `true` so that the `ort` crate will load the `onnxruntime` | ||
// /// library dynamically at runtime. This is because we don't want to statically link the `onnxruntime`. Our `onnxruntime` | ||
// /// is embedded into the binary using `include_bytes!()` and we want to load it dynamically at runtime. This means that | ||
// /// we do not need to move the `onnxruntime` library around with the binary, and there is no complicated setup required | ||
// /// or linking. | ||
// fn unpack_onnx() -> std::io::Result<()> { | ||
// let out_dir = env::var("OUT_DIR").expect("OUT_DIR not set"); | ||
// let out_path = Path::new(&out_dir); | ||
// let build_dir = out_path | ||
// .ancestors() // This gives an iterator over all ancestors of the path | ||
// .nth(3) // 'nth(3)' gets the fourth ancestor (counting from 0), which should be the debug directory | ||
// .expect("Failed to find debug directory"); | ||
|
||
match std::env::var("ONNXRUNTIME_LIB_PATH") { | ||
Ok(onnx_path) => { | ||
println!("Surrealml Core Debug: ONNXRUNTIME_LIB_PATH set at: {}", onnx_path); | ||
println!("cargo:rustc-cfg=onnx_runtime_env_var_set"); | ||
} | ||
Err(_) => { | ||
println!("Surrealml Core Debug: ONNXRUNTIME_LIB_PATH not set"); | ||
let target_lib = match env::var("CARGO_CFG_TARGET_OS").unwrap() { | ||
ref s if s.contains("linux") => "libonnxruntime.so", | ||
ref s if s.contains("macos") => "libonnxruntime.dylib", | ||
ref s if s.contains("windows") => "onnxruntime.dll", | ||
// ref s if s.contains("android") => "android", => not building for android | ||
_ => panic!("Unsupported target os"), | ||
}; | ||
// match std::env::var("ONNXRUNTIME_LIB_PATH") { | ||
// Ok(onnx_path) => { | ||
// println!("Surrealml Core Debug: ONNXRUNTIME_LIB_PATH set at: {}", onnx_path); | ||
// println!("cargo:rustc-cfg=onnx_runtime_env_var_set"); | ||
// } | ||
// Err(_) => { | ||
// println!("Surrealml Core Debug: ONNXRUNTIME_LIB_PATH not set"); | ||
// let target_lib = match env::var("CARGO_CFG_TARGET_OS").unwrap() { | ||
// ref s if s.contains("linux") => "libonnxruntime.so", | ||
// ref s if s.contains("macos") => "libonnxruntime.dylib", | ||
// ref s if s.contains("windows") => "onnxruntime.dll", | ||
// // ref s if s.contains("android") => "android", => not building for android | ||
// _ => panic!("Unsupported target os"), | ||
// }; | ||
|
||
let lib_path = build_dir.join(target_lib); | ||
let lib_path = lib_path.to_str().unwrap(); | ||
println!("Surrealml Core Debug: lib_path={}", lib_path); | ||
// let lib_path = build_dir.join(target_lib); | ||
// let lib_path = lib_path.to_str().unwrap(); | ||
// println!("Surrealml Core Debug: lib_path={}", lib_path); | ||
|
||
// Check if the path exists | ||
if fs::metadata(lib_path).is_ok() { | ||
println!("Surrealml Core Debug: lib_path exists"); | ||
} else { | ||
println!("Surrealml Core Debug: lib_path does not exist"); | ||
// Extract the directory path | ||
if let Some(parent) = std::path::Path::new(lib_path).parent() { | ||
// Print the contents of the directory | ||
match fs::read_dir(parent) { | ||
Ok(entries) => { | ||
println!("Surrealml Core Debug: content of directory {}", parent.display()); | ||
for entry in entries { | ||
if let Ok(entry) = entry { | ||
println!("{}", entry.path().display()); | ||
} | ||
} | ||
} | ||
Err(e) => { | ||
println!("Surrealml Core Debug: Failed to read directory {}: {}", parent.display(), e); | ||
} | ||
} | ||
} else { | ||
println!("Surrealml Core Debug: Could not determine the parent directory of the path."); | ||
} | ||
} | ||
// // Check if the path exists | ||
// if fs::metadata(lib_path).is_ok() { | ||
// println!("Surrealml Core Debug: lib_path exists"); | ||
// } else { | ||
// println!("Surrealml Core Debug: lib_path does not exist"); | ||
// // Extract the directory path | ||
// if let Some(parent) = std::path::Path::new(lib_path).parent() { | ||
// // Print the contents of the directory | ||
// match fs::read_dir(parent) { | ||
// Ok(entries) => { | ||
// println!("Surrealml Core Debug: content of directory {}", parent.display()); | ||
// for entry in entries { | ||
// if let Ok(entry) = entry { | ||
// println!("{}", entry.path().display()); | ||
// } | ||
// } | ||
// } | ||
// Err(e) => { | ||
// println!("Surrealml Core Debug: Failed to read directory {}: {}", parent.display(), e); | ||
// } | ||
// } | ||
// } else { | ||
// println!("Surrealml Core Debug: Could not determine the parent directory of the path."); | ||
// } | ||
// } | ||
|
||
// put it next to the file of the embedding | ||
let destination = Path::new(target_lib); | ||
fs::copy(lib_path, destination)?; | ||
println!("Surrealml Core Debug: onnx lib copied from {} to {}", lib_path, destination.display()); | ||
} | ||
} | ||
Ok(()) | ||
} | ||
// // put it next to the file of the embedding | ||
// let destination = Path::new(target_lib); | ||
// fs::copy(lib_path, destination)?; | ||
// println!("Surrealml Core Debug: onnx lib copied from {} to {}", lib_path, destination.display()); | ||
// } | ||
// } | ||
// Ok(()) | ||
// } | ||
|
||
fn main() -> std::io::Result<()> { | ||
if std::env::var("DOCS_RS").is_ok() { | ||
// we are not going to be anything here for docs.rs, because we are merely building the docs. When we are just building | ||
// the docs, the onnx environment variable will not look for the `onnxruntime` library, so we don't need to unpack it. | ||
return Ok(()); | ||
} | ||
// fn main() -> std::io::Result<()> { | ||
// if std::env::var("DOCS_RS").is_ok() { | ||
// // we are not going to be anything here for docs.rs, because we are merely building the docs. When we are just building | ||
// // the docs, the onnx environment variable will not look for the `onnxruntime` library, so we don't need to unpack it. | ||
// return Ok(()); | ||
// } | ||
|
||
if env::var("ORT_STRATEGY").as_deref() == Ok("system") { | ||
// If the ORT crate is built with the `system` strategy, then the crate will take care of statically linking the library. | ||
// No need to do anything here. | ||
println!("cargo:rustc-cfg=onnx_statically_linked"); | ||
// if env::var("ORT_STRATEGY").as_deref() == Ok("system") { | ||
// // If the ORT crate is built with the `system` strategy, then the crate will take care of statically linking the library. | ||
// // No need to do anything here. | ||
// println!("cargo:rustc-cfg=onnx_statically_linked"); | ||
|
||
return Ok(()); | ||
} | ||
// return Ok(()); | ||
// } | ||
|
||
unpack_onnx()?; | ||
Ok(()) | ||
// unpack_onnx()?; | ||
// Ok(()) | ||
// } | ||
|
||
fn main() { | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
//! Defines operations around performing computations on a loaded model. | ||
pub mod compute; | ||
pub mod onnx_environment; | ||
// pub mod onnx_environment; | ||
pub mod session; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
use ort::session::Session; | ||
use crate::errors::error::{SurrealError, SurrealErrorStatus}; | ||
use crate::safe_eject; | ||
|
||
|
||
pub fn get_session(model_bytes: Vec<u8>) -> Result<Session, SurrealError> { | ||
let builder = safe_eject!(Session::builder(), SurrealErrorStatus::Unknown); | ||
|
||
#[cfg(feature = "gpu")] | ||
{ | ||
let cuda = CUDAExecutionProvider::default(); | ||
if let Err(e) = cuda.register(&builder) { | ||
eprintln!("Failed to register CUDA: {:?}. Falling back to CPU.", e); | ||
} else { | ||
println!("CUDA registered successfully"); | ||
} | ||
} | ||
let session: Session = safe_eject!(builder | ||
.commit_from_memory(&model_bytes), SurrealErrorStatus::Unknown); | ||
Ok(session) | ||
} |