feat: add jwt auth

This commit is contained in:
acx
2024-07-18 15:39:09 +00:00
parent 7270399f35
commit 952a37892d
10 changed files with 527 additions and 38 deletions

View File

@@ -1,16 +1,24 @@
// use std::sync::Arc;
use axum::{http::StatusCode, extract::{State,Path}, Json, Router};
use axum::routing::{get, post};
use axum::{
extract::{Path, State},
http::StatusCode,
Json, Router,
};
use axum_macros::debug_handler;
use diesel::prelude::*;
// use diesel::update;
use serde::{Deserialize, Serialize};
// use serde_json::to_string;
use crate::util;
use crate::model::schema;
use crate::model::db_model;
use crate::model::schema;
use crate::util;
// use crate::model::schema::categories::dsl::categories;
use crate::util::req::CommonResp;
use chrono::prelude::*;
use tracing::info;
use crate::middleware::auth;
use crate::middleware::auth::Claims;
#[derive(Serialize)]
pub struct CreateCategoryResponse {
@@ -29,16 +37,22 @@ pub struct CreateCategoryRequest {
name: String,
}
#[debug_handler]
pub async fn create_category(
State(app_state): State<crate::AppState>,
Json(Payload): Json<CreateCategoryRequest>,
claims: Claims,
Json(payload): Json<CreateCategoryRequest>,
) -> Result<Json<db_model::Category>, (StatusCode, String)> {
let uid: i64 = 123214124; // TODO replace with actual user id.
// let ret = CreateCategoryResponse{id: 134132413541, name: "24532452".to_string()};
let conn = app_state.db.get().await.map_err(util::req::internal_error)?;
let new_category = db_model::CategoryForm{
name: Payload.name,
uid: uid,
let uid: i64 = claims.uid.clone(); // TODO replace with actual user id.
// let ret = CreateCategoryResponse{id: 134132413541, name: "24532452".to_string()};
let conn = app_state
.db
.get()
.await
.map_err(util::req::internal_error)?;
let new_category = db_model::CategoryForm {
name: payload.name,
uid,
};
let res = conn
.interact(move |conn| {
@@ -57,11 +71,16 @@ pub async fn create_category(
pub async fn update_category(
Path(id): Path<i64>,
State(app_state): State<crate::AppState>,
Json(Payload): Json<CreateCategoryRequest>,
claims: Claims,
Json(payload): Json<CreateCategoryRequest>,
) -> Result<Json<CommonResp>, (StatusCode, String)> {
let uid: i64 = 123214124; // TODO replace with actual user id.
// let ret = CreateCategoryResponse{id: 134132413541, name: "24532452".to_string()};
let conn = app_state.db.get().await.map_err(util::req::internal_error)?;
let uid: i64 = claims.uid.clone(); // TODO replace with actual user id.
// let ret = CreateCategoryResponse{id: 134132413541, name: "24532452".to_string()};
let conn = app_state
.db
.get()
.await
.map_err(util::req::internal_error)?;
let now = Utc::now().naive_utc();
let res = conn
.interact(move |conn| {
@@ -69,8 +88,8 @@ pub async fn update_category(
.filter(schema::categories::id.eq(id))
.filter(schema::categories::uid.eq(uid))
.set((
schema::categories::name.eq(Payload.name),
schema::categories::update_at.eq(now),
schema::categories::name.eq(payload.name),
schema::categories::update_at.eq(now),
))
.execute(conn)
})
@@ -78,18 +97,21 @@ pub async fn update_category(
.map_err(util::req::internal_error)?
.map_err(util::req::internal_error)?;
// let ret = CreateCategoryResponse{id: res.id, name: res.name};
let resp = util::req::CommonResp{
code: 0,
};
let resp = util::req::CommonResp { code: 0 };
Ok(Json(resp))
}
pub async fn get_category(
Path(id): Path<i64>,
State(app_state): State<crate::AppState>,
) -> Result<Json<db_model::Category>, (StatusCode, String)>{
let uid: i64 = 123214124; // TODO replace with actual user id.
let conn = app_state.db.get().await.map_err(util::req::internal_error)?;
claims: Claims,
) -> Result<Json<db_model::Category>, (StatusCode, String)> {
let uid: i64 = claims.uid.clone();
let conn = app_state
.db
.get()
.await
.map_err(util::req::internal_error)?;
let res = conn
.interact(move |conn| {
schema::categories::table
@@ -107,9 +129,14 @@ pub async fn get_category(
pub async fn get_all_categories(
State(app_state): State<crate::AppState>,
) -> Result<Json<Vec<db_model::Category>>, (StatusCode, String)>{
let uid: i64 = 123214124; // TODO replace with actual user id.
let conn = app_state.db.get().await.map_err(util::req::internal_error)?;
claims: Claims,
) -> Result<Json<Vec<db_model::Category>>, (StatusCode, String)> {
let uid: i64 = claims.uid.clone();
let conn = app_state
.db
.get()
.await
.map_err(util::req::internal_error)?;
let res = conn
.interact(move |conn| {
schema::categories::table

View File

@@ -1 +1 @@
pub mod handler;
pub mod handler;

View File

@@ -4,25 +4,32 @@ use axum::{
// Json,
Router,
};
use axum::http::Method;
use serde::{Deserialize, Serialize};
use tower::ServiceBuilder;
use tower_http::cors::{Any, CorsLayer};
use tower_http::trace::TraceLayer;
use tracing::info;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
// Project modules
mod category;
mod middleware;
mod model;
mod util;
// Passed App State
#[derive(Clone)]
pub struct AppState{
pub struct AppState {
db: deadpool_diesel::postgres::Pool,
}
#[tokio::main]
async fn main() {
dotenvy::dotenv().unwrap();
tracing_subscriber::registry().with(tracing_subscriber::fmt::layer()).init();
tracing_subscriber::registry()
.with(tracing_subscriber::fmt::layer())
.init();
// initialize db connection
let db_url = std::env::var("DATABASE_URL").unwrap();
@@ -31,14 +38,22 @@ async fn main() {
.build()
.unwrap();
let shared_state = AppState {db: pool,};
let shared_state = AppState { db: pool };
// Register routers
let cors_layer = CorsLayer::new()
.allow_methods([Method::GET, Method::POST])
.allow_origin(Any);
let global_layer = ServiceBuilder::new()
.layer(TraceLayer::new_for_http())
.layer(cors_layer);
let app = Router::new()
// V1 apis
.nest("/api/v1/category", category::handler::get_nest_handlers())
.with_state(shared_state);
.nest("/api/v1/v2", category::handler::get_nest_handlers())
.with_state(shared_state)
.layer(global_layer);
let listener = tokio::net::TcpListener::bind("0.0.0.0:8987").await.unwrap();
info!("starting server on 0.0.0.0:8987");

118
src/middleware/auth.rs Normal file
View File

@@ -0,0 +1,118 @@
use axum::{
async_trait,
extract::FromRequestParts,
http::{
request::Parts,
StatusCode,
},
Json, RequestPartsExt,
response::{IntoResponse, Response},
};
use axum_extra::{
headers::{authorization::Bearer, Authorization},
TypedHeader,
};
use serde::{Deserialize, Serialize};
use serde_json::json;
use jsonwebtoken::{decode,encode, DecodingKey, EncodingKey, Header, Validation};
use std::fmt::Display;
use once_cell::sync::Lazy;
use crate::util;
#[derive(Debug, Serialize, Deserialize)]
pub struct Claims {
sub: String,
company: String,
exp: usize,
pub uid: i64,
}
#[derive(Debug, Serialize)]
struct AuthBody {
access_token: String,
token_type: String,
}
#[derive(Debug, Deserialize)]
struct AuthPayload {
client_id: String,
client_secret: String,
}
#[derive(Debug)]
pub enum AuthError {
WrongCredentials,
MissingCredentials,
TokenCreation,
InvalidToken,
}
static KEYS: Lazy<Keys> = Lazy::new(|| {
let secret = std::env::var("JWT_SECRET").expect("JWT_SECRET must be set");
Keys::new(secret.as_bytes())
});
struct Keys {
encoding: EncodingKey,
decoding: DecodingKey,
}
impl Keys {
fn new(secret: &[u8]) -> Self {
Self {
encoding: EncodingKey::from_secret(secret),
decoding: DecodingKey::from_secret(secret),
}
}
}
impl Display for Claims {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Email: {}\nCompany: {}", self.sub, self.company)
}
}
impl AuthBody {
fn new(access_token: String) -> Self {
Self {
access_token,
token_type: "Bearer".to_string(),
}
}
}
#[async_trait]
impl<S> FromRequestParts<S> for Claims
where
S: Send + Sync,
{
type Rejection = (StatusCode, String);
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
// Extract the token from the authorization header
let TypedHeader(Authorization(bearer)) = parts
.extract::<TypedHeader<Authorization<Bearer>>>()
.await
.map_err(util::req::internal_error)?;
// Decode the user data
let token_data = decode::<Claims>(bearer.token(), &KEYS.decoding, &Validation::default())
.map_err(util::req::internal_error)?;
Ok(token_data.claims)
}
}
impl IntoResponse for AuthError {
fn into_response(self) -> Response {
let (status, error_message) = match self {
AuthError::WrongCredentials => (StatusCode::UNAUTHORIZED, "Wrong credentials"),
AuthError::MissingCredentials => (StatusCode::BAD_REQUEST, "Missing credentials"),
AuthError::TokenCreation => (StatusCode::INTERNAL_SERVER_ERROR, "Token creation error"),
AuthError::InvalidToken => (StatusCode::BAD_REQUEST, "Invalid token"),
};
let body = Json(json!({
"error": error_message,
}));
(status, body).into_response()
}
}

1
src/middleware/mod.rs Normal file
View File

@@ -0,0 +1 @@
pub mod auth;

View File

@@ -1,8 +1,7 @@
use diesel::prelude::*;
use crate::model::schema;
use diesel::prelude::*;
#[derive(Queryable, Selectable)]
#[derive(serde::Serialize, serde::Deserialize)]
#[derive(Queryable, Selectable, serde::Serialize, serde::Deserialize)]
#[diesel(table_name = schema::categories)]
#[diesel(check_for_backend(diesel::pg::Pg))]
pub struct Category {

View File

@@ -1,2 +1,2 @@
pub mod db_model;
pub mod schema;
pub mod db_model;

View File

@@ -1 +1 @@
pub mod req;
pub mod req;