feat: add tracing and auth middleware

This commit is contained in:
brian
2025-06-08 21:11:45 +08:00
parent 27c94f4276
commit 93b7e46655
7 changed files with 383 additions and 33 deletions

View File

@@ -1,17 +1,26 @@
use axum::Router;
use axum::{http::Method, Router};
use clap::Parser;
use sea_orm::{Database, DatabaseConnection};
use serde::Deserialize;
use tower::ServiceBuilder;
use tower_http::cors::{Any, CorsLayer};
use tower_http::trace::TraceLayer;
use tracing::info;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
// Project modules
mod api;
mod middleware;
mod model;
mod util;
#[tokio::main]
async fn main() {
dotenvy::dotenv().unwrap();
// initialize tracing
tracing_subscriber::fmt::init();
tracing_subscriber::registry()
.with(tracing_subscriber::fmt::layer())
.init();
let cli = Cli::parse();
match cli.command {
Command::Serve { config_path } => {
@@ -76,24 +85,30 @@ async fn load_config(path: &str) -> Result<Config, Box<dyn std::error::Error>> {
// ====== Commands ======
// start http server
async fn start_server(config: &Config){
// Define the router
// let app = Router.new()
// .nest();
async fn start_server(config: &Config) {
let conn = Database::connect(&config.database.connection)
.await
.await
.expect("Database connection failed.");
let state = AppState{conn };
let state = AppState { conn };
// Build router
let cors_layer = CorsLayer::new()
.allow_methods([Method::GET, Method::POST])
.allow_origin(Any);
let global_layer = ServiceBuilder::new()
.layer(TraceLayer::new_for_http())
.layer(cors_layer);
let app = Router::new()
.nest("/api/v1/book", api::book::get_nest_handlers())
.with_state(state);
.with_state(state)
.layer(global_layer);
let host = config.service.host.clone();
let port = config.service.port;
let server_url = format!("{host}:{port}");
let listener = tokio::net::TcpListener::bind(&server_url).await.unwrap();
axum::serve(listener, app).await.expect("Service panic happened");
axum::serve(listener, app)
.await
.expect("Service panic happened");
}

115
src/middleware/auth.rs Normal file
View File

@@ -0,0 +1,115 @@
use axum::{
extract::FromRequestParts,
http::{
request::Parts,
StatusCode,
},
Json, RequestPartsExt,
response::{IntoResponse, Response},
};
use axum_extra::{
headers::{authorization::Bearer, Authorization},
TypedHeader,
};
use serde::{Deserialize, Serialize};
use serde_json::json;
use jsonwebtoken::{decode,encode, DecodingKey, EncodingKey, Header, Validation};
use std::fmt::Display;
use once_cell::sync::Lazy;
use crate::util;
#[derive(Debug, Serialize, Deserialize)]
pub struct Claims {
sub: String,
// company: String,
exp: usize,
pub uid: i64,
}
#[derive(Debug, Serialize)]
struct AuthBody {
access_token: String,
token_type: String,
}
#[derive(Debug, Deserialize)]
struct AuthPayload {
client_id: String,
client_secret: String,
}
#[derive(Debug)]
pub enum AuthError {
WrongCredentials,
MissingCredentials,
TokenCreation,
InvalidToken,
}
static KEYS: Lazy<Keys> = Lazy::new(|| {
let secret = std::env::var("JWT_SECRET").expect("JWT_SECRET must be set");
Keys::new(secret.as_bytes())
});
struct Keys {
encoding: EncodingKey,
decoding: DecodingKey,
}
impl Keys {
fn new(secret: &[u8]) -> Self {
Self {
encoding: EncodingKey::from_secret(secret),
decoding: DecodingKey::from_secret(secret),
}
}
}
impl Display for Claims {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Email: {}", self.sub)
}
}
impl AuthBody {
fn new(access_token: String) -> Self {
Self {
access_token,
token_type: "Bearer".to_string(),
}
}
}
impl<S> FromRequestParts<S> for Claims
where
S: Send + Sync,
{
type Rejection = AuthError;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
// Extract the token from the authorization header
let TypedHeader(Authorization(bearer)) = parts
.extract::<TypedHeader<Authorization<Bearer>>>()
.await
.map_err(|_| AuthError::InvalidToken)?;
// Decode the user data
let token_data = decode::<Claims>(bearer.token(), &KEYS.decoding, &Validation::default())
.map_err(|_| AuthError::InvalidToken)?;
Ok(token_data.claims)
}
}
impl IntoResponse for AuthError {
fn into_response(self) -> Response {
let (status, error_message) = match self {
AuthError::WrongCredentials => (StatusCode::UNAUTHORIZED, "Wrong credentials"),
AuthError::MissingCredentials => (StatusCode::BAD_REQUEST, "Missing credentials"),
AuthError::TokenCreation => (StatusCode::INTERNAL_SERVER_ERROR, "Token creation error"),
AuthError::InvalidToken => (StatusCode::BAD_REQUEST, "Invalid token"),
};
let body = Json(json!({
"error": error_message,
}));
(status, body).into_response()
}
}

1
src/middleware/mod.rs Normal file
View File

@@ -0,0 +1 @@
pub mod auth;

1
src/util/mod.rs Normal file
View File

@@ -0,0 +1 @@
pub mod pass;

16
src/util/pass.rs Normal file
View File

@@ -0,0 +1,16 @@
use std::error::Error;
use pbkdf2::{
password_hash::{
rand_core::OsRng,
PasswordHash,SaltString,
},
Pbkdf2,
};
use pbkdf2::password_hash::PasswordHasher;
pub fn get_pbkdf2_from_psw(password:String) -> Result<String, pbkdf2::password_hash::Error> {
let salt = SaltString::generate(&mut OsRng);
let password_hash = Pbkdf2.hash_password(password.as_bytes(), &salt)?.to_string();
println!("{}",password_hash);
return Ok(password_hash)
}