Files
mov-renamarr/src/llm.rs
2025-12-30 10:52:59 -05:00

119 lines
3.4 KiB
Rust

use anyhow::{Context, Result};
use reqwest::blocking::Client;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Default)]
pub struct LlmHints {
pub title: Option<String>,
pub year: Option<i32>,
pub alt_titles: Vec<String>,
}
#[derive(Clone)]
pub struct LlmClient {
endpoint: String,
model: String,
max_tokens: Option<u32>,
client: Client,
}
impl LlmClient {
pub fn new(endpoint: String, model: String, timeout_seconds: u64, max_tokens: Option<u32>) -> Result<Self> {
let client = Client::builder()
.timeout(std::time::Duration::from_secs(timeout_seconds))
.build()
.context("failed to build HTTP client for LLM")?;
Ok(Self {
endpoint,
model,
max_tokens,
client,
})
}
pub fn parse_filename(&self, raw: &str) -> Result<LlmHints> {
let prompt = build_prompt(raw);
let request = OllamaRequest {
model: self.model.clone(),
prompt,
stream: false,
format: Some("json".to_string()),
options: Some(OllamaOptions {
num_predict: self.max_tokens,
temperature: 0.0,
}),
};
let url = format!("{}/api/generate", self.endpoint.trim_end_matches('/'));
let response = self
.client
.post(url)
.json(&request)
.send()
.context("LLM request failed")?;
let status = response.status();
if !status.is_success() {
return Err(anyhow::anyhow!("LLM returned HTTP {status}"));
}
let body: OllamaResponse = response.json().context("failed to parse LLM response")?;
let hints = parse_hints(&body.response).unwrap_or_default();
Ok(hints)
}
}
fn build_prompt(raw: &str) -> String {
format!(
"You are a strict parser. Extract the full movie title and year from the filename below.\n\nRules:\n- Output JSON only.\n- Title must include all words of the movie name in order (no partial tokens).\n- Strip release metadata (resolution, codec, source, group tags).\n- Year must be a 4-digit number if present.\n- If unsure, use null for fields and empty array for alt_titles.\n- Do NOT invent data.\n\nReturn JSON with keys: title, year, alt_titles.\n\nFilename: {raw}\n"
)
}
#[derive(Serialize)]
struct OllamaRequest {
model: String,
prompt: String,
stream: bool,
format: Option<String>,
options: Option<OllamaOptions>,
}
#[derive(Serialize)]
struct OllamaOptions {
#[serde(skip_serializing_if = "Option::is_none")]
num_predict: Option<u32>,
temperature: f32,
}
#[derive(Deserialize)]
struct OllamaResponse {
response: String,
}
#[derive(Deserialize, Default)]
struct LlmHintsRaw {
title: Option<String>,
year: Option<YearValue>,
alt_titles: Option<Vec<String>>,
}
#[derive(Deserialize)]
#[serde(untagged)]
enum YearValue {
Number(i32),
String(String),
}
fn parse_hints(raw: &str) -> Option<LlmHints> {
let parsed: LlmHintsRaw = serde_json::from_str(raw).ok()?;
let year = parsed.year.and_then(|value| match value {
YearValue::Number(num) => Some(num),
YearValue::String(s) => s.chars().filter(|c| c.is_ascii_digit()).collect::<String>().parse().ok(),
});
Some(LlmHints {
title: parsed.title,
year,
alt_titles: parsed.alt_titles.unwrap_or_default(),
})
}