feat: add tracing and auth middleware
This commit is contained in:
39
src/main.rs
39
src/main.rs
@@ -1,17 +1,26 @@
|
||||
use axum::Router;
|
||||
use axum::{http::Method, Router};
|
||||
use clap::Parser;
|
||||
use sea_orm::{Database, DatabaseConnection};
|
||||
use serde::Deserialize;
|
||||
use tower::ServiceBuilder;
|
||||
use tower_http::cors::{Any, CorsLayer};
|
||||
use tower_http::trace::TraceLayer;
|
||||
use tracing::info;
|
||||
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
|
||||
|
||||
// Project modules
|
||||
mod api;
|
||||
mod middleware;
|
||||
mod model;
|
||||
mod util;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
dotenvy::dotenv().unwrap();
|
||||
// initialize tracing
|
||||
tracing_subscriber::fmt::init();
|
||||
tracing_subscriber::registry()
|
||||
.with(tracing_subscriber::fmt::layer())
|
||||
.init();
|
||||
let cli = Cli::parse();
|
||||
match cli.command {
|
||||
Command::Serve { config_path } => {
|
||||
@@ -76,24 +85,30 @@ async fn load_config(path: &str) -> Result<Config, Box<dyn std::error::Error>> {
|
||||
|
||||
// ====== Commands ======
|
||||
|
||||
|
||||
// start http server
|
||||
async fn start_server(config: &Config){
|
||||
// Define the router
|
||||
// let app = Router.new()
|
||||
// .nest();
|
||||
async fn start_server(config: &Config) {
|
||||
let conn = Database::connect(&config.database.connection)
|
||||
.await
|
||||
.await
|
||||
.expect("Database connection failed.");
|
||||
|
||||
let state = AppState{conn };
|
||||
let state = AppState { conn };
|
||||
// Build router
|
||||
let cors_layer = CorsLayer::new()
|
||||
.allow_methods([Method::GET, Method::POST])
|
||||
.allow_origin(Any);
|
||||
let global_layer = ServiceBuilder::new()
|
||||
.layer(TraceLayer::new_for_http())
|
||||
.layer(cors_layer);
|
||||
|
||||
let app = Router::new()
|
||||
.nest("/api/v1/book", api::book::get_nest_handlers())
|
||||
.with_state(state);
|
||||
.with_state(state)
|
||||
.layer(global_layer);
|
||||
let host = config.service.host.clone();
|
||||
let port = config.service.port;
|
||||
let server_url = format!("{host}:{port}");
|
||||
let listener = tokio::net::TcpListener::bind(&server_url).await.unwrap();
|
||||
axum::serve(listener, app).await.expect("Service panic happened");
|
||||
|
||||
axum::serve(listener, app)
|
||||
.await
|
||||
.expect("Service panic happened");
|
||||
}
|
||||
|
||||
115
src/middleware/auth.rs
Normal file
115
src/middleware/auth.rs
Normal file
@@ -0,0 +1,115 @@
|
||||
use axum::{
|
||||
extract::FromRequestParts,
|
||||
http::{
|
||||
request::Parts,
|
||||
StatusCode,
|
||||
},
|
||||
Json, RequestPartsExt,
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
use axum_extra::{
|
||||
headers::{authorization::Bearer, Authorization},
|
||||
TypedHeader,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
use jsonwebtoken::{decode,encode, DecodingKey, EncodingKey, Header, Validation};
|
||||
use std::fmt::Display;
|
||||
use once_cell::sync::Lazy;
|
||||
use crate::util;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct Claims {
|
||||
sub: String,
|
||||
// company: String,
|
||||
exp: usize,
|
||||
pub uid: i64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct AuthBody {
|
||||
access_token: String,
|
||||
token_type: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct AuthPayload {
|
||||
client_id: String,
|
||||
client_secret: String,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum AuthError {
|
||||
WrongCredentials,
|
||||
MissingCredentials,
|
||||
TokenCreation,
|
||||
InvalidToken,
|
||||
}
|
||||
|
||||
static KEYS: Lazy<Keys> = Lazy::new(|| {
|
||||
let secret = std::env::var("JWT_SECRET").expect("JWT_SECRET must be set");
|
||||
Keys::new(secret.as_bytes())
|
||||
});
|
||||
|
||||
struct Keys {
|
||||
encoding: EncodingKey,
|
||||
decoding: DecodingKey,
|
||||
}
|
||||
|
||||
impl Keys {
|
||||
fn new(secret: &[u8]) -> Self {
|
||||
Self {
|
||||
encoding: EncodingKey::from_secret(secret),
|
||||
decoding: DecodingKey::from_secret(secret),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for Claims {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "Email: {}", self.sub)
|
||||
}
|
||||
}
|
||||
|
||||
impl AuthBody {
|
||||
fn new(access_token: String) -> Self {
|
||||
Self {
|
||||
access_token,
|
||||
token_type: "Bearer".to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> FromRequestParts<S> for Claims
|
||||
where
|
||||
S: Send + Sync,
|
||||
{
|
||||
type Rejection = AuthError;
|
||||
|
||||
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
|
||||
// Extract the token from the authorization header
|
||||
let TypedHeader(Authorization(bearer)) = parts
|
||||
.extract::<TypedHeader<Authorization<Bearer>>>()
|
||||
.await
|
||||
.map_err(|_| AuthError::InvalidToken)?;
|
||||
// Decode the user data
|
||||
let token_data = decode::<Claims>(bearer.token(), &KEYS.decoding, &Validation::default())
|
||||
.map_err(|_| AuthError::InvalidToken)?;
|
||||
|
||||
Ok(token_data.claims)
|
||||
}
|
||||
}
|
||||
impl IntoResponse for AuthError {
|
||||
fn into_response(self) -> Response {
|
||||
let (status, error_message) = match self {
|
||||
AuthError::WrongCredentials => (StatusCode::UNAUTHORIZED, "Wrong credentials"),
|
||||
AuthError::MissingCredentials => (StatusCode::BAD_REQUEST, "Missing credentials"),
|
||||
AuthError::TokenCreation => (StatusCode::INTERNAL_SERVER_ERROR, "Token creation error"),
|
||||
AuthError::InvalidToken => (StatusCode::BAD_REQUEST, "Invalid token"),
|
||||
};
|
||||
let body = Json(json!({
|
||||
"error": error_message,
|
||||
}));
|
||||
(status, body).into_response()
|
||||
}
|
||||
}
|
||||
1
src/middleware/mod.rs
Normal file
1
src/middleware/mod.rs
Normal file
@@ -0,0 +1 @@
|
||||
pub mod auth;
|
||||
1
src/util/mod.rs
Normal file
1
src/util/mod.rs
Normal file
@@ -0,0 +1 @@
|
||||
pub mod pass;
|
||||
16
src/util/pass.rs
Normal file
16
src/util/pass.rs
Normal file
@@ -0,0 +1,16 @@
|
||||
use std::error::Error;
|
||||
use pbkdf2::{
|
||||
password_hash::{
|
||||
rand_core::OsRng,
|
||||
PasswordHash,SaltString,
|
||||
},
|
||||
Pbkdf2,
|
||||
};
|
||||
use pbkdf2::password_hash::PasswordHasher;
|
||||
|
||||
pub fn get_pbkdf2_from_psw(password:String) -> Result<String, pbkdf2::password_hash::Error> {
|
||||
let salt = SaltString::generate(&mut OsRng);
|
||||
let password_hash = Pbkdf2.hash_password(password.as_bytes(), &salt)?.to_string();
|
||||
println!("{}",password_hash);
|
||||
return Ok(password_hash)
|
||||
}
|
||||
Reference in New Issue
Block a user