Files
aef-website/src/main.rs
2025-08-13 14:57:32 +02:00

288 lines
7.3 KiB
Rust

use crate::model::client::Client;
use axum::{http::HeaderMap, routing::get, Router};
use axum_extra::extract::{
cookie::{Cookie, Expiration, Key},
PrivateCookieJar,
};
use serde::{Deserialize, Serialize};
use sqlx::{pool::PoolOptions, sqlite::SqliteConnectOptions, SqlitePool};
use std::{
collections::HashSet,
fmt::Display,
fs,
path::Path,
str::FromStr,
sync::{Arc, LazyLock},
};
use time::{Duration, OffsetDateTime};
use tower_http::services::ServeDir;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter};
use uuid::Uuid;
#[macro_use]
extern crate rust_i18n;
i18n!("locales", fallback = "en");
mod game;
mod index;
pub(crate) mod language;
pub(crate) mod model;
mod page;
pub(crate) mod random_names;
pub(crate) enum Backend {
Sqlite(SqlitePool),
}
#[derive(Debug)]
enum Language {
German,
English,
}
impl Display for Language {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Language::German => f.write_str("de"),
Language::English => f.write_str("en"),
}
}
}
impl Language {
pub(crate) fn next_language(&self) -> Self {
match self {
Language::German => Language::English,
Language::English => Language::German,
}
}
fn to_locale(&self) -> &'static str {
match self {
Language::German => "de",
Language::English => "en",
}
}
}
impl From<String> for Language {
fn from(value: String) -> Self {
if value.starts_with("de") {
Language::German
} else {
Language::English
}
}
}
#[derive(Debug)]
struct Req {
client: Client,
lang: Language,
}
pub(crate) enum NameUpdateError {
TooLong(usize, usize),
TooShort(usize, usize),
ContainsBadWord,
}
static BAD_WORDS: LazyLock<HashSet<String>> = LazyLock::new(|| {
const BAD_WORDS_FILE: &str = include_str!("../bad/bad-list.txt");
BAD_WORDS_FILE
.lines()
.map(|line| line.trim())
.map(|word| word.to_lowercase())
.collect()
});
fn contains_bad_word(text: &str) -> bool {
// check parts of string, e.g. ABC_DEF checks both ABC and DEF
let cleaned_text = text.to_lowercase();
if cleaned_text
.split(|c: char| !c.is_alphabetic())
.any(|part| BAD_WORDS.iter().any(|bad| part == bad))
{
return true;
}
// check string as a whole, e.g. ABC_DEF checks for ABCDEF
let cleaned_text: String = text
.to_lowercase()
.chars()
.filter(|c| c.is_alphabetic())
.collect();
BAD_WORDS.iter().any(|bad_word| &cleaned_text == bad_word)
}
#[cfg(test)]
mod tests {
use super::contains_bad_word;
use std::fs;
#[test]
fn test_whitelist_words_are_not_flagged() {
let whitelist_content =
fs::read_to_string("bad/test-common-names.txt").expect("Failed to read file");
for (line_number, line) in whitelist_content.lines().enumerate() {
let word = line.trim();
// Skip empty lines
if word.is_empty() {
continue;
}
assert!(
!contains_bad_word(word),
"Word '{}' on line {} should not be flagged as bad but was detected",
word,
line_number + 1
);
}
}
}
impl Backend {
async fn client(&self, cookies: PrivateCookieJar) -> (PrivateCookieJar, Client) {
let existing_uuid = cookies
.get("client_id")
.and_then(|cookie| Uuid::parse_str(cookie.value()).ok());
match existing_uuid {
Some(uuid) => (cookies, self.get_client(&uuid).await),
None => {
let new_id = Uuid::new_v4();
let expiration_date = OffsetDateTime::now_utc() + Duration::days(30);
let mut cookie = Cookie::new("client_id", new_id.to_string());
cookie.set_expires(Expiration::DateTime(expiration_date));
cookie.set_http_only(true);
cookie.set_secure(true);
let updated_cookies = cookies.add(cookie);
(updated_cookies, self.get_client(&new_id).await)
}
}
}
// Combined method for getting both client and language
async fn client_full(
&self,
cookies: PrivateCookieJar,
headers: &HeaderMap,
) -> (PrivateCookieJar, Req) {
let (cookies, client) = self.client(cookies).await;
let lang = language::language(&cookies, headers);
(cookies, Req { client, lang })
}
async fn set_client_name(&self, client: &Client, name: &str) -> Result<(), NameUpdateError> {
if name.len() > 25 {
return Err(NameUpdateError::TooLong(25, name.len()));
}
if name.len() < 3 {
return Err(NameUpdateError::TooShort(3, name.len()));
}
if contains_bad_word(name) {
return Err(NameUpdateError::ContainsBadWord);
}
match self {
Backend::Sqlite(db) => {
sqlx::query!(
"UPDATE client SET name = ? WHERE uuid = ?;",
name,
client.uuid
)
.execute(db)
.await
.unwrap();
}
}
Ok(())
}
}
#[derive(Clone)]
pub struct AppState {
pub(crate) backend: Arc<Backend>,
pub key: Key,
}
impl axum::extract::FromRef<AppState> for Key {
fn from_ref(state: &AppState) -> Self {
state.key.clone()
}
}
impl axum::extract::FromRef<AppState> for Arc<Backend> {
fn from_ref(state: &AppState) -> Self {
state.backend.clone()
}
}
#[derive(Serialize, Deserialize)]
struct Config {
key: Vec<u8>,
}
impl Config {
fn generate() -> Self {
Self {
key: Key::generate().master().to_vec(),
}
}
}
fn load_or_create_key() -> Result<Key, Box<dyn std::error::Error>> {
let config_path = "config.toml";
// Try to read existing config
if Path::new(config_path).exists() {
let content = fs::read_to_string(config_path)?;
let config: Config = toml::from_str(&content)?;
return Ok(Key::from(&config.key));
}
// Create new config if file doesn't exist
let config = Config::generate();
let toml_string = toml::to_string(&config)?;
fs::write(config_path, toml_string)?;
Ok(Key::from(&config.key))
}
#[tokio::main]
async fn main() {
tracing_subscriber::registry()
.with(tracing_subscriber::fmt::layer())
.with(EnvFilter::from_default_env())
.init();
let connection_options = SqliteConnectOptions::from_str("sqlite://db.sqlite").unwrap();
let db: SqlitePool = PoolOptions::new()
.connect_with(connection_options)
.await
.unwrap();
let key = load_or_create_key().unwrap();
let state = AppState {
backend: Arc::new(Backend::Sqlite(db)),
key,
};
let app = Router::new()
.route("/", get(index::index))
.nest_service("/static", ServeDir::new("./static/serve"))
.merge(game::routes())
.with_state(state);
// run our app with hyper, listening globally on port 3000
let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
axum::serve(listener, app).await.unwrap();
}