diff --git a/src/stelline/parse_stelline.rs b/src/stelline/parse_stelline.rs index 4df5f34b2..b0eee64c1 100644 --- a/src/stelline/parse_stelline.rs +++ b/src/stelline/parse_stelline.rs @@ -512,7 +512,7 @@ fn parse_section>>( origin = new_origin.to_string(); } } else { - let mut zonefile = Zonefile::new(); + let mut zonefile = Zonefile::new().allow_invalid(); zonefile.extend_from_slice( format!("$ORIGIN {origin}\n").as_bytes(), ); diff --git a/src/zonefile/inplace.rs b/src/zonefile/inplace.rs index 98451d8ef..b59dfbbcf 100644 --- a/src/zonefile/inplace.rs +++ b/src/zonefile/inplace.rs @@ -55,9 +55,14 @@ pub type ScannedString = Str; /// into the memory buffer. The function [`load`][Self::load] can be used to /// create a value directly from a reader. /// -/// Once data has been added, you can simply iterate over the value to -/// get entries. The [`next_entry`][Self::next_entry] method provides an +/// Once data has been added, you can simply iterate over the value to get +/// entries. The [`next_entry`][Self::next_entry] method provides an /// alternative with a more question mark friendly signature. +/// +/// By default RFC 1035 validity checks are enabled. At present only the first +/// check is implemented: "1. All RRs in the zonefile should have the same +/// class". To disable strict validation call [`allow_invalid()`] prior to +/// calling [`load()`]. #[derive(Clone, Debug)] pub struct Zonefile { /// This is where we keep the data of the next entry. @@ -73,7 +78,11 @@ pub struct Zonefile { last_ttl: Ttl, /// The last class. - last_class: Class, + last_class: Option, + + /// Whether the loaded zonefile should be required to pass RFC 1035 + /// validity checks. + require_valid: bool, } impl Zonefile { @@ -89,6 +98,12 @@ impl Zonefile { ))) } + /// Disables RFC 1035 section 5.2 zonefile validity checks. + pub fn allow_invalid(mut self) -> Self { + self.require_valid = false; + self + } + /// Creates a new value using the given buffer. fn with_buf(buf: SourceBuf) -> Self { Zonefile { @@ -96,7 +111,8 @@ impl Zonefile { origin: None, last_owner: None, last_ttl: Ttl::from_secs(3600), - last_class: Class::IN, + last_class: None, + require_valid: true, } } @@ -342,12 +358,27 @@ impl<'a> EntryScanner<'a> { self.zonefile.last_owner = Some(owner.clone()); } - let class = match class { - Some(class) => { - self.zonefile.last_class = class; + let class = match (class, self.zonefile.last_class) { + (Some(class), Some(last_class)) => { + if self.zonefile.require_valid && class != last_class { + return Err(EntryError::different_class( + last_class, class, + )); + } + class + } + + (Some(class), None) => { + self.zonefile.last_class = Some(class); class } - None => self.zonefile.last_class, + + (None, Some(last_class)) => last_class, + + (None, None) => { + self.zonefile.last_class = Some(Class::IN); + Class::IN + } }; let ttl = match ttl { @@ -472,7 +503,7 @@ impl<'a> EntryScanner<'a> { self.zonefile.buf.require_line_feed()?; Ok(ScannedEntry::Ttl(Ttl::from_secs(ttl))) } else { - Err(EntryError::unknown_control()) + Err(EntryError::unknown_control(ctrl)) } } } @@ -1438,75 +1469,136 @@ enum ItemCat { /// An error returned by the entry scanner. #[derive(Clone, Debug)] -pub struct EntryError(&'static str); +pub struct EntryError { + msg: &'static str, + + #[cfg(feature = "std")] + context: Option, +} impl EntryError { fn bad_symbol(_err: SymbolOctetsError) -> Self { - EntryError("bad symbol") + EntryError { + msg: "bad symbol", + #[cfg(feature = "std")] + context: Some(format!("{}", _err)), + } } fn bad_charstr() -> Self { - EntryError("bad charstr") + EntryError { + msg: "bad charstr", + #[cfg(feature = "std")] + context: None, + } } fn bad_name() -> Self { - EntryError("bad name") + EntryError { + msg: "bad name", + #[cfg(feature = "std")] + context: None, + } } fn unbalanced_parens() -> Self { - EntryError("unbalanced parens") + EntryError { + msg: "unbalanced parens", + #[cfg(feature = "std")] + context: None, + } } fn missing_last_owner() -> Self { - EntryError("missing last owner") + EntryError { + msg: "missing last owner", + #[cfg(feature = "std")] + context: None, + } } fn missing_origin() -> Self { - EntryError("missing origin") + EntryError { + msg: "missing origin", + #[cfg(feature = "std")] + context: None, + } } fn expected_rtype() -> Self { - EntryError("expected rtype") + EntryError { + msg: "expected rtype", + #[cfg(feature = "std")] + context: None, + } } - fn unknown_control() -> Self { - EntryError("unknown control") + fn unknown_control(ctrl: Str) -> Self { + EntryError { + msg: "unknown control", + #[cfg(feature = "std")] + context: Some(format!("{}", ctrl)), + } + } + + fn different_class(expected_class: Class, found_class: Class) -> Self { + EntryError { + msg: "different class", + #[cfg(feature = "std")] + context: Some(format!("{found_class} != {expected_class}")), + } } } impl ScannerError for EntryError { fn custom(msg: &'static str) -> Self { - EntryError(msg) + EntryError { + msg, + #[cfg(feature = "std")] + context: None, + } } fn end_of_entry() -> Self { - Self("unexpected end of entry") + Self::custom("unexpected end of entry") } fn short_buf() -> Self { - Self("short buffer") + Self::custom("short buffer") } fn trailing_tokens() -> Self { - Self("trailing tokens") + Self::custom("trailing tokens") } } impl From for EntryError { - fn from(_: SymbolOctetsError) -> Self { - EntryError("symbol octets error") + fn from(err: SymbolOctetsError) -> Self { + Self::bad_symbol(err) } } impl From for EntryError { fn from(_: BadSymbol) -> Self { - EntryError("bad symbol") + Self::custom("bad symbol") } } impl fmt::Display for EntryError { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - f.write_str(self.0.as_ref()) + #[cfg(not(feature = "std"))] + { + f.write_str(self.msg) + } + + #[cfg(feature = "std")] + { + if let Some(context) = &self.context { + f.write_fmt(format_args!("{}: {}", self.msg, context)) + } else { + f.write_str(self.msg) + } + } } }