diff --git a/Cargo.lock b/Cargo.lock index 0621f61bb..d9c2d9782 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -232,8 +232,10 @@ version = "0.10.3" dependencies = [ "arbitrary", "arc-swap", + "bumpalo", "bytes", "chrono", + "domain-macros", "futures-util", "hashbrown 0.14.5", "heapless", @@ -271,6 +273,15 @@ dependencies = [ "webpki-roots", ] +[[package]] +name = "domain-macros" +version = "0.10.3" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "either" version = "1.13.0" @@ -1441,9 +1452,9 @@ dependencies = [ [[package]] name = "unicode-ident" -version = "1.0.14" +version = "1.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "adb9e6ca4f869e1180728b7950e35922a7fc6397f7b641499e8f3ef06e50dc83" +checksum = "11cd88e12b17c6494200a9c1b683a04fcac9573ed74cd1b62aeb2727c5592243" [[package]] name = "unsafe-libyaml" @@ -1459,9 +1470,9 @@ checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" [[package]] name = "uuid" -version = "1.12.0" +version = "1.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "744018581f9a3454a9e15beb8a33b017183f1e7c0cd170232a2d1453b23a51c4" +checksum = "b3758f5e68192bb96cc8f9b7e2c2cfdabb435499a28499a42f8f984092adad4b" dependencies = [ "getrandom", ] diff --git a/Cargo.toml b/Cargo.toml index ad0fc1e0b..0a79794e5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,3 +1,7 @@ +[workspace] +resolver = "2" +members = [".", "./macros"] + [package] name = "domain" version = "0.10.3" @@ -15,12 +19,11 @@ readme = "README.md" keywords = ["DNS", "domain"] license = "BSD-3-Clause" -[lib] -name = "domain" -path = "src/lib.rs" - [dependencies] +domain-macros = { path = "./macros", version = "0.10.3" } + arbitrary = { version = "1.4.1", optional = true, features = ["derive"] } +bumpalo = { version = "3.12", optional = true } octseq = { version = "0.5.2", default-features = false } time = { version = "0.3.1", default-features = false } rand = { version = "0.8", optional = true } @@ -52,11 +55,12 @@ tracing-subscriber = { version = "0.3.18", optional = true, features = ["env-fil default = ["std", "rand"] # Support for libraries +bumpalo = ["dep:bumpalo", "std"] bytes = ["dep:bytes", "octseq/bytes"] heapless = ["dep:heapless", "octseq/heapless"] serde = ["dep:serde", "octseq/serde"] smallvec = ["dep:smallvec", "octseq/smallvec"] -std = ["dep:hashbrown", "bytes?/std", "octseq/std", "time/std"] +std = ["dep:hashbrown", "bumpalo?/std", "bytes?/std", "octseq/std", "time/std"] tracing = ["dep:log", "dep:tracing"] # Cryptographic backends diff --git a/macros/Cargo.toml b/macros/Cargo.toml new file mode 100644 index 000000000..7060a61eb --- /dev/null +++ b/macros/Cargo.toml @@ -0,0 +1,28 @@ +[package] +name = "domain-macros" + +# Copied from 'domain'. +version = "0.10.3" +rust-version = "1.68.2" +edition = "2021" + +authors = ["NLnet Labs "] +description = "Procedural macros for the `domain` crate." +documentation = "https://docs.rs/domain-macros" +homepage = "https://github.com/nlnetlabs/domain/" +repository = "https://github.com/nlnetlabs/domain/" +keywords = ["DNS", "domain"] +license = "BSD-3-Clause" + +[lib] +proc-macro = true + +[dependencies.proc-macro2] +version = "1.0" + +[dependencies.syn] +version = "2.0" +features = ["full", "visit"] + +[dependencies.quote] +version = "1.0" diff --git a/macros/src/data.rs b/macros/src/data.rs new file mode 100644 index 000000000..ee0c52baf --- /dev/null +++ b/macros/src/data.rs @@ -0,0 +1,159 @@ +//! Working with structs, enums, and unions. + +use std::ops::Deref; + +use proc_macro2::TokenStream; +use quote::{quote, ToTokens}; +use syn::{spanned::Spanned, Field, Fields, Ident, Index, Member, Token}; + +//----------- Struct --------------------------------------------------------- + +/// A defined 'struct'. +pub struct Struct { + /// The identifier for this 'struct'. + ident: Ident, + + /// The fields in this 'struct'. + fields: Fields, +} + +impl Struct { + /// Construct a [`Struct`] for a 'Self'. + pub fn new_as_self(fields: &Fields) -> Self { + Self { + ident: ::default().into(), + fields: fields.clone(), + } + } + + /// Whether this 'struct' has no fields. + pub fn is_empty(&self) -> bool { + self.fields.is_empty() + } + + /// The number of fields in this 'struct'. + pub fn num_fields(&self) -> usize { + self.fields.len() + } + + /// The fields of this 'struct'. + pub fn fields(&self) -> impl Iterator + '_ { + self.fields.iter() + } + + /// The sized fields of this 'struct'. + pub fn sized_fields(&self) -> impl Iterator + '_ { + self.fields().take(self.num_fields() - 1) + } + + /// The unsized field of this 'struct'. + pub fn unsized_field(&self) -> Option<&Field> { + self.fields.iter().next_back() + } + + /// The names of the fields of this 'struct'. + pub fn members(&self) -> impl Iterator + '_ { + self.fields + .iter() + .enumerate() + .map(|(i, f)| make_member(i, f)) + } + + /// The names of the sized fields of this 'struct'. + pub fn sized_members(&self) -> impl Iterator + '_ { + self.members().take(self.num_fields() - 1) + } + + /// The name of the last field of this 'struct'. + pub fn unsized_member(&self) -> Option { + self.fields + .iter() + .next_back() + .map(|f| make_member(self.num_fields() - 1, f)) + } + + /// Construct a builder for this 'struct'. + pub fn builder Ident>( + &self, + f: F, + ) -> StructBuilder<'_, F> { + StructBuilder { + target: self, + var_fn: f, + } + } +} + +/// Construct a [`Member`] from a field and index. +fn make_member(index: usize, field: &Field) -> Member { + match &field.ident { + Some(ident) => Member::Named(ident.clone()), + None => Member::Unnamed(Index { + index: index as u32, + span: field.ty.span(), + }), + } +} + +//----------- StructBuilder -------------------------------------------------- + +/// A means of constructing a 'struct'. +pub struct StructBuilder<'a, F: Fn(Member) -> Ident> { + /// The 'struct' being constructed. + target: &'a Struct, + + /// A map from field names to constructing variables. + var_fn: F, +} + +impl Ident> StructBuilder<'_, F> { + /// The initializing variables for this 'struct'. + pub fn init_vars(&self) -> impl Iterator + '_ { + self.members().map(&self.var_fn) + } + + /// The names of the sized fields of this 'struct'. + pub fn sized_init_vars(&self) -> impl Iterator + '_ { + self.sized_members().map(&self.var_fn) + } + + /// The name of the last field of this 'struct'. + pub fn unsized_init_var(&self) -> Option { + self.unsized_member().map(&self.var_fn) + } +} + +impl Ident> Deref for StructBuilder<'_, F> { + type Target = Struct; + + fn deref(&self) -> &Self::Target { + self.target + } +} + +impl Ident> ToTokens for StructBuilder<'_, F> { + fn to_tokens(&self, tokens: &mut TokenStream) { + let ident = &self.ident; + match self.fields { + Fields::Named(_) => { + let members = self.members(); + let init_vars = self.init_vars(); + quote! { + #ident { #(#members: #init_vars),* } + } + } + + Fields::Unnamed(_) => { + let init_vars = self.init_vars(); + quote! { + #ident ( #(#init_vars),* ) + } + } + + Fields::Unit => { + quote! { #ident } + } + } + .to_tokens(tokens); + } +} diff --git a/macros/src/impls.rs b/macros/src/impls.rs new file mode 100644 index 000000000..4c0971998 --- /dev/null +++ b/macros/src/impls.rs @@ -0,0 +1,274 @@ +//! Helpers for generating `impl` blocks. + +use proc_macro2::{Span, TokenStream}; +use quote::{format_ident, quote, ToTokens}; +use syn::{ + punctuated::Punctuated, visit::Visit, ConstParam, GenericArgument, + GenericParam, Ident, Lifetime, LifetimeParam, Token, TypeParam, + TypeParamBound, WhereClause, WherePredicate, +}; + +//----------- ImplSkeleton --------------------------------------------------- + +/// The skeleton of an `impl` block. +pub struct ImplSkeleton { + /// Lifetime parameters for the `impl` block. + pub lifetimes: Vec, + + /// Type parameters for the `impl` block. + pub types: Vec, + + /// Const generic parameters for the `impl` block. + pub consts: Vec, + + /// Whether the `impl` is unsafe. + pub unsafety: Option, + + /// The trait being implemented. + pub bound: Option, + + /// The type being implemented on. + pub subject: syn::Path, + + /// The where clause of the `impl` block. + pub where_clause: WhereClause, + + /// The contents of the `impl`. + pub contents: syn::Block, + + /// A `const` block for asserting requirements. + pub requirements: syn::Block, +} + +impl ImplSkeleton { + /// Construct an [`ImplSkeleton`] for a [`DeriveInput`]. + pub fn new(input: &syn::DeriveInput, unsafety: bool) -> Self { + let mut lifetimes = Vec::new(); + let mut types = Vec::new(); + let mut consts = Vec::new(); + let mut subject_args = Punctuated::new(); + + for param in &input.generics.params { + match param { + GenericParam::Lifetime(value) => { + lifetimes.push(value.clone()); + let id = value.lifetime.clone(); + subject_args.push(GenericArgument::Lifetime(id)); + } + + GenericParam::Type(value) => { + types.push(value.clone()); + let id = value.ident.clone(); + let id = syn::TypePath { + qself: None, + path: syn::Path { + leading_colon: None, + segments: [syn::PathSegment { + ident: id, + arguments: syn::PathArguments::None, + }] + .into_iter() + .collect(), + }, + }; + subject_args.push(GenericArgument::Type(id.into())); + } + + GenericParam::Const(value) => { + consts.push(value.clone()); + let id = value.ident.clone(); + let id = syn::TypePath { + qself: None, + path: syn::Path { + leading_colon: None, + segments: [syn::PathSegment { + ident: id, + arguments: syn::PathArguments::None, + }] + .into_iter() + .collect(), + }, + }; + subject_args.push(GenericArgument::Type(id.into())); + } + } + } + + let unsafety = unsafety.then_some(::default()); + + let subject = syn::Path { + leading_colon: None, + segments: [syn::PathSegment { + ident: input.ident.clone(), + arguments: syn::PathArguments::AngleBracketed( + syn::AngleBracketedGenericArguments { + colon2_token: None, + lt_token: Default::default(), + args: subject_args, + gt_token: Default::default(), + }, + ), + }] + .into_iter() + .collect(), + }; + + let where_clause = + input.generics.where_clause.clone().unwrap_or(WhereClause { + where_token: Default::default(), + predicates: Punctuated::new(), + }); + + let contents = syn::Block { + brace_token: Default::default(), + stmts: Vec::new(), + }; + + let requirements = syn::Block { + brace_token: Default::default(), + stmts: Vec::new(), + }; + + Self { + lifetimes, + types, + consts, + unsafety, + bound: None, + subject, + where_clause, + contents, + requirements, + } + } + + /// Require a bound for a type. + /// + /// If the type is concrete, a verifying statement is added for it. + /// Otherwise, it is added to the where clause. + pub fn require_bound( + &mut self, + target: syn::Type, + bound: TypeParamBound, + ) { + let mut visitor = ConcretenessVisitor { + skeleton: self, + is_concrete: true, + }; + + // Concreteness applies to both the type and the bound. + visitor.visit_type(&target); + visitor.visit_type_param_bound(&bound); + + if visitor.is_concrete { + // Add a concrete requirement for this bound. + self.requirements.stmts.push(syn::parse_quote! { + const _: fn() = || { + fn assert_impl() {} + assert_impl::<#target>(); + }; + }); + } else { + // Add this bound to the `where` clause. + let mut bounds = Punctuated::new(); + bounds.push(bound); + let pred = WherePredicate::Type(syn::PredicateType { + lifetimes: None, + bounded_ty: target, + colon_token: Default::default(), + bounds, + }); + self.where_clause.predicates.push(pred); + } + } + + /// Generate a unique lifetime with the given prefix. + pub fn new_lifetime(&self, prefix: &str) -> Lifetime { + [format_ident!("{}", prefix)] + .into_iter() + .chain((0u32..).map(|i| format_ident!("{}_{}", prefix, i))) + .find(|id| self.lifetimes.iter().all(|l| l.lifetime.ident != *id)) + .map(|ident| Lifetime { + apostrophe: Span::call_site(), + ident, + }) + .unwrap() + } + + /// Generate a unique lifetime parameter with the given prefix and bounds. + pub fn new_lifetime_param( + &self, + prefix: &str, + bounds: impl IntoIterator, + ) -> (Lifetime, LifetimeParam) { + let lifetime = self.new_lifetime(prefix); + let mut bounds = bounds.into_iter().peekable(); + let param = if bounds.peek().is_some() { + syn::parse_quote! { #lifetime: #(#bounds)+* } + } else { + syn::parse_quote! { #lifetime } + }; + (lifetime, param) + } +} + +impl ToTokens for ImplSkeleton { + fn to_tokens(&self, tokens: &mut TokenStream) { + let Self { + lifetimes, + types, + consts, + unsafety, + bound, + subject, + where_clause, + contents, + requirements, + } = self; + + let target = match bound { + Some(bound) => quote!(#bound for #subject), + None => quote!(#subject), + }; + + quote! { + #unsafety + impl<#(#lifetimes,)* #(#types,)* #(#consts,)*> + #target + #where_clause + #contents + } + .to_tokens(tokens); + + if !requirements.stmts.is_empty() { + quote! { + const _: () = #requirements; + } + .to_tokens(tokens); + } + } +} + +//----------- ConcretenessVisitor -------------------------------------------- + +struct ConcretenessVisitor<'a> { + /// The `impl` skeleton being added to. + skeleton: &'a ImplSkeleton, + + /// Whether the visited type is concrete. + is_concrete: bool, +} + +impl<'ast> Visit<'ast> for ConcretenessVisitor<'_> { + fn visit_lifetime(&mut self, i: &'ast Lifetime) { + self.is_concrete = self.is_concrete + && self.skeleton.lifetimes.iter().all(|l| l.lifetime != *i); + } + + fn visit_ident(&mut self, i: &'ast Ident) { + self.is_concrete = self.is_concrete + && self.skeleton.types.iter().all(|t| t.ident != *i); + self.is_concrete = self.is_concrete + && self.skeleton.consts.iter().all(|c| c.ident != *i); + } +} diff --git a/macros/src/lib.rs b/macros/src/lib.rs new file mode 100644 index 000000000..74ce48337 --- /dev/null +++ b/macros/src/lib.rs @@ -0,0 +1,724 @@ +//! Procedural macros for [`domain`]. +//! +//! [`domain`]: https://docs.rs/domain + +use proc_macro as pm; +use proc_macro2::TokenStream; +use quote::{format_ident, ToTokens}; +use syn::{Error, Ident, Result}; + +mod impls; +use impls::ImplSkeleton; + +mod data; +use data::Struct; + +mod repr; +use repr::Repr; + +//----------- SplitBytes ----------------------------------------------------- + +#[proc_macro_derive(SplitBytes)] +pub fn derive_split_bytes(input: pm::TokenStream) -> pm::TokenStream { + fn inner(input: syn::DeriveInput) -> Result { + let data = match &input.data { + syn::Data::Struct(data) => data, + syn::Data::Enum(data) => { + return Err(Error::new_spanned( + data.enum_token, + "'SplitBytes' can only be 'derive'd for 'struct's", + )); + } + syn::Data::Union(data) => { + return Err(Error::new_spanned( + data.union_token, + "'SplitBytes' can only be 'derive'd for 'struct's", + )); + } + }; + + // Construct an 'ImplSkeleton' so that we can add trait bounds. + let mut skeleton = ImplSkeleton::new(&input, false); + + // Add the parsing lifetime to the 'impl'. + let (lifetime, param) = skeleton.new_lifetime_param( + "bytes", + skeleton.lifetimes.iter().map(|l| l.lifetime.clone()), + ); + skeleton.lifetimes.push(param); + skeleton.bound = Some( + syn::parse_quote!(::domain::new_base::wire::SplitBytes<#lifetime>), + ); + + // Inspect the 'struct' fields. + let data = Struct::new_as_self(&data.fields); + let builder = data.builder(field_prefixed); + + // Establish bounds on the fields. + for field in data.fields() { + skeleton.require_bound( + field.ty.clone(), + syn::parse_quote!(::domain::new_base::wire::SplitBytes<#lifetime>), + ); + } + + // Define 'parse_bytes()'. + let init_vars = builder.init_vars(); + let tys = data.fields().map(|f| &f.ty); + skeleton.contents.stmts.push(syn::parse_quote! { + fn split_bytes( + bytes: & #lifetime [::domain::__core::primitive::u8], + ) -> ::domain::__core::result::Result< + (Self, & #lifetime [::domain::__core::primitive::u8]), + ::domain::new_base::wire::ParseError, + > { + #(let (#init_vars, bytes) = + <#tys as ::domain::new_base::wire::SplitBytes<#lifetime>> + ::split_bytes(bytes)?;)* + Ok((#builder, bytes)) + } + }); + + Ok(skeleton.into_token_stream()) + } + + let input = syn::parse_macro_input!(input as syn::DeriveInput); + inner(input) + .unwrap_or_else(syn::Error::into_compile_error) + .into() +} + +//----------- ParseBytes ----------------------------------------------------- + +#[proc_macro_derive(ParseBytes)] +pub fn derive_parse_bytes(input: pm::TokenStream) -> pm::TokenStream { + fn inner(input: syn::DeriveInput) -> Result { + let data = match &input.data { + syn::Data::Struct(data) => data, + syn::Data::Enum(data) => { + return Err(Error::new_spanned( + data.enum_token, + "'ParseBytes' can only be 'derive'd for 'struct's", + )); + } + syn::Data::Union(data) => { + return Err(Error::new_spanned( + data.union_token, + "'ParseBytes' can only be 'derive'd for 'struct's", + )); + } + }; + + // Construct an 'ImplSkeleton' so that we can add trait bounds. + let mut skeleton = ImplSkeleton::new(&input, false); + + // Add the parsing lifetime to the 'impl'. + let (lifetime, param) = skeleton.new_lifetime_param( + "bytes", + skeleton.lifetimes.iter().map(|l| l.lifetime.clone()), + ); + skeleton.lifetimes.push(param); + skeleton.bound = Some( + syn::parse_quote!(::domain::new_base::wire::ParseBytes<#lifetime>), + ); + + // Inspect the 'struct' fields. + let data = Struct::new_as_self(&data.fields); + let builder = data.builder(field_prefixed); + + // Establish bounds on the fields. + for field in data.sized_fields() { + skeleton.require_bound( + field.ty.clone(), + syn::parse_quote!(::domain::new_base::wire::SplitBytes<#lifetime>), + ); + } + if let Some(field) = data.unsized_field() { + skeleton.require_bound( + field.ty.clone(), + syn::parse_quote!(::domain::new_base::wire::ParseBytes<#lifetime>), + ); + } + + // Finish early if the 'struct' has no fields. + if data.is_empty() { + skeleton.contents.stmts.push(syn::parse_quote! { + fn parse_bytes( + bytes: & #lifetime [::domain::__core::primitive::u8], + ) -> ::domain::__core::result::Result< + Self, + ::domain::new_base::wire::ParseError, + > { + if bytes.is_empty() { + Ok(#builder) + } else { + Err(::domain::new_base::wire::ParseError) + } + } + }); + + return Ok(skeleton.into_token_stream()); + } + + // Define 'parse_bytes()'. + let init_vars = builder.sized_init_vars(); + let tys = builder.sized_fields().map(|f| &f.ty); + let unsized_ty = &builder.unsized_field().unwrap().ty; + let unsized_init_var = builder.unsized_init_var().unwrap(); + skeleton.contents.stmts.push(syn::parse_quote! { + fn parse_bytes( + bytes: & #lifetime [::domain::__core::primitive::u8], + ) -> ::domain::__core::result::Result< + Self, + ::domain::new_base::wire::ParseError, + > { + #(let (#init_vars, bytes) = + <#tys as ::domain::new_base::wire::SplitBytes<#lifetime>> + ::split_bytes(bytes)?;)* + let #unsized_init_var = + <#unsized_ty as ::domain::new_base::wire::ParseBytes<#lifetime>> + ::parse_bytes(bytes)?; + Ok(#builder) + } + }); + + Ok(skeleton.into_token_stream()) + } + + let input = syn::parse_macro_input!(input as syn::DeriveInput); + inner(input) + .unwrap_or_else(syn::Error::into_compile_error) + .into() +} + +//----------- SplitBytesByRef ------------------------------------------------ + +#[proc_macro_derive(SplitBytesByRef)] +pub fn derive_split_bytes_by_ref(input: pm::TokenStream) -> pm::TokenStream { + fn inner(input: syn::DeriveInput) -> Result { + let data = match &input.data { + syn::Data::Struct(data) => data, + syn::Data::Enum(data) => { + return Err(Error::new_spanned( + data.enum_token, + "'SplitBytesByRef' can only be 'derive'd for 'struct's", + )); + } + syn::Data::Union(data) => { + return Err(Error::new_spanned( + data.union_token, + "'SplitBytesByRef' can only be 'derive'd for 'struct's", + )); + } + }; + + let _ = Repr::determine(&input.attrs, "SplitBytesByRef")?; + + // Construct an 'ImplSkeleton' so that we can add trait bounds. + let mut skeleton = ImplSkeleton::new(&input, true); + skeleton.bound = Some(syn::parse_quote!( + ::domain::new_base::wire::SplitBytesByRef + )); + + // Inspect the 'struct' fields. + let data = Struct::new_as_self(&data.fields); + + // Establish bounds on the fields. + for field in data.fields() { + skeleton.require_bound( + field.ty.clone(), + syn::parse_quote!(::domain::new_base::wire::SplitBytesByRef), + ); + } + + // Finish early if the 'struct' has no fields. + if data.is_empty() { + skeleton.contents.stmts.push(syn::parse_quote! { + fn split_bytes_by_ref( + bytes: &[::domain::__core::primitive::u8], + ) -> ::domain::__core::result::Result< + (&Self, &[::domain::__core::primitive::u8]), + ::domain::new_base::wire::ParseError, + > { + Ok(( + // SAFETY: 'Self' is a 'struct' with no fields, + // and so has size 0 and alignment 1. It can be + // constructed at any address. + unsafe { &*bytes.as_ptr().cast::() }, + bytes, + )) + } + }); + + return Ok(skeleton.into_token_stream()); + } + + // Define 'split_bytes_by_ref()'. + let tys = data.sized_fields().map(|f| &f.ty); + let unsized_ty = &data.unsized_field().unwrap().ty; + skeleton.contents.stmts.push(syn::parse_quote! { + fn split_bytes_by_ref( + bytes: &[::domain::__core::primitive::u8], + ) -> ::domain::__core::result::Result< + (&Self, &[::domain::__core::primitive::u8]), + ::domain::new_base::wire::ParseError, + > { + let start = bytes.as_ptr(); + #(let (_, bytes) = + <#tys as ::domain::new_base::wire::SplitBytesByRef> + ::split_bytes_by_ref(bytes)?;)* + let (last, rest) = + <#unsized_ty as ::domain::new_base::wire::SplitBytesByRef> + ::split_bytes_by_ref(bytes)?; + let ptr = + <#unsized_ty as ::domain::new_base::wire::ParseBytesByRef> + ::ptr_with_address(last, start as *const ()); + + // SAFETY: + // - The original 'bytes' contained a valid instance of every + // field in 'Self', in succession. + // - Every field implements 'ParseBytesByRef' and so has no + // alignment restriction. + // - 'Self' is unaligned, since every field is unaligned, and + // any explicit alignment modifiers only make it unaligned. + // - 'start' is thus the start of a valid instance of 'Self'. + // - 'ptr' has the same address as 'start' but can be cast to + // 'Self', since it has the right pointer metadata. + Ok((unsafe { &*(ptr as *const Self) }, rest)) + } + }); + + // Define 'split_bytes_by_mut()'. + let tys = data.sized_fields().map(|f| &f.ty); + skeleton.contents.stmts.push(syn::parse_quote! { + fn split_bytes_by_mut( + bytes: &mut [::domain::__core::primitive::u8], + ) -> ::domain::__core::result::Result< + (&mut Self, &mut [::domain::__core::primitive::u8]), + ::domain::new_base::wire::ParseError, + > { + let start = bytes.as_ptr(); + #(let (_, bytes) = + <#tys as ::domain::new_base::wire::SplitBytesByRef> + ::split_bytes_by_mut(bytes)?;)* + let (last, rest) = + <#unsized_ty as ::domain::new_base::wire::SplitBytesByRef> + ::split_bytes_by_mut(bytes)?; + let ptr = + <#unsized_ty as ::domain::new_base::wire::ParseBytesByRef> + ::ptr_with_address(last, start as *const ()); + + // SAFETY: + // - The original 'bytes' contained a valid instance of every + // field in 'Self', in succession. + // - Every field implements 'ParseBytesByRef' and so has no + // alignment restriction. + // - 'Self' is unaligned, since every field is unaligned, and + // any explicit alignment modifiers only make it unaligned. + // - 'start' is thus the start of a valid instance of 'Self'. + // - 'ptr' has the same address as 'start' but can be cast to + // 'Self', since it has the right pointer metadata. + Ok((unsafe { &mut *(ptr as *const Self as *mut Self) }, rest)) + } + }); + + Ok(skeleton.into_token_stream()) + } + + let input = syn::parse_macro_input!(input as syn::DeriveInput); + inner(input) + .unwrap_or_else(syn::Error::into_compile_error) + .into() +} + +//----------- ParseBytesByRef ------------------------------------------------ + +#[proc_macro_derive(ParseBytesByRef)] +pub fn derive_parse_bytes_by_ref(input: pm::TokenStream) -> pm::TokenStream { + fn inner(input: syn::DeriveInput) -> Result { + let data = match &input.data { + syn::Data::Struct(data) => data, + syn::Data::Enum(data) => { + return Err(Error::new_spanned( + data.enum_token, + "'ParseBytesByRef' can only be 'derive'd for 'struct's", + )); + } + syn::Data::Union(data) => { + return Err(Error::new_spanned( + data.union_token, + "'ParseBytesByRef' can only be 'derive'd for 'struct's", + )); + } + }; + + let _ = Repr::determine(&input.attrs, "ParseBytesByRef")?; + + // Construct an 'ImplSkeleton' so that we can add trait bounds. + let mut skeleton = ImplSkeleton::new(&input, true); + skeleton.bound = Some(syn::parse_quote!( + ::domain::new_base::wire::ParseBytesByRef + )); + + // Inspect the 'struct' fields. + let data = Struct::new_as_self(&data.fields); + + // Establish bounds on the fields. + for field in data.sized_fields() { + skeleton.require_bound( + field.ty.clone(), + syn::parse_quote!(::domain::new_base::wire::SplitBytesByRef), + ); + } + if let Some(field) = data.unsized_field() { + skeleton.require_bound( + field.ty.clone(), + syn::parse_quote!(::domain::new_base::wire::ParseBytesByRef), + ); + } + + // Finish early if the 'struct' has no fields. + if data.is_empty() { + skeleton.contents.stmts.push(syn::parse_quote! { + fn parse_bytes_by_ref( + bytes: &[::domain::__core::primitive::u8], + ) -> ::domain::__core::result::Result< + &Self, + ::domain::new_base::wire::ParseError, + > { + if bytes.is_empty() { + // SAFETY: 'Self' is a 'struct' with no fields, + // and so has size 0 and alignment 1. It can be + // constructed at any address. + Ok(unsafe { &*bytes.as_ptr().cast::() }) + } else { + Err(::domain::new_base::wire::ParseError) + } + } + }); + + skeleton.contents.stmts.push(syn::parse_quote! { + fn ptr_with_address( + &self, + addr: *const (), + ) -> *const Self { + addr.cast() + } + }); + + return Ok(skeleton.into_token_stream()); + } + + // Define 'parse_bytes_by_ref()'. + let tys = data.sized_fields().map(|f| &f.ty); + let unsized_ty = &data.unsized_field().unwrap().ty; + skeleton.contents.stmts.push(syn::parse_quote! { + fn parse_bytes_by_ref( + bytes: &[::domain::__core::primitive::u8], + ) -> ::domain::__core::result::Result< + &Self, + ::domain::new_base::wire::ParseError, + > { + let start = bytes.as_ptr(); + #(let (_, bytes) = + <#tys as ::domain::new_base::wire::SplitBytesByRef> + ::split_bytes_by_ref(bytes)?;)* + let last = + <#unsized_ty as ::domain::new_base::wire::ParseBytesByRef> + ::parse_bytes_by_ref(bytes)?; + let ptr = + <#unsized_ty as ::domain::new_base::wire::ParseBytesByRef> + ::ptr_with_address(last, start as *const ()); + + // SAFETY: + // - The original 'bytes' contained a valid instance of every + // field in 'Self', in succession. + // - Every field implements 'ParseBytesByRef' and so has no + // alignment restriction. + // - 'Self' is unaligned, since every field is unaligned, and + // any explicit alignment modifiers only make it unaligned. + // - 'start' is thus the start of a valid instance of 'Self'. + // - 'ptr' has the same address as 'start' but can be cast to + // 'Self', since it has the right pointer metadata. + Ok(unsafe { &*(ptr as *const Self) }) + } + }); + + // Define 'parse_bytes_by_mut()'. + let tys = data.sized_fields().map(|f| &f.ty); + skeleton.contents.stmts.push(syn::parse_quote! { + fn parse_bytes_by_mut( + bytes: &mut [::domain::__core::primitive::u8], + ) -> ::domain::__core::result::Result< + &mut Self, + ::domain::new_base::wire::ParseError, + > { + let start = bytes.as_ptr(); + #(let (_, bytes) = + <#tys as ::domain::new_base::wire::SplitBytesByRef> + ::split_bytes_by_mut(bytes)?;)* + let last = + <#unsized_ty as ::domain::new_base::wire::ParseBytesByRef> + ::parse_bytes_by_mut(bytes)?; + let ptr = + <#unsized_ty as ::domain::new_base::wire::ParseBytesByRef> + ::ptr_with_address(last, start as *const ()); + + // SAFETY: + // - The original 'bytes' contained a valid instance of every + // field in 'Self', in succession. + // - Every field implements 'ParseBytesByRef' and so has no + // alignment restriction. + // - 'Self' is unaligned, since every field is unaligned, and + // any explicit alignment modifiers only make it unaligned. + // - 'start' is thus the start of a valid instance of 'Self'. + // - 'ptr' has the same address as 'start' but can be cast to + // 'Self', since it has the right pointer metadata. + Ok(unsafe { &mut *(ptr as *const Self as *mut Self) }) + } + }); + + // Define 'ptr_with_address()'. + let unsized_member = data.unsized_member(); + skeleton.contents.stmts.push(syn::parse_quote! { + fn ptr_with_address(&self, addr: *const ()) -> *const Self { + <#unsized_ty as ::domain::new_base::wire::ParseBytesByRef> + ::ptr_with_address(&self.#unsized_member, addr) + as *const Self + } + }); + + Ok(skeleton.into_token_stream()) + } + + let input = syn::parse_macro_input!(input as syn::DeriveInput); + inner(input) + .unwrap_or_else(syn::Error::into_compile_error) + .into() +} + +//----------- BuildBytes ----------------------------------------------------- + +#[proc_macro_derive(BuildBytes)] +pub fn derive_build_bytes(input: pm::TokenStream) -> pm::TokenStream { + fn inner(input: syn::DeriveInput) -> Result { + let data = match &input.data { + syn::Data::Struct(data) => data, + syn::Data::Enum(data) => { + return Err(Error::new_spanned( + data.enum_token, + "'BuildBytes' can only be 'derive'd for 'struct's", + )); + } + syn::Data::Union(data) => { + return Err(Error::new_spanned( + data.union_token, + "'BuildBytes' can only be 'derive'd for 'struct's", + )); + } + }; + + // Construct an 'ImplSkeleton' so that we can add trait bounds. + let mut skeleton = ImplSkeleton::new(&input, false); + skeleton.bound = + Some(syn::parse_quote!(::domain::new_base::wire::BuildBytes)); + + // Inspect the 'struct' fields. + let data = Struct::new_as_self(&data.fields); + + // Get a lifetime for the input buffer. + let lifetime = skeleton.new_lifetime("bytes"); + + // Establish bounds on the fields. + for field in data.fields() { + skeleton.require_bound( + field.ty.clone(), + syn::parse_quote!(::domain::new_base::wire::BuildBytes), + ); + } + + // Define 'build_bytes()'. + let members = data.members(); + let tys = data.fields().map(|f| &f.ty); + skeleton.contents.stmts.push(syn::parse_quote! { + fn build_bytes<#lifetime>( + &self, + mut bytes: & #lifetime mut [::domain::__core::primitive::u8], + ) -> ::domain::__core::result::Result< + & #lifetime mut [::domain::__core::primitive::u8], + ::domain::new_base::wire::TruncationError, + > { + #(bytes = <#tys as ::domain::new_base::wire::BuildBytes> + ::build_bytes(&self.#members, bytes)?;)* + Ok(bytes) + } + }); + + Ok(skeleton.into_token_stream()) + } + + let input = syn::parse_macro_input!(input as syn::DeriveInput); + inner(input) + .unwrap_or_else(syn::Error::into_compile_error) + .into() +} + +//----------- AsBytes -------------------------------------------------------- + +#[proc_macro_derive(AsBytes)] +pub fn derive_as_bytes(input: pm::TokenStream) -> pm::TokenStream { + fn inner(input: syn::DeriveInput) -> Result { + let data = match &input.data { + syn::Data::Struct(data) => data, + syn::Data::Enum(data) => { + return Err(Error::new_spanned( + data.enum_token, + "'AsBytes' can only be 'derive'd for 'struct's", + )); + } + syn::Data::Union(data) => { + return Err(Error::new_spanned( + data.union_token, + "'AsBytes' can only be 'derive'd for 'struct's", + )); + } + }; + + let _ = Repr::determine(&input.attrs, "AsBytes")?; + + // Construct an 'ImplSkeleton' so that we can add trait bounds. + let mut skeleton = ImplSkeleton::new(&input, true); + skeleton.bound = + Some(syn::parse_quote!(::domain::new_base::wire::AsBytes)); + + // Establish bounds on the fields. + for field in data.fields.iter() { + skeleton.require_bound( + field.ty.clone(), + syn::parse_quote!(::domain::new_base::wire::AsBytes), + ); + } + + // The default implementation of 'as_bytes()' works perfectly. + + Ok(skeleton.into_token_stream()) + } + + let input = syn::parse_macro_input!(input as syn::DeriveInput); + inner(input) + .unwrap_or_else(syn::Error::into_compile_error) + .into() +} + +//----------- UnsizedClone --------------------------------------------------- + +#[proc_macro_derive(UnsizedClone)] +pub fn derive_unsized_clone(input: pm::TokenStream) -> pm::TokenStream { + fn inner(input: syn::DeriveInput) -> Result { + // Construct an 'ImplSkeleton' so that we can add trait bounds. + let mut skeleton = ImplSkeleton::new(&input, true); + skeleton.bound = + Some(syn::parse_quote!(::domain::utils::UnsizedClone)); + + let struct_data = match &input.data { + syn::Data::Struct(data) if !data.fields.is_empty() => { + let data = Struct::new_as_self(&data.fields); + for field in data.sized_fields() { + skeleton.require_bound( + field.ty.clone(), + syn::parse_quote!(::domain::__core::clone::Clone), + ); + } + + skeleton.require_bound( + data.unsized_field().unwrap().ty.clone(), + syn::parse_quote!(::domain::utils::UnsizedClone), + ); + + Some(data) + } + + syn::Data::Struct(_) => None, + + syn::Data::Enum(data) => { + for variant in data.variants.iter() { + for field in variant.fields.iter() { + skeleton.require_bound( + field.ty.clone(), + syn::parse_quote!(::domain::__core::clone::Clone), + ); + } + } + + None + } + + syn::Data::Union(data) => { + return Err(Error::new_spanned( + data.union_token, + "'UnsizedClone' cannot be 'derive'd for 'union's", + )); + } + }; + + if let Some(data) = struct_data { + let sized_members = data.sized_members(); + let unsized_member = data.unsized_member().unwrap(); + + skeleton.contents.stmts.push(syn::parse_quote! { + unsafe fn unsized_clone(&self, dst: *mut ()) { + let dst = ::domain::utils::UnsizedClone::ptr_with_address(self, dst); + unsafe { + #(::domain::__core::ptr::write( + ::domain::__core::ptr::addr_of_mut!((*dst).#sized_members), + ::domain::__core::clone::Clone::clone(&self.#sized_members), + );)* + ::domain::utils::UnsizedClone::unsized_clone( + &self.#unsized_member, + ::domain::__core::ptr::addr_of_mut!((*dst).#unsized_member) as *mut (), + ); + } + } + }); + + skeleton.contents.stmts.push(syn::parse_quote! { + fn ptr_with_address(&self, addr: *mut ()) -> *mut Self { + ::domain::utils::UnsizedClone::ptr_with_address( + &self.#unsized_member, + addr, + ) as *mut Self + } + }); + } else { + skeleton.contents.stmts.push(syn::parse_quote! { + unsafe fn unsized_clone(&self, dst: *mut ()) { + let dst = dst as *mut Self; + let this = ::domain::__core::clone::Clone::clone(self); + unsafe { + ::domain::__core::ptr::write(dst as *mut Self, this); + } + } + }); + + skeleton.contents.stmts.push(syn::parse_quote! { + fn ptr_with_address(&self, addr: *mut ()) -> *mut Self { + addr as *mut Self + } + }); + } + + Ok(skeleton.into_token_stream()) + } + + let input = syn::parse_macro_input!(input as syn::DeriveInput); + inner(input) + .unwrap_or_else(syn::Error::into_compile_error) + .into() +} + +//----------- Utility Functions ---------------------------------------------- + +/// Add a `field_` prefix to member names. +fn field_prefixed(member: syn::Member) -> Ident { + format_ident!("field_{}", member) +} diff --git a/macros/src/repr.rs b/macros/src/repr.rs new file mode 100644 index 000000000..b699b571b --- /dev/null +++ b/macros/src/repr.rs @@ -0,0 +1,91 @@ +//! Determining the memory layout of a type. + +use proc_macro2::Span; +use syn::{ + punctuated::Punctuated, spanned::Spanned, Attribute, Error, LitInt, Meta, + Token, +}; + +//----------- Repr ----------------------------------------------------------- + +/// The memory representation of a type. +#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] +pub enum Repr { + /// Transparent to an underlying field. + Transparent, + + /// Compatible with C. + C, +} + +impl Repr { + /// Determine the representation for a type from its attributes. + /// + /// This will fail if a stable representation cannot be found. + pub fn determine( + attrs: &[Attribute], + bound: &str, + ) -> Result { + let mut repr = None; + for attr in attrs { + if !attr.path().is_ident("repr") { + continue; + } + + let nested = attr.parse_args_with( + Punctuated::::parse_terminated, + )?; + + // We don't check for consistency in the 'repr' attributes, since + // the compiler should be doing that for us anyway. This lets us + // ignore conflicting 'repr's entirely. + for meta in nested { + match meta { + Meta::Path(p) if p.is_ident("transparent") => { + repr = Some(Repr::Transparent); + } + + Meta::Path(p) if p.is_ident("C") => { + repr = Some(Repr::C); + } + + Meta::Path(p) if p.is_ident("Rust") => { + return Err(Error::new_spanned(p, + format!("repr(Rust) is not stable, cannot derive {bound} for it"))); + } + + Meta::Path(p) if p.is_ident("packed") => { + // The alignment can be set to 1 safely. + } + + Meta::List(meta) + if meta.path.is_ident("packed") + || meta.path.is_ident("aligned") => + { + let span = meta.span(); + let lit: LitInt = syn::parse2(meta.tokens)?; + let n: usize = lit.base10_parse()?; + if n != 1 { + return Err(Error::new(span, + format!("'Self' must be unaligned to derive {bound}"))); + } + } + + meta => { + // We still need to error out here, in case a future + // version of Rust introduces more memory layout data + return Err(Error::new_spanned( + meta, + "unrecognized repr attribute", + )); + } + } + } + } + + repr.ok_or_else(|| { + Error::new(Span::call_site(), + "repr(C) or repr(transparent) must be specified to derive this") + }) + } +} diff --git a/src/lib.rs b/src/lib.rs index ff1b81e00..e38ebcaa7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -184,8 +184,17 @@ #[macro_use] extern crate std; -#[macro_use] -extern crate core; +// The 'domain-macros' crate introduces 'derive' macros which can be used by +// users of the 'domain' crate, but also by the 'domain' crate itself. Within +// those macros, references to declarations in the 'domain' crate are written +// as '::domain::*' ... but this doesn't work when those proc macros are used +// in the 'domain' crate itself. The alias introduced here fixes this: now +// '::domain' means the same thing within this crate as in dependents of it. +extern crate self as domain; + +// Re-export 'core' for use in macros. +#[doc(hidden)] +pub use core as __core; pub mod base; pub mod dep; @@ -200,3 +209,7 @@ pub mod validate; pub mod validator; pub mod zonefile; pub mod zonetree; + +pub mod new_base; +pub mod new_edns; +pub mod new_rdata; diff --git a/src/new_base/build/builder.rs b/src/new_base/build/builder.rs new file mode 100644 index 000000000..175c8bd78 --- /dev/null +++ b/src/new_base/build/builder.rs @@ -0,0 +1,441 @@ +//! A builder for DNS messages. + +use core::{ + cell::UnsafeCell, + mem::ManuallyDrop, + ptr::{self}, + slice, +}; + +use crate::new_base::{ + name::RevName, + wire::{BuildBytes, TruncationError}, +}; + +use super::{BuildCommitted, BuilderContext}; + +//----------- Builder -------------------------------------------------------- + +/// A DNS wire format serializer. +/// +/// This can be used to write arbitrary bytes and (compressed) domain names to +/// a buffer containing a DNS message. It is a low-level interface, providing +/// the foundations for high-level builder types. +/// +/// In order to build a regular DNS message, users would typically look to +/// [`MessageBuilder`](super::MessageBuilder). This offers the high-level +/// interface (with methods to append questions and records) that most users +/// need. +/// +/// # Committing and Delegation +/// +/// [`Builder`] provides an "atomic" interface: if a function fails while +/// building a DNS message using a [`Builder`], any partial content added by +/// the [`Builder`] will be reverted. The content of a [`Builder`] is only +/// confirmed when [`Builder::commit()`] is called. +/// +/// It is useful to first describe what "building functions" look like. While +/// they may take additional arguments, their signatures are usually: +/// +/// ```no_run +/// # use domain::new_base::build::{Builder, BuildResult}; +/// +/// fn foo(mut builder: Builder<'_>) -> BuildResult { +/// // Append to the message using 'builder'. +/// +/// // Commit all appended content and return successfully. +/// Ok(builder.commit()) +/// } +/// ``` +/// +/// Note that the builder is taken by value; if an error occurs, and the +/// function returns early, `builder` will be dropped, and its drop code will +/// revert all uncommitted changes. However, if building is successful, the +/// appended content is committed, and so will not be reverted. +/// +/// If `foo` were to call another function with the same signature, it would +/// need to create a new [`Builder`] to pass in by value. This [`Builder`] +/// should refer to the same message buffer, but should have not report any +/// uncommitted content (so that only the content added by the called function +/// will be reverted on failure). For this, we have [`delegate()`]. +/// +/// [`delegate()`]: Self::delegate() +pub struct Builder<'b> { + /// The contents of the built message. + /// + /// The buffer is divided into three parts: + /// + /// - Committed message contents (borrowed *immutably* by this type). + /// - Appended message contents (borrowed mutably by this type). + /// - Uninitialized message contents (borrowed mutably by this type). + contents: &'b UnsafeCell<[u8]>, + + /// Context for building. + context: &'b mut BuilderContext, + + /// The start point of this builder. + /// + /// Message contents up to this point are committed and cannot be removed + /// by this builder. Message contents following this (up to the size in + /// the builder context) are appended but uncommitted. + start: usize, +} + +impl<'b> Builder<'b> { + /// Construct a [`Builder`] from raw parts. + /// + /// # Safety + /// + /// The expression `from_raw_parts(contents, context, start)` is sound if + /// and only if all of the following conditions are satisfied: + /// + /// - `message[..start]` is immutably borrowed for `'b`. + /// - `message[start..]` is mutably borrowed for `'b`. + /// + /// - `message` and `context` originate from the same builder. + /// - `start <= context.size() <= message.len()`. + pub unsafe fn from_raw_parts( + contents: &'b UnsafeCell<[u8]>, + context: &'b mut BuilderContext, + start: usize, + ) -> Self { + Self { + contents, + context, + start, + } + } +} + +/// # Inspection +/// +/// A [`Builder`] references a message buffer to write into. That buffer is +/// broken down into the following segments: +/// +/// ```text +/// name | position +/// --------------+--------- +/// committed | 0 .. start +/// appended | start .. offset +/// uninitialized | offset .. limit +/// inaccessible | limit .. +/// ``` +/// +/// The committed content of the builder is immutable, and is available to +/// reference, through [`committed()`], for the lifetime `'b`. +/// +/// [`committed()`]: Self::committed() +/// +/// The appended but uncommitted content of the builder is made available via +/// [`uncommitted_mut()`]. It is content that has been added by this builder, +/// but that has not yet been committed. When the [`Builder`] is dropped, +/// this content is removed (it becomes uninitialized). Appended content can +/// be modified, but any compressed names within it have to be handled with +/// great care; they can only be modified by removing them entirely (by +/// rewinding the builder, using [`rewind()`]) and building them again. When +/// compressed names are guaranteed to not be modified, [`uncommitted_mut()`] +/// can be used. +/// +/// [`appended()`]: Self::appended() +/// [`rewind()`]: Self::rewind() +/// [`uncommitted_mut()`]: Self::uncommitted_mut() +/// +/// The uninitialized space in the builder will be written to when appending +/// new content. It can be accessed directly, in case that is more efficient +/// for building, using [`uninitialized()`]. [`mark_appended()`] can be used +/// to specify how many bytes were initialized. +/// +/// [`uninitialized()`]: Self::uninitialized() +/// [`mark_appended()`]: Self::mark_appended() +/// +/// The inaccessible space of a builder cannot be written to. While it exists +/// in the underlying message buffer, it has been made inaccessible so that +/// the built message fits within certain size constraints. A message's size +/// can be limited using [`limit_to()`], but this only applies to the current +/// builder (and its delegates); parent builders are unaffected by it. +/// +/// [`limit_to()`]: Self::limit_to() +impl<'b> Builder<'b> { + /// Committed message contents. + pub fn committed(&self) -> &'b [u8] { + let message = self.contents.get().cast_const().cast(); + // SAFETY: 'message[..start]' is immutably borrowed. + unsafe { slice::from_raw_parts(message, self.start) } + } + + /// Appended (and committed) message contents. + pub fn appended(&self) -> &[u8] { + let message = self.contents.get().cast_const().cast(); + // SAFETY: 'message[..offset]' is (im)mutably borrowed. + unsafe { slice::from_raw_parts(message, self.context.size) } + } + + /// The appended but uncommitted contents of the message. + /// + /// The builder can modify or rewind these contents, so they are offered + /// with a short lifetime. + pub fn uncommitted(&self) -> &[u8] { + let message = self.contents.get().cast::().cast_const(); + // SAFETY: It is guaranteed that 'start <= message.len()'. + let message = unsafe { message.add(self.start) }; + let size = self.context.size - self.start; + // SAFETY: 'message[start..]' is mutably borrowed. + unsafe { slice::from_raw_parts(message, size) } + } + + /// The appended but uncommitted contents of the message, mutably. + /// + /// # Safety + /// + /// The caller must not modify any compressed names among these bytes. + /// This can invalidate name compression state. + pub unsafe fn uncommitted_mut(&mut self) -> &mut [u8] { + let message = self.contents.get().cast::(); + // SAFETY: It is guaranteed that 'start <= message.len()'. + let message = unsafe { message.add(self.start) }; + let size = self.context.size - self.start; + // SAFETY: 'message[start..]' is mutably borrowed. + unsafe { slice::from_raw_parts_mut(message, size) } + } + + /// Uninitialized space in the message buffer. + /// + /// When the first `n` bytes of the returned buffer are initialized, and + /// should be treated as appended content in the message, call + /// [`self.mark_appended(n)`](Self::mark_appended()). + pub fn uninitialized(&mut self) -> &mut [u8] { + let message = self.contents.get().cast::(); + // SAFETY: It is guaranteed that 'size <= message.len()'. + let message = unsafe { message.add(self.context.size) }; + let size = self.max_size() - self.context.size; + // SAFETY: 'message[size..]' is mutably borrowed. + unsafe { slice::from_raw_parts_mut(message, size) } + } + + /// The builder context. + pub fn context(&self) -> &BuilderContext { + &*self.context + } + + /// The start point of this builder. + /// + /// This is the offset into the message contents at which this builder was + /// initialized. The content before this point has been committed and is + /// immutable. The builder can be rewound up to this point. + pub fn start(&self) -> usize { + self.start + } + + /// The append point of this builder. + /// + /// This is the offset into the message contents at which new data will be + /// written. The content after this point is uninitialized. + pub fn offset(&self) -> usize { + self.context.size + } + + /// The size limit of this builder. + /// + /// This is the maximum size the message contents can grow to; beyond it, + /// [`TruncationError`]s will occur. The limit can be tightened using + /// [`limit_to()`](Self::limit_to()). + pub fn max_size(&self) -> usize { + // SAFETY: We can cast 'contents' to another slice type and the + // pointer representation is unchanged. By using a slice type of ZST + // elements, aliasing is impossible, and it can be dereferenced + // safely. + unsafe { &*(self.contents.get() as *mut [()]) }.len() + } + + /// Decompose this builder into raw parts. + /// + /// This returns three components: + /// + /// - The message buffer. The committed contents of the message (the + /// first `commit` bytes of the message contents) are borrowed immutably + /// for the lifetime `'b`. The remainder of the message buffer is + /// borrowed mutably for the lifetime `'b`. + /// + /// - Context for this builder. + /// + /// - The amount of data committed in the message (`commit`). + /// + /// The builder can be recomposed with [`Self::from_raw_parts()`]. + pub fn into_raw_parts( + self, + ) -> (&'b UnsafeCell<[u8]>, &'b mut BuilderContext, usize) { + // NOTE: The context has to be moved out carefully. + let (contents, start) = (self.contents, self.start); + let this = ManuallyDrop::new(self); + let this = (&*this) as *const Self; + // SAFETY: 'this' is a valid object that can be moved out of. + let context = unsafe { ptr::read(ptr::addr_of!((*this).context)) }; + (contents, context, start) + } +} + +/// # Interaction +/// +/// There are several ways to build up a DNS message using a [`Builder`]. +/// +/// When directly adding content, use [`append_bytes()`] or [`append_name()`]. +/// The former will add the bytes as-is, while the latter will compress domain +/// names. +/// +/// [`append_bytes()`]: Self::append_bytes() +/// [`append_name()`]: Self::append_name() +/// +/// When delegating to another builder method, use [`delegate()`]. This will +/// construct a new [`Builder`] that borrows from the current one. When the +/// method returns, the content it has committed will be registered as content +/// appended (but not committed) by the outer builder. If the method fails, +/// any content it tried to add will be removed automatically, and the outer +/// builder will be left unaffected. +/// +/// [`delegate()`]: Self::delegate() +/// +/// After all data is appended, call [`commit()`]. This will return a marker +/// type, [`BuildCommitted`], that may need to be returned to the caller. +/// +/// [`commit()`]: Self::commit() +/// +/// Some lower-level building methods are also available in the interest of +/// efficiency. Use [`append_with()`] if the amount of data to be written is +/// known upfront; it takes a closure to fill that space in the buffer. The +/// most general and efficient technique is to write into [`uninitialized()`] +/// and to mark the number of initialized bytes using [`mark_appended()`]. +/// +/// [`append_with()`]: Self::append_with() +/// [`uninitialized()`]: Self::uninitialized() +/// [`mark_appended()`]: Self::mark_appended() +impl Builder<'_> { + /// Rewind the builder, removing all uncommitted content. + pub fn rewind(&mut self) { + self.context.size = self.start; + } + + /// Commit the changes made by this builder. + /// + /// For convenience, a unit type [`BuildCommitted`] is returned; it is + /// used as the return type of build functions to remind users to call + /// this method on success paths. + pub fn commit(mut self) -> BuildCommitted { + // Update 'commit' so that the drop glue is a no-op. + self.start = self.context.size; + BuildCommitted + } + + /// Limit this builder to the given size. + /// + /// This builder, and all its delegates, will not allow the message + /// contents (i.e. excluding the 12-byte message header) to exceed the + /// specified size in bytes. If the message has already crossed that + /// limit, a [`TruncationError`] is returned. + pub fn limit_to(&mut self, size: usize) -> Result<(), TruncationError> { + if self.context.size <= size { + let message = self.contents.get().cast::(); + debug_assert!(size <= self.max_size()); + self.contents = unsafe { + &*(ptr::slice_from_raw_parts_mut(message, size) as *const _) + }; + Ok(()) + } else { + Err(TruncationError) + } + } + + /// Mark bytes in the buffer as initialized. + /// + /// The given number of bytes from the beginning of + /// [`Self::uninitialized()`] will be marked as initialized, and will be + /// treated as appended content in the buffer. + /// + /// # Panics + /// + /// Panics if the uninitialized buffer is smaller than the given number of + /// initialized bytes. + pub fn mark_appended(&mut self, amount: usize) { + assert!(self.max_size() - self.context.size >= amount); + self.context.size += amount; + } + + /// Delegate to a new builder. + /// + /// Any content committed by the builder will be added as uncommitted + /// content for this builder. + pub fn delegate(&mut self) -> Builder<'_> { + let commit = self.context.size; + unsafe { + Builder::from_raw_parts(self.contents, &mut *self.context, commit) + } + } + + /// Append data of a known size using a closure. + /// + /// All the requested bytes must be initialized. If not enough free space + /// could be obtained, a [`TruncationError`] is returned. + pub fn append_with( + &mut self, + size: usize, + fill: impl FnOnce(&mut [u8]), + ) -> Result<(), TruncationError> { + self.uninitialized() + .get_mut(..size) + .ok_or(TruncationError) + .map(fill) + .map(|()| self.context.size += size) + } + + /// Append some bytes. + /// + /// No name compression will be performed. + pub fn append_bytes( + &mut self, + bytes: &[u8], + ) -> Result<(), TruncationError> { + self.append_with(bytes.len(), |buffer| buffer.copy_from_slice(bytes)) + } + + /// Serialize an object into bytes and append it. + /// + /// No name compression will be performed. + pub fn append_built_bytes( + &mut self, + object: &impl BuildBytes, + ) -> Result<(), TruncationError> { + let rest = object.build_bytes(self.uninitialized())?.len(); + let appended = self.uninitialized().len() - rest; + self.mark_appended(appended); + Ok(()) + } + + /// Compress and append a domain name. + pub fn append_name( + &mut self, + name: &RevName, + ) -> Result<(), TruncationError> { + // TODO: Perform name compression. + name.build_bytes(self.uninitialized())?; + self.mark_appended(name.len()); + Ok(()) + } +} + +//--- Drop + +impl Drop for Builder<'_> { + fn drop(&mut self) { + // Drop uncommitted content. + self.rewind(); + } +} + +//--- Send, Sync + +// SAFETY: The parts of the referenced message that can be accessed mutably +// are not accessible by any reference other than `self`. +unsafe impl Send for Builder<'_> {} + +// SAFETY: Only parts of the referenced message that are borrowed immutably +// can be accessed through an immutable reference to `self`. +unsafe impl Sync for Builder<'_> {} diff --git a/src/new_base/build/context.rs b/src/new_base/build/context.rs new file mode 100644 index 000000000..e62ad265b --- /dev/null +++ b/src/new_base/build/context.rs @@ -0,0 +1,182 @@ +//! Context for building DNS messages. + +//----------- BuilderContext ------------------------------------------------- + +use crate::new_base::SectionCounts; + +/// Context for building a DNS message. +/// +/// This type holds auxiliary information necessary for building DNS messages, +/// e.g. name compression state. To construct it, call [`default()`]. +/// +/// [`default()`]: Self::default() +#[derive(Clone, Debug, Default)] +pub struct BuilderContext { + // TODO: Name compression. + /// The current size of the message contents. + pub size: usize, + + /// The state of the DNS message. + pub state: MessageState, +} + +//----------- MessageState --------------------------------------------------- + +/// The state of a DNS message being built. +/// +/// A DNS message consists of a header, questions, answers, authorities, and +/// additionals. [`MessageState`] remembers the start position of the last +/// question or record in the message, allowing it to be modifying or removed +/// (for additional flexibility in the building process). +#[derive(Clone, Debug, Default, PartialEq, Eq)] +pub enum MessageState { + /// Questions are being built. + /// + /// The message already contains zero or more DNS questions. If there is + /// a last DNS question, its start position is unknown, so it cannot be + /// modified or removed. + /// + /// This is the default state for an empty message. + #[default] + Questions, + + /// A question is being built. + /// + /// The message contains one or more DNS questions. The last question can + /// be modified or truncated. + MidQuestion { + /// The offset of the question name. + /// + /// The offset is measured from the start of the message contents. + name: u16, + }, + + /// Answer records are being built. + /// + /// The message already contains zero or more DNS answer records. If + /// there is a last DNS record, its start position is unknown, so it + /// cannot be modified or removed. + Answers, + + /// An answer record is being built. + /// + /// The message contains one or more DNS answer records. The last record + /// can be modified or truncated. + MidAnswer { + /// The offset of the record name. + /// + /// The offset is measured from the start of the message contents. + name: u16, + + /// The offset of the record data. + /// + /// The offset is measured from the start of the message contents. + data: u16, + }, + + /// Authority records are being built. + /// + /// The message already contains zero or more DNS authority records. If + /// there is a last DNS record, its start position is unknown, so it + /// cannot be modified or removed. + Authorities, + + /// An authority record is being built. + /// + /// The message contains one or more DNS authority records. The last + /// record can be modified or truncated. + MidAuthority { + /// The offset of the record name. + /// + /// The offset is measured from the start of the message contents. + name: u16, + + /// The offset of the record data. + /// + /// The offset is measured from the start of the message contents. + data: u16, + }, + + /// Additional records are being built. + /// + /// The message already contains zero or more DNS additional records. If + /// there is a last DNS record, its start position is unknown, so it + /// cannot be modified or removed. + Additionals, + + /// An additional record is being built. + /// + /// The message contains one or more DNS additional records. The last + /// record can be modified or truncated. + MidAdditional { + /// The offset of the record name. + /// + /// The offset is measured from the start of the message contents. + name: u16, + + /// The offset of the record data. + /// + /// The offset is measured from the start of the message contents. + data: u16, + }, +} + +impl MessageState { + /// The current section index. + /// + /// Questions, answers, authorities, and additionals are mapped to 0, 1, + /// 2, and 3, respectively. + pub const fn section_index(&self) -> u8 { + match self { + Self::Questions | Self::MidQuestion { .. } => 0, + Self::Answers | Self::MidAnswer { .. } => 1, + Self::Authorities | Self::MidAuthority { .. } => 2, + Self::Additionals | Self::MidAdditional { .. } => 3, + } + } + + /// Whether a question or record is being built. + pub const fn mid_component(&self) -> bool { + matches!( + self, + Self::MidQuestion { .. } + | Self::MidAnswer { .. } + | Self::MidAuthority { .. } + | Self::MidAdditional { .. } + ) + } + + /// Commit a question or record and update the section counts. + pub fn commit(&mut self, counts: &mut SectionCounts) { + match self { + Self::MidQuestion { .. } => { + counts.questions += 1; + *self = Self::Questions; + } + Self::MidAnswer { .. } => { + counts.answers += 1; + *self = Self::Answers; + } + Self::MidAuthority { .. } => { + counts.authorities += 1; + *self = Self::Authorities; + } + Self::MidAdditional { .. } => { + counts.additional += 1; + *self = Self::Additionals; + } + _ => {} + } + } + + /// Cancel a question or record. + pub fn cancel(&mut self) { + match self { + Self::MidQuestion { .. } => *self = Self::Questions, + Self::MidAnswer { .. } => *self = Self::Answers, + Self::MidAuthority { .. } => *self = Self::Authorities, + Self::MidAdditional { .. } => *self = Self::Additionals, + _ => {} + } + } +} diff --git a/src/new_base/build/message.rs b/src/new_base/build/message.rs new file mode 100644 index 000000000..de38af4fa --- /dev/null +++ b/src/new_base/build/message.rs @@ -0,0 +1,489 @@ +//! Building whole DNS messages. + +use core::cell::UnsafeCell; + +use crate::new_base::{ + wire::{ParseBytesByRef, TruncationError}, + Header, Message, Question, Record, +}; + +use super::{ + BuildIntoMessage, Builder, BuilderContext, MessageState, QuestionBuilder, + RecordBuilder, +}; + +//----------- MessageBuilder ------------------------------------------------- + +/// A builder for a whole DNS message. +/// +/// This is a high-level building interface, offering methods to put together +/// entire questions and records. It directly writes into an allocated buffer +/// (on the stack or the heap). +pub struct MessageBuilder<'b, 'c> { + /// The message being constructed. + pub(super) message: &'b mut Message, + + /// Context for building. + pub(super) context: &'c mut BuilderContext, +} + +//--- Initialization + +impl<'b, 'c> MessageBuilder<'b, 'c> { + /// Initialize an empty [`MessageBuilder`]. + /// + /// The message header is left uninitialized. use [`Self::header_mut()`] + /// to initialize it. + /// + /// # Panics + /// + /// Panics if the buffer is less than 12 bytes long (which is the minimum + /// possible size for a DNS message). + #[must_use] + pub fn new( + buffer: &'b mut [u8], + context: &'c mut BuilderContext, + ) -> Self { + let message = Message::parse_bytes_by_mut(buffer) + .expect("The caller's buffer is at least 12 bytes big"); + *context = BuilderContext::default(); + Self { message, context } + } +} + +//--- Inspection + +impl MessageBuilder<'_, '_> { + /// The message header. + #[must_use] + pub fn header(&self) -> &Header { + &self.message.header + } + + /// The message header, mutably. + #[must_use] + pub fn header_mut(&mut self) -> &mut Header { + &mut self.message.header + } + + /// The message built thus far. + #[must_use] + pub fn message(&self) -> &Message { + self.message.slice_to(self.context.size) + } + + /// The message built thus far, mutably. + /// + /// # Safety + /// + /// The caller must not modify any compressed names among these bytes. + /// This can invalidate name compression state. + #[must_use] + pub unsafe fn message_mut(&mut self) -> &mut Message { + self.message.slice_to_mut(self.context.size) + } + + /// The builder context. + #[must_use] + pub fn context(&self) -> &BuilderContext { + self.context + } +} + +//--- Interaction + +impl<'b> MessageBuilder<'b, '_> { + /// End the builder, returning the built message. + /// + /// The returned message is valid, but it can be modified by the caller + /// arbitrarily; avoid modifying the message beyond the header. + #[must_use] + pub fn finish(self) -> &'b mut Message { + self.message.slice_to_mut(self.context.size) + } + + /// Reborrow the builder with a shorter lifetime. + #[must_use] + pub fn reborrow(&mut self) -> MessageBuilder<'_, '_> { + MessageBuilder { + message: self.message, + context: self.context, + } + } + + /// Limit the total message size. + /// + /// The message will not be allowed to exceed the given size, in bytes. + /// Only the message header and contents are counted; the enclosing UDP + /// or TCP packet size is not considered. If the message already exceeds + /// this size, a [`TruncationError`] is returned. + /// + /// # Panics + /// + /// Panics if the given size is less than 12 bytes. + pub fn limit_to(&mut self, size: usize) -> Result<(), TruncationError> { + if 12 + self.context.size <= size { + // Move out of 'message' so that the full lifetime is available. + // See the 'replace_with' and 'take_mut' crates. + debug_assert!(size < 12 + self.message.contents.len()); + let message = unsafe { core::ptr::read(&self.message) }; + // NOTE: Precondition checked, will not panic. + let message = message.slice_to_mut(size - 12); + unsafe { core::ptr::write(&mut self.message, message) }; + Ok(()) + } else { + Err(TruncationError) + } + } + + /// Truncate the message. + /// + /// This will remove all message contents and mark it as truncated. + pub fn truncate(&mut self) { + self.message.header.flags.set_truncated(true); + *self.context = BuilderContext::default(); + } + + /// Obtain a [`Builder`]. + #[must_use] + pub(super) fn builder(&mut self, start: usize) -> Builder<'_> { + debug_assert!(start <= self.context.size); + unsafe { + let contents = &mut self.message.contents; + let contents = contents as *mut [u8] as *const UnsafeCell<[u8]>; + Builder::from_raw_parts(&*contents, self.context, start) + } + } + + /// Build a question. + /// + /// If a question is already being built, it will be finished first. If + /// an answer, authority, or additional record has been added, [`None`] is + /// returned instead. + pub fn build_question( + &mut self, + question: &Question, + ) -> Result>, TruncationError> { + let state = &mut self.context.state; + if state.section_index() > 0 { + // We've progressed into a later section. + return Ok(None); + } + + if state.mid_component() { + let index = state.section_index() as usize; + self.message.header.counts.as_array_mut()[index] += 1; + } + + *state = MessageState::Questions; + QuestionBuilder::build(self.reborrow(), question).map(Some) + } + + /// Resume building a question. + /// + /// If a question was built (using [`build_question()`]) but the returned + /// builder was neither committed nor canceled, the question builder will + /// be recovered and returned. + /// + /// [`build_question()`]: Self::build_question() + pub fn resume_question(&mut self) -> Option> { + let MessageState::MidQuestion { name } = self.context.state else { + return None; + }; + + // SAFETY: 'self.context.state' is synchronized with the message. + Some(unsafe { + QuestionBuilder::from_raw_parts(self.reborrow(), name) + }) + } + + /// Build an answer record. + /// + /// If a question or answer is already being built, it will be finished + /// first. If an authority or additional record has been added, [`None`] + /// is returned instead. + pub fn build_answer( + &mut self, + record: &Record, + ) -> Result>, TruncationError> { + let state = &mut self.context.state; + if state.section_index() > 1 { + // We've progressed into a later section. + return Ok(None); + } + + if state.mid_component() { + let index = state.section_index() as usize; + self.message.header.counts.as_array_mut()[index] += 1; + } + + *state = MessageState::Answers; + RecordBuilder::build(self.reborrow(), record).map(Some) + } + + /// Resume building an answer record. + /// + /// If an answer record was built (using [`build_answer()`]) but the + /// returned builder was neither committed nor canceled, the record + /// builder will be recovered and returned. + /// + /// [`build_answer()`]: Self::build_answer() + pub fn resume_answer(&mut self) -> Option> { + let MessageState::MidAnswer { name, data } = self.context.state + else { + return None; + }; + + // SAFETY: 'self.context.state' is synchronized with the message. + Some(unsafe { + RecordBuilder::from_raw_parts(self.reborrow(), name, data) + }) + } + + /// Build an authority record. + /// + /// If a question, answer, or authority is already being built, it will be + /// finished first. If an additional record has been added, [`None`] is + /// returned instead. + pub fn build_authority( + &mut self, + record: &Record, + ) -> Result>, TruncationError> { + let state = &mut self.context.state; + if state.section_index() > 2 { + // We've progressed into a later section. + return Ok(None); + } + + if state.mid_component() { + let index = state.section_index() as usize; + self.message.header.counts.as_array_mut()[index] += 1; + } + + *state = MessageState::Authorities; + RecordBuilder::build(self.reborrow(), record).map(Some) + } + + /// Resume building an authority record. + /// + /// If an authority record was built (using [`build_authority()`]) but + /// the returned builder was neither committed nor canceled, the record + /// builder will be recovered and returned. + /// + /// [`build_authority()`]: Self::build_authority() + pub fn resume_authority(&mut self) -> Option> { + let MessageState::MidAuthority { name, data } = self.context.state + else { + return None; + }; + + // SAFETY: 'self.context.state' is synchronized with the message. + Some(unsafe { + RecordBuilder::from_raw_parts(self.reborrow(), name, data) + }) + } + + /// Build an additional record. + /// + /// If a question or record is already being built, it will be finished + /// first. Note that it is always possible to add an additional record to + /// a message. + pub fn build_additional( + &mut self, + record: &Record, + ) -> Result, TruncationError> { + let state = &mut self.context.state; + if state.mid_component() { + let index = state.section_index() as usize; + self.message.header.counts.as_array_mut()[index] += 1; + } + + *state = MessageState::Additionals; + RecordBuilder::build(self.reborrow(), record) + } + + /// Resume building an additional record. + /// + /// If an additional record was built (using [`build_additional()`]) but + /// the returned builder was neither committed nor canceled, the record + /// builder will be recovered and returned. + /// + /// [`build_additional()`]: Self::build_additional() + pub fn resume_additional(&mut self) -> Option> { + let MessageState::MidAdditional { name, data } = self.context.state + else { + return None; + }; + + // SAFETY: 'self.context.state' is synchronized with the message. + Some(unsafe { + RecordBuilder::from_raw_parts(self.reborrow(), name, data) + }) + } +} + +//============ Tests ========================================================= + +#[cfg(test)] +mod test { + use crate::{ + new_base::{ + build::{BuilderContext, MessageState}, + name::RevName, + wire::U16, + QClass, QType, Question, RClass, RType, Record, SectionCounts, + TTL, + }, + new_rdata::A, + }; + + use super::MessageBuilder; + + const WWW_EXAMPLE_ORG: &RevName = unsafe { + RevName::from_bytes_unchecked(b"\x00\x03org\x07example\x03www") + }; + + #[test] + fn new() { + let mut buffer = [0u8; 12]; + let mut context = BuilderContext::default(); + + let mut builder = MessageBuilder::new(&mut buffer, &mut context); + + assert_eq!(&builder.message().contents, &[] as &[u8]); + assert_eq!(unsafe { &builder.message_mut().contents }, &[] as &[u8]); + assert_eq!(builder.context().size, 0); + assert_eq!(builder.context().state, MessageState::Questions); + } + + #[test] + fn build_question() { + let mut buffer = [0u8; 33]; + let mut context = BuilderContext::default(); + let mut builder = MessageBuilder::new(&mut buffer, &mut context); + + let question = Question { + qname: WWW_EXAMPLE_ORG, + qtype: QType::A, + qclass: QClass::IN, + }; + let qb = builder.build_question(&question).unwrap().unwrap(); + + assert_eq!(qb.qname().as_bytes(), b"\x03www\x07example\x03org\x00"); + assert_eq!(qb.qtype(), question.qtype); + assert_eq!(qb.qclass(), question.qclass); + + let state = MessageState::MidQuestion { name: 0 }; + assert_eq!(builder.context().state, state); + assert_eq!(builder.message().header.counts, SectionCounts::default()); + let contents = b"\x03www\x07example\x03org\x00\x00\x01\x00\x01"; + assert_eq!(&builder.message().contents, contents); + } + + #[test] + fn resume_question() { + let mut buffer = [0u8; 33]; + let mut context = BuilderContext::default(); + let mut builder = MessageBuilder::new(&mut buffer, &mut context); + + let question = Question { + qname: WWW_EXAMPLE_ORG, + qtype: QType::A, + qclass: QClass::IN, + }; + let _ = builder.build_question(&question).unwrap().unwrap(); + + let qb = builder.resume_question().unwrap(); + assert_eq!(qb.qname().as_bytes(), b"\x03www\x07example\x03org\x00"); + assert_eq!(qb.qtype(), question.qtype); + assert_eq!(qb.qclass(), question.qclass); + + qb.commit(); + assert_eq!( + builder.message().header.counts, + SectionCounts { + questions: U16::new(1), + ..Default::default() + } + ); + } + + #[test] + fn build_record() { + let mut buffer = [0u8; 43]; + let mut context = BuilderContext::default(); + let mut builder = MessageBuilder::new(&mut buffer, &mut context); + + let record = Record { + rname: WWW_EXAMPLE_ORG, + rtype: RType::A, + rclass: RClass::IN, + ttl: TTL::from(42), + rdata: b"", + }; + + { + let mut rb = builder.build_answer(&record).unwrap().unwrap(); + + assert_eq!( + rb.rname().as_bytes(), + b"\x03www\x07example\x03org\x00" + ); + assert_eq!(rb.rtype(), record.rtype); + assert_eq!(rb.rclass(), record.rclass); + assert_eq!(rb.ttl(), record.ttl); + assert_eq!(rb.rdata(), b""); + + assert!(rb.delegate().append_bytes(&[0u8; 5]).is_err()); + + { + let mut builder = rb.delegate(); + builder + .append_built_bytes(&A { + octets: [127, 0, 0, 1], + }) + .unwrap(); + builder.commit(); + } + assert_eq!(rb.rdata(), b"\x7F\x00\x00\x01"); + } + + let state = MessageState::MidAnswer { name: 0, data: 27 }; + assert_eq!(builder.context().state, state); + assert_eq!(builder.message().header.counts, SectionCounts::default()); + let contents = b"\x03www\x07example\x03org\x00\x00\x01\x00\x01\x00\x00\x00\x2A\x00\x04\x7F\x00\x00\x01"; + assert_eq!(&builder.message().contents, contents.as_slice()); + } + + #[test] + fn resume_record() { + let mut buffer = [0u8; 39]; + let mut context = BuilderContext::default(); + let mut builder = MessageBuilder::new(&mut buffer, &mut context); + + let record = Record { + rname: WWW_EXAMPLE_ORG, + rtype: RType::A, + rclass: RClass::IN, + ttl: TTL::from(42), + rdata: b"", + }; + let _ = builder.build_answer(&record).unwrap().unwrap(); + + let rb = builder.resume_answer().unwrap(); + assert_eq!(rb.rname().as_bytes(), b"\x03www\x07example\x03org\x00"); + assert_eq!(rb.rtype(), record.rtype); + assert_eq!(rb.rclass(), record.rclass); + assert_eq!(rb.ttl(), record.ttl); + assert_eq!(rb.rdata(), b""); + + rb.commit(); + assert_eq!( + builder.message().header.counts, + SectionCounts { + answers: U16::new(1), + ..Default::default() + } + ); + } +} diff --git a/src/new_base/build/mod.rs b/src/new_base/build/mod.rs new file mode 100644 index 000000000..a028e3b3d --- /dev/null +++ b/src/new_base/build/mod.rs @@ -0,0 +1,168 @@ +//! Building DNS messages in the wire format. +//! +//! The [`wire`](super::wire) module provides basic serialization capability, +//! but it is not specialized to DNS messages. This module provides that +//! specialization within an ergonomic interface. +//! +//! The core of the high-level interface is [`MessageBuilder`]. It provides +//! the most intuitive methods for appending whole questions and records. +//! +//! ``` +//! use domain::new_base::{Header, HeaderFlags, Question, QType, QClass}; +//! use domain::new_base::build::{BuilderContext, MessageBuilder, BuildIntoMessage}; +//! use domain::new_base::name::RevName; +//! use domain::new_base::wire::U16; +//! +//! // Initialize a DNS message builder. +//! let mut buffer = [0u8; 512]; +//! let mut context = BuilderContext::default(); +//! let mut builder = MessageBuilder::new(&mut buffer, &mut context); +//! +//! // Initialize the message header. +//! let header = builder.header_mut(); +//! *builder.header_mut() = Header { +//! // Select a randomized ID here. +//! id: U16::new(1234), +//! // A recursive query for authoritative data. +//! flags: *HeaderFlags::default() +//! .query(0) +//! .set_authoritative(true) +//! .request_recursion(true), +//! counts: Default::default(), +//! }; +//! +//! // Add a question for an A record. +//! // TODO: Use a more ergonomic way to make a name. +//! let name = b"\x00\x03org\x07example\x03www"; +//! let name = unsafe { RevName::from_bytes_unchecked(name) }; +//! let question = Question { +//! qname: name, +//! qtype: QType::A, +//! qclass: QClass::IN, +//! }; +//! let _ = builder.build_question(&question).unwrap().unwrap(); +//! +//! // Use the built message. +//! let message = builder.message(); +//! # let _ = message; +//! ``` + +mod builder; +pub use builder::Builder; + +mod context; +pub use context::{BuilderContext, MessageState}; + +mod message; +pub use message::MessageBuilder; + +mod question; +pub use question::QuestionBuilder; + +mod record; +pub use record::RecordBuilder; + +use super::wire::TruncationError; + +//----------- Message-aware building traits ---------------------------------- + +/// Building into a DNS message. +pub trait BuildIntoMessage { + // Append this value to the DNS message. + /// + /// If the builder has enough capacity to fit the message, it is appended + /// and committed. Otherwise, a [`TruncationError`] is returned. + fn build_into_message(&self, builder: Builder<'_>) -> BuildResult; +} + +impl BuildIntoMessage for &T { + fn build_into_message(&self, builder: Builder<'_>) -> BuildResult { + (**self).build_into_message(builder) + } +} + +impl BuildIntoMessage for u8 { + fn build_into_message(&self, mut builder: Builder<'_>) -> BuildResult { + builder.append_bytes(&[*self])?; + Ok(builder.commit()) + } +} + +impl BuildIntoMessage for [T] { + fn build_into_message(&self, mut builder: Builder<'_>) -> BuildResult { + for elem in self { + elem.build_into_message(builder.delegate())?; + } + Ok(builder.commit()) + } +} + +impl BuildIntoMessage for [T; N] { + fn build_into_message(&self, builder: Builder<'_>) -> BuildResult { + self.as_slice().build_into_message(builder) + } +} + +//----------- BuildResult ---------------------------------------------------- + +/// The result of building into a DNS message. +/// +/// This is used in [`BuildIntoMessage::build_into_message()`]. +pub type BuildResult = Result; + +//----------- BuildCommitted ------------------------------------------------- + +/// The output of [`Builder::commit()`]. +/// +/// This is a simple marker type, produced by [`Builder::commit()`]. Certain +/// trait methods (e.g. [`BuildIntoMessage::build_into_message()`]) require it +/// in the return type, as a way to remind users to commit their builders. +/// +/// # Examples +/// +/// If `build_into_message()` simply returned a unit type, an example impl may +/// look like: +/// +/// ```compile_fail +/// # use domain::new_base::name::RevName; +/// # use domain::new_base::build::{BuildIntoMessage, Builder, BuildResult}; +/// # use domain::new_base::wire::AsBytes; +/// +/// struct Foo<'a>(&'a RevName, u8); +/// +/// impl BuildIntoMessage for Foo<'_> { +/// fn build_into_message( +/// &self, +/// mut builder: Builder<'_>, +/// ) -> BuildResult { +/// builder.append_name(self.0)?; +/// builder.append_bytes(self.1.as_bytes()); +/// Ok(()) +/// } +/// } +/// ``` +/// +/// This code is incorrect: since the appended content is not committed, the +/// builder will remove it when it is dropped (at the end of the function), +/// and so nothing gets written. Instead, users have to write: +/// +/// ``` +/// # use domain::new_base::name::RevName; +/// # use domain::new_base::build::{BuildIntoMessage, Builder, BuildResult}; +/// # use domain::new_base::wire::AsBytes; +/// +/// struct Foo<'a>(&'a RevName, u8); +/// +/// impl BuildIntoMessage for Foo<'_> { +/// fn build_into_message( +/// &self, +/// mut builder: Builder<'_>, +/// ) -> BuildResult { +/// builder.append_name(self.0)?; +/// builder.append_bytes(self.1.as_bytes()); +/// Ok(builder.commit()) +/// } +/// } +/// ``` +#[derive(Debug)] +pub struct BuildCommitted; diff --git a/src/new_base/build/question.rs b/src/new_base/build/question.rs new file mode 100644 index 000000000..8f6f04c3c --- /dev/null +++ b/src/new_base/build/question.rs @@ -0,0 +1,113 @@ +//! Building DNS questions. + +use crate::new_base::{ + name::UnparsedName, + parse::ParseMessageBytes, + wire::{ParseBytes, TruncationError}, + QClass, QType, Question, +}; + +use super::{BuildCommitted, BuildIntoMessage, MessageBuilder, MessageState}; + +//----------- QuestionBuilder ------------------------------------------------ + +/// A DNS question builder. +/// +/// A [`QuestionBuilder`] provides control over a DNS question that has been +/// appended to a message (using a [`MessageBuilder`]). It can be used to +/// inspect the question's fields, to replace it with a new question, and to +/// commit (finish building) or cancel (remove) the question. +#[must_use = "A 'QuestionBuilder' must be explicitly committed, else all added content will be lost"] +pub struct QuestionBuilder<'b> { + /// The underlying message builder. + builder: MessageBuilder<'b, 'b>, + + /// The offset of the question name. + name: u16, +} + +//--- Construction + +impl<'b> QuestionBuilder<'b> { + /// Build a [`Question`]. + /// + /// The provided builder must be empty (i.e. must not have uncommitted + /// content). + pub(super) fn build( + mut builder: MessageBuilder<'b, 'b>, + question: &Question, + ) -> Result { + // TODO: Require that the QNAME serialize correctly? + let start = builder.context.size; + question.build_into_message(builder.builder(start))?; + let name = start.try_into().expect("Messages are at most 64KiB"); + builder.context.state = MessageState::MidQuestion { name }; + Ok(Self { builder, name }) + } + + /// Reconstruct a [`QuestionBuilder`] from raw parts. + /// + /// # Safety + /// + /// `builder.message().contents[name..]` must represent a valid + /// [`Question`] in the wire format. + pub unsafe fn from_raw_parts( + builder: MessageBuilder<'b, 'b>, + name: u16, + ) -> Self { + Self { builder, name } + } +} + +//--- Inspection + +impl<'b> QuestionBuilder<'b> { + /// The (unparsed) question name. + pub fn qname(&self) -> &UnparsedName { + let contents = &self.builder.message().contents; + let contents = &contents[usize::from(self.name)..contents.len() - 4]; + <&UnparsedName>::parse_message_bytes(contents, self.name.into()) + .expect("The question was serialized correctly") + } + + /// The question type. + pub fn qtype(&self) -> QType { + let contents = &self.builder.message().contents; + QType::parse_bytes(&contents[contents.len() - 4..contents.len() - 2]) + .expect("The question was serialized correctly") + } + + /// The question class. + pub fn qclass(&self) -> QClass { + let contents = &self.builder.message().contents; + QClass::parse_bytes(&contents[contents.len() - 2..]) + .expect("The question was serialized correctly") + } + + /// Deconstruct this [`QuestionBuilder`] into its raw parts. + pub fn into_raw_parts(self) -> (MessageBuilder<'b, 'b>, u16) { + (self.builder, self.name) + } +} + +//--- Interaction + +impl QuestionBuilder<'_> { + /// Commit this question. + /// + /// The builder will be consumed, and the question will be committed so + /// that it can no longer be removed. + pub fn commit(self) -> BuildCommitted { + self.builder.context.state = MessageState::Questions; + self.builder.message.header.counts.questions += 1; + BuildCommitted + } + + /// Stop building and remove this question. + /// + /// The builder will be consumed, and the question will be removed. + pub fn cancel(self) { + self.builder.context.size = self.name.into(); + self.builder.context.state = MessageState::Questions; + } +} diff --git a/src/new_base/build/record.rs b/src/new_base/build/record.rs new file mode 100644 index 000000000..327f289d8 --- /dev/null +++ b/src/new_base/build/record.rs @@ -0,0 +1,214 @@ +//! Building DNS records. + +use core::{mem::ManuallyDrop, ptr}; + +use crate::new_base::{ + name::UnparsedName, + parse::ParseMessageBytes, + wire::{AsBytes, ParseBytes, SizePrefixed, TruncationError, U16}, + RClass, RType, Record, TTL, +}; + +use super::{ + BuildCommitted, BuildIntoMessage, Builder, MessageBuilder, MessageState, +}; + +//----------- RecordBuilder ------------------------------------------------ + +/// A DNS record builder. +/// +/// A [`RecordBuilder`] provides access to a record that has been appended to +/// a DNS message (using a [`MessageBuilder`]). It can be used to inspect the +/// record, to (re)write the record data, and to commit (finish building) or +/// cancel (remove) the record. +#[must_use = "A 'RecordBuilder' must be explicitly committed, else all added content will be lost"] +pub struct RecordBuilder<'b> { + /// The underlying message builder. + builder: MessageBuilder<'b, 'b>, + + /// The offset of the record name. + name: u16, + + /// The offset of the record data. + data: u16, +} + +//--- Construction + +impl<'b> RecordBuilder<'b> { + /// Build a [`Record`]. + /// + /// The provided builder must be empty (i.e. must not have uncommitted + /// content). + pub(super) fn build( + mut builder: MessageBuilder<'b, 'b>, + record: &Record, + ) -> Result + where + N: BuildIntoMessage, + D: BuildIntoMessage, + { + // Build the record and remember important positions. + let start = builder.context.size; + let (name, data) = { + let name = start.try_into().expect("Messages are at most 64KiB"); + let mut b = builder.builder(start); + record.rname.build_into_message(b.delegate())?; + b.append_bytes(record.rtype.as_bytes())?; + b.append_bytes(record.rclass.as_bytes())?; + b.append_bytes(record.ttl.as_bytes())?; + let size = b.context().size; + SizePrefixed::::new(&record.rdata) + .build_into_message(b.delegate())?; + let data = + (size + 2).try_into().expect("Messages are at most 64KiB"); + b.commit(); + (name, data) + }; + + // Update the message state. + match builder.context.state { + ref mut state @ MessageState::Answers => { + *state = MessageState::MidAnswer { name, data }; + } + + ref mut state @ MessageState::Authorities => { + *state = MessageState::MidAuthority { name, data }; + } + + ref mut state @ MessageState::Additionals => { + *state = MessageState::MidAdditional { name, data }; + } + + _ => unreachable!(), + } + + Ok(Self { + builder, + name, + data, + }) + } + + /// Reconstruct a [`RecordBuilder`] from raw parts. + /// + /// # Safety + /// + /// `builder.message().contents[name..]` must represent a valid + /// [`Record`] in the wire format. `contents[data..]` must represent the + /// record data (i.e. immediately after the record data size field). + pub unsafe fn from_raw_parts( + builder: MessageBuilder<'b, 'b>, + name: u16, + data: u16, + ) -> Self { + Self { + builder, + name, + data, + } + } +} + +//--- Inspection + +impl<'b> RecordBuilder<'b> { + /// The (unparsed) record name. + pub fn rname(&self) -> &UnparsedName { + let contents = &self.builder.message().contents; + let contents = + &contents[usize::from(self.name)..usize::from(self.data) - 10]; + <&UnparsedName>::parse_message_bytes(contents, self.name.into()) + .expect("The record was serialized correctly") + } + + /// The record type. + pub fn rtype(&self) -> RType { + let contents = &self.builder.message().contents; + let contents = &contents[usize::from(self.data) - 10..]; + RType::parse_bytes(&contents[0..2]) + .expect("The record was serialized correctly") + } + + /// The record class. + pub fn rclass(&self) -> RClass { + let contents = &self.builder.message().contents; + let contents = &contents[usize::from(self.data) - 10..]; + RClass::parse_bytes(&contents[2..4]) + .expect("The record was serialized correctly") + } + + /// The TTL. + pub fn ttl(&self) -> TTL { + let contents = &self.builder.message().contents; + let contents = &contents[usize::from(self.data) - 10..]; + TTL::parse_bytes(&contents[4..8]) + .expect("The record was serialized correctly") + } + + /// The record data built thus far. + pub fn rdata(&self) -> &[u8] { + &self.builder.message().contents[usize::from(self.data)..] + } + + /// Deconstruct this [`RecordBuilder`] into its raw parts. + pub fn into_raw_parts(self) -> (MessageBuilder<'b, 'b>, u16, u16) { + let (name, data) = (self.name, self.data); + let this = ManuallyDrop::new(self); + let this = (&*this) as *const Self; + // SAFETY: 'this' is a valid object that can be moved out of. + let builder = unsafe { ptr::read(ptr::addr_of!((*this).builder)) }; + (builder, name, data) + } +} + +//--- Interaction + +impl RecordBuilder<'_> { + /// Commit this record. + /// + /// The builder will be consumed, and the record will be committed so that + /// it can no longer be removed. + pub fn commit(self) -> BuildCommitted { + self.builder + .context + .state + .commit(&mut self.builder.message.header.counts); + + // NOTE: The record data size will be fixed on drop. + BuildCommitted + } + + /// Stop building and remove this record. + /// + /// The builder will be consumed, and the record will be removed. + pub fn cancel(self) { + self.builder.context.size = self.name.into(); + self.builder.context.state.cancel(); + + // NOTE: The drop glue is a no-op. + } + + /// Delegate further building of the record data to a new [`Builder`]. + pub fn delegate(&mut self) -> Builder<'_> { + let offset = self.builder.context.size; + self.builder.builder(offset) + } +} + +//--- Drop + +impl Drop for RecordBuilder<'_> { + fn drop(&mut self) { + // Fixup the record data size so the overall message builder is valid. + let size = self.builder.context.size as u16; + if self.data <= size { + // SAFETY: Only the record data size field is being modified. + let message = unsafe { self.builder.message_mut() }; + let data = usize::from(self.data); + let size = size - self.data; + message.contents[data - 2..data] + .copy_from_slice(&size.to_be_bytes()); + } + } +} diff --git a/src/new_base/charstr.rs b/src/new_base/charstr.rs new file mode 100644 index 000000000..6e9c97bc6 --- /dev/null +++ b/src/new_base/charstr.rs @@ -0,0 +1,194 @@ +//! DNS "character strings". + +use core::fmt; + +use domain_macros::UnsizedClone; + +use super::{ + build::{self, BuildIntoMessage, BuildResult}, + parse::{ParseMessageBytes, SplitMessageBytes}, + wire::{BuildBytes, ParseBytes, ParseError, SplitBytes, TruncationError}, +}; + +//----------- CharStr -------------------------------------------------------- + +/// A DNS "character string". +#[derive(UnsizedClone)] +#[repr(transparent)] +pub struct CharStr { + /// The underlying octets. + /// + /// This is at most 255 bytes. It does not include the length octet that + /// precedes the character string when serialized in the wire format. + pub octets: [u8], +} + +//--- Inspection + +impl CharStr { + /// The length of the [`CharStr`]. + /// + /// This is always less than 256 -- it is guaranteed to fit in a [`u8`]. + pub const fn len(&self) -> usize { + self.octets.len() + } + + /// Whether the [`CharStr`] is empty. + pub const fn is_empty(&self) -> bool { + self.octets.is_empty() + } +} + +//--- Parsing from DNS messages + +impl<'a> SplitMessageBytes<'a> for &'a CharStr { + fn split_message_bytes( + contents: &'a [u8], + start: usize, + ) -> Result<(Self, usize), ParseError> { + Self::split_bytes(&contents[start..]) + .map(|(this, rest)| (this, contents.len() - start - rest.len())) + } +} + +impl<'a> ParseMessageBytes<'a> for &'a CharStr { + fn parse_message_bytes( + contents: &'a [u8], + start: usize, + ) -> Result { + Self::parse_bytes(&contents[start..]) + } +} + +//--- Building into DNS messages + +impl BuildIntoMessage for CharStr { + fn build_into_message( + &self, + mut builder: build::Builder<'_>, + ) -> BuildResult { + builder.append_bytes(&[self.octets.len() as u8])?; + builder.append_bytes(&self.octets)?; + Ok(builder.commit()) + } +} + +//--- Parsing from bytes + +impl<'a> SplitBytes<'a> for &'a CharStr { + fn split_bytes(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError> { + let (&length, rest) = bytes.split_first().ok_or(ParseError)?; + if length as usize > rest.len() { + return Err(ParseError); + } + let (bytes, rest) = rest.split_at(length as usize); + + // SAFETY: 'CharStr' is 'repr(transparent)' to '[u8]'. + Ok((unsafe { core::mem::transmute::<&[u8], Self>(bytes) }, rest)) + } +} + +impl<'a> ParseBytes<'a> for &'a CharStr { + fn parse_bytes(bytes: &'a [u8]) -> Result { + let (&length, rest) = bytes.split_first().ok_or(ParseError)?; + if length as usize != rest.len() { + return Err(ParseError); + } + + // SAFETY: 'CharStr' is 'repr(transparent)' to '[u8]'. + Ok(unsafe { core::mem::transmute::<&[u8], Self>(rest) }) + } +} + +//--- Building into byte strings + +impl BuildBytes for CharStr { + fn build_bytes<'b>( + &self, + bytes: &'b mut [u8], + ) -> Result<&'b mut [u8], TruncationError> { + let (length, bytes) = + bytes.split_first_mut().ok_or(TruncationError)?; + *length = self.octets.len() as u8; + self.octets.build_bytes(bytes) + } +} + +//--- Equality + +impl PartialEq for CharStr { + fn eq(&self, other: &Self) -> bool { + self.octets.eq_ignore_ascii_case(&other.octets) + } +} + +impl Eq for CharStr {} + +//--- Formatting + +impl fmt::Debug for CharStr { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + use fmt::Write; + + struct Native<'a>(&'a [u8]); + impl fmt::Debug for Native<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("b\"")?; + for &b in self.0 { + f.write_str(match b { + b'"' => "\\\"", + b' ' => " ", + b'\n' => "\\n", + b'\r' => "\\r", + b'\t' => "\\t", + b'\\' => "\\\\", + + _ => { + if b.is_ascii_graphic() { + f.write_char(b as char)?; + } else { + write!(f, "\\x{:02X}", b)?; + } + continue; + } + })?; + } + f.write_char('"')?; + Ok(()) + } + } + + f.debug_struct("CharStr") + .field("content", &Native(&self.octets)) + .finish() + } +} + +//============ Tests ========================================================= + +#[cfg(test)] +mod test { + use super::CharStr; + + use crate::new_base::wire::{ + BuildBytes, ParseBytes, ParseError, SplitBytes, + }; + + #[test] + fn parse_build() { + let bytes = b"\x05Hello!"; + let (charstr, rest) = <&CharStr>::split_bytes(bytes).unwrap(); + assert_eq!(&charstr.octets, b"Hello"); + assert_eq!(rest, b"!"); + + assert_eq!(<&CharStr>::parse_bytes(bytes), Err(ParseError)); + assert!(<&CharStr>::parse_bytes(&bytes[..6]).is_ok()); + + let mut buffer = [0u8; 6]; + assert_eq!( + charstr.build_bytes(&mut buffer), + Ok(&mut [] as &mut [u8]) + ); + assert_eq!(buffer, &bytes[..6]); + } +} diff --git a/src/new_base/message.rs b/src/new_base/message.rs new file mode 100644 index 000000000..358a75e2a --- /dev/null +++ b/src/new_base/message.rs @@ -0,0 +1,365 @@ +//! DNS message headers. + +use core::fmt; + +use domain_macros::*; + +use super::wire::{AsBytes, ParseBytesByRef, U16}; + +//----------- Message -------------------------------------------------------- + +/// A DNS message. +#[derive(AsBytes, BuildBytes, ParseBytesByRef, UnsizedClone)] +#[repr(C, packed)] +pub struct Message { + /// The message header. + pub header: Header, + + /// The message contents. + pub contents: [u8], +} + +//--- Inspection + +impl Message { + /// Represent this as a mutable byte sequence. + /// + /// Given `&mut self`, it is already possible to individually modify the + /// message header and contents; since neither has invalid instances, it + /// is safe to represent the entire object as mutable bytes. + pub fn as_bytes_mut(&mut self) -> &mut [u8] { + // SAFETY: + // - 'Self' has no padding bytes and no interior mutability. + // - Its size in memory is exactly 'size_of_val(self)'. + unsafe { + core::slice::from_raw_parts_mut( + self as *mut Self as *mut u8, + core::mem::size_of_val(self), + ) + } + } +} + +//--- Interaction + +impl Message { + /// Truncate the contents of this message to the given size. + /// + /// The returned value will have a `contents` field of the given size. + pub fn slice_to(&self, size: usize) -> &Self { + let bytes = &self.as_bytes()[..12 + size]; + Self::parse_bytes_by_ref(bytes) + .expect("A 12-or-more byte string is a valid 'Message'") + } + + /// Truncate the contents of this message to the given size, mutably. + /// + /// The returned value will have a `contents` field of the given size. + pub fn slice_to_mut(&mut self, size: usize) -> &mut Self { + let bytes = &mut self.as_bytes_mut()[..12 + size]; + Self::parse_bytes_by_mut(bytes) + .expect("A 12-or-more byte string is a valid 'Message'") + } + + /// Truncate the contents of this message to the given size, by pointer. + /// + /// The returned value will have a `contents` field of the given size. + /// + /// # Safety + /// + /// This method uses `pointer::offset()`: `self` must be "derived from a + /// pointer to some allocated object". There must be at least 12 bytes + /// between `self` and the end of that allocated object. A reference to + /// `Message` will always result in a pointer satisfying this. + pub unsafe fn ptr_slice_to(this: *mut Message, size: usize) -> *mut Self { + let bytes = unsafe { core::ptr::addr_of_mut!((*this).contents) }; + let len = unsafe { &*(bytes as *mut [()]) }.len(); + debug_assert!(size <= len); + core::ptr::slice_from_raw_parts_mut(this.cast::(), size) + as *mut Self + } +} + +//----------- Header --------------------------------------------------------- + +/// A DNS message header. +#[derive( + Copy, + Clone, + Debug, + Hash, + AsBytes, + BuildBytes, + ParseBytes, + ParseBytesByRef, + SplitBytes, + SplitBytesByRef, +)] +#[repr(C)] +pub struct Header { + /// A unique identifier for the message. + pub id: U16, + + /// Properties of the message. + pub flags: HeaderFlags, + + /// Counts of objects in the message. + pub counts: SectionCounts, +} + +//--- Formatting + +impl fmt::Display for Header { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "{} of ID {:04X} ({})", + self.flags, + self.id.get(), + self.counts + ) + } +} + +//----------- HeaderFlags ---------------------------------------------------- + +/// DNS message header flags. +#[derive( + Copy, + Clone, + Default, + Hash, + AsBytes, + BuildBytes, + ParseBytes, + ParseBytesByRef, + SplitBytes, + SplitBytesByRef, +)] +#[repr(transparent)] +pub struct HeaderFlags { + inner: U16, +} + +//--- Interaction + +impl HeaderFlags { + /// Get the specified flag bit. + fn get_flag(&self, pos: u32) -> bool { + self.inner.get() & (1 << pos) != 0 + } + + /// Set the specified flag bit. + fn set_flag(&mut self, pos: u32, value: bool) -> &mut Self { + self.inner &= !(1 << pos); + self.inner |= (value as u16) << pos; + self + } + + /// The raw flags bits. + pub fn bits(&self) -> u16 { + self.inner.get() + } + + /// Whether this is a query. + pub fn is_query(&self) -> bool { + !self.get_flag(15) + } + + /// Whether this is a response. + pub fn is_response(&self) -> bool { + self.get_flag(15) + } + + /// The operation code. + pub fn opcode(&self) -> u8 { + (self.inner.get() >> 11) as u8 & 0xF + } + + /// The response code. + pub fn rcode(&self) -> u8 { + self.inner.get() as u8 & 0xF + } + + /// Construct a query. + pub fn query(&mut self, opcode: u8) -> &mut Self { + assert!(opcode < 16); + self.inner &= !(0xF << 11); + self.inner |= (opcode as u16) << 11; + self.set_flag(15, false) + } + + /// Construct a response. + pub fn respond(&mut self, rcode: u8) -> &mut Self { + assert!(rcode < 16); + self.inner &= !0xF; + self.inner |= rcode as u16; + self.set_flag(15, true) + } + + /// Whether this is an authoritative answer. + pub fn is_authoritative(&self) -> bool { + self.get_flag(10) + } + + /// Mark this as an authoritative answer. + pub fn set_authoritative(&mut self, value: bool) -> &mut Self { + self.set_flag(10, value) + } + + /// Whether this message is truncated. + pub fn is_truncated(&self) -> bool { + self.get_flag(9) + } + + /// Mark this message as truncated. + pub fn set_truncated(&mut self, value: bool) -> &mut Self { + self.set_flag(9, value) + } + + /// Whether the server should query recursively. + pub fn should_recurse(&self) -> bool { + self.get_flag(8) + } + + /// Direct the server to query recursively. + pub fn request_recursion(&mut self, value: bool) -> &mut Self { + self.set_flag(8, value) + } + + /// Whether the server supports recursion. + pub fn can_recurse(&self) -> bool { + self.get_flag(7) + } + + /// Indicate support for recursive queries. + pub fn support_recursion(&mut self, value: bool) -> &mut Self { + self.set_flag(7, value) + } +} + +//--- Formatting + +impl fmt::Debug for HeaderFlags { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("HeaderFlags") + .field("is_response (qr)", &self.is_response()) + .field("opcode", &self.opcode()) + .field("is_authoritative (aa)", &self.is_authoritative()) + .field("is_truncated (tc)", &self.is_truncated()) + .field("should_recurse (rd)", &self.should_recurse()) + .field("can_recurse (ra)", &self.can_recurse()) + .field("rcode", &self.rcode()) + .field("bits", &self.bits()) + .finish() + } +} + +impl fmt::Display for HeaderFlags { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + if self.is_query() { + if self.should_recurse() { + f.write_str("recursive ")?; + } + write!(f, "query (opcode {})", self.opcode())?; + } else { + if self.is_authoritative() { + f.write_str("authoritative ")?; + } + if self.should_recurse() && self.can_recurse() { + f.write_str("recursive ")?; + } + write!(f, "response (rcode {})", self.rcode())?; + } + + if self.is_truncated() { + f.write_str(" (message truncated)")?; + } + + Ok(()) + } +} + +//----------- SectionCounts -------------------------------------------------- + +/// Counts of objects in a DNS message. +#[derive( + Copy, + Clone, + Debug, + Default, + PartialEq, + Eq, + Hash, + AsBytes, + BuildBytes, + ParseBytes, + ParseBytesByRef, + SplitBytes, + SplitBytesByRef, +)] +#[repr(C)] +pub struct SectionCounts { + /// The number of questions in the message. + pub questions: U16, + + /// The number of answer records in the message. + pub answers: U16, + + /// The number of name server records in the message. + pub authorities: U16, + + /// The number of additional records in the message. + pub additional: U16, +} + +//--- Interaction + +impl SectionCounts { + /// Represent these counts as an array. + pub fn as_array(&self) -> &[U16; 4] { + // SAFETY: 'SectionCounts' has the same layout as '[U16; 4]'. + unsafe { core::mem::transmute(self) } + } + + /// Represent these counts as a mutable array. + pub fn as_array_mut(&mut self) -> &mut [U16; 4] { + // SAFETY: 'SectionCounts' has the same layout as '[U16; 4]'. + unsafe { core::mem::transmute(self) } + } +} + +//--- Formatting + +impl fmt::Display for SectionCounts { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut some = false; + + for (num, single, many) in [ + (self.questions.get(), "question", "questions"), + (self.answers.get(), "answer", "answers"), + (self.authorities.get(), "authority", "authorities"), + (self.additional.get(), "additional", "additional"), + ] { + // Add a comma if we have printed something before. + if some && num > 0 { + f.write_str(", ")?; + } + + // Print a count of this section. + match num { + 0 => {} + 1 => write!(f, "1 {single}")?, + n => write!(f, "{n} {many}")?, + } + + some |= num > 0; + } + + if !some { + f.write_str("empty")?; + } + + Ok(()) + } +} diff --git a/src/new_base/mod.rs b/src/new_base/mod.rs new file mode 100644 index 000000000..3e22e27fc --- /dev/null +++ b/src/new_base/mod.rs @@ -0,0 +1,186 @@ +//! Basic DNS. +//! +//! This module provides the essential types and functionality for working +//! with DNS. Most importantly, it provides functionality for parsing and +//! building DNS messages on the wire. + +//--- DNS messages + +mod message; +pub use message::{Header, HeaderFlags, Message, SectionCounts}; + +mod question; +pub use question::{QClass, QType, Question, UnparsedQuestion}; + +mod record; +pub use record::{ + CanonicalRecordData, ParseRecordData, RClass, RType, Record, + UnparsedRecord, UnparsedRecordData, TTL, +}; + +//--- Elements of DNS messages + +pub mod name; + +mod charstr; +pub use charstr::CharStr; + +mod serial; +pub use serial::Serial; + +//--- Wire format + +pub mod build; +pub mod parse; +pub mod wire; + +//--- Compatibility exports + +/// A compatibility module with [`domain::base`]. +/// +/// This re-exports a large part of the `new_base` API surface using the same +/// import paths as the old `base` module. It is a stopgap measure to help +/// users port existing code over to `new_base`. Every export comes with a +/// deprecation message to help users switch to the right tools. +pub mod compat { + #![allow(deprecated)] + + #[deprecated = "use 'crate::new_base::HeaderFlags' instead."] + pub use header::Flags; + + #[deprecated = "use 'crate::new_base::Header' instead."] + pub use header::HeaderSection; + + #[deprecated = "use 'crate::new_base::SectionCounts' instead."] + pub use header::HeaderCounts; + + #[deprecated = "use 'crate::new_base::RType' instead."] + pub use iana::rtype::Rtype; + + #[deprecated = "use 'crate::new_base::name::Label' instead."] + pub use name::Label; + + #[deprecated = "use 'crate::new_base::name::Name' instead."] + pub use name::Name; + + #[deprecated = "use 'crate::new_base::Question' instead."] + pub use question::Question; + + #[deprecated = "use 'crate::new_base::ParseRecordData' instead."] + pub use rdata::ParseRecordData; + + #[deprecated = "use 'crate::new_rdata::UnknownRecordData' instead."] + pub use rdata::UnknownRecordData; + + #[deprecated = "use 'crate::new_base::Record' instead."] + pub use record::Record; + + #[deprecated = "use 'crate::new_base::TTL' instead."] + pub use record::Ttl; + + #[deprecated = "use 'crate::new_base::Serial' instead."] + pub use serial::Serial; + + pub mod header { + #[deprecated = "use 'crate::new_base::HeaderFlags' instead."] + pub use crate::new_base::HeaderFlags as Flags; + + #[deprecated = "use 'crate::new_base::Header' instead."] + pub use crate::new_base::Header as HeaderSection; + + #[deprecated = "use 'crate::new_base::SectionCounts' instead."] + pub use crate::new_base::SectionCounts as HeaderCounts; + } + + pub mod iana { + #[deprecated = "use 'crate::new_base::RClass' instead."] + pub use class::Class; + + #[deprecated = "use 'crate::new_rdata::DigestType' instead."] + pub use digestalg::DigestAlg; + + #[deprecated = "use 'crate::new_rdata::NSec3HashAlg' instead."] + pub use nsec3::Nsec3HashAlg; + + #[deprecated = "use 'crate::new_edns::OptionCode' instead."] + pub use opt::OptionCode; + + #[deprecated = "for now, just use 'u8', but a better API is coming."] + pub use rcode::Rcode; + + #[deprecated = "use 'crate::new_base::RType' instead."] + pub use rtype::Rtype; + + #[deprecated = "use 'crate::new_rdata::SecAlg' instead."] + pub use secalg::SecAlg; + + pub mod class { + #[deprecated = "use 'crate::new_base::RClass' instead."] + pub use crate::new_base::RClass as Class; + } + + pub mod digestalg { + #[deprecated = "use 'crate::new_rdata::DigestType' instead."] + pub use crate::new_rdata::DigestType as DigestAlg; + } + + pub mod nsec3 { + #[deprecated = "use 'crate::new_rdata::NSec3HashAlg' instead."] + pub use crate::new_rdata::NSec3HashAlg as Nsec3HashAlg; + } + + pub mod opt { + #[deprecated = "use 'crate::new_edns::OptionCode' instead."] + pub use crate::new_edns::OptionCode; + } + + pub mod rcode { + #[deprecated = "for now, just use 'u8', but a better API is coming."] + pub use u8 as Rcode; + } + + pub mod rtype { + #[deprecated = "use 'crate::new_base::RType' instead."] + pub use crate::new_base::RType as Rtype; + } + + pub mod secalg { + #[deprecated = "use 'crate::new_rdata::SecAlg' instead."] + pub use crate::new_rdata::SecAlg; + } + } + + pub mod name { + #[deprecated = "use 'crate::new_base::name::Label' instead."] + pub use crate::new_base::name::Label; + + #[deprecated = "use 'crate::new_base::name::Name' instead."] + pub use crate::new_base::name::Name; + } + + pub mod question { + #[deprecated = "use 'crate::new_base::Question' instead."] + pub use crate::new_base::Question; + } + + pub mod rdata { + #[deprecated = "use 'crate::new_base::ParseRecordData' instead."] + pub use crate::new_base::ParseRecordData; + + #[deprecated = "use 'crate::new_rdata::UnknownRecordData' instead."] + pub use crate::new_rdata::UnknownRecordData; + } + + pub mod record { + #[deprecated = "use 'crate::new_base::Record' instead."] + pub use crate::new_base::Record; + + #[deprecated = "use 'crate::new_base::TTL' instead."] + pub use crate::new_base::TTL as Ttl; + } + + pub mod serial { + #[deprecated = "use 'crate::new_base::Serial' instead."] + pub use crate::new_base::Serial; + } +} diff --git a/src/new_base/name/absolute.rs b/src/new_base/name/absolute.rs new file mode 100644 index 000000000..0b14263c5 --- /dev/null +++ b/src/new_base/name/absolute.rs @@ -0,0 +1,588 @@ +//! Absolute domain names. + +use core::{ + borrow::{Borrow, BorrowMut}, + cmp::Ordering, + fmt, + hash::{Hash, Hasher}, + ops::{Deref, DerefMut}, +}; + +use domain_macros::*; + +use crate::{ + new_base::{ + parse::{ParseMessageBytes, SplitMessageBytes}, + wire::{ + BuildBytes, ParseBytes, ParseError, SplitBytes, TruncationError, + }, + }, + utils::CloneFrom, +}; + +use super::{CanonicalName, LabelIter}; + +//----------- Name ----------------------------------------------------------- + +/// An absolute domain name. +#[derive(AsBytes, BuildBytes, UnsizedClone)] +#[repr(transparent)] +pub struct Name([u8]); + +//--- Constants + +impl Name { + /// The maximum size of a domain name. + pub const MAX_SIZE: usize = 255; + + /// The root name. + pub const ROOT: &'static Self = { + // SAFETY: A root label is the shortest valid name. + unsafe { Self::from_bytes_unchecked(&[0u8]) } + }; +} + +//--- Construction + +impl Name { + /// Assume a byte string is a valid [`Name`]. + /// + /// # Safety + /// + /// The byte string must represent a valid uncompressed domain name in the + /// conventional wire format (a sequence of labels terminating with a root + /// label, totalling 255 bytes or less). + pub const unsafe fn from_bytes_unchecked(bytes: &[u8]) -> &Self { + // SAFETY: 'Name' is 'repr(transparent)' to '[u8]', so casting a + // '[u8]' into a 'Name' is sound. + core::mem::transmute(bytes) + } + + /// Assume a mutable byte string is a valid [`Name`]. + /// + /// # Safety + /// + /// The byte string must represent a valid uncompressed domain name in the + /// conventional wire format (a sequence of labels terminating with a root + /// label, totalling 255 bytes or less). + pub unsafe fn from_bytes_unchecked_mut(bytes: &mut [u8]) -> &mut Self { + // SAFETY: 'Name' is 'repr(transparent)' to '[u8]', so casting a + // '[u8]' into a 'Name' is sound. + core::mem::transmute(bytes) + } +} + +//--- Inspection + +impl Name { + /// The size of this name in the wire format. + #[allow(clippy::len_without_is_empty)] + pub const fn len(&self) -> usize { + self.0.len() + } + + /// Whether this is the root label. + pub const fn is_root(&self) -> bool { + self.0.len() == 1 + } + + /// A byte representation of the [`Name`]. + pub const fn as_bytes(&self) -> &[u8] { + &self.0 + } + + /// The labels in the [`Name`]. + /// + /// Note that labels appear in reverse order to the _conventional_ format + /// (it thus starts with the root label). + pub const fn labels(&self) -> LabelIter<'_> { + // SAFETY: A 'Name' always contains valid encoded labels. + unsafe { LabelIter::new_unchecked(self.as_bytes()) } + } +} + +//--- Canonical operations + +impl CanonicalName for Name { + fn cmp_composed(&self, other: &Self) -> Ordering { + self.as_bytes().cmp(other.as_bytes()) + } + + fn cmp_lowercase_composed(&self, other: &Self) -> Ordering { + self.as_bytes() + .iter() + .map(u8::to_ascii_lowercase) + .cmp(other.as_bytes().iter().map(u8::to_ascii_lowercase)) + } +} + +//--- Parsing from bytes + +impl<'a> ParseBytes<'a> for &'a Name { + fn parse_bytes(bytes: &'a [u8]) -> Result { + match Self::split_bytes(bytes) { + Ok((this, &[])) => Ok(this), + _ => Err(ParseError), + } + } +} + +impl<'a> SplitBytes<'a> for &'a Name { + fn split_bytes(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError> { + let mut offset = 0usize; + while offset < 255 { + match *bytes.get(offset..).ok_or(ParseError)? { + [0, ..] => { + // Found the root, stop. + let (name, rest) = bytes.split_at(offset + 1); + + // SAFETY: 'name' follows the wire format and is 255 bytes + // or shorter. + let name = unsafe { Name::from_bytes_unchecked(name) }; + return Ok((name, rest)); + } + + [l, ..] if l < 64 => { + // This looks like a regular label. + + if bytes.len() < offset + 1 + l as usize { + // The input doesn't contain the whole label. + return Err(ParseError); + } + + offset += 1 + l as usize; + } + + _ => return Err(ParseError), + } + } + + Err(ParseError) + } +} + +//--- Equality + +impl PartialEq for Name { + fn eq(&self, that: &Self) -> bool { + // Instead of iterating labels, blindly iterate bytes. The locations + // of labels don't matter since we're testing everything for equality. + + // NOTE: Label lengths (which are less than 64) aren't affected by + // 'to_ascii_lowercase', so this method can be applied uniformly. + let this = self.as_bytes().iter().map(u8::to_ascii_lowercase); + let that = that.as_bytes().iter().map(u8::to_ascii_lowercase); + + this.eq(that) + } +} + +impl Eq for Name {} + +//--- Comparison + +impl PartialOrd for Name { + fn partial_cmp(&self, that: &Self) -> Option { + Some(self.cmp(that)) + } +} + +impl Ord for Name { + fn cmp(&self, that: &Self) -> Ordering { + // We wish to compare the labels in these names in reverse order. + // Unfortunately, labels in absolute names cannot be traversed + // backwards efficiently. We need to try harder. + // + // Consider two names that are not equal. This means that one name is + // a strict suffix of the other, or that the two had different labels + // at some position. Following this mismatched label is a suffix of + // labels that both names do agree on. + // + // We traverse the bytes in the names in reverse order and find the + // length of their shared suffix. The actual shared suffix, in units + // of labels, may be shorter than this (because the last bytes of the + // mismatched labels could be the same). + // + // Then, we traverse the labels of both names in forward order, until + // we hit the shared suffix territory. We try to match up the names + // in order to discover the real shared suffix. Once the suffix is + // found, the immediately preceding label (if there is one) contains + // the inequality, and can be compared as usual. + + let suffix_len = core::iter::zip( + self.as_bytes().iter().rev().map(u8::to_ascii_lowercase), + that.as_bytes().iter().rev().map(u8::to_ascii_lowercase), + ) + .position(|(a, b)| a != b); + + let Some(suffix_len) = suffix_len else { + // 'iter::zip()' simply ignores unequal iterators, stopping when + // either iterator finishes. Even though the two names had no + // mismatching bytes, one could be longer than the other. + return self.len().cmp(&that.len()); + }; + + // Prepare for forward traversal. + let (mut lhs, mut rhs) = (self.labels(), that.labels()); + // SAFETY: There is at least one unequal byte, and it cannot be the + // root label, so both names have at least one additional label. + let mut prev = unsafe { + (lhs.next().unwrap_unchecked(), rhs.next().unwrap_unchecked()) + }; + + // Traverse both names in lockstep, trying to match their lengths. + loop { + let (llen, rlen) = (lhs.remaining().len(), rhs.remaining().len()); + if llen == rlen && llen <= suffix_len { + // We're in shared suffix territory, and 'lhs' and 'rhs' have + // the same length. Thus, they must be identical, and we have + // found the shared suffix. + break prev.0.cmp(prev.1); + } else if llen > rlen { + // Try to match the lengths by shortening 'lhs'. + + // SAFETY: 'llen > rlen >= 1', thus 'lhs' contains at least + // one additional label before the root. + prev.0 = unsafe { lhs.next().unwrap_unchecked() }; + } else { + // Try to match the lengths by shortening 'rhs'. + + // SAFETY: Either: + // - '1 <= llen < rlen', thus 'rhs' contains at least one + // additional label before the root. + // - 'llen == rlen > suffix_len >= 1', thus 'rhs' contains at + // least one additional label before the root. + prev.1 = unsafe { rhs.next().unwrap_unchecked() }; + } + } + } +} + +//--- Hashing + +impl Hash for Name { + fn hash(&self, state: &mut H) { + for byte in self.as_bytes() { + // NOTE: Label lengths (which are less than 64) aren't affected by + // 'to_ascii_lowercase', so this method can be applied uniformly. + state.write_u8(byte.to_ascii_lowercase()) + } + } +} + +//--- Formatting + +impl fmt::Display for Name { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut first = true; + self.labels().try_for_each(|label| { + if !first { + f.write_str(".")?; + } else { + first = false; + } + + label.fmt(f) + }) + } +} + +impl fmt::Debug for Name { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "Name({})", self) + } +} + +//----------- NameBuf -------------------------------------------------------- + +/// A 256-byte buffer containing a [`Name`]. +#[derive(Clone)] +#[repr(C)] // make layout compatible with '[u8; 256]' +pub struct NameBuf { + /// The size of the encoded name. + size: u8, + + /// The buffer containing the [`Name`]. + buffer: [u8; 255], +} + +//--- Construction + +impl NameBuf { + /// Construct an empty, invalid buffer. + const fn empty() -> Self { + Self { + size: 0, + buffer: [0; 255], + } + } + + /// Copy a [`Name`] into a buffer. + pub fn copy_from(name: &Name) -> Self { + let mut buffer = [0u8; 255]; + buffer[..name.len()].copy_from_slice(name.as_bytes()); + Self { + size: name.len() as u8, + buffer, + } + } +} + +impl CloneFrom for NameBuf { + fn clone_from(value: &Self::Target) -> Self { + Self::copy_from(value) + } +} + +//--- Parsing from DNS messages + +impl<'a> SplitMessageBytes<'a> for NameBuf { + fn split_message_bytes( + contents: &'a [u8], + start: usize, + ) -> Result<(Self, usize), ParseError> { + // NOTE: The input may be controlled by an attacker. Compression + // pointers can be arranged to cause loops or to access every byte in + // the message in random order. Instead of performing complex loop + // detection, which would probably perform allocations, we simply + // disallow a name to point to data _after_ it. Standard name + // compressors will never generate such pointers. + + let mut buffer = Self::empty(); + + // Perform the first iteration early, to catch the end of the name. + let bytes = contents.get(start..).ok_or(ParseError)?; + let (mut pointer, rest) = parse_segment(bytes, &mut buffer)?; + let orig_end = contents.len() - rest.len(); + + // Traverse compression pointers. + let mut old_start = start; + while let Some(start) = pointer.map(usize::from) { + // Ensure the referenced position comes earlier. + if start >= old_start { + return Err(ParseError); + } + + // Keep going, from the referenced position. + let start = start.checked_sub(12).ok_or(ParseError)?; + let bytes = contents.get(start..).ok_or(ParseError)?; + (pointer, _) = parse_segment(bytes, &mut buffer)?; + old_start = start; + continue; + } + + // Stop and return the original end. + // NOTE: 'buffer' is now well-formed because we only stop when we + // reach a root label (which has been appended into it). + Ok((buffer, orig_end)) + } +} + +impl<'a> ParseMessageBytes<'a> for NameBuf { + fn parse_message_bytes( + contents: &'a [u8], + start: usize, + ) -> Result { + // See 'split_from_message()' for details. The only differences are + // in the range of the first iteration, and the check that the first + // iteration exactly covers the input range. + + let mut buffer = Self::empty(); + + // Perform the first iteration early, to catch the end of the name. + let bytes = contents.get(start..).ok_or(ParseError)?; + let (mut pointer, rest) = parse_segment(bytes, &mut buffer)?; + + if !rest.is_empty() { + // The name didn't reach the end of the input range, fail. + return Err(ParseError); + } + + // Traverse compression pointers. + let mut old_start = start; + while let Some(start) = pointer.map(usize::from) { + // Ensure the referenced position comes earlier. + if start >= old_start { + return Err(ParseError); + } + + // Keep going, from the referenced position. + let start = start.checked_sub(12).ok_or(ParseError)?; + let bytes = contents.get(start..).ok_or(ParseError)?; + (pointer, _) = parse_segment(bytes, &mut buffer)?; + old_start = start; + continue; + } + + // NOTE: 'buffer' is now well-formed because we only stop when we + // reach a root label (which has been appended into it). + Ok(buffer) + } +} + +/// Parse an encoded and potentially-compressed domain name, without +/// following any compression pointer. +fn parse_segment<'a>( + mut bytes: &'a [u8], + buffer: &mut NameBuf, +) -> Result<(Option, &'a [u8]), ParseError> { + loop { + match *bytes { + [0, ref rest @ ..] => { + // Found the root, stop. + buffer.append_bytes(&[0u8]); + return Ok((None, rest)); + } + + [l, ..] if l < 64 => { + // This looks like a regular label. + + if bytes.len() < 1 + l as usize { + // The input doesn't contain the whole label. + return Err(ParseError); + } else if 255 - buffer.size < 2 + l { + // The output name would exceed 254 bytes (this isn't + // the root label, so it can't fill the 255th byte). + return Err(ParseError); + } + + let (label, rest) = bytes.split_at(1 + l as usize); + buffer.append_bytes(label); + bytes = rest; + } + + [hi, lo, ref rest @ ..] if hi >= 0xC0 => { + let pointer = u16::from_be_bytes([hi, lo]); + + // NOTE: We don't verify the pointer here, that's left to + // the caller (since they have to actually use it). + return Ok((Some(pointer & 0x3FFF), rest)); + } + + _ => return Err(ParseError), + } + } +} + +//--- Parsing from bytes + +impl<'a> SplitBytes<'a> for NameBuf { + fn split_bytes(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError> { + <&Name>::split_bytes(bytes) + .map(|(name, rest)| (NameBuf::copy_from(name), rest)) + } +} + +impl<'a> ParseBytes<'a> for NameBuf { + fn parse_bytes(bytes: &'a [u8]) -> Result { + <&Name>::parse_bytes(bytes).map(NameBuf::copy_from) + } +} + +//--- Building into byte strings + +impl BuildBytes for NameBuf { + fn build_bytes<'b>( + &self, + bytes: &'b mut [u8], + ) -> Result<&'b mut [u8], TruncationError> { + (**self).build_bytes(bytes) + } +} + +//--- Interaction + +impl NameBuf { + /// Append bytes to this buffer. + /// + /// This is an internal convenience function used while building buffers. + fn append_bytes(&mut self, bytes: &[u8]) { + self.buffer[self.size as usize..][..bytes.len()] + .copy_from_slice(bytes); + self.size += bytes.len() as u8; + } +} + +//--- Access to the underlying 'Name' + +impl Deref for NameBuf { + type Target = Name; + + fn deref(&self) -> &Self::Target { + let name = &self.buffer[..self.size as usize]; + // SAFETY: A 'NameBuf' always contains a valid 'Name'. + unsafe { Name::from_bytes_unchecked(name) } + } +} + +impl DerefMut for NameBuf { + fn deref_mut(&mut self) -> &mut Self::Target { + let name = &mut self.buffer[..self.size as usize]; + // SAFETY: A 'NameBuf' always contains a valid 'Name'. + unsafe { Name::from_bytes_unchecked_mut(name) } + } +} + +impl Borrow for NameBuf { + fn borrow(&self) -> &Name { + self + } +} + +impl BorrowMut for NameBuf { + fn borrow_mut(&mut self) -> &mut Name { + self + } +} + +impl AsRef for NameBuf { + fn as_ref(&self) -> &Name { + self + } +} + +impl AsMut for NameBuf { + fn as_mut(&mut self) -> &mut Name { + self + } +} + +//--- Forwarding equality, comparison, hashing, and formatting + +impl PartialEq for NameBuf { + fn eq(&self, that: &Self) -> bool { + **self == **that + } +} + +impl Eq for NameBuf {} + +impl PartialOrd for NameBuf { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for NameBuf { + fn cmp(&self, other: &Self) -> Ordering { + (**self).cmp(&**other) + } +} + +impl Hash for NameBuf { + fn hash(&self, state: &mut H) { + (**self).hash(state) + } +} + +impl fmt::Display for NameBuf { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + (**self).fmt(f) + } +} + +impl fmt::Debug for NameBuf { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + (**self).fmt(f) + } +} diff --git a/src/new_base/name/label.rs b/src/new_base/name/label.rs new file mode 100644 index 000000000..f76fca4c4 --- /dev/null +++ b/src/new_base/name/label.rs @@ -0,0 +1,508 @@ +//! Labels in domain names. + +use core::{ + borrow::{Borrow, BorrowMut}, + cmp::Ordering, + fmt, + hash::{Hash, Hasher}, + iter::FusedIterator, + ops::{Deref, DerefMut}, +}; + +use domain_macros::{AsBytes, UnsizedClone}; + +use crate::{ + new_base::{ + build::{BuildIntoMessage, BuildResult, Builder}, + parse::{ParseMessageBytes, SplitMessageBytes}, + wire::{ + BuildBytes, ParseBytes, ParseError, SplitBytes, TruncationError, + }, + }, + utils::CloneFrom, +}; + +//----------- Label ---------------------------------------------------------- + +/// A label in a domain name. +/// +/// A label contains up to 63 bytes of arbitrary data. +#[derive(AsBytes, UnsizedClone)] +#[repr(transparent)] +pub struct Label([u8]); + +//--- Associated Constants + +impl Label { + /// The root label. + pub const ROOT: &'static Self = { + // SAFETY: All slices of 63 bytes or less are valid. + unsafe { Self::from_bytes_unchecked(b"") } + }; + + /// The wildcard label. + pub const WILDCARD: &'static Self = { + // SAFETY: All slices of 63 bytes or less are valid. + unsafe { Self::from_bytes_unchecked(b"*") } + }; +} + +//--- Construction + +impl Label { + /// Assume a byte slice is a valid label. + /// + /// # Safety + /// + /// The byte slice must have length 63 or less. + pub const unsafe fn from_bytes_unchecked(bytes: &[u8]) -> &Self { + // SAFETY: 'Label' is 'repr(transparent)' to '[u8]'. + unsafe { core::mem::transmute(bytes) } + } + + /// Assume a mutable byte slice is a valid label. + /// + /// # Safety + /// + /// The byte slice must have length 63 or less. + pub unsafe fn from_bytes_unchecked_mut(bytes: &mut [u8]) -> &mut Self { + // SAFETY: 'Label' is 'repr(transparent)' to '[u8]'. + unsafe { core::mem::transmute(bytes) } + } +} + +//--- Parsing from DNS messages + +impl<'a> ParseMessageBytes<'a> for &'a Label { + fn parse_message_bytes( + contents: &'a [u8], + start: usize, + ) -> Result { + Self::parse_bytes(&contents[start..]) + } +} + +impl<'a> SplitMessageBytes<'a> for &'a Label { + fn split_message_bytes( + contents: &'a [u8], + start: usize, + ) -> Result<(Self, usize), ParseError> { + Self::split_bytes(&contents[start..]) + .map(|(this, rest)| (this, contents.len() - start - rest.len())) + } +} + +//--- Building into DNS messages + +impl BuildIntoMessage for Label { + fn build_into_message(&self, mut builder: Builder<'_>) -> BuildResult { + builder.append_with(self.len() + 1, |buf| { + buf[0] = self.len() as u8; + buf[1..].copy_from_slice(self.as_bytes()); + })?; + Ok(builder.commit()) + } +} + +//--- Parsing from bytes + +impl<'a> SplitBytes<'a> for &'a Label { + fn split_bytes(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError> { + let (&size, rest) = bytes.split_first().ok_or(ParseError)?; + if size < 64 && rest.len() >= size as usize { + let (label, rest) = rest.split_at(size as usize); + // SAFETY: 'label' is 'size < 64' bytes in size. + Ok((unsafe { Label::from_bytes_unchecked(label) }, rest)) + } else { + Err(ParseError) + } + } +} + +impl<'a> ParseBytes<'a> for &'a Label { + fn parse_bytes(bytes: &'a [u8]) -> Result { + match Self::split_bytes(bytes) { + Ok((this, &[])) => Ok(this), + _ => Err(ParseError), + } + } +} + +//--- Building into byte strings + +impl BuildBytes for Label { + fn build_bytes<'b>( + &self, + bytes: &'b mut [u8], + ) -> Result<&'b mut [u8], TruncationError> { + let (size, data) = bytes.split_first_mut().ok_or(TruncationError)?; + let rest = self.as_bytes().build_bytes(data)?; + *size = self.len() as u8; + Ok(rest) + } +} + +//--- Inspection + +impl Label { + /// The length of this label, in bytes. + #[allow(clippy::len_without_is_empty)] + pub const fn len(&self) -> usize { + self.0.len() + } + + /// Whether this is the root label. + pub const fn is_root(&self) -> bool { + self.0.is_empty() + } + + /// Whether this is a wildcard label. + pub const fn is_wildcard(&self) -> bool { + // NOTE: '==' for byte slices is not 'const'. + self.0.len() == 1 && self.0[0] == b'*' + } + + /// The bytes making up this label. + pub const fn as_bytes(&self) -> &[u8] { + &self.0 + } +} + +//--- Access to the underlying bytes + +impl AsRef<[u8]> for Label { + fn as_ref(&self) -> &[u8] { + &self.0 + } +} + +impl<'a> From<&'a Label> for &'a [u8] { + fn from(value: &'a Label) -> Self { + &value.0 + } +} + +//--- Comparison + +impl PartialEq for Label { + /// Compare two labels for equality. + /// + /// Labels are compared ASCII-case-insensitively. + fn eq(&self, other: &Self) -> bool { + let this = self.as_bytes().iter().map(u8::to_ascii_lowercase); + let that = other.as_bytes().iter().map(u8::to_ascii_lowercase); + this.eq(that) + } +} + +impl Eq for Label {} + +//--- Ordering + +impl PartialOrd for Label { + /// Determine the order between labels. + /// + /// Any uppercase ASCII characters in the labels are treated as if they + /// were lowercase. The first unequal byte between two labels determines + /// its ordering: the label with the smaller byte value is the lesser. If + /// two labels have all the same bytes, the shorter label is lesser; if + /// they are the same length, they are equal. + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for Label { + /// Determine the order between labels. + /// + /// Any uppercase ASCII characters in the labels are treated as if they + /// were lowercase. The first unequal byte between two labels determines + /// its ordering: the label with the smaller byte value is the lesser. If + /// two labels have all the same bytes, the shorter label is lesser; if + /// they are the same length, they are equal. + fn cmp(&self, other: &Self) -> Ordering { + let this = self.as_bytes().iter().map(u8::to_ascii_lowercase); + let that = other.as_bytes().iter().map(u8::to_ascii_lowercase); + this.cmp(that) + } +} + +//--- Hashing + +impl Hash for Label { + /// Hash this label. + /// + /// All uppercase ASCII characters are lowercased beforehand. This way, + /// the hash of a label is case-independent, consistent with how labels + /// are compared and ordered. + /// + /// The label is hashed as if it were a name containing a single label -- + /// the length octet is thus included. This makes the hashing consistent + /// between names and tuples (not slices!) of labels. + fn hash(&self, state: &mut H) { + state.write_u8(self.len() as u8); + for &byte in self.as_bytes() { + state.write_u8(byte.to_ascii_lowercase()) + } + } +} + +//--- Formatting + +impl fmt::Display for Label { + /// Print a label. + /// + /// The label is printed in the conventional zone file format, with bytes + /// outside printable ASCII formatted as `\\DDD` (a backslash followed by + /// three zero-padded decimal digits), and `.` and `\\` simply escaped by + /// a backslash. + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.as_bytes().iter().try_for_each(|&byte| { + if b".\\".contains(&byte) { + write!(f, "\\{}", byte as char) + } else if byte.is_ascii_graphic() { + write!(f, "{}", byte as char) + } else { + write!(f, "\\{:03}", byte) + } + }) + } +} + +impl fmt::Debug for Label { + /// Print a label for debugging purposes. + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("Label") + .field(&format_args!("{}", self)) + .finish() + } +} + +//----------- LabelBuf ------------------------------------------------------- + +/// A 64-byte buffer holding a [`Label`]. +#[derive(Clone)] +#[repr(C)] // make layout compatible with '[u8; 64]' +pub struct LabelBuf { + /// The size of the label, in bytes. + /// + /// This value is guaranteed to be in the range '0..64'. + size: u8, + + /// The underlying label data. + data: [u8; 63], +} + +//--- Construction + +impl LabelBuf { + /// Copy a [`Label`] into a buffer. + pub fn copy_from(label: &Label) -> Self { + let size = label.len() as u8; + let mut data = [0u8; 63]; + data[..size as usize].copy_from_slice(label.as_bytes()); + Self { size, data } + } +} + +impl CloneFrom for LabelBuf { + fn clone_from(value: &Self::Target) -> Self { + Self::copy_from(value) + } +} + +//--- Parsing from DNS messages + +impl ParseMessageBytes<'_> for LabelBuf { + fn parse_message_bytes( + contents: &'_ [u8], + start: usize, + ) -> Result { + Self::parse_bytes(&contents[start..]) + } +} + +impl SplitMessageBytes<'_> for LabelBuf { + fn split_message_bytes( + contents: &'_ [u8], + start: usize, + ) -> Result<(Self, usize), ParseError> { + Self::split_bytes(&contents[start..]) + .map(|(this, rest)| (this, contents.len() - start - rest.len())) + } +} + +//--- Building into DNS messages + +impl BuildIntoMessage for LabelBuf { + fn build_into_message(&self, builder: Builder<'_>) -> BuildResult { + (**self).build_into_message(builder) + } +} + +//--- Parsing from byte strings + +impl ParseBytes<'_> for LabelBuf { + fn parse_bytes(bytes: &[u8]) -> Result { + <&Label>::parse_bytes(bytes).map(Self::copy_from) + } +} + +impl SplitBytes<'_> for LabelBuf { + fn split_bytes(bytes: &'_ [u8]) -> Result<(Self, &'_ [u8]), ParseError> { + <&Label>::split_bytes(bytes) + .map(|(label, rest)| (Self::copy_from(label), rest)) + } +} + +//--- Building into byte strings + +impl BuildBytes for LabelBuf { + fn build_bytes<'b>( + &self, + bytes: &'b mut [u8], + ) -> Result<&'b mut [u8], TruncationError> { + (**self).build_bytes(bytes) + } +} + +//--- Access to the underlying 'Label' + +impl Deref for LabelBuf { + type Target = Label; + + fn deref(&self) -> &Self::Target { + let label = &self.data[..self.size as usize]; + // SAFETY: A 'LabelBuf' always contains a valid 'Label'. + unsafe { Label::from_bytes_unchecked(label) } + } +} + +impl DerefMut for LabelBuf { + fn deref_mut(&mut self) -> &mut Self::Target { + let label = &mut self.data[..self.size as usize]; + // SAFETY: A 'LabelBuf' always contains a valid 'Label'. + unsafe { Label::from_bytes_unchecked_mut(label) } + } +} + +impl Borrow