Skip to content

Commit

Permalink
[Auth] Cover the auth middleware with tests
Browse files Browse the repository at this point in the history
  • Loading branch information
MohamedBassem committed Jul 31, 2023
1 parent 17ec962 commit adcb71c
Show file tree
Hide file tree
Showing 4 changed files with 291 additions and 13 deletions.
3 changes: 2 additions & 1 deletion cronback-services/src/api/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ impl From<AuthError> for ApiError {
}
}

#[derive(Clone)]
pub struct Authenticator {
store: AuthStore,
}
Expand Down Expand Up @@ -190,7 +191,7 @@ impl FromStr for SecretApiKey {
}

impl SecretApiKey {
fn generate() -> Self {
pub fn generate() -> Self {
Self {
key_id: Uuid::new_v4().simple().to_string(),
plain_secret: Uuid::new_v4().simple().to_string(),
Expand Down
291 changes: 284 additions & 7 deletions cronback-services/src/api/auth_middleware.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,34 @@
use std::sync::Arc;

use axum::extract::State;
use axum::extract::{FromRef, State};
use axum::http::{self, HeaderMap, HeaderValue, Request};
use axum::middleware::Next;
use axum::response::IntoResponse;
use lib::prelude::*;

use super::auth::{AuthError, SecretApiKey};
use super::auth::{AuthError, Authenticator, SecretApiKey};
use super::errors::ApiError;
use super::AppState;

const ON_BEHALF_OF_HEADER_NAME: &str = "X-On-Behalf-Of";

// Partial state from the main app state to facilitate writing tests for the
// middleware.
#[derive(Clone)]
pub struct AuthenticationState {
authenticator: Authenticator,
config: super::config::ApiSvcConfig,
}

impl FromRef<Arc<AppState>> for AuthenticationState {
fn from_ref(input: &Arc<AppState>) -> Self {
Self {
authenticator: input.authenticator.clone(),
config: input.context.service_config(),
}
}
}

enum AuthenticationStatus {
Unauthenticated,
Authenticated(ValidShardedId<ProjectId>),
Expand Down Expand Up @@ -61,14 +78,14 @@ fn get_auth_key(
}

async fn get_auth_status<B>(
state: &AppState,
state: &AuthenticationState,
req: &Request<B>,
) -> Result<AuthenticationStatus, ApiError> {
let auth_key = get_auth_key(req.headers())?;
let Some(auth_key) = auth_key else {
return Ok(AuthenticationStatus::Unauthenticated);
};
let config = state.context.service_config();
let config = &state.config;
let admin_keys = &config.admin_api_keys;
if admin_keys.contains(&auth_key) {
let project: Option<ValidShardedId<ProjectId>> = req
Expand Down Expand Up @@ -98,7 +115,10 @@ async fn get_auth_status<B>(
return Ok(AuthenticationStatus::Unauthenticated);
};

let project = state.authenicator.authenticate(&user_provided_secret).await;
let project = state
.authenticator
.authenticate(&user_provided_secret)
.await;
match project {
| Ok(project_id) => Ok(AuthenticationStatus::Authenticated(project_id)),
| Err(AuthError::AuthFailed(_)) => {
Expand Down Expand Up @@ -178,11 +198,11 @@ pub async fn ensure_admin<B>(
/// of the other "ensure_*" middlewares in this module to enforce the expected
/// AuthenticationStatus for a certain route.
pub async fn authenticate<B>(
State(state): State<Arc<AppState>>,
State(state): State<AuthenticationState>,
mut req: Request<B>,
next: Next<B>,
) -> Result<impl IntoResponse, ApiError> {
let auth_status = get_auth_status(state.as_ref(), &req).await?;
let auth_status = get_auth_status(&state, &req).await?;

let project_id = auth_status.project_id();
req.extensions_mut().insert(auth_status);
Expand All @@ -200,3 +220,260 @@ pub async fn authenticate<B>(

Ok(resp)
}

#[cfg(test)]
mod tests {

use std::collections::HashSet;
use std::fmt::Debug;

use axum::routing::get;
use axum::{middleware, Router};
use cronback_api_model::admin::CreateAPIkeyRequest;
use hyper::{Body, StatusCode};
use tower::ServiceExt;

use super::*;
use crate::api::auth_store::AuthStore;
use crate::api::config::ApiSvcConfig;
use crate::api::ApiService;

async fn make_state() -> AuthenticationState {
let mut set = HashSet::new();
set.insert("adminkey1".to_string());
set.insert("adminkey2".to_string());

let config = ApiSvcConfig {
address: String::new(),
port: 123,
database_uri: String::new(),
admin_api_keys: set,
log_request_body: false,
log_response_body: false,
};

let db = ApiService::in_memory_database().await.unwrap();
let auth_store = AuthStore::new(db);
let authenticator = Authenticator::new(auth_store);

AuthenticationState {
authenticator,
config,
}
}

struct TestInput {
app: Router,
auth_header: Option<String>,
on_behalf_on_header: Option<String>,
expected_status: StatusCode,
}

impl Debug for TestInput {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TestInput")
.field("auth_header", &self.auth_header)
.field("on_behalf_on_header", &self.on_behalf_on_header)
.field("expected_status", &self.expected_status)
.finish()
}
}

struct TestExpectations {
unauthenticated: StatusCode,
authenticated: StatusCode,
admin_no_project: StatusCode,
admin_with_project: StatusCode,
unknown_secret_key: StatusCode,
}

async fn run_tests(
app: Router,
state: AuthenticationState,
expectations: TestExpectations,
) -> anyhow::Result<()> {
// Define one project and generate a key for it.
let prj1 = ProjectId::generate();
let key = state
.authenticator
.gen_key(
CreateAPIkeyRequest {
key_name: "test".to_string(),
metadata: Default::default(),
},
&prj1,
)
.await?;

let inputs = vec![
// Unauthenticated user
TestInput {
app: app.clone(),
auth_header: None,
on_behalf_on_header: None,
expected_status: expectations.unauthenticated,
},
// Authenticated user
TestInput {
app: app.clone(),
auth_header: Some(format!("Bearer {}", key.unsafe_to_string())),
on_behalf_on_header: None,
expected_status: expectations.authenticated,
},
// Admin without project
TestInput {
app: app.clone(),
auth_header: Some("Bearer adminkey1".to_string()),
on_behalf_on_header: None,
expected_status: expectations.admin_no_project,
},
// Admin with project
TestInput {
app: app.clone(),
auth_header: Some("Bearer adminkey1".to_string()),
on_behalf_on_header: Some(prj1.to_string()),
expected_status: expectations.admin_with_project,
},
// Unknown secret key
TestInput {
app: app.clone(),
auth_header: Some(format!(
"Bearer {}",
SecretApiKey::generate().unsafe_to_string()
)),
on_behalf_on_header: Some(prj1.to_string()),
expected_status: expectations.unknown_secret_key,
},
// Malformed secret key should be treated as an unknown secret key
TestInput {
app: app.clone(),
auth_header: Some("Bearer wrong key".to_string()),
on_behalf_on_header: Some("wrong_project".to_string()),
expected_status: expectations.unknown_secret_key,
},
// Malformed authorization header
TestInput {
app: app.clone(),
auth_header: Some(format!("Token {}", key.unsafe_to_string())),
on_behalf_on_header: Some(prj1.to_string()),
expected_status: StatusCode::BAD_REQUEST,
},
// Malformed on-behalf-on project id
TestInput {
app: app.clone(),
auth_header: Some("Bearer adminkey1".to_string()),
on_behalf_on_header: Some("wrong_project".to_string()),
expected_status: StatusCode::BAD_REQUEST,
},
];

for input in inputs {
let input_str = format!("{:?}", input);

let mut req = Request::builder();
if let Some(v) = input.auth_header {
req = req.header("Authorization", v);
}
if let Some(v) = input.on_behalf_on_header {
req = req.header(ON_BEHALF_OF_HEADER_NAME, v);
}

let resp = input
.app
.oneshot(req.uri("/").body(Body::empty()).unwrap())
.await?;

assert_eq!(
resp.status(),
input.expected_status,
"Input: {}",
input_str
);
}
Ok(())
}

#[tokio::test]
async fn test_ensure_authenticated() -> anyhow::Result<()> {
let state = make_state().await;

let app = Router::new()
.route("/", get(|| async { "Hello, World!" }))
.layer(middleware::from_fn(super::ensure_authenticated))
.layer(middleware::from_fn_with_state(
state.clone(),
super::authenticate,
));

run_tests(
app,
state,
TestExpectations {
unauthenticated: StatusCode::UNAUTHORIZED,
authenticated: StatusCode::OK,
admin_no_project: StatusCode::BAD_REQUEST,
admin_with_project: StatusCode::OK,
unknown_secret_key: StatusCode::UNAUTHORIZED,
},
)
.await?;

Ok(())
}

#[tokio::test]
async fn test_ensure_admin() -> anyhow::Result<()> {
let state = make_state().await;

let app = Router::new()
.route("/", get(|| async { "Hello, World!" }))
.layer(middleware::from_fn(super::ensure_admin))
.layer(middleware::from_fn_with_state(
state.clone(),
super::authenticate,
));

run_tests(
app,
state,
TestExpectations {
unauthenticated: StatusCode::UNAUTHORIZED,
authenticated: StatusCode::FORBIDDEN,
admin_no_project: StatusCode::OK,
admin_with_project: StatusCode::OK,
unknown_secret_key: StatusCode::UNAUTHORIZED,
},
)
.await?;

Ok(())
}

#[tokio::test]
async fn test_ensure_admin_for_project() -> anyhow::Result<()> {
let state = make_state().await;

let app = Router::new()
.route("/", get(|| async { "Hello, World!" }))
.layer(middleware::from_fn(super::ensure_admin_for_project))
.layer(middleware::from_fn_with_state(
state.clone(),
super::authenticate,
));

run_tests(
app,
state,
TestExpectations {
unauthenticated: StatusCode::UNAUTHORIZED,
authenticated: StatusCode::FORBIDDEN,
admin_no_project: StatusCode::BAD_REQUEST,
admin_with_project: StatusCode::OK,
unknown_secret_key: StatusCode::UNAUTHORIZED,
},
)
.await?;

Ok(())
}
}
6 changes: 3 additions & 3 deletions cronback-services/src/api/handlers/admin/api_keys.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ pub(crate) async fn create(
Extension(project): Extension<ValidShardedId<ProjectId>>,
ValidatedJson(req): ValidatedJson<CreateAPIkeyRequest>,
) -> Result<Json<CreateAPIKeyResponse>, ApiError> {
let key = state.authenicator.gen_key(req, &project).await?;
let key = state.authenticator.gen_key(req, &project).await?;

// This is the only legitimate place where this function should be used.
let key_str = key.unsafe_to_string();
Expand All @@ -38,7 +38,7 @@ pub(crate) async fn list(
Extension(project): Extension<ValidShardedId<ProjectId>>,
) -> Result<Paginated<ApiKey>, ApiError> {
let keys = state
.authenicator
.authenticator
.list_keys(&project)
.await
.map_err(|e| AppStateError::DatabaseError(e.to_string()))?
Expand Down Expand Up @@ -72,7 +72,7 @@ pub(crate) async fn revoke(
Extension(project): Extension<ValidShardedId<ProjectId>>,
) -> Result<StatusCode, ApiError> {
let deleted = state
.authenicator
.authenticator
.revoke_key(&id, &project)
.await
.map_err(|e| AppStateError::DatabaseError(e.to_string()))?;
Expand Down
Loading

0 comments on commit adcb71c

Please sign in to comment.