Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Require single class zonefiles by default, and give context if possible on parsing errors. #477

Merged
merged 10 commits into from
Apr 1, 2025
3 changes: 3 additions & 0 deletions src/dnssec/validator/anchor.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
//! Create DNSSEC trust anchors.

use super::context::Error;
use crate::base::iana::Class;
use crate::base::name::{Chain, Name, ToName};
use crate::base::{Record, RelativeName};
use crate::rdata::ZoneRecordData;
Expand Down Expand Up @@ -118,6 +119,7 @@ impl TrustAnchors {
let mut new_self = Self(Vec::new());

let mut zonefile = Zonefile::new();
zonefile.set_default_class(Class::IN);
zonefile.extend_from_slice(str);
zonefile.extend_from_slice(b"\n");
for e in zonefile {
Expand All @@ -137,6 +139,7 @@ impl TrustAnchors {
/// zonefile format.
pub fn add_u8(&mut self, str: &[u8]) -> Result<(), Error> {
let mut zonefile = Zonefile::new();
zonefile.set_default_class(Class::IN);
zonefile.extend_from_slice(str);
zonefile.extend_from_slice("\n".as_bytes());
for e in zonefile {
Expand Down
2 changes: 1 addition & 1 deletion src/stelline/parse_stelline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,7 @@ fn parse_section<Lines: Iterator<Item = Result<String, std::io::Error>>>(
origin = new_origin.to_string();
}
} else {
let mut zonefile = Zonefile::new();
let mut zonefile = Zonefile::new().allow_invalid();
zonefile.extend_from_slice(
format!("$ORIGIN {origin}\n").as_bytes(),
);
Expand Down
215 changes: 185 additions & 30 deletions src/zonefile/inplace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,14 @@ pub type ScannedString = Str<Bytes>;
/// into the memory buffer. The function [`load`][Self::load] can be used to
/// create a value directly from a reader.
///
/// Once data has been added, you can simply iterate over the value to
/// get entries. The [`next_entry`][Self::next_entry] method provides an
/// Once data has been added, you can simply iterate over the value to get
/// entries. The [`next_entry`][Self::next_entry] method provides an
/// alternative with a more question mark friendly signature.
///
/// By default RFC 1035 validity checks are enabled. At present only the first
/// check is implemented: "1. All RRs in the zonefile should have the same
/// class". To disable strict validation call [`allow_invalid()`] prior to
/// calling [`load()`].
#[derive(Clone, Debug)]
pub struct Zonefile {
/// This is where we keep the data of the next entry.
Expand All @@ -73,7 +78,11 @@ pub struct Zonefile {
last_ttl: Ttl,

/// The last class.
last_class: Class,
last_class: Option<Class>,

/// Whether the loaded zonefile should be required to pass RFC 1035
/// validity checks.
require_valid: bool,
}

impl Zonefile {
Expand All @@ -89,14 +98,21 @@ impl Zonefile {
)))
}

/// Disables RFC 1035 section 5.2 zonefile validity checks.
pub fn allow_invalid(mut self) -> Self {
self.require_valid = false;
self
}

/// Creates a new value using the given buffer.
fn with_buf(buf: SourceBuf) -> Self {
Zonefile {
buf,
origin: None,
last_owner: None,
last_ttl: Ttl::from_secs(3600),
last_class: Class::IN,
last_class: None,
require_valid: true,
}
}

Expand Down Expand Up @@ -168,7 +184,19 @@ impl Zonefile {
/// any relative names encountered will cause iteration to terminate with
/// a missing origin error.
pub fn set_origin(&mut self, origin: Name<Bytes>) {
self.origin = Some(origin)
self.origin = Some(origin);
}

/// Set a default class to use.
///
/// RFC 1035 does not define a default class for zone file records to use,
/// it only states that the class field for a record is optional with
/// omitted class values defaulting to the last explicitly stated value.
///
/// If no last explicitly stated value exists, the class passed to this
/// function will be used, otherwise an error will be raised.
pub fn set_default_class(&mut self, class: Class) {
self.last_class = Some(class);
}

/// Returns the next entry in the zonefile.
Expand Down Expand Up @@ -342,12 +370,38 @@ impl<'a> EntryScanner<'a> {
self.zonefile.last_owner = Some(owner.clone());
}

let class = match class {
Some(class) => {
self.zonefile.last_class = class;
let class = match (class, self.zonefile.last_class) {
// https://www.rfc-editor.org/rfc/rfc1035#section-5.2
// 5.2. Use of master files to define zones
// ..
// "1. All RRs in the file should have the same class."
(Some(class), Some(last_class)) => {
if self.zonefile.require_valid && class != last_class {
return Err(EntryError::different_class(
last_class, class,
));
}
class
}

// Record lacks a class but a last class is known, use it.
//
// https://www.rfc-editor.org/rfc/rfc1035#section-5.2
// 5.1. Format
// ..
// "Omitted class and TTL values are default to the last
// explicitly stated values."
(None, Some(last_class)) => last_class,

// Record specifies a class, use it.
(Some(class), None) => {
self.zonefile.last_class = Some(class);
class
}
None => self.zonefile.last_class,

// Record lacks a class and no last class is known, raise an
// error.
(None, None) => return Err(EntryError::missing_last_class()),
};

let ttl = match ttl {
Expand Down Expand Up @@ -472,7 +526,7 @@ impl<'a> EntryScanner<'a> {
self.zonefile.buf.require_line_feed()?;
Ok(ScannedEntry::Ttl(Ttl::from_secs(ttl)))
} else {
Err(EntryError::unknown_control())
Err(EntryError::unknown_control(ctrl))
}
}
}
Expand Down Expand Up @@ -1438,75 +1492,137 @@ enum ItemCat {

/// An error returned by the entry scanner.
#[derive(Clone, Debug)]
pub struct EntryError(&'static str);
pub struct EntryError {
msg: &'static str,

#[cfg(feature = "std")]
context: Option<std::string::String>,
}

impl EntryError {
fn bad_symbol(_err: SymbolOctetsError) -> Self {
EntryError("bad symbol")
EntryError {
msg: "bad symbol",
#[cfg(feature = "std")]
context: Some(format!("{}", _err)),
}
}

fn bad_charstr() -> Self {
EntryError("bad charstr")
EntryError {
msg: "bad charstr",
#[cfg(feature = "std")]
context: None,
}
}

fn bad_name() -> Self {
EntryError("bad name")
EntryError {
msg: "bad name",
#[cfg(feature = "std")]
context: None,
}
}

fn unbalanced_parens() -> Self {
EntryError("unbalanced parens")
EntryError {
msg: "unbalanced parens",
#[cfg(feature = "std")]
context: None,
}
}

fn missing_last_owner() -> Self {
EntryError("missing last owner")
EntryError {
msg: "missing last owner",
#[cfg(feature = "std")]
context: None,
}
}

fn missing_last_class() -> Self {
EntryError {
msg: "missing last class",
#[cfg(feature = "std")]
context: None,
}
}

fn missing_origin() -> Self {
EntryError("missing origin")
EntryError {
msg: "missing origin",
#[cfg(feature = "std")]
context: None,
}
}

fn expected_rtype() -> Self {
EntryError("expected rtype")
EntryError {
msg: "expected rtype",
#[cfg(feature = "std")]
context: None,
}
}

fn unknown_control() -> Self {
EntryError("unknown control")
fn unknown_control(ctrl: Str<Bytes>) -> Self {
EntryError {
msg: "unknown control",
#[cfg(feature = "std")]
context: Some(format!("{}", ctrl)),
}
}

fn different_class(expected_class: Class, found_class: Class) -> Self {
EntryError {
msg: "different class",
#[cfg(feature = "std")]
context: Some(format!("{found_class} != {expected_class}")),
}
}
}

impl ScannerError for EntryError {
fn custom(msg: &'static str) -> Self {
EntryError(msg)
EntryError {
msg,
#[cfg(feature = "std")]
context: None,
}
}

fn end_of_entry() -> Self {
Self("unexpected end of entry")
Self::custom("unexpected end of entry")
}

fn short_buf() -> Self {
Self("short buffer")
Self::custom("short buffer")
}

fn trailing_tokens() -> Self {
Self("trailing tokens")
Self::custom("trailing tokens")
}
}

impl From<SymbolOctetsError> for EntryError {
fn from(_: SymbolOctetsError) -> Self {
EntryError("symbol octets error")
fn from(err: SymbolOctetsError) -> Self {
Self::bad_symbol(err)
}
}

impl From<BadSymbol> for EntryError {
fn from(_: BadSymbol) -> Self {
EntryError("bad symbol")
Self::custom("bad symbol")
}
}

impl fmt::Display for EntryError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.write_str(self.0.as_ref())
f.write_str(self.msg)?;
#[cfg(feature = "std")]
if let Some(context) = &self.context {
write!(f, ": {}", context)?;
}
Ok(())
}
}

Expand Down Expand Up @@ -1601,16 +1717,31 @@ mod test {
#[allow(clippy::type_complexity)]
struct TestCase {
origin: Name<Bytes>,
default_class: Option<Class>,
zonefile: std::string::String,
result: Vec<Record<Name<Bytes>, ZoneRecordData<Bytes, Name<Bytes>>>>,
#[serde(default)]
allow_invalid: bool,
}

impl From<&str> for TestCase {
fn from(yaml: &str) -> Self {
serde_yaml::from_str(yaml).unwrap()
}
}

impl TestCase {
fn test(yaml: &str) {
let case = serde_yaml::from_str::<Self>(yaml).unwrap();
fn test<T: Into<TestCase>>(case: T) {
let case = case.into();
let mut input = case.zonefile.as_bytes();
let mut zone = Zonefile::load(&mut input).unwrap();
if case.allow_invalid {
zone = zone.allow_invalid();
}
zone.set_origin(case.origin);
if let Some(class) = case.default_class {
zone.set_default_class(class);
}
let mut result = case.result.as_slice();
while let Some(entry) = zone.next_entry().unwrap() {
match entry {
Expand Down Expand Up @@ -1672,6 +1803,30 @@ mod test {
));
}

#[test]
fn test_default_and_last_class() {
TestCase::test(include_str!(
"../../test-data/zonefiles/defaultclass.yaml"
));
}

#[test]
#[should_panic(expected = "different class")]
fn test_rfc1035_same_class_validity_check() {
TestCase::test(include_str!(
"../../test-data/zonefiles/mixedclass.yaml"
));
}

#[test]
fn test_rfc1035_validity_checks_override() {
let mut case = TestCase::from(include_str!(
"../../test-data/zonefiles/mixedclass.yaml"
));
case.allow_invalid = true;
TestCase::test(case);
}

#[test]
fn test_chrstr_decoding() {
TestCase::test(include_str!("../../test-data/zonefiles/strlen.yaml"));
Expand Down
Loading