From f66c7dfbd4901e808a79ae7c42920f7818a59149 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABlle=20Huisman?= Date: Wed, 28 Jan 2026 15:22:39 +0100 Subject: [PATCH] feat(oauth,oidc): add redirect URL --- .../shield-oauth/src/actions/sign_in.rs | 30 +++++++++++++------ .../src/actions/sign_in_callback.rs | 9 ++++-- packages/methods/shield-oauth/src/session.rs | 2 +- .../shield-oidc/src/actions/sign_in.rs | 30 +++++++++++++------ .../src/actions/sign_in_callback.rs | 12 ++------ packages/methods/shield-oidc/src/session.rs | 2 +- 6 files changed, 54 insertions(+), 31 deletions(-) diff --git a/packages/methods/shield-oauth/src/actions/sign_in.rs b/packages/methods/shield-oauth/src/actions/sign_in.rs index eb87254..145a272 100644 --- a/packages/methods/shield-oauth/src/actions/sign_in.rs +++ b/packages/methods/shield-oauth/src/actions/sign_in.rs @@ -18,6 +18,7 @@ use crate::{ #[serde(rename_all = "camelCase")] pub struct SignInData { pub redirect_origin: Option, + pub redirect_url: Option, } pub struct OauthSignInAction { @@ -74,15 +75,26 @@ impl Action for OauthSignInAction { let data = serde_json::from_value::(request.form_data) .map_err(|err| ShieldError::Validation(err.to_string()))?; - let redirect_origin = if let Some(redirect_origins) = &self.options.redirect_origins - && let Some(redirect_origin) = data.redirect_origin - // TODO: Consider returning an error when redirect origin is not allowed. - && redirect_origins.contains(&redirect_origin) + let redirect_url = data.redirect_url.or_else(|| { + data.redirect_origin.and_then(|redirect_origin| { + redirect_origin.join(&self.options.sign_in_redirect).ok() + }) + }); + + if let Some(redirect_url) = &redirect_url + && let Some(redirect_origins) = &self.options.redirect_origins { - Some(redirect_origin) - } else { - None - }; + let redirect_origin = Url::parse(&redirect_url.origin().ascii_serialization()) + .map_err(|err| { + ShieldError::Validation(format!("redirect origin parse error: {err}")) + })?; + + if !redirect_origins.contains(&redirect_origin) { + return Err(ShieldError::Validation(format!( + "redirect origin `{redirect_origin}` not allowed" + ))); + } + } let client = provider.oauth_client().await?; @@ -120,7 +132,7 @@ impl Action for OauthSignInAction { Ok(Response::new(ResponseType::Redirect(auth_url.to_string())) .session_action(SessionAction::Unauthenticate) .session_action(SessionAction::data(OauthSession { - redirect_origin, + redirect_url, csrf: Some(csrf_token.secret().clone()), pkce_verifier: pkce_code_challenge .map(|(_, pkce_code_verifier)| pkce_code_verifier.secret().clone()), diff --git a/packages/methods/shield-oauth/src/actions/sign_in_callback.rs b/packages/methods/shield-oauth/src/actions/sign_in_callback.rs index 0623d08..f03e03e 100644 --- a/packages/methods/shield-oauth/src/actions/sign_in_callback.rs +++ b/packages/methods/shield-oauth/src/actions/sign_in_callback.rs @@ -257,11 +257,16 @@ impl Action for OauthSignInCallb }; Ok(Response::new(ResponseType::Redirect( - self.options.sign_in_redirect.clone(), + session + .method + .redirect_url + .as_ref() + .map(ToString::to_string) + .unwrap_or_else(|| self.options.sign_in_redirect.clone()), )) .session_action(SessionAction::authenticate(user)) .session_action(SessionAction::data(OauthSession { - redirect_origin: None, + redirect_url: None, csrf: None, pkce_verifier: None, oauth_connection_id: Some(connection.id), diff --git a/packages/methods/shield-oauth/src/session.rs b/packages/methods/shield-oauth/src/session.rs index 4a70be4..565a737 100644 --- a/packages/methods/shield-oauth/src/session.rs +++ b/packages/methods/shield-oauth/src/session.rs @@ -3,7 +3,7 @@ use url::Url; #[derive(Clone, Debug, Default, Deserialize, Serialize)] pub struct OauthSession { - pub redirect_origin: Option, + pub redirect_url: Option, pub csrf: Option, pub pkce_verifier: Option, pub oauth_connection_id: Option, diff --git a/packages/methods/shield-oidc/src/actions/sign_in.rs b/packages/methods/shield-oidc/src/actions/sign_in.rs index 25d9ccd..8d087e8 100644 --- a/packages/methods/shield-oidc/src/actions/sign_in.rs +++ b/packages/methods/shield-oidc/src/actions/sign_in.rs @@ -21,6 +21,7 @@ use crate::{ #[serde(rename_all = "camelCase")] pub struct SignInData { pub redirect_origin: Option, + pub redirect_url: Option, } pub struct OidcSignInAction { @@ -85,15 +86,26 @@ impl Action for OidcSignInAction { let data = serde_json::from_value::(request.form_data) .map_err(|err| ShieldError::Validation(err.to_string()))?; - let redirect_origin = if let Some(redirect_origins) = &self.options.redirect_origins - && let Some(redirect_origin) = data.redirect_origin - // TODO: Consider returning an error when redirect origin is not allowed. - && redirect_origins.contains(&redirect_origin) + let redirect_url = data.redirect_url.or_else(|| { + data.redirect_origin.and_then(|redirect_origin| { + redirect_origin.join(&self.options.sign_in_redirect).ok() + }) + }); + + if let Some(redirect_url) = &redirect_url + && let Some(redirect_origins) = &self.options.redirect_origins { - Some(redirect_origin) - } else { - None - }; + let redirect_origin = Url::parse(&redirect_url.origin().ascii_serialization()) + .map_err(|err| { + ShieldError::Validation(format!("redirect origin parse error: {err}")) + })?; + + if !redirect_origins.contains(&redirect_origin) { + return Err(ShieldError::Validation(format!( + "redirect origin `{redirect_origin}` not allowed" + ))); + } + } let client = provider.oidc_client().await?; @@ -133,7 +145,7 @@ impl Action for OidcSignInAction { Ok(Response::new(ResponseType::Redirect(auth_url.to_string())) .session_action(SessionAction::unauthenticate()) .session_action(SessionAction::data(OidcSession { - redirect_origin, + redirect_url, csrf: Some(csrf_token.secret().clone()), nonce: Some(nonce.secret().clone()), pkce_verifier: pkce_code_challenge diff --git a/packages/methods/shield-oidc/src/actions/sign_in_callback.rs b/packages/methods/shield-oidc/src/actions/sign_in_callback.rs index 2fc98d9..9b3b229 100644 --- a/packages/methods/shield-oidc/src/actions/sign_in_callback.rs +++ b/packages/methods/shield-oidc/src/actions/sign_in_callback.rs @@ -293,20 +293,14 @@ impl Action for OidcSignInCallback Ok(Response::new(ResponseType::Redirect( session .method - .redirect_origin + .redirect_url .as_ref() - .and_then(|redirect_origin| { - redirect_origin - .join(&self.options.sign_in_redirect) - .as_ref() - .map(ToString::to_string) - .ok() - }) + .map(ToString::to_string) .unwrap_or_else(|| self.options.sign_in_redirect.clone()), )) .session_action(SessionAction::authenticate(user)) .session_action(SessionAction::data(OidcSession { - redirect_origin: None, + redirect_url: None, csrf: None, nonce: None, pkce_verifier: None, diff --git a/packages/methods/shield-oidc/src/session.rs b/packages/methods/shield-oidc/src/session.rs index 5f57509..79e7ec9 100644 --- a/packages/methods/shield-oidc/src/session.rs +++ b/packages/methods/shield-oidc/src/session.rs @@ -3,7 +3,7 @@ use url::Url; #[derive(Clone, Debug, Default, Deserialize, Serialize)] pub struct OidcSession { - pub redirect_origin: Option, + pub redirect_url: Option, pub csrf: Option, pub nonce: Option, pub pkce_verifier: Option,