diff --git a/Cargo.lock b/Cargo.lock index 6e7d6b2..ce9643d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5948,6 +5948,7 @@ dependencies = [ "bon", "chrono", "oauth2", + "regex", "secrecy", "serde", "serde_json", @@ -5964,6 +5965,7 @@ dependencies = [ "chrono", "oauth2", "openidconnect", + "regex", "secrecy", "serde", "serde_json", diff --git a/Cargo.toml b/Cargo.toml index b6581ae..c63eac5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -50,6 +50,7 @@ leptos_actix = "0.8.3" leptos_axum = "0.8.3" leptos_meta = "0.8.3" leptos_router = "0.8.3" +regex = "1.12.2" sea-orm = "1.1.2" sea-orm-migration = "1.1.2" secrecy = "0.10.3" diff --git a/packages/core/shield/src/options.rs b/packages/core/shield/src/options.rs index ddeda26..1599fa9 100644 --- a/packages/core/shield/src/options.rs +++ b/packages/core/shield/src/options.rs @@ -2,12 +2,7 @@ use bon::Builder; #[derive(Builder, Clone, Debug)] #[builder(on(String, into), state_mod(vis = "pub(crate)"))] -pub struct ShieldOptions { - #[builder(default = "/")] - pub sign_in_redirect: String, - #[builder(default = "/")] - pub sign_out_redirect: String, -} +pub struct ShieldOptions {} impl Default for ShieldOptions { fn default() -> Self { diff --git a/packages/integrations/shield-axum/src/middleware.rs b/packages/integrations/shield-axum/src/middleware.rs index 0932961..f388adc 100644 --- a/packages/integrations/shield-axum/src/middleware.rs +++ b/packages/integrations/shield-axum/src/middleware.rs @@ -1,11 +1,11 @@ use axum::{ extract::Request, middleware::Next, - response::{IntoResponse, Redirect, Response}, + response::{IntoResponse, Response}, }; use shield::{ShieldError, User}; -use crate::{ExtractShield, ExtractUser, error::RouteError}; +use crate::{ExtractUser, error::RouteError}; pub async fn auth_required( ExtractUser(user): ExtractUser, @@ -17,15 +17,3 @@ pub async fn auth_required( None => RouteError::from(ShieldError::Unauthorized).into_response(), } } - -pub async fn auth_required_redirect( - ExtractShield(shield): ExtractShield, - ExtractUser(user): ExtractUser, - request: Request, - next: Next, -) -> Response { - match user { - Some(_) => next.run(request).await, - None => Redirect::to(&shield.options().sign_in_redirect).into_response(), - } -} diff --git a/packages/methods/shield-oauth/Cargo.toml b/packages/methods/shield-oauth/Cargo.toml index d97906f..df85772 100644 --- a/packages/methods/shield-oauth/Cargo.toml +++ b/packages/methods/shield-oauth/Cargo.toml @@ -24,6 +24,7 @@ oauth2 = { version = "5.0.0", default-features = false, features = [ "pkce-plain", "reqwest", ] } +regex.workspace = true secrecy.workspace = true serde.workspace = true serde_json.workspace = true diff --git a/packages/methods/shield-oauth/src/actions/sign_in.rs b/packages/methods/shield-oauth/src/actions/sign_in.rs index 745c7fd..94dd855 100644 --- a/packages/methods/shield-oauth/src/actions/sign_in.rs +++ b/packages/methods/shield-oauth/src/actions/sign_in.rs @@ -110,6 +110,18 @@ impl Action for OauthSignInAction { } } + if let Some(redirect_patterns) = &self.options.redirect_patterns { + let redirect_url_str = redirect_url.to_string(); + if !redirect_patterns + .iter() + .any(|pattern| pattern.is_match(&redirect_url_str)) + { + return Err(ShieldError::Validation(format!( + "redirect URL `{redirect_url}` not allowed" + ))); + } + } + let client = provider.oauth_client().await?; let mut authorization_request = client diff --git a/packages/methods/shield-oauth/src/options.rs b/packages/methods/shield-oauth/src/options.rs index bf6b4b9..4222ab0 100644 --- a/packages/methods/shield-oauth/src/options.rs +++ b/packages/methods/shield-oauth/src/options.rs @@ -1,4 +1,5 @@ use bon::Builder; +use regex::Regex; use url::Url; #[derive(Builder, Clone, Debug)] @@ -9,6 +10,9 @@ pub struct OauthOptions { #[builder(with = FromIterator::from_iter)] pub(crate) redirect_origins: Option>, + + #[builder(with = FromIterator::from_iter)] + pub(crate) redirect_patterns: Option>, } impl Default for OauthOptions { diff --git a/packages/methods/shield-oidc/Cargo.toml b/packages/methods/shield-oidc/Cargo.toml index 024bd0d..cd448ad 100644 --- a/packages/methods/shield-oidc/Cargo.toml +++ b/packages/methods/shield-oidc/Cargo.toml @@ -26,6 +26,7 @@ oauth2 = { version = "5.0.0", default-features = false, features = [ openidconnect = { version = "4.0.0", default-features = false, features = [ "reqwest", ] } +regex.workspace = true secrecy.workspace = true serde.workspace = true serde_json.workspace = true diff --git a/packages/methods/shield-oidc/src/actions/sign_in.rs b/packages/methods/shield-oidc/src/actions/sign_in.rs index 034fd2f..d62f7d8 100644 --- a/packages/methods/shield-oidc/src/actions/sign_in.rs +++ b/packages/methods/shield-oidc/src/actions/sign_in.rs @@ -113,6 +113,18 @@ impl Action for OidcSignInAction { } } + if let Some(redirect_patterns) = &self.options.redirect_patterns { + let redirect_url_str = redirect_url.to_string(); + if !redirect_patterns + .iter() + .any(|pattern| pattern.is_match(&redirect_url_str)) + { + return Err(ShieldError::Validation(format!( + "redirect URL `{redirect_url}` not allowed" + ))); + } + } + let client = provider.oidc_client().await?; let mut authorization_request = client.authorize_url( diff --git a/packages/methods/shield-oidc/src/options.rs b/packages/methods/shield-oidc/src/options.rs index 0e2ce27..4b6ba0b 100644 --- a/packages/methods/shield-oidc/src/options.rs +++ b/packages/methods/shield-oidc/src/options.rs @@ -1,4 +1,5 @@ use bon::Builder; +use regex::Regex; use url::Url; #[derive(Builder, Clone, Debug)] @@ -9,6 +10,9 @@ pub struct OidcOptions { #[builder(with = FromIterator::from_iter)] pub(crate) redirect_origins: Option>, + + #[builder(with = FromIterator::from_iter)] + pub(crate) redirect_patterns: Option>, } impl Default for OidcOptions {