use crate::model::client::Client; use axum::{http::HeaderMap, routing::get, Router}; use axum_extra::extract::{ cookie::{Cookie, Expiration, Key}, PrivateCookieJar, }; use serde::{Deserialize, Serialize}; use sqlx::{pool::PoolOptions, sqlite::SqliteConnectOptions, SqlitePool}; use std::{ collections::HashSet, fmt::Display, fs, path::Path, str::FromStr, sync::{Arc, LazyLock}, }; use time::{Duration, OffsetDateTime}; use tower_http::services::ServeDir; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter}; use uuid::Uuid; #[macro_use] extern crate rust_i18n; i18n!("locales", fallback = "en"); mod game; mod index; pub(crate) mod language; pub(crate) mod model; mod page; pub(crate) mod random_names; pub(crate) enum Backend { Sqlite(SqlitePool), } #[derive(Debug)] enum Language { German, English, } impl Display for Language { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Language::German => f.write_str("de"), Language::English => f.write_str("en"), } } } impl Language { pub(crate) fn next_language(&self) -> Self { match self { Language::German => Language::English, Language::English => Language::German, } } fn to_locale(&self) -> &'static str { match self { Language::German => "de", Language::English => "en", } } } impl From for Language { fn from(value: String) -> Self { if value.starts_with("de") { Language::German } else { Language::English } } } #[derive(Debug)] struct Req { client: Client, lang: Language, } pub(crate) enum NameUpdateError { TooLong(usize, usize), TooShort(usize, usize), ContainsBadWord, } static BAD_WORDS: LazyLock> = LazyLock::new(|| { const BAD_WORDS_FILE: &str = include_str!("../bad/bad-list.txt"); BAD_WORDS_FILE .lines() .map(|line| line.trim()) .map(|word| word.to_lowercase()) .collect() }); fn contains_bad_word(text: &str) -> bool { // check parts of string, e.g. ABC_DEF checks both ABC and DEF let cleaned_text = text.to_lowercase(); if cleaned_text .split(|c: char| !c.is_alphabetic()) .any(|part| BAD_WORDS.iter().any(|bad| part == bad)) { return true; } // check string as a whole, e.g. ABC_DEF checks for ABCDEF let cleaned_text: String = text .to_lowercase() .chars() .filter(|c| c.is_alphabetic()) .collect(); BAD_WORDS.iter().any(|bad_word| &cleaned_text == bad_word) } #[cfg(test)] mod tests { use super::contains_bad_word; use std::fs; #[test] fn test_whitelist_words_are_not_flagged() { let whitelist_content = fs::read_to_string("bad/test-common-names.txt").expect("Failed to read file"); for (line_number, line) in whitelist_content.lines().enumerate() { let word = line.trim(); // Skip empty lines if word.is_empty() { continue; } assert!( !contains_bad_word(word), "Word '{}' on line {} should not be flagged as bad but was detected", word, line_number + 1 ); } } } impl Backend { async fn client(&self, cookies: PrivateCookieJar) -> (PrivateCookieJar, Client) { let existing_uuid = cookies .get("client_id") .and_then(|cookie| Uuid::parse_str(cookie.value()).ok()); match existing_uuid { Some(uuid) => (cookies, self.get_client(&uuid).await), None => { let new_id = Uuid::new_v4(); let expiration_date = OffsetDateTime::now_utc() + Duration::days(30); let mut cookie = Cookie::new("client_id", new_id.to_string()); cookie.set_expires(Expiration::DateTime(expiration_date)); cookie.set_http_only(true); cookie.set_secure(true); let updated_cookies = cookies.add(cookie); (updated_cookies, self.get_client(&new_id).await) } } } // Combined method for getting both client and language async fn client_full( &self, cookies: PrivateCookieJar, headers: &HeaderMap, ) -> (PrivateCookieJar, Req) { let (cookies, client) = self.client(cookies).await; let lang = language::language(&cookies, headers); (cookies, Req { client, lang }) } async fn set_client_name(&self, client: &Client, name: &str) -> Result<(), NameUpdateError> { if name.len() > 25 { return Err(NameUpdateError::TooLong(25, name.len())); } if name.len() < 3 { return Err(NameUpdateError::TooShort(3, name.len())); } if contains_bad_word(name) { return Err(NameUpdateError::ContainsBadWord); } match self { Backend::Sqlite(db) => { sqlx::query!( "UPDATE client SET name = ? WHERE uuid = ?;", name, client.uuid ) .execute(db) .await .unwrap(); } } Ok(()) } } #[derive(Clone)] pub struct AppState { pub(crate) backend: Arc, pub key: Key, } impl axum::extract::FromRef for Key { fn from_ref(state: &AppState) -> Self { state.key.clone() } } impl axum::extract::FromRef for Arc { fn from_ref(state: &AppState) -> Self { state.backend.clone() } } #[derive(Serialize, Deserialize)] struct Config { key: Vec, } impl Config { fn generate() -> Self { Self { key: Key::generate().master().to_vec(), } } } fn load_or_create_key() -> Result> { let config_path = "config.toml"; // Try to read existing config if Path::new(config_path).exists() { let content = fs::read_to_string(config_path)?; let config: Config = toml::from_str(&content)?; return Ok(Key::from(&config.key)); } // Create new config if file doesn't exist let config = Config::generate(); let toml_string = toml::to_string(&config)?; fs::write(config_path, toml_string)?; Ok(Key::from(&config.key)) } #[tokio::main] async fn main() { tracing_subscriber::registry() .with(tracing_subscriber::fmt::layer()) .with(EnvFilter::from_default_env()) .init(); let connection_options = SqliteConnectOptions::from_str("sqlite://db.sqlite").unwrap(); let db: SqlitePool = PoolOptions::new() .connect_with(connection_options) .await .unwrap(); let key = load_or_create_key().unwrap(); let state = AppState { backend: Arc::new(Backend::Sqlite(db)), key, }; let app = Router::new() .route("/", get(index::index)) .nest_service("/static", ServeDir::new("./static/serve")) .merge(game::routes()) .with_state(state); // run our app with hyper, listening globally on port 3000 let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap(); axum::serve(listener, app).await.unwrap(); }