diff --git a/backend/src/services/ai.rs b/backend/src/services/ai.rs index 3b68b19..b5ad1b0 100644 --- a/backend/src/services/ai.rs +++ b/backend/src/services/ai.rs @@ -14,6 +14,8 @@ use serde_json::{Value, json}; use std::fs; use std::path::{Path, PathBuf}; use std::sync::{Mutex, OnceLock}; +use std::thread; +use std::time::{Duration, Instant}; use uuid::Uuid; use crate::{ @@ -40,6 +42,8 @@ pub(crate) const REINDEX_EMBEDDING_BATCH_SIZE: usize = 4; const EMBEDDING_DIMENSION: usize = 384; const LOCAL_EMBEDDING_MODEL_LABEL: &str = "fastembed / local all-MiniLM-L6-v2"; const LOCAL_EMBEDDING_CACHE_DIR: &str = "storage/ai_embedding_models/all-minilm-l6-v2"; +const LOCAL_EMBEDDING_IDLE_TIMEOUT_SECS: u64 = 300; +const LOCAL_EMBEDDING_REAPER_INTERVAL_SECS: u64 = 30; const LOCAL_EMBEDDING_BASE_URL: &str = "https://huggingface.co/Qdrant/all-MiniLM-L6-v2-onnx/resolve/main"; const LOCAL_EMBEDDING_FILES: [&str; 5] = [ @@ -50,7 +54,13 @@ const LOCAL_EMBEDDING_FILES: [&str; 5] = [ "tokenizer_config.json", ]; -static TEXT_EMBEDDING_MODEL: OnceLock> = OnceLock::new(); +static TEXT_EMBEDDING_MODEL: OnceLock>> = OnceLock::new(); +static TEXT_EMBEDDING_REAPER_STARTED: OnceLock<()> = OnceLock::new(); + +struct LocalEmbeddingRuntime { + model: TextEmbedding, + last_used_at: Instant, +} #[derive(Clone, Debug)] struct AiImageRuntimeSettings { @@ -403,18 +413,78 @@ fn load_local_embedding_model() -> Result { .map_err(|error| Error::BadRequest(format!("本地 embedding 模型初始化失败: {error}"))) } -fn local_embedding_engine() -> Result<&'static Mutex> { - if let Some(model) = TEXT_EMBEDDING_MODEL.get() { - return Ok(model); +fn local_embedding_state() -> &'static Mutex> { + TEXT_EMBEDDING_MODEL.get_or_init(|| Mutex::new(None)) +} + +fn ensure_local_embedding_reaper_started() { + TEXT_EMBEDDING_REAPER_STARTED.get_or_init(|| { + if let Err(error) = thread::Builder::new() + .name("local-embedding-reaper".to_string()) + .spawn(|| { + let idle_timeout = Duration::from_secs(LOCAL_EMBEDDING_IDLE_TIMEOUT_SECS); + let check_interval = Duration::from_secs(LOCAL_EMBEDDING_REAPER_INTERVAL_SECS); + + loop { + thread::sleep(check_interval); + + let Some(state) = TEXT_EMBEDDING_MODEL.get() else { + continue; + }; + + let mut guard = match state.lock() { + Ok(guard) => guard, + Err(_) => { + tracing::warn!("failed to lock local embedding model for idle cleanup"); + continue; + } + }; + + let should_unload = guard + .as_ref() + .is_some_and(|runtime| runtime.last_used_at.elapsed() >= idle_timeout); + + if should_unload { + *guard = None; + tracing::info!( + "unloaded local embedding model after {} seconds of inactivity", + LOCAL_EMBEDDING_IDLE_TIMEOUT_SECS + ); + } + } + }) + { + tracing::warn!("failed to start local embedding reaper thread: {error}"); + } + }); +} + +fn with_local_embedding_engine( + operation: impl FnOnce(&mut TextEmbedding) -> Result, +) -> Result { + ensure_local_embedding_reaper_started(); + + let state = local_embedding_state(); + let mut guard = state + .lock() + .map_err(|_| Error::BadRequest("本地 embedding 模型当前不可用,请稍后重试".to_string()))?; + + if guard.is_none() { + tracing::info!("loading local embedding model into memory"); + *guard = Some(LocalEmbeddingRuntime { + model: load_local_embedding_model()?, + last_used_at: Instant::now(), + }); } - let model = load_local_embedding_model()?; + let runtime = guard + .as_mut() + .ok_or_else(|| Error::BadRequest("本地 embedding 模型未能成功缓存".to_string()))?; + runtime.last_used_at = Instant::now(); - let _ = TEXT_EMBEDDING_MODEL.set(Mutex::new(model)); - - TEXT_EMBEDDING_MODEL - .get() - .ok_or_else(|| Error::BadRequest("本地 embedding 模型未能成功缓存".to_string())) + let result = operation(&mut runtime.model); + runtime.last_used_at = Instant::now(); + result } fn vector_literal(embedding: &[f64]) -> Result { @@ -793,24 +863,21 @@ async fn embed_texts_locally_with_batch_size( batch_size: usize, ) -> Result>> { tokio::task::spawn_blocking(move || { - let model = local_embedding_engine()?; let prepared = inputs .iter() .map(|item| prepare_embedding_text(kind, item)) .collect::>(); - let mut guard = model.lock().map_err(|_| { - Error::BadRequest("本地 embedding 模型当前不可用,请稍后重试".to_string()) - })?; + with_local_embedding_engine(|model| { + let embeddings = model + .embed(prepared, Some(batch_size.max(1))) + .map_err(|error| Error::BadRequest(format!("本地 embedding 生成失败: {error}")))?; - let embeddings = guard - .embed(prepared, Some(batch_size.max(1))) - .map_err(|error| Error::BadRequest(format!("本地 embedding 生成失败: {error}")))?; - - Ok(embeddings - .into_iter() - .map(|embedding| embedding.into_iter().map(f64::from).collect::>()) - .collect::>()) + Ok(embeddings + .into_iter() + .map(|embedding| embedding.into_iter().map(f64::from).collect::>()) + .collect::>()) + }) }) .await .map_err(|error| Error::BadRequest(format!("本地 embedding 任务执行失败: {error}")))?