feat: tag

This commit is contained in:
acx
2024-07-28 14:24:53 +00:00
parent 952a37892d
commit b3ee37fbe3
16 changed files with 469 additions and 7 deletions

2
src/ledger/mod.rs Normal file
View File

@@ -0,0 +1,2 @@
pub mod category;
pub mod tag;

146
src/ledger/tag.rs Normal file
View File

@@ -0,0 +1,146 @@
// use std::sync::Arc;
use axum::routing::{get, post};
use axum::{
extract::{Path, State},
http::StatusCode,
Json, Router,
};
use axum_macros::debug_handler;
use diesel::prelude::*;
// use diesel::update;
use serde::{Deserialize, Serialize};
// use serde_json::to_string;
use crate::model::db_model;
use crate::model::schema;
use crate::util;
use crate::util::req::CommonResp;
use chrono::prelude::*;
use tracing::info;
use crate::middleware::auth;
use crate::middleware::auth::Claims;
#[derive(Deserialize)]
pub struct CreateTagRequest {
name: String,
}
#[derive(Serialize)]
pub struct CreateTagResponse {
id: i64,
name: String,
}
pub fn get_nest_handlers() -> Router<crate::AppState> {
Router::new()
.route("/", post(create_tag).get(get_all_tags))
.route("/:id", post(update_tag).get(get_tag))
}
#[debug_handler]
pub async fn create_tag(
State(app_state): State<crate::AppState>,
claims: Claims,
Json(payload): Json<CreateTagRequest>,
) -> Result<Json<db_model::Tag>, (StatusCode, String)> {
let uid: i64 = claims.uid.clone();
let conn = app_state
.db
.get()
.await
.map_err(util::req::internal_error)?;
let new_tag = db_model::TagForm {
name: payload.name,
uid,
};
let res = conn
.interact(move |conn| {
diesel::insert_into(schema::tags::table)
.values(&new_tag)
.returning(db_model::Tag::as_returning())
.get_result(conn)
})
.await
.map_err(util::req::internal_error)?
.map_err(util::req::internal_error)?;
Ok(Json(res))
}
pub async fn update_tag(
Path(id): Path<i64>,
State(app_state): State<crate::AppState>,
claims: Claims,
Json(payload): Json<CreateTagRequest>,
) -> Result<Json<CommonResp>, (StatusCode, String)> {
let uid: i64 = claims.uid.clone();
let conn = app_state
.db
.get()
.await
.map_err(util::req::internal_error)?;
let now = Utc::now().naive_utc();
let res = conn
.interact(move |conn| {
diesel::update(schema::tags::table)
.filter(schema::tags::id.eq(id))
.filter(schema::tags::uid.eq(uid))
.set((
schema::tags::name.eq(payload.name),
schema::tags::update_at.eq(now),
))
.execute(conn)
})
.await
.map_err(util::req::internal_error)?
.map_err(util::req::internal_error)?;
let resp = util::req::CommonResp { code: 0 };
Ok(Json(resp))
}
pub async fn get_tag(
Path(id): Path<i64>,
State(app_state): State<crate::AppState>,
claims: Claims,
) -> Result<Json<db_model::Tag>, (StatusCode, String)> {
let uid: i64 = claims.uid.clone();
let conn = app_state
.db
.get()
.await
.map_err(util::req::internal_error)?;
let res = conn
.interact(move |conn| {
schema::tags::table
.filter(schema::tags::id.eq(id))
.filter(schema::tags::uid.eq(uid))
.select(db_model::Tag::as_select())
.limit(1)
.get_result(conn)
})
.await
.map_err(util::req::internal_error)?
.map_err(util::req::internal_error)?;
Ok(Json(res))
}
pub async fn get_all_tags(
State(app_state): State<crate::AppState>,
claims: Claims,
) -> Result<Json<Vec<db_model::Tag>>, (StatusCode, String)> {
let uid: i64 = claims.uid.clone();
let conn = app_state
.db
.get()
.await
.map_err(util::req::internal_error)?;
let res = conn
.interact(move |conn| {
schema::tags::table
.filter(schema::tags::uid.eq(uid))
.select(db_model::Tag::as_select())
.load(conn)
})
.await
.map_err(util::req::internal_error)?
.map_err(util::req::internal_error)?;
Ok(Json(res))
}

