Skip to content

Commit

Permalink
Requested authn context (#50)
Browse files Browse the repository at this point in the history
* Add `RequestedAuthnContext` to `AuthnRequest`

* Update CHANGELOG

* Implement `FromStr` for Assertion and add a test for it

---------

Co-authored-by: Daniel Wiesenberg <[email protected]>
  • Loading branch information
Weasy666 and Daniel Wiesenberg authored May 29, 2024
1 parent d5d0fff commit 2acc01b
Show file tree
Hide file tree
Showing 4 changed files with 180 additions and 1 deletion.
3 changes: 3 additions & 0 deletions CHANGELOG
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# Changelog

### Unreleased
- Add support for basic `RequestedAuthnContext` de-/serialization in `AuthnRequest`

### 0.0.15

- Updates dependencies
Expand Down
9 changes: 8 additions & 1 deletion src/schema/authn_request.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::schema::{Conditions, Issuer, NameIdPolicy, Subject};
use crate::schema::{Conditions, Issuer, NameIdPolicy, RequestedAuthnContext, Subject};
use crate::signature::Signature;
use chrono::prelude::*;
use quick_xml::events::{BytesDecl, BytesEnd, BytesStart, BytesText, Event};
Expand Down Expand Up @@ -36,6 +36,8 @@ pub struct AuthnRequest {
pub name_id_policy: Option<NameIdPolicy>,
#[serde(rename = "Conditions")]
pub conditions: Option<Conditions>,
#[serde(rename = "RequestedAuthnContext")]
pub requested_authn_context: Option<RequestedAuthnContext>,
#[serde(rename = "@ForceAuthn")]
pub force_authn: Option<bool>,
#[serde(rename = "@IsPassive")]
Expand Down Expand Up @@ -65,6 +67,7 @@ impl Default for AuthnRequest {
subject: None,
name_id_policy: None,
conditions: None,
requested_authn_context: None,
force_authn: None,
is_passive: None,
assertion_consumer_service_index: None,
Expand Down Expand Up @@ -219,6 +222,10 @@ impl TryFrom<&AuthnRequest> for Event<'_> {
let event: Event<'_> = conditions.try_into()?;
writer.write_event(event)?;
}
if let Some(requested_authn_context) = &value.requested_authn_context {
let event: Event<'_> = requested_authn_context.try_into()?;
writer.write_event(event)?;
}

writer.write_event(Event::End(BytesEnd::new(NAME)))?;

Expand Down
51 changes: 51 additions & 0 deletions src/schema/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@ pub mod authn_request;
mod conditions;
mod issuer;
mod name_id_policy;
mod requested_authn_context;
mod response;
mod subject;

pub use authn_request::AuthnRequest;
pub use conditions::*;
pub use issuer::Issuer;
pub use name_id_policy::NameIdPolicy;
pub use requested_authn_context::{AuthnContextComparison, RequestedAuthnContext};
pub use response::Response;
pub use subject::*;

Expand Down Expand Up @@ -197,6 +199,14 @@ impl Assertion {
}
}

impl FromStr for Assertion {
type Err = Box<dyn std::error::Error>;

fn from_str(s: &str) -> Result<Self, Self::Err> {
Ok(quick_xml::de::from_str(s)?)
}
}

impl TryFrom<Assertion> for Event<'_> {
type Error = Box<dyn std::error::Error>;

Expand Down Expand Up @@ -483,6 +493,47 @@ impl TryFrom<&AuthnContextClassRef> for Event<'_> {
}
}

#[derive(Clone, Debug, Deserialize, Hash, Eq, PartialEq, Ord, PartialOrd)]
pub struct AuthnContextDeclRef {
#[serde(rename = "$value")]
pub value: Option<String>,
}

impl AuthnContextDeclRef {
fn name() -> &'static str {
"saml2:AuthnContextDeclRef"
}
}

impl TryFrom<AuthnContextDeclRef> for Event<'_> {
type Error = Box<dyn std::error::Error>;

fn try_from(value: AuthnContextDeclRef) -> Result<Self, Self::Error> {
(&value).try_into()
}
}

impl TryFrom<&AuthnContextDeclRef> for Event<'_> {
type Error = Box<dyn std::error::Error>;

fn try_from(value: &AuthnContextDeclRef) -> Result<Self, Self::Error> {
if let Some(value) = &value.value {
let mut write_buf = Vec::new();
let mut writer = Writer::new(Cursor::new(&mut write_buf));
let root = BytesStart::new(AuthnContextDeclRef::name());

writer.write_event(Event::Start(root))?;
writer.write_event(Event::Text(BytesText::from_escaped(value)))?;
writer.write_event(Event::End(BytesEnd::new(AuthnContextDeclRef::name())))?;
Ok(Event::Text(BytesText::from_escaped(String::from_utf8(
write_buf,
)?)))
} else {
Ok(Event::Text(BytesText::from_escaped(String::new())))
}
}
}

