From bf942e5114d955edba3facab8523a64c3a16348e Mon Sep 17 00:00:00 2001 From: Marvin Drescher Date: Thu, 20 Jun 2024 11:53:06 +0200 Subject: [PATCH] feat(rumqttd/broker): implemented filter for publish packets BREAKING CHANGE: `Router::new` now takes an list of `PublishFilterRef` as additional parameter --- rumqttd/src/lib.rs | 2 +- rumqttd/src/router/filter.rs | 134 ++++++++++++++++++++++++++++++++++ rumqttd/src/router/mod.rs | 2 + rumqttd/src/router/routing.rs | 23 ++++-- rumqttd/src/server/broker.rs | 10 ++- 5 files changed, 161 insertions(+), 10 deletions(-) create mode 100644 rumqttd/src/router/filter.rs diff --git a/rumqttd/src/lib.rs b/rumqttd/src/lib.rs index 405a1a243..fb8c073b6 100644 --- a/rumqttd/src/lib.rs +++ b/rumqttd/src/lib.rs @@ -21,7 +21,7 @@ use tracing_subscriber::{ pub use link::alerts; pub use link::local; pub use link::meters; -pub use router::{Alert, Forward, IncomingMeter, Meter, Notification, OutgoingMeter}; +pub use router::{Alert, Forward, IncomingMeter, Meter, Notification, OutgoingMeter, PublishFilter, PublishFilterRef}; use segments::Storage; pub use server::Broker; diff --git a/rumqttd/src/router/filter.rs b/rumqttd/src/router/filter.rs new file mode 100644 index 000000000..ebd3e2a76 --- /dev/null +++ b/rumqttd/src/router/filter.rs @@ -0,0 +1,134 @@ +use std::{fmt::Debug, ops::Deref, sync::Arc}; + +use crate::protocol::{Publish, PublishProperties}; + +/// Filter for [`Publish`] packets +pub trait PublishFilter { + /// Determines weather an [`Publish`] packet should be processed + /// Arguments: + /// * `packet`: to be published, may be modified if necessary + /// * `properties`: received along with the packet, may be `None` for older MQTT versions + /// Returns: [`bool`] indicating if the packet should be processed + fn filter(&self, packet: &mut Publish, properties: Option<&mut PublishProperties>) -> bool; +} + +/// Container for either an owned [`PublishFilter`] or an `'static` reference +#[derive(Clone)] +pub enum PublishFilterRef { + Owned(Arc), + Static(&'static (dyn PublishFilter + Send + Sync)), +} + +impl Debug for PublishFilterRef { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Owned(_arg0) => f.debug_tuple("Owned").finish(), + Self::Static(_arg0) => f.debug_tuple("Static").finish(), + } + } +} + +impl Deref for PublishFilterRef { + type Target = dyn PublishFilter; + + fn deref(&self) -> &Self::Target { + match self { + Self::Static(filter) => *filter, + Self::Owned(filter) => &**filter, + } + } +} + +/// Implements [`PublishFilter`] for any ordinary function +impl PublishFilter for F +where + F: Fn(&mut Publish, Option<&mut PublishProperties>) -> bool + Send + Sync, +{ + fn filter(&self, packet: &mut Publish, properties: Option<&mut PublishProperties>) -> bool { + self(packet, properties) + } +} + +/// Implements the conversion +/// ```rust +/// # use rumqttd::{protocol::{Publish, PublishProperties}, PublishFilterRef}; +/// fn filter_static(packet: &mut Publish, properties: Option<&mut PublishProperties>) -> bool { +/// todo!() +/// } +/// +/// let filter = PublishFilterRef::from(&filter_static); +/// # assert!(matches!(filter, PublishFilterRef::Static(_))); +/// ``` +impl From<&'static F> for PublishFilterRef +where + F: Fn(&mut Publish, Option<&mut PublishProperties>) -> bool + Send + Sync, +{ + fn from(value: &'static F) -> Self { + Self::Static(value) + } +} + +/// Implements the conversion +/// ```rust +/// # use std::boxed::Box; +/// # use rumqttd::{protocol::{Publish, PublishProperties}, PublishFilter, PublishFilterRef}; +/// #[derive(Clone)] +/// struct MyFilter {} +/// +/// impl PublishFilter for MyFilter { +/// fn filter(&self, packet: &mut Publish, properties: Option<&mut PublishProperties>) -> bool { +/// todo!() +/// } +/// } +/// let boxed: Box = Box::new(MyFilter {}); +/// +/// let filter = PublishFilterRef::from(boxed); +/// # assert!(matches!(filter, PublishFilterRef::Owned(_))); +/// ``` +impl From> for PublishFilterRef +where + T: PublishFilter + 'static + Send + Sync, +{ + fn from(value: Arc) -> Self { + Self::Owned(value) + } +} + +impl From> for PublishFilterRef +where + T: PublishFilter + 'static + Send + Sync, +{ + fn from(value: Box) -> Self { + Self::Owned(Arc::::from(value)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn filter_static(_packet: &mut Publish, _properties: Option<&mut PublishProperties>) -> bool { + true + } + struct Prejudiced(bool); + + impl PublishFilter for Prejudiced { + fn filter(&self, _packet: &mut Publish,_propertiess: Option<&mut PublishProperties>) -> bool { + self.0 + } + } + #[test] + fn static_filter() { + fn is_send(_: &T) {} + fn takes_static_filter(filter: impl Into) { + assert!(matches!(filter.into(), PublishFilterRef::Static(_))); + } + fn takes_owned_filter(filter: impl Into) { + assert!(matches!(filter.into(), PublishFilterRef::Owned(_))); + } + takes_static_filter(&filter_static); + let boxed: PublishFilterRef = Box::new(Prejudiced(false)).into(); + is_send(&boxed); + takes_owned_filter(boxed); + } +} diff --git a/rumqttd/src/router/mod.rs b/rumqttd/src/router/mod.rs index a75cad34f..d35f7babe 100644 --- a/rumqttd/src/router/mod.rs +++ b/rumqttd/src/router/mod.rs @@ -24,7 +24,9 @@ mod routing; mod scheduler; pub(crate) mod shared_subs; mod waiters; +mod filter; +pub use filter::{PublishFilter, PublishFilterRef}; pub use alertlog::Alert; pub use connection::Connection; pub use routing::Router; diff --git a/rumqttd/src/router/routing.rs b/rumqttd/src/router/routing.rs index ebf362a71..f4a076227 100644 --- a/rumqttd/src/router/routing.rs +++ b/rumqttd/src/router/routing.rs @@ -70,6 +70,8 @@ pub struct Router { connections: Slab, /// Connection map from device id to connection id connection_map: HashMap, + /// Filters to be applied to an [`Publish`] packets payload + publish_filters: Vec, /// Subscription map to interested connection ids subscription_map: HashMap>, /// Incoming data grouped by connection @@ -105,7 +107,7 @@ pub struct Router { } impl Router { - pub fn new(router_id: RouterId, config: RouterConfig) -> Router { + pub fn new(router_id: RouterId, publish_filters: Vec, config: RouterConfig) -> Router { let (router_tx, router_rx) = bounded(1000); let meters = Slab::with_capacity(10); @@ -129,6 +131,7 @@ impl Router { alerts, connections, connection_map: Default::default(), + publish_filters, subscription_map: Default::default(), ibufs, obufs, @@ -557,13 +560,18 @@ impl Router { for packet in packets.drain(0..) { match packet { - Packet::Publish(publish, properties) => { + Packet::Publish(mut publish, mut properties) => { + println!("publish: {publish:?} payload: {:?}", publish.payload.to_vec()); let span = tracing::error_span!("publish", topic = ?publish.topic, pkid = publish.pkid); let _guard = span.enter(); let qos = publish.qos; let pkid = publish.pkid; - + + // Decide weather to keep or discard this packet + // Packet will be discard if *at least one* filter returns *false* + let keep = self.publish_filters.iter().fold(true,|keep,f| keep && f.filter(&mut publish, properties.as_mut())) ; + // Prepare acks for the above publish // If any of the publish in the batch results in force flush, // set global force flush flag. Force flush is triggered when the @@ -577,12 +585,11 @@ impl Router { // coordinate using multiple offsets, and we don't have any idea how to do so right now. // Currently as we don't have replication, we just use a single offset, even when appending to // multiple commit logs. - match qos { QoS::AtLeastOnce => { let puback = PubAck { pkid, - reason: PubAckReason::Success, + reason: if keep { PubAckReason::Success } else { PubAckReason::PayloadFormatInvalid }, }; let ackslog = self.ackslog.get_mut(id).unwrap(); @@ -592,7 +599,7 @@ impl Router { QoS::ExactlyOnce => { let pubrec = PubRec { pkid, - reason: PubRecReason::Success, + reason: if keep { PubRecReason::Success } else { PubRecReason::PayloadFormatInvalid }, }; let ackslog = self.ackslog.get_mut(id).unwrap(); @@ -604,7 +611,9 @@ impl Router { // Do nothing } }; - + if !keep { + break; + } self.router_meters.total_publishes += 1; // Try to append publish to commitlog diff --git a/rumqttd/src/server/broker.rs b/rumqttd/src/server/broker.rs index 9886541c9..f5cc331a8 100644 --- a/rumqttd/src/server/broker.rs +++ b/rumqttd/src/server/broker.rs @@ -35,7 +35,7 @@ use std::{io, thread}; use crate::link::console; use crate::link::local::{self, LinkRx, LinkTx}; -use crate::router::{Event, Router}; +use crate::router::{Event, PublishFilterRef, Router}; use crate::{Config, ConnectionId, ServerSettings}; use tokio::net::{TcpListener, TcpStream}; @@ -71,9 +71,13 @@ pub struct Broker { impl Broker { pub fn new(config: Config) -> Broker { + Self::with_filter(config, Vec::new()) + } + + pub fn with_filter(config: Config, publish_filters: Vec) -> Broker { let config = Arc::new(config); let router_config = config.router.clone(); - let router: Router = Router::new(config.id, router_config); + let router: Router = Router::new(config.id, publish_filters, router_config); // Setup cluster if cluster settings are configured. match config.cluster.clone() { @@ -96,6 +100,8 @@ impl Broker { } } + + // pub fn new_local_cluster( // config: Config, // node_id: NodeId,