View File

@@ -1,3 +1,4 @@
use std::env;
use axum::{
// http::StatusCode,
// routing::{get, post},
@@ -5,18 +6,21 @@ use axum::{
Router,
};
use axum::http::Method;
use serde::{Deserialize, Serialize};
// use pbkdf2::password_hash::Error;
// use serde::{Deserialize, Serialize};
use tower::ServiceBuilder;
use tower_http::cors::{Any, CorsLayer};
use tower_http::trace::TraceLayer;
use tracing::info;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
use crate::util::pass::get_pbkdf2_from_psw;
// Project modules
mod category;
mod ledger;
mod middleware;
mod model;
mod util;
mod user;
// Passed App State
#[derive(Clone)]
@@ -30,6 +34,12 @@ async fn main() {
tracing_subscriber::registry()
.with(tracing_subscriber::fmt::layer())
.init();
let args: Vec<String> = env::args().collect();
if args.len() <= 1 {
return;
}
// initialize db connection
let db_url = std::env::var("DATABASE_URL").unwrap();
@@ -39,6 +49,37 @@ async fn main() {
.unwrap();
let shared_state = AppState { db: pool };
let cmd = args[1].clone();
match cmd.as_str() {
"add_user" => {
println!("adding user");
if args.len() <= 4 {
println!("insufficient arg number");
return;
}
let user = args[2].clone();
let psw = args[3].clone();
let mail = args[4].clone();
println!("adding user {}", user);
let hashed = get_pbkdf2_from_psw(psw);
let mut hash_psw = "".to_string();
match hashed {
Ok(val) => {
println!("get hash {}", val);
hash_psw=val;
}
Err(_) => {}
}
let res = user::dal::add_user(shared_state, user, hash_psw, mail)
.await;
return;
}
_ => {
println!("unknown command {}", cmd);
}
}
// Register routers
let cors_layer = CorsLayer::new()
@@ -50,8 +91,9 @@ async fn main() {
let app = Router::new()
// V1 apis
.nest("/api/v1/category", category::handler::get_nest_handlers())
.nest("/api/v1/v2", category::handler::get_nest_handlers())
.nest("/api/v1/category", ledger::category::get_nest_handlers())
.nest("/api/v1/tag", ledger::tag::get_nest_handlers())
.nest("/api/v1/user", user::handler::get_nest_handlers())
.with_state(shared_state)
.layer(global_layer);

View File

@@ -22,7 +22,7 @@ use crate::util;
#[derive(Debug, Serialize, Deserialize)]
pub struct Claims {
sub: String,
company: String,
// company: String,
exp: usize,
pub uid: i64,
}
@@ -68,7 +68,7 @@ impl Keys {
impl Display for Claims {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Email: {}\nCompany: {}", self.sub, self.company)
write!(f, "Email: {}", self.sub)
}
}

View File

@@ -19,3 +19,40 @@ pub struct CategoryForm {
pub uid: i64,
pub name: String,
}
#[derive(Queryable, Selectable, serde::Serialize, serde::Deserialize)]
#[diesel(table_name = schema::tags)]
#[diesel(check_for_backend(diesel::pg::Pg))]
pub struct Tag {
id: i64,
uid: i64,
name: String,
is_delete: bool,
create_at: chrono::NaiveDateTime,
update_at: chrono::NaiveDateTime,
}
#[derive(serde::Deserialize, Insertable)]
#[diesel(table_name = schema::tags)]
pub struct TagForm {
pub uid: i64,
pub name: String,
}
#[derive(Queryable, Selectable, serde::Serialize)]
#[diesel(table_name = schema::users)]
pub struct User {
pub id: i64,
pub username: String,
pub password: String,
pub mail: String,
pub is_delete: bool,
}
#[derive(Insertable)]
#[diesel(table_name = schema::users)]
pub struct UserForm {
pub username: String,
pub password: String,
pub mail: String,
}

View File

@@ -1,2 +1,3 @@
pub mod db_model;
pub mod schema;
pub mod req;

0
src/model/req.rs Normal file
View File

118
src/user/dal.rs Normal file
View File

@@ -0,0 +1,118 @@
use diesel::prelude::*;
use crate::model::{db_model, schema};
use std::error::Error;
use std::fmt::Debug;
use pbkdf2::password_hash::{PasswordHash, PasswordVerifier};
use pbkdf2::Pbkdf2;
use serde_json::json;
pub async fn add_user(app_state: crate::AppState, username: String, password: String, mail: String) -> Result<(), ()> {
let conn = app_state
.db
.get()
.await
.map_err(|_| {
println!("fail to get db connection");
()
})?;
let target_username = username.clone();
// 1. check if current username exists.
let res = conn.interact(
move |conn| {
schema::users::table
.filter(schema::users::username.eq(target_username.clone()))
.count()
.get_result::<i64>(conn)
})
.await
.map_err(|res| {
()
})?
.map_err(|res| {
()
})?;
println!("ret {}", res);
if res > 0 {
println!("user already exists.");
return Ok(());
}
let new_user_form = db_model::UserForm {
username: username.clone(),
password: password.clone(),
mail: mail.clone(),
};
// 2. adding user
let add_res = conn.interact(
move |conn| {
diesel::insert_into(schema::users::table)
.values(&new_user_form)
.returning(db_model::User::as_returning())
.get_result(conn)
})
.await
.map_err(|e| {
()
})?
.map_err(|e| {
()
})?;
let out = json!(add_res);
println!("new user {}", out.to_string());
Ok(())
}
pub async fn check_user_psw(app_state: crate::AppState, username: String, password: String) -> bool {
let conn_res = app_state
.db
.get()
.await
.map_err(|_| {
println!("fail to get db connection");
()
});
let conn = match conn_res {
Ok(res) => res,
Err(err) => { return false; }
};
// 1. get psw hash
let query_username = username.clone();
let user_rr = conn.interact(
|conn| {
schema::users::table
.filter(schema::users::username.eq(query_username))
.select(db_model::User::as_select())
.get_results(conn)
})
.await;
let user_res = match user_rr {
Ok(res) => res,
Err(_) => return false,
};
println!("get user_res success");
let user = match user_res {
Ok(u) => u,
Err(_) => return false,
};
println!("get user success");
if user.len() != 1 {
return false;
}
println!("get uniq user success");
let cur_user = user.get(0);
let psw = match cur_user {
Some(usr) => usr.password.clone(),
None => "".to_string(),
};
println!("comparing psw, get {}, stored {}.", password.clone(), psw.clone());
let hash_res = PasswordHash::new(psw.as_str());
let hash = match hash_res {
Ok(rs) => rs,
Err(_) => return false,
};
let check_res = Pbkdf2.verify_password(password.as_bytes(), &hash);
return check_res.is_ok();
}

29
src/user/handler.rs Normal file
View File

@@ -0,0 +1,29 @@
use axum::{
extract::State, http::StatusCode, routing::post, Json, Router
};
use axum_macros::debug_handler;
use crate::middleware::auth::Claims;
use super::dal::check_user_psw;
pub fn get_nest_handlers() -> Router<crate::AppState> {
Router::new()
.route("/login", post(login))
}
#[derive(serde::Deserialize)]
pub struct LoginCredentialRequest {
pub username: String,
pub password: String,
}
#[debug_handler]
pub async fn login(
State(app_state): State<crate::AppState>,
Json(payload): Json<LoginCredentialRequest>,
) -> Result<(), (StatusCode, String)> {
let res = check_user_psw(app_state, payload.username.clone(), payload.password.clone()).await;
if !res {
return Err((StatusCode::UNAUTHORIZED, "invalid credentials".to_string()));
}
Ok(())
}

View File

@@ -1 +1,2 @@
pub mod dal;
pub mod handler;

View File

@@ -1 +1,2 @@
pub mod req;
pub mod pass;

16
src/util/pass.rs Normal file
View File

@@ -0,0 +1,16 @@
use std::error::Error;
use pbkdf2::{
password_hash::{
rand_core::OsRng,
PasswordHash,SaltString,
},
Pbkdf2,
};
use pbkdf2::password_hash::PasswordHasher;
pub fn get_pbkdf2_from_psw(password:String) -> Result<String, pbkdf2::password_hash::Error> {
let salt = SaltString::generate(&mut OsRng);
let password_hash = Pbkdf2.hash_password(password.as_bytes(), &salt)?.to_string();
println!("{}",password_hash);
return Ok(password_hash)
}