116 lines
2.9 KiB
Rust
116 lines
2.9 KiB
Rust
|
|
use axum::{
|
||
|
|
extract::FromRequestParts,
|
||
|
|
http::{
|
||
|
|
request::Parts,
|
||
|
|
StatusCode,
|
||
|
|
},
|
||
|
|
Json, RequestPartsExt,
|
||
|
|
response::{IntoResponse, Response},
|
||
|
|
};
|
||
|
|
use axum_extra::{
|
||
|
|
headers::{authorization::Bearer, Authorization},
|
||
|
|
TypedHeader,
|
||
|
|
};
|
||
|
|
use serde::{Deserialize, Serialize};
|
||
|
|
use serde_json::json;
|
||
|
|
use jsonwebtoken::{decode,encode, DecodingKey, EncodingKey, Header, Validation};
|
||
|
|
use std::fmt::Display;
|
||
|
|
use once_cell::sync::Lazy;
|
||
|
|
use crate::util;
|
||
|
|
|
||
|
|
#[derive(Debug, Serialize, Deserialize)]
|
||
|
|
pub struct Claims {
|
||
|
|
sub: String,
|
||
|
|
// company: String,
|
||
|
|
exp: usize,
|
||
|
|
pub uid: i64,
|
||
|
|
}
|
||
|
|
|
||
|
|
#[derive(Debug, Serialize)]
|
||
|
|
struct AuthBody {
|
||
|
|
access_token: String,
|
||
|
|
token_type: String,
|
||
|
|
}
|
||
|
|
|
||
|
|
#[derive(Debug, Deserialize)]
|
||
|
|
struct AuthPayload {
|
||
|
|
client_id: String,
|
||
|
|
client_secret: String,
|
||
|
|
}
|
||
|
|
|
||
|
|
#[derive(Debug)]
|
||
|
|
pub enum AuthError {
|
||
|
|
WrongCredentials,
|
||
|
|
MissingCredentials,
|
||
|
|
TokenCreation,
|
||
|
|
InvalidToken,
|
||
|
|
}
|
||
|
|
|
||
|
|
static KEYS: Lazy<Keys> = Lazy::new(|| {
|
||
|
|
let secret = std::env::var("JWT_SECRET").expect("JWT_SECRET must be set");
|
||
|
|
Keys::new(secret.as_bytes())
|
||
|
|
});
|
||
|
|
|
||
|
|
struct Keys {
|
||
|
|
encoding: EncodingKey,
|
||
|
|
decoding: DecodingKey,
|
||
|
|
}
|
||
|
|
|
||
|
|
impl Keys {
|
||
|
|
fn new(secret: &[u8]) -> Self {
|
||
|
|
Self {
|
||
|
|
encoding: EncodingKey::from_secret(secret),
|
||
|
|
decoding: DecodingKey::from_secret(secret),
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
impl Display for Claims {
|
||
|
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||
|
|
write!(f, "Email: {}", self.sub)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
impl AuthBody {
|
||
|
|
fn new(access_token: String) -> Self {
|
||
|
|
Self {
|
||
|
|
access_token,
|
||
|
|
token_type: "Bearer".to_string(),
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
impl<S> FromRequestParts<S> for Claims
|
||
|
|
where
|
||
|
|
S: Send + Sync,
|
||
|
|
{
|
||
|
|
type Rejection = AuthError;
|
||
|
|
|
||
|
|
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
|
||
|
|
// Extract the token from the authorization header
|
||
|
|
let TypedHeader(Authorization(bearer)) = parts
|
||
|
|
.extract::<TypedHeader<Authorization<Bearer>>>()
|
||
|
|
.await
|
||
|
|
.map_err(|_| AuthError::InvalidToken)?;
|
||
|
|
// Decode the user data
|
||
|
|
let token_data = decode::<Claims>(bearer.token(), &KEYS.decoding, &Validation::default())
|
||
|
|
.map_err(|_| AuthError::InvalidToken)?;
|
||
|
|
|
||
|
|
Ok(token_data.claims)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
impl IntoResponse for AuthError {
|
||
|
|
fn into_response(self) -> Response {
|
||
|
|
let (status, error_message) = match self {
|
||
|
|
AuthError::WrongCredentials => (StatusCode::UNAUTHORIZED, "Wrong credentials"),
|
||
|
|
AuthError::MissingCredentials => (StatusCode::BAD_REQUEST, "Missing credentials"),
|
||
|
|
AuthError::TokenCreation => (StatusCode::INTERNAL_SERVER_ERROR, "Token creation error"),
|
||
|
|
AuthError::InvalidToken => (StatusCode::BAD_REQUEST, "Invalid token"),
|
||
|
|
};
|
||
|
|
let body = Json(json!({
|
||
|
|
"error": error_message,
|
||
|
|
}));
|
||
|
|
(status, body).into_response()
|
||
|
|
}
|
||
|
|
}
|