121 lines
3.0 KiB
Rust
121 lines
3.0 KiB
Rust
use axum::{
|
|
extract::FromRequestParts,
|
|
http::{request::Parts, StatusCode},
|
|
response::{IntoResponse, Response},
|
|
Json, RequestPartsExt,
|
|
};
|
|
use axum_extra::{
|
|
headers::{authorization::Bearer, Authorization},
|
|
TypedHeader,
|
|
};
|
|
use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, Validation};
|
|
use serde::{Deserialize, Serialize};
|
|
use serde_json::json;
|
|
use std::fmt::Display;
|
|
use std::sync::OnceLock;
|
|
|
|
#[derive(Debug, Serialize, Deserialize)]
|
|
pub struct Claims {
|
|
sub: String,
|
|
// company: String,
|
|
exp: usize,
|
|
pub uid: i64,
|
|
}
|
|
|
|
#[derive(Debug, Serialize)]
|
|
struct AuthBody {
|
|
access_token: String,
|
|
token_type: String,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct AuthPayload {
|
|
client_id: String,
|
|
client_secret: String,
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
pub enum AuthError {
|
|
WrongCredentials,
|
|
MissingCredentials,
|
|
TokenCreation,
|
|
InvalidToken,
|
|
}
|
|
|
|
static KEYS: OnceLock<Keys> = OnceLock::new();
|
|
|
|
pub fn initialize_jwt_key(key_str: String) {
|
|
let res = KEYS.set(Keys::new(key_str.as_bytes()));
|
|
match res {
|
|
Ok(_) => {}
|
|
Err(_) => panic!("jwt key initialize failed"),
|
|
}
|
|
}
|
|
|
|
struct Keys {
|
|
encoding: EncodingKey,
|
|
decoding: DecodingKey,
|
|
}
|
|
|
|
impl Keys {
|
|
fn new(secret: &[u8]) -> Self {
|
|
Self {
|
|
encoding: EncodingKey::from_secret(secret),
|
|
decoding: DecodingKey::from_secret(secret),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Display for Claims {
|
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
write!(f, "Email: {}", self.sub)
|
|
}
|
|
}
|
|
|
|
impl AuthBody {
|
|
fn new(access_token: String) -> Self {
|
|
Self {
|
|
access_token,
|
|
token_type: "Bearer".to_string(),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl<S> FromRequestParts<S> for Claims
|
|
where
|
|
S: Send + Sync,
|
|
{
|
|
type Rejection = AuthError;
|
|
|
|
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
|
|
// Extract the token from the authorization header
|
|
let TypedHeader(Authorization(bearer)) = parts
|
|
.extract::<TypedHeader<Authorization<Bearer>>>()
|
|
.await
|
|
.map_err(|_| AuthError::InvalidToken)?;
|
|
// Decode the user data
|
|
let token_data = decode::<Claims>(
|
|
bearer.token(),
|
|
&KEYS.get().unwrap().decoding,
|
|
&Validation::default(),
|
|
)
|
|
.map_err(|_| AuthError::InvalidToken)?;
|
|
|
|
Ok(token_data.claims)
|
|
}
|
|
}
|
|
impl IntoResponse for AuthError {
|
|
fn into_response(self) -> Response {
|
|
let (status, error_message) = match self {
|
|
AuthError::WrongCredentials => (StatusCode::UNAUTHORIZED, "Wrong credentials"),
|
|
AuthError::MissingCredentials => (StatusCode::BAD_REQUEST, "Missing credentials"),
|
|
AuthError::TokenCreation => (StatusCode::INTERNAL_SERVER_ERROR, "Token creation error"),
|
|
AuthError::InvalidToken => (StatusCode::BAD_REQUEST, "Invalid token"),
|
|
};
|
|
let body = Json(json!({
|
|
"error": error_message,
|
|
}));
|
|
(status, body).into_response()
|
|
}
|
|
}
|