use std::ops::Deref; use argon2::{password_hash::SaltString, Argon2, PasswordHasher}; use rocket::{ async_trait, http::Status, request::{self, FromRequest, Outcome}, Request, }; use serde::{Deserialize, Serialize}; use sqlx::{FromRow, SqlitePool}; #[derive(FromRow, Debug, Serialize, Deserialize)] pub struct User { pub id: i64, pub name: String, pw: Option, pub is_cox: bool, is_admin: bool, is_guest: bool, } pub struct AdminUser { user: User, } impl TryFrom for AdminUser { type Error = LoginError; fn try_from(user: User) -> Result { if user.is_admin { Ok(AdminUser { user }) } else { Err(LoginError::NotAnAdmin) } } } pub struct CoxUser { user: User, } impl Deref for CoxUser { type Target = User; fn deref(&self) -> &Self::Target { &self.user } } impl TryFrom for CoxUser { type Error = LoginError; fn try_from(user: User) -> Result { if user.is_cox { Ok(CoxUser { user }) } else { Err(LoginError::NotACox) } } } #[derive(Debug)] pub enum LoginError { SqlxError(sqlx::Error), InvalidAuthenticationCombo, NotLoggedIn, NotAnAdmin, NotACox, NoPasswordSet(User), } impl From for LoginError { fn from(sqlx_error: sqlx::Error) -> Self { Self::SqlxError(sqlx_error) } } impl User { pub async fn update(&self, db: &SqlitePool, is_cox: bool, is_admin: bool, is_guest: bool) { sqlx::query!( "UPDATE user SET is_cox = ?, is_admin = ?, is_guest = ? where id = ?", is_cox, is_admin, is_guest, self.id ) .execute(db) .await .unwrap(); //TODO: fixme } pub async fn find_by_id(db: &SqlitePool, id: i32) -> Result { let user: User = sqlx::query_as!( User, " SELECT id, name, pw, is_cox, is_admin, is_guest FROM user WHERE id like ? ", id ) .fetch_one(db) .await?; Ok(user) } async fn find_by_name(db: &SqlitePool, name: String) -> Result { let user: User = sqlx::query_as!( User, " SELECT id, name, pw, is_cox, is_admin, is_guest FROM user WHERE name like ? ", name ) .fetch_one(db) .await?; Ok(user) } fn get_hashed_pw(pw: String) -> String { let salt = SaltString::from_b64("dS/X5/sPEKTj4Rzs/CuvzQ").unwrap(); let argon2 = Argon2::default(); argon2 .hash_password(&pw.as_bytes(), &salt) .unwrap() .to_string() } pub async fn login(db: &SqlitePool, name: String, pw: String) -> Result { let user = User::find_by_name(db, name).await?; match user.pw.clone() { Some(user_pw) => { let password_hash = Self::get_hashed_pw(pw); if password_hash == user_pw { return Ok(user); } Err(LoginError::InvalidAuthenticationCombo) } None => Err(LoginError::NoPasswordSet(user)), } } pub async fn all(db: &SqlitePool) -> Vec { sqlx::query_as!( User, " SELECT id, name, pw, is_cox, is_admin, is_guest FROM user " ) .fetch_all(db) .await .unwrap() //TODO: fixme } pub async fn reset_pw(&self, db: &SqlitePool) { sqlx::query!("UPDATE user SET pw = null where id = ?", self.id) .execute(db) .await .unwrap(); //TODO: fixme } pub async fn update_pw(&self, db: &SqlitePool, pw: String) { let pw = Self::get_hashed_pw(pw); sqlx::query!("UPDATE user SET pw = ? where id = ?", pw, self.id) .execute(db) .await .unwrap(); //TODO: fixme } } #[async_trait] impl<'r> FromRequest<'r> for User { type Error = LoginError; async fn from_request(req: &'r Request<'_>) -> request::Outcome { match req.cookies().get_private("loggedin_user") { Some(user) => { let user: User = serde_json::from_str(&user.value()).unwrap(); //TODO: fixme Outcome::Success(user) } None => Outcome::Failure((Status::Unauthorized, LoginError::NotLoggedIn)), } } } #[async_trait] impl<'r> FromRequest<'r> for AdminUser { type Error = LoginError; async fn from_request(req: &'r Request<'_>) -> request::Outcome { match req.cookies().get_private("loggedin_user") { Some(user) => { let user: User = serde_json::from_str(&user.value()).unwrap(); //TODO: fixme match user.try_into() { Ok(user) => Outcome::Success(user), Err(_) => Outcome::Failure((Status::Unauthorized, LoginError::NotAnAdmin)), } } None => Outcome::Failure((Status::Unauthorized, LoginError::NotLoggedIn)), } } } #[async_trait] impl<'r> FromRequest<'r> for CoxUser { type Error = LoginError; async fn from_request(req: &'r Request<'_>) -> request::Outcome { match req.cookies().get_private("loggedin_user") { Some(user) => { let user: User = serde_json::from_str(&user.value()).unwrap(); //TODO: fixme match user.try_into() { Ok(user) => Outcome::Success(user), Err(_) => Outcome::Failure((Status::Unauthorized, LoginError::NotAnAdmin)), } } None => Outcome::Failure((Status::Unauthorized, LoginError::NotLoggedIn)), } } } #[cfg(test)] mod test { use crate::testdb; use super::User; use sqlx::SqlitePool; #[sqlx::test] fn succ_login_with_test_db() { let pool = testdb!(); User::login(&pool, "admin".into(), "admin".into()) .await .unwrap(); } #[sqlx::test] fn wrong_pw() { let pool = testdb!(); assert!(User::login(&pool, "admin".into(), "admi".into()) .await .is_err()); } #[sqlx::test] fn wrong_username() { let pool = testdb!(); assert!(User::login(&pool, "admi".into(), "admin".into()) .await .is_err()); } }