import datetime import fcntl import json import os import pathlib import shutil import subprocess import tarfile import urllib.error import urllib.parse import urllib.request from typing import Any, Dict, List, Optional from .constants import ( API_DIR, API_PACKAGE_DIR, API_PATH, AUTH_TOKEN, BACKUP_ROOT, DEFAULT_MANIFEST_URL, TMP_ROOT, UPDATE_LOCK, UPDATE_STATE, UPDATE_STATE_DIR, VERSION_FILE, WEB_ROOT, WEB_VERSION_FILE, ) from .diagnostics import diag_log from .helpers import default_host, ensure_dir, sha256_file def read_current_version() -> str: if VERSION_FILE.exists(): return VERSION_FILE.read_text().strip() if WEB_VERSION_FILE.exists(): try: return json.loads(WEB_VERSION_FILE.read_text()).get("version", "unknown") except Exception: return "unknown" return "unknown" def load_update_state() -> Dict[str, Any]: UPDATE_STATE_DIR.mkdir(parents=True, exist_ok=True) if UPDATE_STATE.exists(): try: return json.loads(UPDATE_STATE.read_text()) except Exception: pass return { "current_version": read_current_version(), "latest_version": None, "last_check": None, "status": "unknown", "message": "", "auto_check": False, "in_progress": False, "progress": None, "channel": os.environ.get("PIKIT_CHANNEL", "dev"), } def save_update_state(state: Dict[str, Any]) -> None: UPDATE_STATE_DIR.mkdir(parents=True, exist_ok=True) UPDATE_STATE.write_text(json.dumps(state, indent=2)) def _auth_token(): return os.environ.get("PIKIT_AUTH_TOKEN") or AUTH_TOKEN def _gitea_latest_manifest(target: str): """ Fallback: when a manifest URL 404s, try hitting the Gitea API to grab the latest release asset named manifest.json. """ try: parts = target.split("/") if "releases" not in parts: return None idx = parts.index("releases") if idx < 2: return None base = "/".join(parts[:3]) owner = parts[idx - 2] repo = parts[idx - 1] api_url = f"{base}/api/v1/repos/{owner}/{repo}/releases/latest" req = urllib.request.Request(api_url) token = _auth_token() if token: req.add_header("Authorization", f"token {token}") resp = urllib.request.urlopen(req, timeout=10) rel = json.loads(resp.read().decode()) assets = rel.get("assets") or [] manifest_asset = next((a for a in assets if a.get("name") == "manifest.json"), None) if manifest_asset and manifest_asset.get("browser_download_url"): return fetch_manifest(manifest_asset["browser_download_url"]) except Exception: return None return None def fetch_manifest(url: str | None = None): target = url or os.environ.get("PIKIT_MANIFEST_URL") or DEFAULT_MANIFEST_URL req = urllib.request.Request(target) token = _auth_token() if token: req.add_header("Authorization", f"token {token}") try: resp = urllib.request.urlopen(req, timeout=10) data = resp.read() return json.loads(data.decode()) except urllib.error.HTTPError as e: if e.code == 404: alt = _gitea_latest_manifest(target) if alt: return alt raise def fetch_manifest_for_channel(channel: str): """ For stable: use normal manifest (latest non-prerelease). For dev: try normal manifest; if it points to stable, fetch latest prerelease manifest via Gitea API. """ channel = channel or "dev" base_manifest_url = os.environ.get("PIKIT_MANIFEST_URL") or DEFAULT_MANIFEST_URL manifest = None try: manifest = fetch_manifest(base_manifest_url) except Exception: manifest = None if manifest and channel == "stable": return manifest if manifest: version = manifest.get("version") or manifest.get("latest_version") if channel == "dev" and version and "dev" in str(version): return manifest try: parts = base_manifest_url.split("/") if "releases" not in parts: if manifest: return manifest return fetch_manifest(base_manifest_url) idx = parts.index("releases") owner = parts[idx - 2] repo = parts[idx - 1] base = "/".join(parts[:3]) api_url = f"{base}/api/v1/repos/{owner}/{repo}/releases" req = urllib.request.Request(api_url) token = _auth_token() if token: req.add_header("Authorization", f"token {token}") resp = urllib.request.urlopen(req, timeout=10) releases = json.loads(resp.read().decode()) def pick(predicate): for r in releases: if predicate(r): asset = next((a for a in r.get("assets", []) if a.get("name") == "manifest.json"), None) if asset and asset.get("browser_download_url"): return fetch_manifest(asset["browser_download_url"]) return None if channel == "dev": m = pick(lambda r: r.get("prerelease") is True) if m: return m m = pick(lambda r: r.get("prerelease") is False) if m: return m except Exception: pass if manifest: return manifest raise RuntimeError("No manifest found for channel") def download_file(url: str, dest: pathlib.Path): ensure_dir(dest.parent) req = urllib.request.Request(url) token = _auth_token() if token: req.add_header("Authorization", f"token {token}") with urllib.request.urlopen(req, timeout=60) as resp, dest.open("wb") as f: shutil.copyfileobj(resp, f) return dest def fetch_text_with_auth(url: str): req = urllib.request.Request(url) token = _auth_token() if token: req.add_header("Authorization", f"token {token}") with urllib.request.urlopen(req, timeout=10) as resp: return resp.read().decode() def acquire_lock(): try: ensure_dir(UPDATE_LOCK.parent) lockfile = UPDATE_LOCK.open("w") fcntl.flock(lockfile.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB) lockfile.write(str(os.getpid())) lockfile.flush() return lockfile except Exception: return None def release_lock(lockfile): try: fcntl.flock(lockfile.fileno(), fcntl.LOCK_UN) lockfile.close() UPDATE_LOCK.unlink(missing_ok=True) except Exception: pass def list_backups(): """Return backups sorted by mtime (newest first).""" ensure_dir(BACKUP_ROOT) backups = [p for p in BACKUP_ROOT.iterdir() if p.is_dir()] backups.sort(key=lambda p: p.stat().st_mtime, reverse=True) return backups def get_backup_version(path: pathlib.Path): vf = path / "version.txt" if not vf.exists(): web_version = path / "pikit-web" / "data" / "version.json" if not web_version.exists(): return None try: return json.loads(web_version.read_text()).get("version") except Exception: return None try: return vf.read_text().strip() except Exception: return None def choose_rollback_backup(): """ Pick the most recent backup whose version differs from the currently installed version. If none differ, fall back to the newest backup. """ backups = list_backups() if not backups: return None current = read_current_version() for b in backups: ver = get_backup_version(b) if ver and ver != current: return b return backups[0] def restore_backup(target: pathlib.Path): if (target / "pikit-web").exists(): shutil.rmtree(WEB_ROOT, ignore_errors=True) shutil.copytree(target / "pikit-web", WEB_ROOT, dirs_exist_ok=True) if (target / "pikit-api.py").exists(): shutil.copy2(target / "pikit-api.py", API_PATH) os.chmod(API_PATH, 0o755) if (target / "pikit_api").exists(): shutil.rmtree(API_PACKAGE_DIR, ignore_errors=True) shutil.copytree(target / "pikit_api", API_PACKAGE_DIR, dirs_exist_ok=True) VERSION_FILE.parent.mkdir(parents=True, exist_ok=True) if (target / "version.txt").exists(): shutil.copy2(target / "version.txt", VERSION_FILE) else: ver = get_backup_version(target) if ver: VERSION_FILE.write_text(str(ver)) for svc in ("pikit-api.service", "dietpi-dashboard-frontend.service"): subprocess.run(["systemctl", "restart", svc], check=False) def prune_backups(keep: int = 2): if keep < 1: keep = 1 backups = list_backups() for old in backups[keep:]: shutil.rmtree(old, ignore_errors=True) def check_for_update(): state = load_update_state() lock = acquire_lock() if lock is None: state["status"] = "error" state["message"] = "Another update is running" save_update_state(state) return state diag_log("info", "Update check started", {"channel": state.get("channel") or "dev"}) state["in_progress"] = True state["progress"] = "Checking for updates…" save_update_state(state) try: manifest = fetch_manifest_for_channel(state.get("channel") or "dev") latest = manifest.get("version") or manifest.get("latest_version") state["latest_version"] = latest state["last_check"] = datetime.datetime.utcnow().isoformat() + "Z" channel = state.get("channel") or "dev" if channel == "stable" and latest and "dev" in str(latest): state["status"] = "up_to_date" state["message"] = "Dev release available; enable dev channel to install." else: if latest and latest != state.get("current_version"): state["status"] = "update_available" state["message"] = manifest.get("changelog", "Update available") else: state["status"] = "up_to_date" state["message"] = "Up to date" diag_log("info", "Update check finished", {"status": state["status"], "latest": str(latest)}) except Exception as e: state["status"] = "up_to_date" state["message"] = f"Could not reach update server: {e}" state["last_check"] = datetime.datetime.utcnow().isoformat() + "Z" diag_log("error", "Update check failed", {"error": str(e)}) finally: state["in_progress"] = False state["progress"] = None save_update_state(state) if lock: release_lock(lock) return state def _stage_backup() -> pathlib.Path: ts = datetime.datetime.utcnow().strftime("%Y%m%d-%H%M%S") backup_dir = BACKUP_ROOT / ts ensure_dir(backup_dir) if WEB_ROOT.exists(): ensure_dir(backup_dir / "pikit-web") shutil.copytree(WEB_ROOT, backup_dir / "pikit-web", dirs_exist_ok=True) if API_PATH.exists(): shutil.copy2(API_PATH, backup_dir / "pikit-api.py") if API_PACKAGE_DIR.exists(): shutil.copytree(API_PACKAGE_DIR, backup_dir / "pikit_api", dirs_exist_ok=True) if VERSION_FILE.exists(): shutil.copy2(VERSION_FILE, backup_dir / "version.txt") return backup_dir def apply_update(): state = load_update_state() if state.get("in_progress"): state["message"] = "Update already in progress" save_update_state(state) return state lock = acquire_lock() if lock is None: state["status"] = "error" state["message"] = "Another update is running" save_update_state(state) return state state["in_progress"] = True state["status"] = "in_progress" state["progress"] = "Starting update…" save_update_state(state) diag_log("info", "Update apply started", {"channel": state.get("channel") or "dev"}) try: channel = state.get("channel") or os.environ.get("PIKIT_CHANNEL", "dev") manifest = fetch_manifest_for_channel(channel) latest = manifest.get("version") or manifest.get("latest_version") if not latest: raise RuntimeError("Manifest missing version") backup_dir = _stage_backup() prune_backups(keep=1) bundle_url = manifest.get("bundle") or manifest.get("url") if not bundle_url: raise RuntimeError("Manifest missing bundle url") stage_dir = TMP_ROOT / str(latest) bundle_path = stage_dir / "bundle.tar.gz" ensure_dir(stage_dir) state["progress"] = "Downloading release…" save_update_state(state) download_file(bundle_url, bundle_path) diag_log("debug", "Bundle downloaded", {"url": bundle_url, "path": str(bundle_path)}) expected_hash = None for f in manifest.get("files", []): if f.get("path") == "bundle.tar.gz" and f.get("sha256"): expected_hash = f["sha256"] break if expected_hash: got = sha256_file(bundle_path) if got.lower() != expected_hash.lower(): raise RuntimeError("Bundle hash mismatch") diag_log("debug", "Bundle hash verified", {"expected": expected_hash}) state["progress"] = "Staging files…" save_update_state(state) with tarfile.open(bundle_path, "r:gz") as tar: tar.extractall(stage_dir) staged_web = stage_dir / "pikit-web" if staged_web.exists(): shutil.rmtree(WEB_ROOT, ignore_errors=True) shutil.copytree(staged_web, WEB_ROOT) staged_api = stage_dir / "pikit-api.py" if staged_api.exists(): shutil.copy2(staged_api, API_PATH) os.chmod(API_PATH, 0o755) staged_pkg = stage_dir / "pikit_api" if staged_pkg.exists(): shutil.rmtree(API_PACKAGE_DIR, ignore_errors=True) shutil.copytree(staged_pkg, API_PACKAGE_DIR, dirs_exist_ok=True) for svc in ("pikit-api.service", "dietpi-dashboard-frontend.service"): subprocess.run(["systemctl", "restart", svc], check=False) VERSION_FILE.parent.mkdir(parents=True, exist_ok=True) VERSION_FILE.write_text(str(latest)) state["current_version"] = str(latest) state["latest_version"] = str(latest) state["status"] = "up_to_date" state["message"] = "Update installed" state["progress"] = None save_update_state(state) diag_log("info", "Update applied", {"version": str(latest)}) except urllib.error.HTTPError as e: state["status"] = "error" state["message"] = f"No release available ({e.code})" diag_log("error", "Update apply HTTP error", {"code": e.code}) except Exception as e: state["status"] = "error" state["message"] = f"Update failed: {e}" state["progress"] = None save_update_state(state) diag_log("error", "Update apply failed", {"error": str(e)}) backup = choose_rollback_backup() if backup: try: restore_backup(backup) state["current_version"] = read_current_version() state["message"] += f" (rolled back to backup {backup.name})" save_update_state(state) diag_log("info", "Rollback after failed update", {"backup": backup.name}) except Exception as re: state["message"] += f" (rollback failed: {re})" save_update_state(state) diag_log("error", "Rollback after failed update failed", {"error": str(re)}) finally: state["in_progress"] = False state["progress"] = None save_update_state(state) if lock: release_lock(lock) return state def rollback_update(): state = load_update_state() lock = acquire_lock() if lock is None: state["status"] = "error" state["message"] = "Another update is running" save_update_state(state) return state state["in_progress"] = True state["status"] = "in_progress" state["progress"] = "Rolling back…" save_update_state(state) diag_log("info", "Rollback started") backup = choose_rollback_backup() if not backup: state["status"] = "error" state["message"] = "No backup available to rollback." state["in_progress"] = False state["progress"] = None save_update_state(state) release_lock(lock) return state try: restore_backup(backup) state["status"] = "up_to_date" state["current_version"] = read_current_version() state["latest_version"] = state.get("latest_version") or state["current_version"] ver = get_backup_version(backup) suffix = f" (version {ver})" if ver else "" state["message"] = f"Rolled back to backup {backup.name}{suffix}" diag_log("info", "Rollback complete", {"backup": backup.name, "version": ver}) except Exception as e: state["status"] = "error" state["message"] = f"Rollback failed: {e}" diag_log("error", "Rollback failed", {"error": str(e)}) state["in_progress"] = False state["progress"] = None save_update_state(state) release_lock(lock) return state def start_background_task(mode: str): """ Kick off a background update/rollback via systemd-run so nginx/API restarts do not break the caller connection. mode: "apply" or "rollback" """ assert mode in ("apply", "rollback"), "invalid mode" unit = f"pikit-update-{mode}" cmd = ["systemd-run", "--unit", unit, "--quiet"] if DEFAULT_MANIFEST_URL: cmd += [f"--setenv=PIKIT_MANIFEST_URL={DEFAULT_MANIFEST_URL}"] token = _auth_token() if token: cmd += [f"--setenv=PIKIT_AUTH_TOKEN={token}"] cmd += ["/usr/bin/env", "python3", str(API_PATH), f"--{mode}-update"] subprocess.run(cmd, check=False) # Backwards compat aliases apply_update_stub = apply_update rollback_update_stub = rollback_update