use std::{ collections::HashMap, sync::{Mutex, OnceLock}, }; use axum::http::{HeaderMap, StatusCode, header}; use chrono::{DateTime, Duration, Utc}; use loco_rs::{controller::ErrorDetail, prelude::*}; const DEFAULT_WINDOW_SECONDS: i64 = 5 * 60; const DEFAULT_MAX_REQUESTS_PER_WINDOW: u32 = 45; const DEFAULT_BAN_MINUTES: i64 = 30; const DEFAULT_BURST_LIMIT: u32 = 8; const DEFAULT_BURST_WINDOW_SECONDS: i64 = 30; const ENV_WINDOW_SECONDS: &str = "TERMI_PUBLIC_RATE_LIMIT_WINDOW_SECONDS"; const ENV_MAX_REQUESTS_PER_WINDOW: &str = "TERMI_PUBLIC_RATE_LIMIT_MAX"; const ENV_BAN_MINUTES: &str = "TERMI_PUBLIC_RATE_LIMIT_BAN_MINUTES"; const ENV_BURST_LIMIT: &str = "TERMI_PUBLIC_RATE_LIMIT_BURST_MAX"; const ENV_BURST_WINDOW_SECONDS: &str = "TERMI_PUBLIC_RATE_LIMIT_BURST_WINDOW_SECONDS"; #[derive(Clone, Debug)] struct AbuseGuardConfig { window_seconds: i64, max_requests_per_window: u32, ban_minutes: i64, burst_limit: u32, burst_window_seconds: i64, } #[derive(Clone, Debug)] struct AbuseGuardEntry { window_started_at: DateTime, request_count: u32, burst_window_started_at: DateTime, burst_count: u32, banned_until: Option>, last_reason: Option, } fn parse_env_i64(name: &str, fallback: i64, min: i64, max: i64) -> i64 { std::env::var(name) .ok() .and_then(|value| value.trim().parse::().ok()) .map(|value| value.clamp(min, max)) .unwrap_or(fallback) } fn parse_env_u32(name: &str, fallback: u32, min: u32, max: u32) -> u32 { std::env::var(name) .ok() .and_then(|value| value.trim().parse::().ok()) .map(|value| value.clamp(min, max)) .unwrap_or(fallback) } fn load_config() -> AbuseGuardConfig { AbuseGuardConfig { window_seconds: parse_env_i64(ENV_WINDOW_SECONDS, DEFAULT_WINDOW_SECONDS, 10, 24 * 60 * 60), max_requests_per_window: parse_env_u32( ENV_MAX_REQUESTS_PER_WINDOW, DEFAULT_MAX_REQUESTS_PER_WINDOW, 1, 50_000, ), ban_minutes: parse_env_i64(ENV_BAN_MINUTES, DEFAULT_BAN_MINUTES, 1, 7 * 24 * 60), burst_limit: parse_env_u32(ENV_BURST_LIMIT, DEFAULT_BURST_LIMIT, 1, 1_000), burst_window_seconds: parse_env_i64( ENV_BURST_WINDOW_SECONDS, DEFAULT_BURST_WINDOW_SECONDS, 5, 60 * 60, ), } } fn normalize_token(value: Option<&str>, max_chars: usize) -> Option { value.and_then(|item| { let trimmed = item.trim(); if trimmed.is_empty() { None } else { Some(trimmed.chars().take(max_chars).collect::()) } }) } fn normalize_ip(value: Option<&str>) -> Option { normalize_token(value, 96) } pub fn header_value<'a>(headers: &'a HeaderMap, key: header::HeaderName) -> Option<&'a str> { headers.get(key).and_then(|value| value.to_str().ok()) } fn first_forwarded_ip(value: &str) -> Option<&str> { value .split(',') .map(str::trim) .find(|item| !item.is_empty()) } pub fn detect_client_ip(headers: &HeaderMap) -> Option { let forwarded = header_value(headers, header::HeaderName::from_static("x-forwarded-for")) .and_then(first_forwarded_ip); let real_ip = header_value(headers, header::HeaderName::from_static("x-real-ip")); let cf_connecting_ip = header_value(headers, header::HeaderName::from_static("cf-connecting-ip")); let true_client_ip = header_value(headers, header::HeaderName::from_static("true-client-ip")); normalize_ip( forwarded .or(real_ip) .or(cf_connecting_ip) .or(true_client_ip), ) } fn abuse_store() -> &'static Mutex> { static STORE: OnceLock>> = OnceLock::new(); STORE.get_or_init(|| Mutex::new(HashMap::new())) } fn make_key(scope: &str, client_ip: Option<&str>, fingerprint: Option<&str>) -> String { let normalized_scope = scope.trim().to_ascii_lowercase(); let normalized_ip = normalize_ip(client_ip).unwrap_or_else(|| "unknown".to_string()); let normalized_fingerprint = normalize_token(fingerprint, 160).unwrap_or_default(); if normalized_fingerprint.is_empty() { format!("{normalized_scope}:{normalized_ip}") } else { format!("{normalized_scope}:{normalized_ip}:{normalized_fingerprint}") } } fn too_many_requests(message: impl Into) -> Error { let message = message.into(); Error::CustomError( StatusCode::TOO_MANY_REQUESTS, ErrorDetail::new("rate_limited".to_string(), message), ) } pub fn enforce_public_scope( scope: &str, client_ip: Option<&str>, fingerprint: Option<&str>, ) -> Result<()> { let config = load_config(); let key = make_key(scope, client_ip, fingerprint); let now = Utc::now(); let mut store = abuse_store() .lock() .map_err(|_| Error::InternalServerError)?; store.retain(|_, entry| { entry .banned_until .map(|until| until > now - Duration::days(1)) .unwrap_or_else(|| entry.window_started_at > now - Duration::days(1)) }); let entry = store.entry(key).or_insert_with(|| AbuseGuardEntry { window_started_at: now, request_count: 0, burst_window_started_at: now, burst_count: 0, banned_until: None, last_reason: None, }); if let Some(banned_until) = entry.banned_until { if banned_until > now { let retry_after = (banned_until - now).num_minutes().max(1); return Err(too_many_requests(format!( "请求过于频繁,请在 {retry_after} 分钟后重试" ))); } entry.banned_until = None; } if entry.window_started_at + Duration::seconds(config.window_seconds) <= now { entry.window_started_at = now; entry.request_count = 0; } if entry.burst_window_started_at + Duration::seconds(config.burst_window_seconds) <= now { entry.burst_window_started_at = now; entry.burst_count = 0; } entry.request_count += 1; entry.burst_count += 1; if entry.burst_count > config.burst_limit { entry.banned_until = Some(now + Duration::minutes(config.ban_minutes)); entry.last_reason = Some("burst_limit".to_string()); return Err(too_many_requests("短时间请求过多,已临时封禁,请稍后再试")); } if entry.request_count > config.max_requests_per_window { entry.banned_until = Some(now + Duration::minutes(config.ban_minutes)); entry.last_reason = Some("window_limit".to_string()); return Err(too_many_requests("请求过于频繁,已临时封禁,请稍后再试")); } Ok(()) }