diff --git a/src/lib.rs b/src/lib.rs index c8ae0e4..dd4f853 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,2 +1,19 @@ pub mod model; pub mod rest; + +#[cfg(test)] +#[macro_export] +macro_rules! testdb { + () => {{ + let pool = SqlitePool::connect(":memory:").await.unwrap(); + sqlx::query_file!("./migration.sql") + .execute(&pool) + .await + .unwrap(); + sqlx::query_file!("./seeds.sql") + .execute(&pool) + .await + .unwrap(); + pool + }}; +} diff --git a/src/model/user.rs b/src/model/user.rs index b3b6db2..8faf63d 100644 --- a/src/model/user.rs +++ b/src/model/user.rs @@ -1,8 +1,14 @@ use argon2::{password_hash::SaltString, Argon2, PasswordHasher}; -use serde::Serialize; +use rocket::{ + async_trait, + http::Status, + request::{self, FromRequest, Outcome}, + Request, +}; +use serde::{Deserialize, Serialize}; use sqlx::{FromRow, SqlitePool}; -#[derive(FromRow, Debug, Serialize)] +#[derive(FromRow, Debug, Serialize, Deserialize)] pub struct User { id: i64, name: String, @@ -16,6 +22,7 @@ pub struct User { pub enum LoginError { SqlxError(sqlx::Error), InvalidAuthenticationCombo, + NotLoggedIn, } impl From for LoginError { @@ -58,28 +65,31 @@ WHERE name like ? } } +#[async_trait] +impl<'r> FromRequest<'r> for User { + type Error = LoginError; + + async fn from_request(req: &'r Request<'_>) -> request::Outcome { + match req.cookies().get_private("loggedin_user") { + Some(user) => { + let user: User = serde_json::from_str(&user.value()).unwrap(); //TODO: fixme + Outcome::Success(user) + } + None => Outcome::Failure((Status::Unauthorized, LoginError::NotLoggedIn)), + } + } +} + #[cfg(test)] mod test { + use crate::testdb; + use super::User; use sqlx::SqlitePool; - async fn setup() -> SqlitePool { - let pool = SqlitePool::connect(":memory:").await.unwrap(); - sqlx::query_file!("./migration.sql") - .execute(&pool) - .await - .unwrap(); - sqlx::query_file!("./seeds.sql") - .execute(&pool) - .await - .unwrap(); - - pool - } - #[sqlx::test] fn succ_login_with_test_db() { - let pool = setup().await; + let pool = testdb!(); User::login(&pool, "admin".into(), "admin".into()) .await .unwrap(); @@ -87,7 +97,7 @@ mod test { #[sqlx::test] fn wrong_pw() { - let pool = setup().await; + let pool = testdb!(); assert!(User::login(&pool, "admin".into(), "admi".into()) .await .is_err()); @@ -95,7 +105,7 @@ mod test { #[sqlx::test] fn wrong_username() { - let pool = setup().await; + let pool = testdb!(); assert!(User::login(&pool, "admi".into(), "admin".into()) .await .is_err()); diff --git a/src/rest/auth.rs b/src/rest/auth.rs index 823d803..83682af 100644 --- a/src/rest/auth.rs +++ b/src/rest/auth.rs @@ -38,7 +38,7 @@ async fn login( ) -> Flash { let user = User::login(db, login.name.clone(), login.password.clone()).await; - //TODO: be able to use for find_by_name. This would get rid of the following match clause. + //TODO: be able to use ? for login. This would get rid of the following match clause. let user = match user { Ok(user) => user, Err(_) => { @@ -47,7 +47,7 @@ async fn login( }; let user_json: String = format!("{}", json!(user)); - cookies.add_private(Cookie::new("user", user_json)); + cookies.add_private(Cookie::new("loggedin_user", user_json)); Flash::success(Redirect::to("/"), "Login erfolgreich") } diff --git a/src/rest/mod.rs b/src/rest/mod.rs index 39a4a7a..0d2f670 100644 --- a/src/rest/mod.rs +++ b/src/rest/mod.rs @@ -1,48 +1,51 @@ -use rocket::{get, routes, Build, Rocket}; +use rocket::{catch, catchers, get, response::Redirect, routes, Build, Rocket}; use rocket_dyn_templates::{context, Template}; use sqlx::SqlitePool; +use crate::model::user::User; + mod auth; #[get("/")] -fn index() -> Template { +fn index(_user: User) -> Template { Template::render("index", context! {}) } +#[catch(401)] //unauthorized +fn unauthorized_error() -> Redirect { + Redirect::to("/auth") +} + pub fn start(db: SqlitePool) -> Rocket { rocket::build() .manage(db) .mount("/", routes![index]) .mount("/auth", auth::routes()) + .register("/", catchers![unauthorized_error]) .attach(Template::fairing()) } -//#[cfg(test)] -//mod test { -// use super::start; -// use rocket::http::Status; -// use rocket::local::asynchronous::Client; -// use rocket::uri; -// use sqlx::SqlitePool; -// -// #[sqlx::test] -// fn hello_world() { -// let pool = SqlitePool::connect(":memory:").await.unwrap(); -// sqlx::query_file!("./migration.sql") -// .execute(&pool) -// .await -// .unwrap(); -// sqlx::query_file!("./seeds.sql") -// .execute(&pool) -// .await -// .unwrap(); -// -// let client = Client::tracked(start()) -// .await -// .expect("valid rocket instance"); -// let response = client.get(uri!(super::index)).dispatch().await; -// -// assert_eq!(response.status(), Status::Ok); -// assert_eq!(response.into_string().await, Some("Hello, world!".into())); -// } -//} +#[cfg(test)] +mod test { + use crate::testdb; + + use super::start; + use rocket::http::Status; + use rocket::local::asynchronous::Client; + use rocket::uri; + use sqlx::SqlitePool; + + #[sqlx::test] + fn test_not_logged_in() { + let pool = testdb!(); + + let client = Client::tracked(start(pool)) + .await + .expect("valid rocket instance"); + let response = client.get(uri!(super::index)).dispatch().await; + + assert_eq!(response.status(), Status::SeeOther); + let location = response.headers().get("Location").next().unwrap(); + assert_eq!(location, "/auth"); + } +}