119 lines
3.4 KiB
Rust
119 lines
3.4 KiB
Rust
use anyhow::{Context, Result};
|
|
use reqwest::blocking::Client;
|
|
use serde::{Deserialize, Serialize};
|
|
|
|
#[derive(Debug, Clone, Default)]
|
|
pub struct LlmHints {
|
|
pub title: Option<String>,
|
|
pub year: Option<i32>,
|
|
pub alt_titles: Vec<String>,
|
|
}
|
|
|
|
#[derive(Clone)]
|
|
pub struct LlmClient {
|
|
endpoint: String,
|
|
model: String,
|
|
max_tokens: Option<u32>,
|
|
client: Client,
|
|
}
|
|
|
|
impl LlmClient {
|
|
pub fn new(endpoint: String, model: String, timeout_seconds: u64, max_tokens: Option<u32>) -> Result<Self> {
|
|
let client = Client::builder()
|
|
.timeout(std::time::Duration::from_secs(timeout_seconds))
|
|
.build()
|
|
.context("failed to build HTTP client for LLM")?;
|
|
Ok(Self {
|
|
endpoint,
|
|
model,
|
|
max_tokens,
|
|
client,
|
|
})
|
|
}
|
|
|
|
pub fn parse_filename(&self, raw: &str) -> Result<LlmHints> {
|
|
let prompt = build_prompt(raw);
|
|
let request = OllamaRequest {
|
|
model: self.model.clone(),
|
|
prompt,
|
|
stream: false,
|
|
format: Some("json".to_string()),
|
|
options: Some(OllamaOptions {
|
|
num_predict: self.max_tokens,
|
|
temperature: 0.0,
|
|
}),
|
|
};
|
|
|
|
let url = format!("{}/api/generate", self.endpoint.trim_end_matches('/'));
|
|
let response = self
|
|
.client
|
|
.post(url)
|
|
.json(&request)
|
|
.send()
|
|
.context("LLM request failed")?;
|
|
|
|
let status = response.status();
|
|
if !status.is_success() {
|
|
return Err(anyhow::anyhow!("LLM returned HTTP {status}"));
|
|
}
|
|
|
|
let body: OllamaResponse = response.json().context("failed to parse LLM response")?;
|
|
let hints = parse_hints(&body.response).unwrap_or_default();
|
|
Ok(hints)
|
|
}
|
|
}
|
|
|
|
fn build_prompt(raw: &str) -> String {
|
|
format!(
|
|
"You are a strict parser. Extract the full movie title and year from the filename below.\n\nRules:\n- Output JSON only.\n- Title must include all words of the movie name in order (no partial tokens).\n- Strip release metadata (resolution, codec, source, group tags).\n- Year must be a 4-digit number if present.\n- If unsure, use null for fields and empty array for alt_titles.\n- Do NOT invent data.\n\nReturn JSON with keys: title, year, alt_titles.\n\nFilename: {raw}\n"
|
|
)
|
|
}
|
|
|
|
#[derive(Serialize)]
|
|
struct OllamaRequest {
|
|
model: String,
|
|
prompt: String,
|
|
stream: bool,
|
|
format: Option<String>,
|
|
options: Option<OllamaOptions>,
|
|
}
|
|
|
|
#[derive(Serialize)]
|
|
struct OllamaOptions {
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
num_predict: Option<u32>,
|
|
temperature: f32,
|
|
}
|
|
|
|
#[derive(Deserialize)]
|
|
struct OllamaResponse {
|
|
response: String,
|
|
}
|
|
|
|
#[derive(Deserialize, Default)]
|
|
struct LlmHintsRaw {
|
|
title: Option<String>,
|
|
year: Option<YearValue>,
|
|
alt_titles: Option<Vec<String>>,
|
|
}
|
|
|
|
#[derive(Deserialize)]
|
|
#[serde(untagged)]
|
|
enum YearValue {
|
|
Number(i32),
|
|
String(String),
|
|
}
|
|
|
|
fn parse_hints(raw: &str) -> Option<LlmHints> {
|
|
let parsed: LlmHintsRaw = serde_json::from_str(raw).ok()?;
|
|
let year = parsed.year.and_then(|value| match value {
|
|
YearValue::Number(num) => Some(num),
|
|
YearValue::String(s) => s.chars().filter(|c| c.is_ascii_digit()).collect::<String>().parse().ok(),
|
|
});
|
|
Some(LlmHints {
|
|
title: parsed.title,
|
|
year,
|
|
alt_titles: parsed.alt_titles.unwrap_or_default(),
|
|
})
|
|
}
|