import datetime import fcntl import json import os import pathlib import shutil import subprocess import tarfile import urllib.error import urllib.parse import urllib.request from typing import Any, Dict, List, Optional from .constants import ( API_DIR, API_PACKAGE_DIR, API_PATH, AUTH_TOKEN, BACKUP_ROOT, DEFAULT_MANIFEST_URL, TMP_ROOT, UPDATE_LOCK, UPDATE_STATE, UPDATE_STATE_DIR, VERSION_FILE, WEB_ROOT, WEB_VERSION_FILE, ) from .diagnostics import diag_log from .helpers import default_host, ensure_dir, sha256_file def read_current_version() -> str: if VERSION_FILE.exists(): return VERSION_FILE.read_text().strip() if WEB_VERSION_FILE.exists(): try: return json.loads(WEB_VERSION_FILE.read_text()).get("version", "unknown") except Exception: return "unknown" return "unknown" def load_update_state() -> Dict[str, Any]: UPDATE_STATE_DIR.mkdir(parents=True, exist_ok=True) if UPDATE_STATE.exists(): try: state = json.loads(UPDATE_STATE.read_text()) state.setdefault("changelog_url", None) state.setdefault("latest_release_date", None) state.setdefault("current_release_date", None) return state except Exception: pass return { "current_version": read_current_version(), "latest_version": None, "last_check": None, "status": "unknown", "message": "", "auto_check": False, "in_progress": False, "progress": None, "channel": os.environ.get("PIKIT_CHANNEL", "dev"), "changelog_url": None, "latest_release_date": None, "current_release_date": None, } def save_update_state(state: Dict[str, Any]) -> None: UPDATE_STATE_DIR.mkdir(parents=True, exist_ok=True) UPDATE_STATE.write_text(json.dumps(state, indent=2)) def _auth_token(): return os.environ.get("PIKIT_AUTH_TOKEN") or AUTH_TOKEN def _gitea_latest_manifest(target: str): """ Fallback: when a manifest URL 404s, try hitting the Gitea API to grab the latest release asset named manifest.json. """ try: parts = target.split("/") if "releases" not in parts: return None idx = parts.index("releases") if idx < 2: return None base = "/".join(parts[:3]) owner = parts[idx - 2] repo = parts[idx - 1] api_url = f"{base}/api/v1/repos/{owner}/{repo}/releases/latest" req = urllib.request.Request(api_url) token = _auth_token() if token: req.add_header("Authorization", f"token {token}") resp = urllib.request.urlopen(req, timeout=10) rel = json.loads(resp.read().decode()) assets = rel.get("assets") or [] manifest_asset = next((a for a in assets if a.get("name") == "manifest.json"), None) if manifest_asset and manifest_asset.get("browser_download_url"): return fetch_manifest(manifest_asset["browser_download_url"]) except Exception: return None return None def fetch_manifest(url: str | None = None): target = url or os.environ.get("PIKIT_MANIFEST_URL") or DEFAULT_MANIFEST_URL req = urllib.request.Request(target) token = _auth_token() if token: req.add_header("Authorization", f"token {token}") try: resp = urllib.request.urlopen(req, timeout=10) data = resp.read() return json.loads(data.decode()) except urllib.error.HTTPError as e: if e.code == 404: alt = _gitea_latest_manifest(target) if alt: return alt raise def _try_fetch(url: Optional[str]): if not url: return None try: return fetch_manifest(url) except Exception: return None def fetch_manifest_for_channel(channel: str, with_meta: bool = False): """ For stable: use normal manifest (latest non-prerelease). For dev: try normal manifest; if it points to stable, fetch latest prerelease manifest via Gitea API. If a stable build is newer than the latest dev build, prefer the newer stable even on dev channel. """ channel = channel or "dev" base_manifest_url = os.environ.get("PIKIT_MANIFEST_URL") or DEFAULT_MANIFEST_URL dev_manifest_url = os.environ.get("PIKIT_DEV_MANIFEST_URL") stable_manifest_url = os.environ.get("PIKIT_STABLE_MANIFEST_URL") or base_manifest_url manifest = None version_dates: Dict[str, Optional[str]] = {} try: manifest = fetch_manifest(stable_manifest_url) except Exception: manifest = None def _norm_ver(ver): if ver is None: return None s = str(ver).strip() if s.lower().startswith("v"): s = s[1:] return s def _newer(a, b): try: from distutils.version import LooseVersion return LooseVersion(a) > LooseVersion(b) except Exception: return a > b def _release_version(rel: Dict[str, Any]): for key in ("tag_name", "name"): val = rel.get(key) if val: v = _norm_ver(val) if v: return v return None def _manifest_from_release(rel: Dict[str, Any]): asset = next((a for a in rel.get("assets", []) if a.get("name") == "manifest.json"), None) if not asset or not asset.get("browser_download_url"): return None mf = fetch_manifest(asset["browser_download_url"]) if mf: dt = rel.get("published_at") or rel.get("created_at") if dt: mf["_release_date"] = dt tag = rel.get("tag_name") if tag: mf["_release_tag"] = tag return mf try: parts = base_manifest_url.split("/") if "releases" not in parts: if manifest: return (manifest, {"version_dates": version_dates}) if with_meta else manifest mf = fetch_manifest(base_manifest_url) return (mf, {"version_dates": version_dates}) if with_meta else mf idx = parts.index("releases") owner = parts[idx - 2] repo = parts[idx - 1] base = "/".join(parts[:3]) api_url = f"{base}/api/v1/repos/{owner}/{repo}/releases" req = urllib.request.Request(api_url) token = _auth_token() if token: req.add_header("Authorization", f"token {token}") resp = urllib.request.urlopen(req, timeout=10) releases = json.loads(resp.read().decode()) # Map release versions to published dates so we can surface them later for rel in releases: v = _release_version(rel) if v and v not in version_dates: version_dates[v] = rel.get("published_at") or rel.get("created_at") dev_rel = None stable_rel = None dev_ver = None stable_ver = None for rel in releases: ver_str = _release_version(rel) parsed = _norm_ver(ver_str) if ver_str else None if parsed is None: continue if rel.get("prerelease") is True: if dev_ver is None or _newer(parsed.replace("-", "."), dev_ver): dev_rel = rel dev_ver = parsed.replace("-", ".") elif rel.get("prerelease") is False: if stable_ver is None or _newer(parsed.replace("-", "."), stable_ver): stable_rel = rel stable_ver = parsed.replace("-", ".") latest_dev = _manifest_from_release(dev_rel) if dev_rel else None latest_stable = _manifest_from_release(stable_rel) if stable_rel else None # If API didn't give us a dev manifest, try explicitly configured dev URL if dev_manifest_url and latest_dev is None: latest_dev = _try_fetch(dev_manifest_url) if latest_dev and "_release_date" not in latest_dev: latest_dev["_release_date"] = version_dates.get( _norm_ver(latest_dev.get("version") or latest_dev.get("latest_version")), None ) # Attach publish date to the base manifest when possible if manifest: mver = _norm_ver(manifest.get("version") or manifest.get("latest_version")) if mver and mver in version_dates and "_release_date" not in manifest: manifest["_release_date"] = version_dates[mver] if channel == "dev": # Choose the newest by version comparison across stable/dev/base candidates candidates = [c for c in (latest_dev, latest_stable, manifest) if c] best = None best_ver = None for c in candidates: ver = _norm_ver(c.get("version") or c.get("latest_version")) if not ver: continue ver_cmp = ver.replace("-", ".") if best_ver is None or _newer(ver_cmp, best_ver): best = c best_ver = ver_cmp manifest = best else: # stable channel manifest = latest_stable or manifest except Exception: pass # As a last resort for dev channel, consider explicitly configured dev manifest even without API data if channel == "dev" and manifest is None and dev_manifest_url: manifest = _try_fetch(dev_manifest_url) # If still nothing and stable manifest URL is set, try that once more if manifest is None and stable_manifest_url and stable_manifest_url != base_manifest_url: manifest = _try_fetch(stable_manifest_url) if manifest: if with_meta: return manifest, {"version_dates": version_dates} return manifest raise RuntimeError("No manifest found for channel") def download_file(url: str, dest: pathlib.Path): ensure_dir(dest.parent) req = urllib.request.Request(url) token = _auth_token() if token: req.add_header("Authorization", f"token {token}") with urllib.request.urlopen(req, timeout=60) as resp, dest.open("wb") as f: shutil.copyfileobj(resp, f) return dest def fetch_text_with_auth(url: str): req = urllib.request.Request(url) token = _auth_token() if token: req.add_header("Authorization", f"token {token}") with urllib.request.urlopen(req, timeout=10) as resp: return resp.read().decode() def acquire_lock(): try: ensure_dir(UPDATE_LOCK.parent) lockfile = UPDATE_LOCK.open("w") fcntl.flock(lockfile.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB) lockfile.write(str(os.getpid())) lockfile.flush() return lockfile except Exception: return None def release_lock(lockfile): try: fcntl.flock(lockfile.fileno(), fcntl.LOCK_UN) lockfile.close() UPDATE_LOCK.unlink(missing_ok=True) except Exception: pass def list_backups(): """Return backups sorted by mtime (newest first).""" ensure_dir(BACKUP_ROOT) backups = [p for p in BACKUP_ROOT.iterdir() if p.is_dir()] backups.sort(key=lambda p: p.stat().st_mtime, reverse=True) return backups def get_backup_version(path: pathlib.Path): vf = path / "version.txt" if not vf.exists(): web_version = path / "pikit-web" / "data" / "version.json" if not web_version.exists(): return None try: return json.loads(web_version.read_text()).get("version") except Exception: return None try: return vf.read_text().strip() except Exception: return None def choose_rollback_backup(): """ Pick the most recent backup whose version differs from the currently installed version. If none differ, fall back to the newest backup. """ backups = list_backups() if not backups: return None current = read_current_version() for b in backups: ver = get_backup_version(b) if ver and ver != current: return b return backups[0] def restore_backup(target: pathlib.Path): if (target / "pikit-web").exists(): shutil.rmtree(WEB_ROOT, ignore_errors=True) shutil.copytree(target / "pikit-web", WEB_ROOT, dirs_exist_ok=True) if (target / "pikit-api.py").exists(): shutil.copy2(target / "pikit-api.py", API_PATH) os.chmod(API_PATH, 0o755) if (target / "pikit_api").exists(): shutil.rmtree(API_PACKAGE_DIR, ignore_errors=True) shutil.copytree(target / "pikit_api", API_PACKAGE_DIR, dirs_exist_ok=True) VERSION_FILE.parent.mkdir(parents=True, exist_ok=True) if (target / "version.txt").exists(): shutil.copy2(target / "version.txt", VERSION_FILE) else: ver = get_backup_version(target) if ver: VERSION_FILE.write_text(str(ver)) for svc in ("pikit-api.service", "dietpi-dashboard-frontend.service"): subprocess.run(["systemctl", "restart", svc], check=False) def prune_backups(keep: int = 2): if keep < 1: keep = 1 backups = list_backups() for old in backups[keep:]: shutil.rmtree(old, ignore_errors=True) def check_for_update(): state = load_update_state() lock = acquire_lock() if lock is None: state["status"] = "error" state["message"] = "Another update is running" save_update_state(state) return state diag_log("info", "Update check started", {"channel": state.get("channel") or "dev"}) state["in_progress"] = True state["progress"] = "Checking for updates…" save_update_state(state) try: manifest, meta = fetch_manifest_for_channel(state.get("channel") or "dev", with_meta=True) latest = manifest.get("version") or manifest.get("latest_version") state["latest_version"] = latest state["changelog_url"] = manifest.get("changelog") state["last_check"] = datetime.datetime.utcnow().isoformat() + "Z" version_dates = (meta or {}).get("version_dates") or {} if manifest.get("_release_date"): state["latest_release_date"] = manifest.get("_release_date") elif latest and latest in version_dates: state["latest_release_date"] = version_dates.get(str(latest)) else: state["latest_release_date"] = None state["current_release_date"] = None current_ver = state.get("current_version") if current_ver and current_ver in version_dates: state["current_release_date"] = version_dates.get(str(current_ver)) elif current_ver and current_ver == latest and state["latest_release_date"]: # If current matches latest and we have a date for latest, reuse it state["current_release_date"] = state["latest_release_date"] channel = state.get("channel") or "dev" if channel == "stable" and latest and "dev" in str(latest): state["status"] = "up_to_date" state["message"] = "Dev release available; enable dev channel to install." else: if latest and latest != state.get("current_version"): state["status"] = "update_available" state["message"] = manifest.get("notes") or manifest.get("message") or "Update available" else: state["status"] = "up_to_date" state["message"] = "Up to date" diag_log("info", "Update check finished", {"status": state["status"], "latest": str(latest)}) except Exception as e: state["status"] = "up_to_date" state["message"] = f"Could not reach update server: {e}" state["last_check"] = datetime.datetime.utcnow().isoformat() + "Z" state["latest_release_date"] = None diag_log("error", "Update check failed", {"error": str(e)}) finally: state["in_progress"] = False state["progress"] = None save_update_state(state) if lock: release_lock(lock) return state def _stage_backup() -> pathlib.Path: ts = datetime.datetime.utcnow().strftime("%Y%m%d-%H%M%S") backup_dir = BACKUP_ROOT / ts ensure_dir(backup_dir) if WEB_ROOT.exists(): ensure_dir(backup_dir / "pikit-web") shutil.copytree(WEB_ROOT, backup_dir / "pikit-web", dirs_exist_ok=True) if API_PATH.exists(): shutil.copy2(API_PATH, backup_dir / "pikit-api.py") if API_PACKAGE_DIR.exists(): shutil.copytree(API_PACKAGE_DIR, backup_dir / "pikit_api", dirs_exist_ok=True) if VERSION_FILE.exists(): shutil.copy2(VERSION_FILE, backup_dir / "version.txt") return backup_dir def apply_update(): state = load_update_state() if state.get("in_progress"): state["message"] = "Update already in progress" save_update_state(state) return state lock = acquire_lock() if lock is None: state["status"] = "error" state["message"] = "Another update is running" save_update_state(state) return state state["in_progress"] = True state["status"] = "in_progress" state["progress"] = "Starting update…" save_update_state(state) diag_log("info", "Update apply started", {"channel": state.get("channel") or "dev"}) try: channel = state.get("channel") or os.environ.get("PIKIT_CHANNEL", "dev") manifest, meta = fetch_manifest_for_channel(channel, with_meta=True) latest = manifest.get("version") or manifest.get("latest_version") if not latest: raise RuntimeError("Manifest missing version") backup_dir = _stage_backup() prune_backups(keep=1) bundle_url = manifest.get("bundle") or manifest.get("url") if not bundle_url: raise RuntimeError("Manifest missing bundle url") stage_dir = TMP_ROOT / str(latest) bundle_path = stage_dir / "bundle.tar.gz" ensure_dir(stage_dir) state["progress"] = "Downloading release…" save_update_state(state) download_file(bundle_url, bundle_path) diag_log("debug", "Bundle downloaded", {"url": bundle_url, "path": str(bundle_path)}) expected_hash = None for f in manifest.get("files", []): if f.get("path") == "bundle.tar.gz" and f.get("sha256"): expected_hash = f["sha256"] break if expected_hash: got = sha256_file(bundle_path) if got.lower() != expected_hash.lower(): raise RuntimeError("Bundle hash mismatch") diag_log("debug", "Bundle hash verified", {"expected": expected_hash}) state["progress"] = "Staging files…" save_update_state(state) with tarfile.open(bundle_path, "r:gz") as tar: tar.extractall(stage_dir) staged_web = stage_dir / "pikit-web" if staged_web.exists(): shutil.rmtree(WEB_ROOT, ignore_errors=True) shutil.copytree(staged_web, WEB_ROOT) staged_api = stage_dir / "pikit-api.py" if staged_api.exists(): shutil.copy2(staged_api, API_PATH) os.chmod(API_PATH, 0o755) staged_pkg = stage_dir / "pikit_api" if staged_pkg.exists(): shutil.rmtree(API_PACKAGE_DIR, ignore_errors=True) shutil.copytree(staged_pkg, API_PACKAGE_DIR, dirs_exist_ok=True) for svc in ("pikit-api.service", "dietpi-dashboard-frontend.service"): subprocess.run(["systemctl", "restart", svc], check=False) VERSION_FILE.parent.mkdir(parents=True, exist_ok=True) VERSION_FILE.write_text(str(latest)) state["current_version"] = str(latest) state["latest_version"] = str(latest) state["changelog_url"] = manifest.get("changelog") state["latest_release_date"] = manifest.get("_release_date") or (meta or {}).get("version_dates", {}).get(str(latest)) state["current_release_date"] = state.get("latest_release_date") state["status"] = "up_to_date" state["message"] = "Update installed" state["progress"] = None save_update_state(state) diag_log("info", "Update applied", {"version": str(latest)}) except urllib.error.HTTPError as e: state["status"] = "error" state["message"] = f"No release available ({e.code})" state["latest_release_date"] = None diag_log("error", "Update apply HTTP error", {"code": e.code}) except Exception as e: state["status"] = "error" state["message"] = f"Update failed: {e}" state["progress"] = None state["latest_release_date"] = None save_update_state(state) diag_log("error", "Update apply failed", {"error": str(e)}) backup = choose_rollback_backup() if backup: try: restore_backup(backup) state["current_version"] = read_current_version() state["message"] += f" (rolled back to backup {backup.name})" save_update_state(state) diag_log("info", "Rollback after failed update", {"backup": backup.name}) except Exception as re: state["message"] += f" (rollback failed: {re})" save_update_state(state) diag_log("error", "Rollback after failed update failed", {"error": str(re)}) finally: state["in_progress"] = False state["progress"] = None save_update_state(state) if lock: release_lock(lock) return state def rollback_update(): state = load_update_state() lock = acquire_lock() if lock is None: state["status"] = "error" state["message"] = "Another update is running" save_update_state(state) return state state["in_progress"] = True state["status"] = "in_progress" state["progress"] = "Rolling back…" save_update_state(state) diag_log("info", "Rollback started") backup = choose_rollback_backup() if not backup: state["status"] = "error" state["message"] = "No backup available to rollback." state["in_progress"] = False state["progress"] = None save_update_state(state) release_lock(lock) return state try: restore_backup(backup) state["status"] = "up_to_date" state["current_version"] = read_current_version() state["latest_version"] = state.get("latest_version") or state["current_version"] ver = get_backup_version(backup) suffix = f" (version {ver})" if ver else "" state["message"] = f"Rolled back to backup {backup.name}{suffix}" diag_log("info", "Rollback complete", {"backup": backup.name, "version": ver}) except Exception as e: state["status"] = "error" state["message"] = f"Rollback failed: {e}" diag_log("error", "Rollback failed", {"error": str(e)}) state["in_progress"] = False state["progress"] = None save_update_state(state) release_lock(lock) return state def start_background_task(mode: str): """ Kick off a background update/rollback via systemd-run so nginx/API restarts do not break the caller connection. mode: "apply" or "rollback" """ assert mode in ("apply", "rollback"), "invalid mode" unit = f"pikit-update-{mode}" cmd = ["systemd-run", "--unit", unit, "--quiet"] if DEFAULT_MANIFEST_URL: cmd += [f"--setenv=PIKIT_MANIFEST_URL={DEFAULT_MANIFEST_URL}"] token = _auth_token() if token: cmd += [f"--setenv=PIKIT_AUTH_TOKEN={token}"] cmd += ["/usr/bin/env", "python3", str(API_PATH), f"--{mode}-update"] subprocess.run(cmd, check=False) # Backwards compat aliases apply_update_stub = apply_update rollback_update_stub = rollback_update