Files
helios-server-rs/src/middleware/auth.rs
2025-06-08 22:24:44 +08:00

121 lines
3.0 KiB
Rust

use axum::{
extract::FromRequestParts,
http::{request::Parts, StatusCode},
response::{IntoResponse, Response},
Json, RequestPartsExt,
};
use axum_extra::{
headers::{authorization::Bearer, Authorization},
TypedHeader,
};
use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, Validation};
use serde::{Deserialize, Serialize};
use serde_json::json;
use std::fmt::Display;
use std::sync::OnceLock;
#[derive(Debug, Serialize, Deserialize)]
pub struct Claims {
sub: String,
// company: String,
exp: usize,
pub uid: i64,
}
#[derive(Debug, Serialize)]
struct AuthBody {
access_token: String,
token_type: String,
}
#[derive(Debug, Deserialize)]
struct AuthPayload {
client_id: String,
client_secret: String,
}
#[derive(Debug)]
pub enum AuthError {
WrongCredentials,
MissingCredentials,
TokenCreation,
InvalidToken,
}
static KEYS: OnceLock<Keys> = OnceLock::new();
pub fn initialize_jwt_key(key_str: String) {
let res = KEYS.set(Keys::new(key_str.as_bytes()));
match res {
Ok(_) => {}
Err(_) => panic!("jwt key initialize failed"),
}
}
struct Keys {
encoding: EncodingKey,
decoding: DecodingKey,
}
impl Keys {
fn new(secret: &[u8]) -> Self {
Self {
encoding: EncodingKey::from_secret(secret),
decoding: DecodingKey::from_secret(secret),
}
}
}
impl Display for Claims {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Email: {}", self.sub)
}
}
impl AuthBody {
fn new(access_token: String) -> Self {
Self {
access_token,
token_type: "Bearer".to_string(),
}
}
}
impl<S> FromRequestParts<S> for Claims
where
S: Send + Sync,
{
type Rejection = AuthError;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
// Extract the token from the authorization header
let TypedHeader(Authorization(bearer)) = parts
.extract::<TypedHeader<Authorization<Bearer>>>()
.await
.map_err(|_| AuthError::InvalidToken)?;
// Decode the user data
let token_data = decode::<Claims>(
bearer.token(),
&KEYS.get().unwrap().decoding,
&Validation::default(),
)
.map_err(|_| AuthError::InvalidToken)?;
Ok(token_data.claims)
}
}
impl IntoResponse for AuthError {
fn into_response(self) -> Response {
let (status, error_message) = match self {
AuthError::WrongCredentials => (StatusCode::UNAUTHORIZED, "Wrong credentials"),
AuthError::MissingCredentials => (StatusCode::BAD_REQUEST, "Missing credentials"),
AuthError::TokenCreation => (StatusCode::INTERNAL_SERVER_ERROR, "Token creation error"),
AuthError::InvalidToken => (StatusCode::BAD_REQUEST, "Invalid token"),
};
let body = Json(json!({
"error": error_message,
}));
(status, body).into_response()
}
}