Skip to content

Commit

Permalink
Merge pull request #2840 from fermyon/simplify-default-sqlite
Browse files Browse the repository at this point in the history
Simplify default database resolution in sqlite
  • Loading branch information
rylev authored Sep 18, 2024
2 parents caacf55 + 4d50c8e commit 21c66f4
Show file tree
Hide file tree
Showing 11 changed files with 197 additions and 169 deletions.
1 change: 0 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion crates/factor-sqlite/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ tracing = { workspace = true }

[dev-dependencies]
spin-factors-test = { path = "../factors-test" }
spin-sqlite = { path = "../sqlite" }
tokio = { version = "1", features = ["macros", "rt"] }

[lints]
Expand Down
30 changes: 15 additions & 15 deletions crates/factor-sqlite/src/host.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::collections::HashSet;
use std::collections::{HashMap, HashSet};
use std::sync::Arc;

use async_trait::async_trait;
Expand All @@ -14,35 +14,28 @@ use crate::{Connection, ConnectionCreator};

pub struct InstanceState {
allowed_databases: Arc<HashSet<String>>,
/// A resource table of connections.
connections: table::Table<Box<dyn Connection>>,
get_connection_creator: ConnectionCreatorGetter,
/// A map from database label to connection creators.
connection_creators: HashMap<String, Arc<dyn ConnectionCreator>>,
}

impl InstanceState {
pub fn allowed_databases(&self) -> &HashSet<String> {
&self.allowed_databases
}
}

/// A function that takes a database label and returns a connection creator, if one exists.
pub type ConnectionCreatorGetter =
Arc<dyn Fn(&str) -> Option<Arc<dyn ConnectionCreator>> + Send + Sync>;

impl InstanceState {
/// Create a new `InstanceState`
///
/// Takes the list of allowed databases, and a function for getting a connection creator given a database label.
pub fn new(
allowed_databases: Arc<HashSet<String>>,
get_connection_creator: ConnectionCreatorGetter,
connection_creators: HashMap<String, Arc<dyn ConnectionCreator>>,
) -> Self {
Self {
allowed_databases,
connections: table::Table::new(256),
get_connection_creator,
connection_creators,
}
}

/// Get a connection for a given database label.
fn get_connection(
&self,
connection: Resource<v2::Connection>,
Expand All @@ -52,6 +45,11 @@ impl InstanceState {
.map(|conn| conn.as_ref())
.ok_or(v2::Error::InvalidConnection)
}

/// Get the set of allowed databases.
pub fn allowed_databases(&self) -> &HashSet<String> {
&self.allowed_databases
}
}

impl SelfInstanceBuilder for InstanceState {}
Expand All @@ -69,7 +67,9 @@ impl v2::HostConnection for InstanceState {
if !self.allowed_databases.contains(&database) {
return Err(v2::Error::AccessDenied);
}
let conn = (self.get_connection_creator)(&database)
let conn = self
.connection_creators
.get(&database)
.ok_or(v2::Error::NoSuchDatabase)?
.create_connection(&database)
.await?;
Expand Down
50 changes: 16 additions & 34 deletions crates/factor-sqlite/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,15 @@ use spin_world::v2::sqlite as v2;

pub use runtime_config::RuntimeConfig;

#[derive(Default)]
pub struct SqliteFactor {
default_label_resolver: Arc<dyn DefaultLabelResolver>,
_priv: (),
}

impl SqliteFactor {
/// Create a new `SqliteFactor`
///
/// Takes a `default_label_resolver` for how to handle when a database label doesn't
/// have a corresponding runtime configuration.
pub fn new(default_label_resolver: impl DefaultLabelResolver + 'static) -> Self {
Self {
default_label_resolver: Arc::new(default_label_resolver),
}
pub fn new() -> Self {
Self { _priv: () }
}
}

Expand All @@ -50,8 +46,8 @@ impl Factor for SqliteFactor {
) -> anyhow::Result<Self::AppState> {
let connection_creators = ctx
.take_runtime_config()
.map(|r| r.connection_creators)
.unwrap_or_default();
.unwrap_or_default()
.connection_creators;

let allowed_databases = ctx
.app()
Expand All @@ -69,19 +65,12 @@ impl Factor for SqliteFactor {
))
})
.collect::<anyhow::Result<HashMap<_, _>>>()?;
let resolver = self.default_label_resolver.clone();
let get_connection_creator: host::ConnectionCreatorGetter = Arc::new(move |label| {
connection_creators
.get(label)
.cloned()
.or_else(|| resolver.default(label))
});

ensure_allowed_databases_are_configured(&allowed_databases, |label| {
get_connection_creator(label).is_some()
connection_creators.contains_key(label)
})?;

Ok(AppState::new(allowed_databases, get_connection_creator))
Ok(AppState::new(allowed_databases, connection_creators))
}

fn prepare<T: spin_factors::RuntimeFactors>(
Expand All @@ -94,10 +83,9 @@ impl Factor for SqliteFactor {
.get(ctx.app_component().id())
.cloned()
.unwrap_or_default();
let get_connection_creator = ctx.app_state().get_connection_creator.clone();
Ok(InstanceState::new(
allowed_databases,
get_connection_creator,
ctx.app_state().connection_creators.clone(),
))
}
}
Expand Down Expand Up @@ -138,31 +126,23 @@ fn ensure_allowed_databases_are_configured(
/// Metadata key for a list of allowed databases for a component.
pub const ALLOWED_DATABASES_KEY: MetadataKey<Vec<String>> = MetadataKey::new("databases");

/// Resolves a label to a default connection creator.
pub trait DefaultLabelResolver: Send + Sync {
/// If there is no runtime configuration for a given database label, return a default connection creator.
///
/// If `Option::None` is returned, the database is not allowed.
fn default(&self, label: &str) -> Option<Arc<dyn ConnectionCreator>>;
}

#[derive(Clone)]
pub struct AppState {
/// A map from component id to a set of allowed database labels.
allowed_databases: HashMap<String, Arc<HashSet<String>>>,
/// A function for mapping from database name to a connection creator.
get_connection_creator: host::ConnectionCreatorGetter,
/// A mapping from database label to a connection creator.
connection_creators: HashMap<String, Arc<dyn ConnectionCreator>>,
}

impl AppState {
/// Create a new `AppState`
pub fn new(
allowed_databases: HashMap<String, Arc<HashSet<String>>>,
get_connection_creator: host::ConnectionCreatorGetter,
connection_creators: HashMap<String, Arc<dyn ConnectionCreator>>,
) -> Self {
Self {
allowed_databases,
get_connection_creator,
connection_creators,
}
}

Expand All @@ -173,7 +153,9 @@ impl AppState {
&self,
label: &str,
) -> Option<Result<Box<dyn Connection>, v2::Error>> {
let connection = (self.get_connection_creator)(label)?
let connection = self
.connection_creators
.get(label)?
.create_connection(label)
.await;
Some(connection)
Expand Down
1 change: 1 addition & 0 deletions crates/factor-sqlite/src/runtime_config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use crate::ConnectionCreator;
/// A runtime configuration for SQLite databases.
///
/// Maps database labels to connection creators.
#[derive(Default)]
pub struct RuntimeConfig {
pub connection_creators: HashMap<String, Arc<dyn ConnectionCreator>>,
}
Loading

0 comments on commit 21c66f4

Please sign in to comment.