use axum::http::HeaderMap; use loco_rs::prelude::*; use serde::{Deserialize, Deserializer, Serialize}; use serde_json::Value; use std::{collections::HashSet, time::Instant}; use crate::{ controllers::site_settings, models::_entities::posts, services::{abuse_guard, analytics, content}, }; fn deserialize_boolish_option<'de, D>( deserializer: D, ) -> std::result::Result, D::Error> where D: Deserializer<'de>, { let raw = Option::::deserialize(deserializer)?; raw.map(|value| match value.trim().to_ascii_lowercase().as_str() { "1" | "true" | "yes" | "on" => Ok(true), "0" | "false" | "no" | "off" => Ok(false), other => Err(serde::de::Error::custom(format!( "invalid boolean value `{other}`" ))), }) .transpose() } fn normalize_text(value: &str) -> String { value .split_whitespace() .collect::>() .join(" ") .trim() .to_ascii_lowercase() } fn tokenize(value: &str) -> Vec { value .split(|ch: char| !ch.is_alphanumeric() && ch != '-' && ch != '_') .map(normalize_text) .filter(|item| !item.is_empty()) .collect() } fn levenshtein_distance(left: &str, right: &str) -> usize { if left == right { return 0; } if left.is_empty() { return right.chars().count(); } if right.is_empty() { return left.chars().count(); } let right_chars = right.chars().collect::>(); let mut prev = (0..=right_chars.len()).collect::>(); for (i, left_ch) in left.chars().enumerate() { let mut curr = vec![i + 1; right_chars.len() + 1]; for (j, right_ch) in right_chars.iter().enumerate() { let cost = usize::from(left_ch != *right_ch); curr[j + 1] = (curr[j] + 1).min(prev[j + 1] + 1).min(prev[j] + cost); } prev = curr; } prev[right_chars.len()] } fn parse_synonym_groups(value: &Option) -> Vec> { value .as_ref() .and_then(Value::as_array) .cloned() .unwrap_or_default() .into_iter() .filter_map(|item| item.as_str().map(ToString::to_string)) .map(|item| { let normalized = item.replace("=>", ",").replace('|', ","); normalized .split([',', ',']) .map(normalize_text) .filter(|token| !token.is_empty()) .collect::>() }) .filter(|group| !group.is_empty()) .collect() } fn expand_search_terms(query: &str, synonym_groups: &[Vec]) -> Vec { let normalized_query = normalize_text(query); let query_tokens = tokenize(query); let mut expanded = Vec::new(); let mut seen = HashSet::new(); if !normalized_query.is_empty() && seen.insert(normalized_query.clone()) { expanded.push(normalized_query.clone()); } for token in &query_tokens { if seen.insert(token.clone()) { expanded.push(token.clone()); } } for group in synonym_groups { let matched = group.iter().any(|item| { *item == normalized_query || query_tokens.iter().any(|token| token == item) || normalized_query.contains(item) }); if matched { for token in group { if seen.insert(token.clone()) { expanded.push(token.clone()); } } } } expanded } fn candidate_terms(posts: &[posts::Model]) -> Vec { let mut seen = HashSet::new(); let mut candidates = Vec::new(); for post in posts { for source in [ post.title.as_deref().unwrap_or_default(), post.category.as_deref().unwrap_or_default(), &post.slug, ] { for token in tokenize(source) { if token.len() >= 3 && seen.insert(token.clone()) { candidates.push(token); } } } if let Some(tags) = post.tags.as_ref().and_then(Value::as_array) { for token in tags.iter().filter_map(Value::as_str).flat_map(tokenize) { if token.len() >= 2 && seen.insert(token.clone()) { candidates.push(token); } } } } candidates } fn find_spelling_fallback( query: &str, posts: &[posts::Model], synonym_groups: &[Vec], ) -> Vec { let primary_token = tokenize(query).into_iter().next().unwrap_or_default(); if primary_token.len() < 3 { return Vec::new(); } let mut nearest = candidate_terms(posts) .into_iter() .map(|candidate| { let distance = levenshtein_distance(&primary_token, &candidate); (candidate, distance) }) .filter(|(_, distance)| *distance <= 2) .collect::>(); nearest.sort_by(|left, right| left.1.cmp(&right.1).then_with(|| left.0.cmp(&right.0))); nearest .into_iter() .take(3) .flat_map(|(candidate, _)| expand_search_terms(&candidate, synonym_groups)) .collect() } fn post_has_tag(post: &posts::Model, wanted_tag: &str) -> bool { let wanted = normalize_text(wanted_tag); post.tags .as_ref() .and_then(Value::as_array) .map(|tags| { tags.iter() .filter_map(Value::as_str) .map(normalize_text) .any(|tag| tag == wanted) }) .unwrap_or(false) } fn score_post(post: &posts::Model, query: &str, terms: &[String]) -> f64 { let normalized_query = normalize_text(query); let title = normalize_text(post.title.as_deref().unwrap_or_default()); let description = normalize_text(post.description.as_deref().unwrap_or_default()); let content_text = normalize_text(post.content.as_deref().unwrap_or_default()); let category = normalize_text(post.category.as_deref().unwrap_or_default()); let slug = normalize_text(&post.slug); let tags = post .tags .as_ref() .and_then(Value::as_array) .cloned() .unwrap_or_default() .into_iter() .filter_map(|item| item.as_str().map(normalize_text)) .collect::>(); let mut score = 0.0; if !normalized_query.is_empty() { if title.contains(&normalized_query) { score += 6.0; } if description.contains(&normalized_query) { score += 4.0; } if slug.contains(&normalized_query) { score += 4.0; } if category.contains(&normalized_query) { score += 3.0; } if tags.iter().any(|tag| tag.contains(&normalized_query)) { score += 4.0; } if content_text.contains(&normalized_query) { score += 2.0; } } for term in terms { if term.is_empty() { continue; } if title.contains(term) { score += 3.5; } if description.contains(term) { score += 2.2; } if slug.contains(term) { score += 2.0; } if category.contains(term) { score += 1.8; } if tags.iter().any(|tag| tag == term) { score += 2.5; } else if tags.iter().any(|tag| tag.contains(term)) { score += 1.5; } if content_text.contains(term) { score += 0.8; } } score } fn is_preview_search(query: &SearchQuery, headers: &HeaderMap) -> bool { query.preview.unwrap_or(false) || headers .get("x-termi-search-mode") .and_then(|value| value.to_str().ok()) .map(|value| value.eq_ignore_ascii_case("preview")) .unwrap_or(false) } fn normalize_search_sort_by(value: Option<&str>) -> String { match value .map(str::trim) .unwrap_or_default() .to_ascii_lowercase() .as_str() { "newest" | "created_at" => "newest".to_string(), "oldest" => "oldest".to_string(), "title" => "title".to_string(), _ => "relevance".to_string(), } } fn normalize_sort_order(value: Option<&str>, sort_by: &str) -> String { match value .map(str::trim) .unwrap_or_default() .to_ascii_lowercase() .as_str() { "asc" => "asc".to_string(), "desc" => "desc".to_string(), _ if sort_by == "title" => "asc".to_string(), _ => "desc".to_string(), } } fn sort_search_results(items: &mut [SearchResult], sort_by: &str, sort_order: &str) { items.sort_by(|left, right| { let ordering = match sort_by { "newest" => right.created_at.cmp(&left.created_at), "oldest" => left.created_at.cmp(&right.created_at), "title" => left .title .as_deref() .unwrap_or(&left.slug) .to_ascii_lowercase() .cmp( &right .title .as_deref() .unwrap_or(&right.slug) .to_ascii_lowercase(), ), _ => right .rank .partial_cmp(&left.rank) .unwrap_or(std::cmp::Ordering::Equal) .then_with(|| right.created_at.cmp(&left.created_at)), }; if sort_by == "relevance" || sort_by == "newest" || sort_by == "oldest" { return ordering; } let ordering = if sort_order == "asc" { ordering } else { ordering.reverse() }; ordering.then_with(|| left.slug.cmp(&right.slug)) }); } #[derive(Clone, Debug, Default, Deserialize)] pub struct SearchQuery { pub q: Option, pub limit: Option, pub category: Option, pub tag: Option, #[serde(alias = "type")] pub post_type: Option, #[serde(default, deserialize_with = "deserialize_boolish_option")] pub preview: Option, } #[derive(Clone, Debug, Default, Deserialize)] pub struct SearchPageQuery { #[serde(flatten)] pub search: SearchQuery, pub page: Option, #[serde(alias = "page_size")] pub page_size: Option, pub sort_by: Option, pub sort_order: Option, } #[derive(Clone, Debug, Serialize)] pub struct SearchResult { pub id: i32, pub title: Option, pub slug: String, pub description: Option, pub content: Option, pub category: Option, pub tags: Option, pub post_type: Option, pub image: Option, pub pinned: Option, pub created_at: chrono::DateTime, pub updated_at: chrono::DateTime, pub rank: f64, } #[derive(Clone, Debug, Serialize)] pub struct PagedSearchResponse { pub query: String, pub items: Vec, pub page: u64, pub page_size: u64, pub total: usize, pub total_pages: u64, pub sort_by: String, pub sort_order: String, } async fn build_search_results( ctx: &AppContext, query: &SearchQuery, headers: &HeaderMap, ) -> Result<(String, bool, Vec)> { let preview_search = is_preview_search(query, headers); let q = query.q.clone().unwrap_or_default().trim().to_string(); if q.is_empty() { return Ok((q, preview_search, Vec::new())); } if !preview_search { abuse_guard::enforce_public_scope( "search", abuse_guard::detect_client_ip(headers).as_deref(), Some(&q), )?; } let settings = site_settings::load_current(ctx).await.ok(); let synonym_groups = settings .as_ref() .map(|item| parse_synonym_groups(&item.search_synonyms)) .unwrap_or_default(); let mut all_posts = posts::Entity::find() .all(&ctx.db) .await? .into_iter() .filter(|post| { preview_search || content::is_post_listed_publicly(post, chrono::Utc::now().fixed_offset()) }) .collect::>(); if let Some(category) = query .category .as_deref() .map(str::trim) .filter(|value| !value.is_empty()) { all_posts.retain(|post| { post.category .as_deref() .map(|value| value.eq_ignore_ascii_case(category)) .unwrap_or(false) }); } if let Some(tag) = query .tag .as_deref() .map(str::trim) .filter(|value| !value.is_empty()) { all_posts.retain(|post| post_has_tag(post, tag)); } if let Some(post_type) = query .post_type .as_deref() .map(str::trim) .filter(|value| !value.is_empty()) { all_posts.retain(|post| { post.post_type .as_deref() .map(|value| value.eq_ignore_ascii_case(post_type)) .unwrap_or(false) }); } let mut expanded_terms = expand_search_terms(&q, &synonym_groups); let mut results = all_posts .iter() .map(|post| (post, score_post(post, &q, &expanded_terms))) .filter(|(_, rank)| *rank > 0.0) .map(|(post, rank)| SearchResult { id: post.id, title: post.title.clone(), slug: post.slug.clone(), description: post.description.clone(), content: post.content.clone(), category: post.category.clone(), tags: post.tags.clone(), post_type: post.post_type.clone(), image: post.image.clone(), pinned: post.pinned, created_at: post.created_at.into(), updated_at: post.updated_at.into(), rank, }) .collect::>(); if results.is_empty() { expanded_terms = find_spelling_fallback(&q, &all_posts, &synonym_groups); if !expanded_terms.is_empty() { results = all_posts .iter() .map(|post| (post, score_post(post, &q, &expanded_terms))) .filter(|(_, rank)| *rank > 0.0) .map(|(post, rank)| SearchResult { id: post.id, title: post.title.clone(), slug: post.slug.clone(), description: post.description.clone(), content: post.content.clone(), category: post.category.clone(), tags: post.tags.clone(), post_type: post.post_type.clone(), image: post.image.clone(), pinned: post.pinned, created_at: post.created_at.into(), updated_at: post.updated_at.into(), rank, }) .collect::>(); } } sort_search_results(&mut results, "relevance", "desc"); Ok((q, preview_search, results)) } #[debug_handler] pub async fn search( Query(query): Query, State(ctx): State, headers: HeaderMap, ) -> Result { let started_at = Instant::now(); let limit = query.limit.unwrap_or(20).clamp(1, 100) as usize; let (q, preview_search, mut results) = build_search_results(&ctx, &query, &headers).await?; if q.is_empty() { return format::json(Vec::::new()); } results.truncate(limit); if !preview_search { analytics::record_search_event( &ctx, &q, results.len(), &headers, started_at.elapsed().as_millis() as i64, ) .await; } format::json(results) } #[debug_handler] pub async fn search_page( Query(query): Query, State(ctx): State, headers: HeaderMap, ) -> Result { let started_at = Instant::now(); let page_size = query.page_size.unwrap_or(20).clamp(1, 100); let sort_by = normalize_search_sort_by(query.sort_by.as_deref()); let sort_order = normalize_sort_order(query.sort_order.as_deref(), &sort_by); let (q, preview_search, mut results) = build_search_results(&ctx, &query.search, &headers).await?; if q.is_empty() { return format::json(PagedSearchResponse { query: q, items: Vec::new(), page: 1, page_size, total: 0, total_pages: 1, sort_by, sort_order, }); } sort_search_results(&mut results, &sort_by, &sort_order); let total = results.len(); let total_pages = std::cmp::max(1, ((total as u64) + page_size - 1) / page_size); let page = query.page.unwrap_or(1).clamp(1, total_pages); let start = ((page - 1) * page_size) as usize; let end = std::cmp::min(start + page_size as usize, total); let items = if start >= total { Vec::new() } else { results[start..end].to_vec() }; if !preview_search { analytics::record_search_event( &ctx, &q, total, &headers, started_at.elapsed().as_millis() as i64, ) .await; } format::json(PagedSearchResponse { query: q, items, page, page_size, total, total_pages, sort_by, sort_order, }) } pub fn routes() -> Routes { Routes::new() .prefix("api/search/") .add("page", get(search_page)) .add("/", get(search)) }