#[derive(Clone, Debug, Deserialize, Hash, Eq, PartialEq, Ord, PartialOrd)]
pub struct Status {
#[serde(rename = "StatusCode")]
Expand Down
118 changes: 118 additions & 0 deletions src/schema/requested_authn_context.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
use crate::schema::{AuthnContextClassRef, AuthnContextDeclRef};
use quick_xml::events::{BytesEnd, BytesStart, BytesText, Event};
use quick_xml::Writer;
use serde::Deserialize;
use std::io::Cursor;
use std::str::FromStr;

const NAME: &str = "saml2p:RequestedAuthnContext";
const SCHEMA: (&str, &str) = ("xmlns:saml2", "urn:oasis:names:tc:SAML:2.0:assertion");

#[derive(Clone, Debug, Deserialize, Hash, Eq, PartialEq, Ord, PartialOrd)]
pub struct RequestedAuthnContext {
#[serde(rename = "AuthnContextClassRef")]
pub authn_context_class_refs: Option<Vec<AuthnContextClassRef>>,
#[serde(rename = "AuthnContextDeclRef")]
pub authn_context_decl_refs: Option<Vec<AuthnContextDeclRef>>,
#[serde(rename = "@Comparison")]
pub comparison: Option<AuthnContextComparison>,
}

impl TryFrom<RequestedAuthnContext> for Event<'_> {
type Error = Box<dyn std::error::Error>;

fn try_from(value: RequestedAuthnContext) -> Result<Self, Self::Error> {
(&value).try_into()
}
}

impl TryFrom<&RequestedAuthnContext> for Event<'_> {
type Error = Box<dyn std::error::Error>;

fn try_from(value: &RequestedAuthnContext) -> Result<Self, Self::Error> {
let mut write_buf = Vec::new();
let mut writer = Writer::new(Cursor::new(&mut write_buf));
let mut root = BytesStart::from_content(NAME, NAME.len());
root.push_attribute(SCHEMA);

if let Some(comparison) = &value.comparison {
root.push_attribute(("Comparison", comparison.value()));
}
writer.write_event(Event::Start(root))?;

if let Some(authn_context_class_refs) = &value.authn_context_class_refs {
for authn_context_class_ref in authn_context_class_refs {
let event: Event<'_> = authn_context_class_ref.try_into()?;
writer.write_event(event)?;
}
} else if let Some(authn_context_decl_refs) = &value.authn_context_decl_refs {
for authn_context_decl_ref in authn_context_decl_refs {
let event: Event<'_> = authn_context_decl_ref.try_into()?;
writer.write_event(event)?;
}
}

writer.write_event(Event::End(BytesEnd::new(NAME)))?;
Ok(Event::Text(BytesText::from_escaped(String::from_utf8(
write_buf,
)?)))
}
}

#[derive(Clone, Debug, Deserialize, Hash, Eq, PartialEq, Ord, PartialOrd)]
#[serde(rename_all = "lowercase")]
pub enum AuthnContextComparison {
Exact,
Minimum,
Maximum,
Better,
}

impl AuthnContextComparison {
pub fn value(&self) -> &'static str {
match self {
AuthnContextComparison::Exact => "exact",
AuthnContextComparison::Minimum => "minimum",
AuthnContextComparison::Maximum => "maximum",
AuthnContextComparison::Better => "better",
}
}
}

impl FromStr for AuthnContextComparison {
type Err = quick_xml::DeError;

fn from_str(s: &str) -> Result<Self, Self::Err> {
Ok(match s {
"exact" => AuthnContextComparison::Exact,
"minimum" => AuthnContextComparison::Minimum,
"maximum" => AuthnContextComparison::Maximum,
"better" => AuthnContextComparison::Better,
_ => {
return Err(quick_xml::DeError::Custom("Illegal comparison! Must be one of `exact`, `minimum`, `maximum` or `better`".to_string()));
}
})
}
}

#[cfg(test)]
mod test {
use crate::traits::ToXml;

use super::*;

#[test]
pub fn test_deserialize_serialize_requested_authn_context() {
let xml_context = r#"<saml2p:RequestedAuthnContext xmlns:saml2="urn:oasis:names:tc:SAML:2.0:assertion" Comparison="exact"><saml2:AuthnContextClassRef>urn:oasis:names:tc:SAML:2.0:ac:classes:PasswordProtectedTransport</saml2:AuthnContextClassRef></saml2p:RequestedAuthnContext>"#;

let expected_context: RequestedAuthnContext = quick_xml::de::from_str(xml_context)
.expect("failed to parse RequestedAuthnContext");
let serialized_context = expected_context
.to_xml()
.expect("failed to convert RequestedAuthnContext to xml");
let actual_context: RequestedAuthnContext = quick_xml::de::from_str(&serialized_context)
.expect("failed to re-parse RequestedAuthnContext");

assert_eq!(expected_context, actual_context);
}
}

0 comments on commit 2acc01b

Please sign in to comment.