From e0e15cac45fd4595c61f47e8e9251d4dae6196f9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABlle=20Huisman?= Date: Wed, 28 Jan 2026 16:14:15 +0100 Subject: [PATCH] feat(oauth,oidc): support relative redirect URL --- .../shield-oauth/src/actions/sign_in.rs | 20 +++++++++---------- .../shield-oidc/src/actions/sign_in.rs | 20 +++++++++---------- 2 files changed, 18 insertions(+), 22 deletions(-) diff --git a/packages/methods/shield-oauth/src/actions/sign_in.rs b/packages/methods/shield-oauth/src/actions/sign_in.rs index 035cf3e..745c7fd 100644 --- a/packages/methods/shield-oauth/src/actions/sign_in.rs +++ b/packages/methods/shield-oauth/src/actions/sign_in.rs @@ -17,8 +17,8 @@ use crate::{ #[derive(Debug, Deserialize)] #[serde(rename_all = "camelCase")] pub struct SignInData { - pub redirect_origin: Option, - pub redirect_url: Option, + pub redirect_origin: Url, + pub redirect_url: Option, } pub struct OauthSignInAction { @@ -91,15 +91,13 @@ impl Action for OauthSignInAction { let data = serde_json::from_value::(request.form_data) .map_err(|err| ShieldError::Validation(err.to_string()))?; - let redirect_url = data.redirect_url.or_else(|| { - data.redirect_origin.and_then(|redirect_origin| { - redirect_origin.join(&self.options.sign_in_redirect).ok() - }) - }); + let redirect_url = data + .redirect_url + .map(|redirect_url| data.redirect_origin.join(&redirect_url)) + .unwrap_or_else(|| data.redirect_origin.join(&self.options.sign_in_redirect)) + .map_err(|err| ShieldError::Validation(format!("redirect URL parse error: {err}")))?; - if let Some(redirect_url) = &redirect_url - && let Some(redirect_origins) = &self.options.redirect_origins - { + if let Some(redirect_origins) = &self.options.redirect_origins { let redirect_origin = Url::parse(&redirect_url.origin().ascii_serialization()) .map_err(|err| { ShieldError::Validation(format!("redirect origin parse error: {err}")) @@ -148,7 +146,7 @@ impl Action for OauthSignInAction { Ok(Response::new(ResponseType::Redirect(auth_url.to_string())) .session_action(SessionAction::Unauthenticate) .session_action(SessionAction::data(OauthSession { - redirect_url, + redirect_url: Some(redirect_url), csrf: Some(csrf_token.secret().clone()), pkce_verifier: pkce_code_challenge .map(|(_, pkce_code_verifier)| pkce_code_verifier.secret().clone()), diff --git a/packages/methods/shield-oidc/src/actions/sign_in.rs b/packages/methods/shield-oidc/src/actions/sign_in.rs index 592cd24..034fd2f 100644 --- a/packages/methods/shield-oidc/src/actions/sign_in.rs +++ b/packages/methods/shield-oidc/src/actions/sign_in.rs @@ -20,8 +20,8 @@ use crate::{ #[derive(Debug, Deserialize)] #[serde(rename_all = "camelCase")] pub struct SignInData { - pub redirect_origin: Option, - pub redirect_url: Option, + pub redirect_origin: Url, + pub redirect_url: Option, } pub struct OidcSignInAction { @@ -94,15 +94,13 @@ impl Action for OidcSignInAction { let data = serde_json::from_value::(request.form_data) .map_err(|err| ShieldError::Validation(err.to_string()))?; - let redirect_url = data.redirect_url.or_else(|| { - data.redirect_origin.and_then(|redirect_origin| { - redirect_origin.join(&self.options.sign_in_redirect).ok() - }) - }); + let redirect_url = data + .redirect_url + .map(|redirect_url| data.redirect_origin.join(&redirect_url)) + .unwrap_or_else(|| data.redirect_origin.join(&self.options.sign_in_redirect)) + .map_err(|err| ShieldError::Validation(format!("redirect URL parse error: {err}")))?; - if let Some(redirect_url) = &redirect_url - && let Some(redirect_origins) = &self.options.redirect_origins - { + if let Some(redirect_origins) = &self.options.redirect_origins { let redirect_origin = Url::parse(&redirect_url.origin().ascii_serialization()) .map_err(|err| { ShieldError::Validation(format!("redirect origin parse error: {err}")) @@ -153,7 +151,7 @@ impl Action for OidcSignInAction { Ok(Response::new(ResponseType::Redirect(auth_url.to_string())) .session_action(SessionAction::unauthenticate()) .session_action(SessionAction::data(OidcSession { - redirect_url, + redirect_url: Some(redirect_url), csrf: Some(csrf_token.secret().clone()), nonce: Some(nonce.secret().clone()), pkce_verifier: pkce_code_challenge