Files
pi-kit/pikit-api.py

1447 lines
55 KiB
Python

#!/usr/bin/env python3
import json, os, subprocess, socket, shutil, pathlib, datetime, tarfile, sys, argparse
from http.server import BaseHTTPRequestHandler, HTTPServer
import re
import urllib.request
import hashlib
import urllib.parse
import fcntl
from functools import partial
import json as jsonlib
import io
from collections import deque
HOST = "127.0.0.1"
PORT = 4000
SERVICE_JSON = pathlib.Path("/etc/pikit/services.json")
RESET_LOG = pathlib.Path("/var/log/pikit-reset.log")
API_LOG = pathlib.Path("/var/log/pikit-api.log")
DEBUG_FLAG = pathlib.Path("/boot/pikit-debug").exists()
HTTPS_PORTS = {443, 5252}
CORE_PORTS = {80}
CORE_NAME = "Pi-Kit Dashboard"
READY_FILE = pathlib.Path("/var/run/pikit-ready")
APT_AUTO_CFG = pathlib.Path("/etc/apt/apt.conf.d/20auto-upgrades")
APT_UA_BASE = pathlib.Path("/etc/apt/apt.conf.d/50unattended-upgrades")
APT_UA_OVERRIDE = pathlib.Path("/etc/apt/apt.conf.d/51pikit-unattended.conf")
DEFAULT_UPDATE_TIME = "04:00"
DEFAULT_UPGRADE_TIME = "04:30"
SECURITY_PATTERNS = [
'origin=Debian,codename=${distro_codename},label=Debian-Security',
'origin=Debian,codename=${distro_codename}-security,label=Debian-Security',
]
ALL_PATTERNS = [
'origin=Debian,codename=${distro_codename},label=Debian',
*SECURITY_PATTERNS,
]
# Release updater constants
VERSION_FILE = pathlib.Path("/etc/pikit/version")
WEB_VERSION_FILE = pathlib.Path("/var/www/pikit-web/data/version.json")
UPDATE_STATE_DIR = pathlib.Path("/var/lib/pikit-update")
UPDATE_STATE = UPDATE_STATE_DIR / "state.json"
UPDATE_LOCK = pathlib.Path("/var/run/pikit-update.lock")
DEFAULT_MANIFEST_URL = os.environ.get(
"PIKIT_MANIFEST_URL",
"https://git.44r0n.cc/44r0n7/pi-kit/releases/latest/download/manifest.json",
)
AUTH_TOKEN = os.environ.get("PIKIT_AUTH_TOKEN")
WEB_ROOT = pathlib.Path("/var/www/pikit-web")
API_PATH = pathlib.Path("/usr/local/bin/pikit-api.py")
BACKUP_ROOT = pathlib.Path("/var/backups/pikit")
TMP_ROOT = pathlib.Path("/var/tmp/pikit-update")
# Diagnostics logging (RAM-only)
DIAG_STATE_FILE = pathlib.Path("/dev/shm/pikit-diag.state") if pathlib.Path("/dev/shm").exists() else pathlib.Path("/tmp/pikit-diag.state")
DIAG_LOG_FILE = pathlib.Path("/dev/shm/pikit-diag.log") if pathlib.Path("/dev/shm").exists() else pathlib.Path("/tmp/pikit-diag.log")
DIAG_MAX_BYTES = 1_048_576 # 1 MiB cap in RAM
DIAG_MAX_ENTRY_CHARS = 2048
DIAG_DEFAULT_STATE = {"enabled": False, "level": "normal"} # level: normal|debug
_diag_state = None
def _load_diag_state():
global _diag_state
if _diag_state is not None:
return _diag_state
try:
if DIAG_STATE_FILE.exists():
_diag_state = json.loads(DIAG_STATE_FILE.read_text())
return _diag_state
except Exception:
pass
_diag_state = DIAG_DEFAULT_STATE.copy()
return _diag_state
def _save_diag_state(enabled=None, level=None):
state = _load_diag_state()
if enabled is not None:
state["enabled"] = bool(enabled)
if level in ("normal", "debug"):
state["level"] = level
try:
DIAG_STATE_FILE.parent.mkdir(parents=True, exist_ok=True)
DIAG_STATE_FILE.write_text(json.dumps(state))
except Exception:
pass
return state
def diag_log(level: str, message: str, meta: dict | None = None):
"""
Append a diagnostic log line to RAM-backed file.
Skips when disabled or when debug level is off.
"""
state = _load_diag_state()
if not state.get("enabled"):
return
if level == "debug" and state.get("level") != "debug":
return
try:
ts = datetime.datetime.utcnow().isoformat() + "Z"
entry = {"ts": ts, "level": level, "msg": message}
if meta:
entry["meta"] = meta
line = json.dumps(entry, separators=(",", ":"))
if len(line) > DIAG_MAX_ENTRY_CHARS:
entry.pop("meta", None)
entry["msg"] = (message or "")[: DIAG_MAX_ENTRY_CHARS - 40] + ""
line = json.dumps(entry, separators=(",", ":"))
line_bytes = (line + "\n").encode()
DIAG_LOG_FILE.parent.mkdir(parents=True, exist_ok=True)
with DIAG_LOG_FILE.open("ab") as f:
f.write(line_bytes)
# Trim file if above cap
if DIAG_LOG_FILE.stat().st_size > DIAG_MAX_BYTES:
with DIAG_LOG_FILE.open("rb") as f:
f.seek(-DIAG_MAX_BYTES, io.SEEK_END)
tail = f.read()
# drop partial first line to keep JSON lines clean
if b"\n" in tail:
tail = tail.split(b"\n", 1)[1]
with DIAG_LOG_FILE.open("wb") as f:
f.write(tail)
except Exception:
# Never break caller
pass
def diag_read(limit=500):
"""Return latest log entries (dicts), newest first."""
if not DIAG_LOG_FILE.exists():
return []
try:
data = DIAG_LOG_FILE.read_bytes()
except Exception:
return []
lines = data.splitlines()[-limit:]
out = []
for line in lines:
try:
out.append(json.loads(line.decode("utf-8", errors="ignore")))
except Exception:
continue
return out[::-1]
def ensure_dir(path: pathlib.Path):
path.mkdir(parents=True, exist_ok=True)
def sha256_file(path: pathlib.Path):
h = hashlib.sha256()
with path.open("rb") as f:
for chunk in iter(lambda: f.read(1024 * 1024), b""):
h.update(chunk)
return h.hexdigest()
class FirewallToolMissing(Exception):
"""Raised when ufw is unavailable but a firewall change was requested."""
pass
def normalize_path(path: str | None) -> str:
"""Normalize optional service path. Empty -> "". Ensure leading slash."""
if not path:
return ""
p = str(path).strip()
if not p:
return ""
if not p.startswith("/"):
p = "/" + p
return p
def default_host():
"""Return preferred hostname (append .local if bare)."""
host = socket.gethostname()
if "." not in host:
host = f"{host}.local"
return host
def dbg(msg):
# Legacy debug file logging (when /boot/pikit-debug exists)
if DEBUG_FLAG:
API_LOG.parent.mkdir(parents=True, exist_ok=True)
ts = datetime.datetime.utcnow().isoformat()
with API_LOG.open("a") as f:
f.write(f"[{ts}] {msg}\n")
# Mirror into diagnostics if enabled
try:
diag_log("debug", msg)
except Exception:
pass
def set_ssh_password_auth(allow: bool):
"""
Enable/disable SSH password authentication without requiring the current password.
Used during factory reset to restore a predictable state.
"""
cfg = pathlib.Path("/etc/ssh/sshd_config")
text = cfg.read_text() if cfg.exists() else ""
def set_opt(key, value):
nonlocal text
pattern = f"{key} "
lines = text.splitlines()
replaced = False
for idx, line in enumerate(lines):
if line.strip().startswith(pattern):
lines[idx] = f"{key} {value}"
replaced = True
break
if not replaced:
lines.append(f"{key} {value}")
text_new = "\n".join(lines) + "\n"
return text_new
text = set_opt("PasswordAuthentication", "yes" if allow else "no")
text = set_opt("KbdInteractiveAuthentication", "no")
text = set_opt("ChallengeResponseAuthentication", "no")
text = set_opt("PubkeyAuthentication", "yes")
text = set_opt("PermitRootLogin", "yes" if allow else "prohibit-password")
cfg.write_text(text)
subprocess.run(["systemctl", "restart", "ssh"], check=False)
return True, f"SSH password auth {'enabled' if allow else 'disabled'}"
def load_services():
if SERVICE_JSON.exists():
try:
data = json.loads(SERVICE_JSON.read_text())
# Normalize entries: ensure url built from port if missing
host = default_host()
for svc in data:
svc_path = normalize_path(svc.get("path"))
if svc_path:
svc["path"] = svc_path
if svc.get("port"):
scheme = svc.get("scheme")
if not scheme:
scheme = "https" if int(svc["port"]) in HTTPS_PORTS else "http"
svc["scheme"] = scheme
svc["url"] = f"{scheme}://{host}:{svc['port']}{svc_path}"
return data
except Exception:
dbg("Failed to read services.json")
return []
return []
def save_services(services):
SERVICE_JSON.parent.mkdir(parents=True, exist_ok=True)
SERVICE_JSON.write_text(json.dumps(services, indent=2))
def auto_updates_enabled():
if not APT_AUTO_CFG.exists():
return False
text = APT_AUTO_CFG.read_text()
return 'APT::Periodic::Unattended-Upgrade "1";' in text
def set_auto_updates(enable: bool):
"""
Toggle unattended upgrades in a way that matches systemd state, not just the
apt config file. Assumes unattended-upgrades is already installed.
"""
units_maskable = [
"apt-daily.service",
"apt-daily-upgrade.service",
"apt-daily.timer",
"apt-daily-upgrade.timer",
"unattended-upgrades.service",
]
timers = ["apt-daily.timer", "apt-daily-upgrade.timer"]
service = "unattended-upgrades.service"
APT_AUTO_CFG.parent.mkdir(parents=True, exist_ok=True)
if enable:
APT_AUTO_CFG.write_text(
'APT::Periodic::Update-Package-Lists "1";\n'
'APT::Periodic::Unattended-Upgrade "1";\n'
)
for unit in units_maskable:
subprocess.run(["systemctl", "unmask", unit], check=False)
# Enable timers and the service; start them so state matches immediately
for unit in timers + [service]:
subprocess.run(["systemctl", "enable", unit], check=False)
for unit in timers:
subprocess.run(["systemctl", "start", unit], check=False)
subprocess.run(["systemctl", "start", service], check=False)
else:
APT_AUTO_CFG.write_text(
'APT::Periodic::Update-Package-Lists "0";\n'
'APT::Periodic::Unattended-Upgrade "0";\n'
)
# Stop/disable and mask to mirror DietPi defaults
for unit in timers + [service]:
subprocess.run(["systemctl", "stop", unit], check=False)
subprocess.run(["systemctl", "disable", unit], check=False)
for unit in units_maskable:
subprocess.run(["systemctl", "mask", unit], check=False)
def _systemctl_is(unit: str, verb: str) -> bool:
try:
out = subprocess.check_output(["systemctl", verb, unit], text=True).strip()
return out == "enabled" if verb == "is-enabled" else out == "active"
except Exception:
return False
def auto_updates_state():
config_on = auto_updates_enabled()
service = "unattended-upgrades.service"
timers = ["apt-daily.timer", "apt-daily-upgrade.timer"]
state = {
"config_enabled": config_on,
"service_enabled": _systemctl_is(service, "is-enabled"),
"service_active": _systemctl_is(service, "is-active"),
"timers_enabled": {},
"timers_active": {},
}
for t in timers:
state["timers_enabled"][t] = _systemctl_is(t, "is-enabled")
state["timers_active"][t] = _systemctl_is(t, "is-active")
# Consider overall enabled only if config is on and both timers & service are enabled
state["enabled"] = (
config_on
and state["service_enabled"]
and all(state["timers_enabled"].values())
)
return state
def reboot_required():
return pathlib.Path("/run/reboot-required").exists()
def _parse_directive(text: str, key: str, default=None, as_bool=False, as_int=False):
text = _strip_comments(text)
pattern = rf'{re.escape(key)}\s+"?([^";\n]+)"?;'
m = re.search(pattern, text)
if not m:
return default
val = m.group(1).strip()
if as_bool:
return val.lower() in ("1", "true", "yes", "on")
if as_int:
try:
return int(val)
except ValueError:
return default
return val
def _parse_origins_patterns(text: str):
text = _strip_comments(text)
m = re.search(r"Unattended-Upgrade::Origins-Pattern\s*{([^}]*)}", text, re.S)
patterns = []
if not m:
return patterns
body = m.group(1)
for line in body.splitlines():
ln = line.strip().strip('";')
if ln:
patterns.append(ln)
return patterns
def _read_timer_time(timer: str):
try:
out = subprocess.check_output(
["systemctl", "show", "--property=TimersCalendar", timer], text=True
)
# Example: TimersCalendar={ OnCalendar=*-*-* 03:10:00 ; next_elapse=... }
m = re.search(r"OnCalendar=[^0-9]*([0-9]{1,2}):([0-9]{2})", out)
if m:
return f"{int(m.group(1)):02d}:{m.group(2)}"
except Exception:
pass
return None
def _strip_comments(text: str):
"""Remove // and # line comments for safer parsing."""
lines = []
for ln in text.splitlines():
l = ln.strip()
if l.startswith("//") or l.startswith("#"):
continue
lines.append(ln)
return "\n".join(lines)
def _validate_time(val: str, default: str):
if not val:
return default
m = re.match(r"^(\d{1,2}):(\d{2})$", val.strip())
if not m:
return default
h, mi = int(m.group(1)), int(m.group(2))
if 0 <= h < 24 and 0 <= mi < 60:
return f"{h:02d}:{mi:02d}"
return default
def read_updates_config(state=None):
"""
Return a normalized unattended-upgrades configuration snapshot.
Values are sourced from the Pi-Kit override file when present, else the base file.
"""
text = ""
for path in (APT_UA_OVERRIDE, APT_UA_BASE):
if path.exists():
try:
text += path.read_text() + "\n"
except Exception:
pass
scope_hint = None
m_scope = re.search(r"PIKIT_SCOPE:\s*(\w+)", text)
if m_scope:
scope_hint = m_scope.group(1).lower()
cleaned = _strip_comments(text)
patterns = _parse_origins_patterns(cleaned)
scope = (
scope_hint
or ("all" if any("label=Debian" in p and "-security" not in p for p in patterns) else "security")
)
cleanup = _parse_directive(text, "Unattended-Upgrade::Remove-Unused-Dependencies", False, as_bool=True)
auto_reboot = _parse_directive(text, "Unattended-Upgrade::Automatic-Reboot", False, as_bool=True)
reboot_time = _validate_time(_parse_directive(text, "Unattended-Upgrade::Automatic-Reboot-Time", DEFAULT_UPGRADE_TIME), DEFAULT_UPGRADE_TIME)
reboot_with_users = _parse_directive(text, "Unattended-Upgrade::Automatic-Reboot-WithUsers", False, as_bool=True)
bandwidth = _parse_directive(text, "Acquire::http::Dl-Limit", None, as_int=True)
update_time = _read_timer_time("apt-daily.timer") or DEFAULT_UPDATE_TIME
upgrade_time = _read_timer_time("apt-daily-upgrade.timer") or DEFAULT_UPGRADE_TIME
state = state or auto_updates_state()
return {
"enabled": bool(state.get("enabled", False)),
"scope": scope,
"cleanup": bool(cleanup),
"bandwidth_limit_kbps": bandwidth,
"auto_reboot": bool(auto_reboot),
"reboot_time": reboot_time,
"reboot_with_users": bool(reboot_with_users),
"update_time": update_time,
"upgrade_time": upgrade_time,
"state": state,
}
def _write_timer_override(timer: str, time_str: str):
time_norm = _validate_time(time_str, DEFAULT_UPDATE_TIME)
override_dir = pathlib.Path(f"/etc/systemd/system/{timer}.d")
override_dir.mkdir(parents=True, exist_ok=True)
override_file = override_dir / "pikit.conf"
override_file.write_text(
"[Timer]\n"
f"OnCalendar=*-*-* {time_norm}\n"
"Persistent=true\n"
"RandomizedDelaySec=30min\n"
)
subprocess.run(["systemctl", "daemon-reload"], check=False)
subprocess.run(["systemctl", "restart", timer], check=False)
def set_updates_config(opts: dict):
"""
Apply unattended-upgrades configuration from dashboard inputs.
"""
enable = bool(opts.get("enable", True))
scope = opts.get("scope") or "all"
patterns = ALL_PATTERNS if scope == "all" else SECURITY_PATTERNS
cleanup = bool(opts.get("cleanup", False))
bandwidth = opts.get("bandwidth_limit_kbps")
auto_reboot = bool(opts.get("auto_reboot", False))
reboot_time = _validate_time(opts.get("reboot_time"), DEFAULT_UPGRADE_TIME)
reboot_with_users = bool(opts.get("reboot_with_users", False))
update_time = _validate_time(opts.get("update_time"), DEFAULT_UPDATE_TIME)
upgrade_time = _validate_time(opts.get("upgrade_time") or opts.get("update_time"), DEFAULT_UPGRADE_TIME)
APT_AUTO_CFG.parent.mkdir(parents=True, exist_ok=True)
set_auto_updates(enable)
lines = [
"// Managed by Pi-Kit dashboard",
f"// PIKIT_SCOPE: {scope}",
"Unattended-Upgrade::Origins-Pattern {",
]
for p in patterns:
lines.append(f' "{p}";')
lines.append("};")
lines.append(f'Unattended-Upgrade::Remove-Unused-Dependencies {"true" if cleanup else "false"};')
lines.append(f'Unattended-Upgrade::Automatic-Reboot {"true" if auto_reboot else "false"};')
lines.append(f'Unattended-Upgrade::Automatic-Reboot-Time "{reboot_time}";')
lines.append(
f'Unattended-Upgrade::Automatic-Reboot-WithUsers {"true" if reboot_with_users else "false"};'
)
if bandwidth is not None:
lines.append(f'Acquire::http::Dl-Limit "{int(bandwidth)}";')
APT_UA_OVERRIDE.parent.mkdir(parents=True, exist_ok=True)
APT_UA_OVERRIDE.write_text("\n".join(lines) + "\n")
# Timer overrides for when upgrades run
_write_timer_override("apt-daily.timer", update_time)
_write_timer_override("apt-daily-upgrade.timer", upgrade_time)
return read_updates_config()
def detect_https(host, port):
"""Heuristic: known HTTPS ports or .local certs."""
return int(port) in HTTPS_PORTS or str(host).lower().endswith(".local") or str(host).lower() == "pikit"
def factory_reset():
# Restore services config
if pathlib.Path("/boot/custom-files/pikit-services.json").exists():
shutil.copy("/boot/custom-files/pikit-services.json", SERVICE_JSON)
else:
SERVICE_JSON.write_text(json.dumps([
{"name": "Pi-Kit Dashboard", "port": 80},
{"name": "DietPi Dashboard", "port": 5252},
], indent=2))
# Reset firewall
reset_firewall()
# Reset SSH auth to password and set defaults
set_ssh_password_auth(True)
for user in ("root", "dietpi"):
try:
subprocess.run(["chpasswd"], input=f"{user}:pikit".encode(), check=True)
except Exception:
pass
# Ensure dietpi exists
if not pathlib.Path("/home/dietpi").exists():
subprocess.run(["useradd", "-m", "-s", "/bin/bash", "dietpi"], check=False)
subprocess.run(["chpasswd"], input=b"dietpi:pikit", check=False)
# Log and reboot
pathlib.Path("/var/log/pikit-reset.log").write_text("Factory reset triggered\n")
subprocess.Popen(["/bin/sh", "-c", "sleep 2 && systemctl reboot >/dev/null 2>&1"], close_fds=True)
def port_online(host, port):
try:
with socket.create_connection((host, int(port)), timeout=1.5):
return True
except Exception:
return False
def ufw_status_allows(port: int) -> bool:
try:
out = subprocess.check_output(["ufw", "status"], text=True)
return f"{port}" in out and "ALLOW" in out
except Exception:
return False
def allow_port_lan(port: int):
"""Open a port to RFC1918 subnets; raise if ufw is missing so callers can surface the error."""
if not shutil.which("ufw"):
raise FirewallToolMissing("Cannot update firewall: ufw is not installed on this system.")
for subnet in ("10.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16", "100.64.0.0/10", "169.254.0.0/16"):
subprocess.run(["ufw", "allow", "from", subnet, "to", "any", "port", str(port)], check=False)
def remove_port_lan(port: int):
"""Close a LAN rule for a port; raise if ufw is missing so callers can surface the error."""
if not shutil.which("ufw"):
raise FirewallToolMissing("Cannot update firewall: ufw is not installed on this system.")
for subnet in ("10.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16", "100.64.0.0/10", "169.254.0.0/16"):
subprocess.run(["ufw", "delete", "allow", "from", subnet, "to", "any", "port", str(port)], check=False)
def reset_firewall():
subprocess.run(["ufw", "--force", "reset"], check=False)
subprocess.run(["ufw", "default", "deny", "incoming"], check=False)
subprocess.run(["ufw", "default", "deny", "outgoing"], check=False)
# Outbound essentials + LAN
for port in ("53", "80", "443", "123", "67", "68"):
subprocess.run(["ufw", "allow", "out", port], check=False)
subprocess.run(["ufw", "allow", "out", "on", "lo"], check=False)
for subnet in ("10.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16", "100.64.0.0/10", "169.254.0.0/16"):
subprocess.run(["ufw", "allow", "out", "to", subnet], check=False)
for subnet in ("10.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16", "100.64.0.0/10", "169.254.0.0/16"):
for port in ("22", "80", "443", "5252", "5253"):
subprocess.run(["ufw", "allow", "from", subnet, "to", "any", "port", port], check=False)
subprocess.run(["ufw", "--force", "enable"], check=False)
def read_current_version():
if VERSION_FILE.exists():
return VERSION_FILE.read_text().strip()
if WEB_VERSION_FILE.exists():
try:
return json.loads(WEB_VERSION_FILE.read_text()).get("version", "unknown")
except Exception:
return "unknown"
return "unknown"
def load_update_state():
UPDATE_STATE_DIR.mkdir(parents=True, exist_ok=True)
if UPDATE_STATE.exists():
try:
return json.loads(UPDATE_STATE.read_text())
except Exception:
pass
return {
"current_version": read_current_version(),
"latest_version": None,
"last_check": None,
"status": "unknown",
"message": "",
"auto_check": False,
"in_progress": False,
"progress": None,
"channel": os.environ.get("PIKIT_CHANNEL", "dev"),
}
def save_update_state(state: dict):
UPDATE_STATE_DIR.mkdir(parents=True, exist_ok=True)
UPDATE_STATE.write_text(json.dumps(state, indent=2))
def _auth_token():
return os.environ.get("PIKIT_AUTH_TOKEN") or AUTH_TOKEN
def _gitea_latest_manifest(target: str):
"""
Fallback: when a manifest URL 404s, try hitting the Gitea API to grab the
latest release asset named manifest.json.
"""
try:
# target like https://host/owner/repo/releases/download/vX/manifest.json
parts = target.split("/")
if "releases" not in parts:
return None
idx = parts.index("releases")
if idx < 2:
return None
base = "/".join(parts[:3]) # scheme + host
owner = parts[idx - 2]
repo = parts[idx - 1]
api_url = f"{base}/api/v1/repos/{owner}/{repo}/releases/latest"
req = urllib.request.Request(api_url)
token = _auth_token()
if token:
req.add_header("Authorization", f"token {token}")
resp = urllib.request.urlopen(req, timeout=10)
rel = json.loads(resp.read().decode())
assets = rel.get("assets") or []
manifest_asset = next((a for a in assets if a.get("name") == "manifest.json"), None)
if manifest_asset and manifest_asset.get("browser_download_url"):
# Download that manifest
return fetch_manifest(manifest_asset["browser_download_url"])
except Exception:
return None
return None
def fetch_manifest(url: str = None):
target = url or os.environ.get("PIKIT_MANIFEST_URL") or DEFAULT_MANIFEST_URL
req = urllib.request.Request(target)
token = _auth_token()
if token:
req.add_header("Authorization", f"token {token}")
try:
resp = urllib.request.urlopen(req, timeout=10)
data = resp.read()
return json.loads(data.decode())
except urllib.error.HTTPError as e:
if e.code == 404:
alt = _gitea_latest_manifest(target)
if alt:
return alt
raise
def fetch_manifest_for_channel(channel: str):
"""
For stable: use normal manifest (latest non-prerelease).
For dev: try normal manifest; if it points to stable, fetch latest prerelease manifest via Gitea API.
"""
channel = channel or "dev"
base_manifest_url = os.environ.get("PIKIT_MANIFEST_URL") or DEFAULT_MANIFEST_URL
manifest = None
try:
manifest = fetch_manifest(base_manifest_url)
except Exception:
manifest = None
# If we already have a manifest and channel is stable, return it
if manifest and channel == "stable":
return manifest
# If dev channel and manifest is dev, return it
if manifest:
version = manifest.get("version") or manifest.get("latest_version")
if channel == "dev" and version and "dev" in str(version):
return manifest
# Try Gitea API for latest release (include prerelease)
try:
parts = base_manifest_url.split("/")
if "releases" not in parts:
if manifest:
return manifest
return fetch_manifest(base_manifest_url)
idx = parts.index("releases")
owner = parts[idx - 2]
repo = parts[idx - 1]
base = "/".join(parts[:3])
api_url = f"{base}/api/v1/repos/{owner}/{repo}/releases"
req = urllib.request.Request(api_url)
token = _auth_token()
if token:
req.add_header("Authorization", f"token {token}")
resp = urllib.request.urlopen(req, timeout=10)
releases = json.loads(resp.read().decode())
def pick(predicate):
for r in releases:
if predicate(r):
asset = next((a for a in r.get("assets", []) if a.get("name") == "manifest.json"), None)
if asset and asset.get("browser_download_url"):
return fetch_manifest(asset["browser_download_url"])
return None
if channel == "dev":
m = pick(lambda r: r.get("prerelease") is True)
if m:
return m
m = pick(lambda r: r.get("prerelease") is False)
if m:
return m
except Exception:
pass
# last resort: return whatever manifest we had
if manifest:
return manifest
raise RuntimeError("No manifest found for channel")
def download_file(url: str, dest: pathlib.Path):
ensure_dir(dest.parent)
req = urllib.request.Request(url)
token = _auth_token()
if token:
req.add_header("Authorization", f"token {token}")
with urllib.request.urlopen(req, timeout=60) as resp, dest.open("wb") as f:
shutil.copyfileobj(resp, f)
return dest
def fetch_text_with_auth(url: str):
req = urllib.request.Request(url)
token = _auth_token()
if token:
req.add_header("Authorization", f"token {token}")
with urllib.request.urlopen(req, timeout=10) as resp:
return resp.read().decode()
def check_for_update():
state = load_update_state()
lock = acquire_lock()
if lock is None:
state["status"] = "error"
state["message"] = "Another update is running"
save_update_state(state)
return state
diag_log("info", "Update check started", {"channel": state.get("channel") or "dev"})
state["in_progress"] = True
state["progress"] = "Checking for updates…"
save_update_state(state)
try:
manifest = fetch_manifest_for_channel(state.get("channel") or "dev")
latest = manifest.get("version") or manifest.get("latest_version")
state["latest_version"] = latest
state["last_check"] = datetime.datetime.utcnow().isoformat() + "Z"
channel = state.get("channel") or "dev"
if channel == "stable" and latest and "dev" in str(latest):
state["status"] = "up_to_date"
state["message"] = "Dev release available; enable dev channel to install."
else:
if latest and latest != state.get("current_version"):
state["status"] = "update_available"
state["message"] = manifest.get("changelog", "Update available")
else:
state["status"] = "up_to_date"
state["message"] = "Up to date"
diag_log("info", "Update check finished", {"status": state["status"], "latest": str(latest)})
except Exception as e:
state["status"] = "up_to_date"
state["message"] = f"Could not reach update server: {e}"
state["last_check"] = datetime.datetime.utcnow().isoformat() + "Z"
diag_log("error", "Update check failed", {"error": str(e)})
finally:
state["in_progress"] = False
state["progress"] = None
save_update_state(state)
if lock:
release_lock(lock)
return state
def apply_update_stub():
"""Download + install release tarball with backup/rollback."""
state = load_update_state()
if state.get("in_progress"):
state["message"] = "Update already in progress"
save_update_state(state)
return state
lock = acquire_lock()
if lock is None:
state["status"] = "error"
state["message"] = "Another update is running"
save_update_state(state)
return state
manifest = None
state["in_progress"] = True
state["status"] = "in_progress"
state["progress"] = "Starting update…"
save_update_state(state)
diag_log("info", "Update apply started", {"channel": state.get("channel") or "dev"})
try:
channel = state.get("channel") or os.environ.get("PIKIT_CHANNEL", "dev")
manifest = fetch_manifest_for_channel(channel)
latest = manifest.get("version") or manifest.get("latest_version")
if not latest:
raise RuntimeError("Manifest missing version")
# Backup current BEFORE download/install to guarantee rollback point
ts = datetime.datetime.utcnow().strftime("%Y%m%d-%H%M%S")
backup_dir = BACKUP_ROOT / ts
ensure_dir(backup_dir)
# Backup web and api
if WEB_ROOT.exists():
ensure_dir(backup_dir / "pikit-web")
shutil.copytree(WEB_ROOT, backup_dir / "pikit-web", dirs_exist_ok=True)
if API_PATH.exists():
shutil.copy2(API_PATH, backup_dir / "pikit-api.py")
if VERSION_FILE.exists():
shutil.copy2(VERSION_FILE, backup_dir / "version.txt")
prune_backups(keep=1)
# Paths
bundle_url = manifest.get("bundle") or manifest.get("url")
if not bundle_url:
raise RuntimeError("Manifest missing bundle url")
stage_dir = TMP_ROOT / latest
bundle_path = stage_dir / "bundle.tar.gz"
ensure_dir(stage_dir)
state["progress"] = "Downloading release…"
save_update_state(state)
download_file(bundle_url, bundle_path)
diag_log("debug", "Bundle downloaded", {"url": bundle_url, "path": str(bundle_path)})
# Verify hash if provided
expected_hash = None
for f in manifest.get("files", []):
if f.get("path") == "bundle.tar.gz" and f.get("sha256"):
expected_hash = f["sha256"]
break
if expected_hash:
got = sha256_file(bundle_path)
if got.lower() != expected_hash.lower():
raise RuntimeError("Bundle hash mismatch")
diag_log("debug", "Bundle hash verified", {"expected": expected_hash})
state["progress"] = "Staging files…"
save_update_state(state)
# Extract
with tarfile.open(bundle_path, "r:gz") as tar:
tar.extractall(stage_dir)
# Deploy from staging
staged_web = stage_dir / "pikit-web"
if staged_web.exists():
shutil.rmtree(WEB_ROOT, ignore_errors=True)
shutil.copytree(staged_web, WEB_ROOT)
staged_api = stage_dir / "pikit-api.py"
if staged_api.exists():
shutil.copy2(staged_api, API_PATH)
os.chmod(API_PATH, 0o755)
# Restart services (best-effort)
for svc in ("pikit-api.service", "dietpi-dashboard-frontend.service"):
subprocess.run(["systemctl", "restart", svc], check=False)
VERSION_FILE.parent.mkdir(parents=True, exist_ok=True)
VERSION_FILE.write_text(str(latest))
state["current_version"] = str(latest)
state["latest_version"] = str(latest)
state["status"] = "up_to_date"
state["message"] = "Update installed"
state["progress"] = None
save_update_state(state)
diag_log("info", "Update applied", {"version": str(latest)})
except urllib.error.HTTPError as e:
state["status"] = "error"
state["message"] = f"No release available ({e.code})"
diag_log("error", "Update apply HTTP error", {"code": e.code})
except Exception as e:
state["status"] = "error"
state["message"] = f"Update failed: {e}"
state["progress"] = None
save_update_state(state)
diag_log("error", "Update apply failed", {"error": str(e)})
# Attempt rollback if backup exists
backup = choose_rollback_backup()
if backup:
try:
restore_backup(backup)
state["current_version"] = read_current_version()
state["message"] += f" (rolled back to backup {backup.name})"
save_update_state(state)
diag_log("info", "Rollback after failed update", {"backup": backup.name})
except Exception as re:
state["message"] += f" (rollback failed: {re})"
save_update_state(state)
diag_log("error", "Rollback after failed update failed", {"error": str(re)})
finally:
state["in_progress"] = False
state["progress"] = None
save_update_state(state)
if lock:
release_lock(lock)
return state
def rollback_update_stub():
state = load_update_state()
lock = acquire_lock()
if lock is None:
state["status"] = "error"
state["message"] = "Another update is running"
save_update_state(state)
return state
state["in_progress"] = True
state["status"] = "in_progress"
state["progress"] = "Rolling back…"
save_update_state(state)
diag_log("info", "Rollback started")
backup = choose_rollback_backup()
if not backup:
state["status"] = "error"
state["message"] = "No backup available to rollback."
state["in_progress"] = False
state["progress"] = None
save_update_state(state)
release_lock(lock)
return state
try:
restore_backup(backup)
state["status"] = "up_to_date"
state["current_version"] = read_current_version()
state["latest_version"] = state.get("latest_version") or state["current_version"]
ver = get_backup_version(backup)
suffix = f" (version {ver})" if ver else ""
state["message"] = f"Rolled back to backup {backup.name}{suffix}"
diag_log("info", "Rollback complete", {"backup": backup.name, "version": ver})
except Exception as e:
state["status"] = "error"
state["message"] = f"Rollback failed: {e}"
diag_log("error", "Rollback failed", {"error": str(e)})
state["in_progress"] = False
state["progress"] = None
save_update_state(state)
release_lock(lock)
return state
def start_background_task(mode: str):
"""
Kick off a background update/rollback via systemd-run so nginx/API restarts
do not break the caller connection.
mode: "apply" or "rollback"
"""
assert mode in ("apply", "rollback"), "invalid mode"
unit = f"pikit-update-{mode}"
flag = f"--{mode}-update"
cmd = ["systemd-run", "--unit", unit, "--quiet"]
# Pass manifest URL/token if set in environment
if DEFAULT_MANIFEST_URL:
cmd += [f"--setenv=PIKIT_MANIFEST_URL={DEFAULT_MANIFEST_URL}"]
if AUTH_TOKEN:
cmd += [f"--setenv=PIKIT_AUTH_TOKEN={AUTH_TOKEN}"]
cmd += ["/usr/bin/env", "python3", str(API_PATH), flag]
subprocess.run(cmd, check=False)
def acquire_lock():
try:
ensure_dir(UPDATE_LOCK.parent)
lockfile = UPDATE_LOCK.open("w")
fcntl.flock(lockfile.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB)
lockfile.write(str(os.getpid()))
lockfile.flush()
return lockfile
except Exception:
return None
def release_lock(lockfile):
try:
fcntl.flock(lockfile.fileno(), fcntl.LOCK_UN)
lockfile.close()
UPDATE_LOCK.unlink(missing_ok=True)
except Exception:
pass
def prune_backups(keep: int = 2):
if keep < 1:
keep = 1
backups = list_backups()
for old in backups[keep:]:
shutil.rmtree(old, ignore_errors=True)
def list_backups():
"""Return backups sorted by mtime (newest first)."""
ensure_dir(BACKUP_ROOT)
backups = [p for p in BACKUP_ROOT.iterdir() if p.is_dir()]
backups.sort(key=lambda p: p.stat().st_mtime, reverse=True)
return backups
def get_backup_version(path: pathlib.Path):
vf = path / "version.txt"
if not vf.exists():
web_version = path / "pikit-web" / "data" / "version.json"
if not web_version.exists():
return None
try:
return json.loads(web_version.read_text()).get("version")
except Exception:
return None
try:
return vf.read_text().strip()
except Exception:
return None
def choose_rollback_backup():
"""
Pick the most recent backup whose version differs from the currently
installed version. If none differ, fall back to the newest backup.
"""
backups = list_backups()
if not backups:
return None
current = read_current_version()
for b in backups:
ver = get_backup_version(b)
if ver and ver != current:
return b
return backups[0]
def restore_backup(target: pathlib.Path):
if (target / "pikit-web").exists():
shutil.rmtree(WEB_ROOT, ignore_errors=True)
shutil.copytree(target / "pikit-web", WEB_ROOT, dirs_exist_ok=True)
if (target / "pikit-api.py").exists():
shutil.copy2(target / "pikit-api.py", API_PATH)
os.chmod(API_PATH, 0o755)
VERSION_FILE.parent.mkdir(parents=True, exist_ok=True)
if (target / "version.txt").exists():
shutil.copy2(target / "version.txt", VERSION_FILE)
else:
# Fall back to the version recorded in the web bundle
ver = get_backup_version(target)
if ver:
VERSION_FILE.write_text(str(ver))
for svc in ("pikit-api.service", "dietpi-dashboard-frontend.service"):
subprocess.run(["systemctl", "restart", svc], check=False)
class Handler(BaseHTTPRequestHandler):
"""Minimal JSON API for the dashboard (status, services, updates, reset)."""
def _send(self, code, data):
body = json.dumps(data).encode()
self.send_response(code)
self.send_header("Content-Type", "application/json")
self.send_header("Content-Length", str(len(body)))
self.end_headers()
self.wfile.write(body)
def log_message(self, fmt, *args):
return
def do_GET(self):
if self.path.startswith("/api/status"):
uptime = float(open("/proc/uptime").read().split()[0])
load1, load5, load15 = os.getloadavg()
meminfo = {}
for ln in open("/proc/meminfo"):
k, v = ln.split(":", 1)
meminfo[k] = int(v.strip().split()[0])
total = meminfo.get("MemTotal", 0)//1024
free = meminfo.get("MemAvailable", 0)//1024
disk = shutil.disk_usage("/")
# CPU temperature (best-effort)
cpu_temp = None
for path in ("/sys/class/thermal/thermal_zone0/temp",):
if pathlib.Path(path).exists():
try:
cpu_temp = float(pathlib.Path(path).read_text().strip())/1000.0
break
except Exception:
pass
# LAN IP (first non-loopback)
ip_addr = None
try:
out = subprocess.check_output(["hostname", "-I"], text=True).strip()
ip_addr = out.split()[0] if out else None
except Exception:
pass
# OS version
os_ver = "DietPi"
try:
for line in pathlib.Path("/etc/os-release").read_text().splitlines():
if line.startswith("PRETTY_NAME="):
os_ver = line.split("=",1)[1].strip().strip('"')
break
except Exception:
pass
updates_state = auto_updates_state()
updates_config = read_updates_config(updates_state)
services = []
for svc in load_services():
svc = dict(svc)
port = svc.get("port")
if port:
svc["online"] = port_online("127.0.0.1", port)
svc["firewall_open"] = ufw_status_allows(port)
services.append(svc)
data = {
"hostname": socket.gethostname(),
"uptime_seconds": uptime,
"load": [load1, load5, load15],
"memory_mb": {"total": total, "free": free},
"disk_mb": {"total": disk.total//1024//1024, "free": disk.free//1024//1024},
"cpu_temp_c": cpu_temp,
"lan_ip": ip_addr,
"os_version": os_ver,
"auto_updates_enabled": updates_state.get("enabled", False),
"auto_updates": updates_state,
"updates_config": updates_config,
"reboot_required": reboot_required(),
"ready": READY_FILE.exists(),
"services": services
}
self._send(200, data)
elif self.path.startswith("/api/services"):
services = []
for svc in load_services():
svc = dict(svc)
port = svc.get("port")
if port:
svc["online"] = port_online("127.0.0.1", port)
svc["firewall_open"] = ufw_status_allows(port)
# Rebuild URL with preferred host (adds .local)
host = default_host()
path = normalize_path(svc.get("path"))
scheme = svc.get("scheme") or ("https" if detect_https(host, port) else "http")
svc["scheme"] = scheme
svc["url"] = f"{scheme}://{host}:{port}{path}"
services.append(svc)
self._send(200, {"services": services})
elif self.path.startswith("/api/updates/auto"):
state = auto_updates_state()
self._send(200, {"enabled": state.get("enabled", False), "details": state})
elif self.path.startswith("/api/updates/config"):
cfg = read_updates_config()
self._send(200, cfg)
elif self.path.startswith("/api/update/status"):
state = load_update_state()
state["current_version"] = read_current_version()
state["channel"] = state.get("channel", os.environ.get("PIKIT_CHANNEL", "dev"))
self._send(200, state)
elif self.path.startswith("/api/update/changelog"):
# Fetch changelog text (URL param ?url= overrides manifest changelog)
try:
qs = urllib.parse.urlparse(self.path).query
params = urllib.parse.parse_qs(qs)
url = params.get("url", [None])[0]
if not url:
manifest = fetch_manifest()
url = manifest.get("changelog")
if not url:
return self._send(404, {"error": "no changelog url"})
text = fetch_text_with_auth(url)
return self._send(200, {"text": text})
except Exception as e:
return self._send(500, {"error": str(e)})
elif self.path.startswith("/api/diag/log"):
entries = diag_read()
state = _load_diag_state()
return self._send(200, {"entries": entries, "state": state})
else:
self._send(404, {"error": "not found"})
def do_POST(self):
length = int(self.headers.get("Content-Length", 0))
payload = json.loads(self.rfile.read(length) or "{}")
if self.path.startswith("/api/reset"):
if payload.get("confirm") == "YES":
self._send(200, {"message": "Resetting and rebooting..."})
dbg("Factory reset triggered via API")
diag_log("info", "Factory reset requested")
factory_reset()
else:
self._send(400, {"error": "type YES to confirm"})
return
if self.path.startswith("/api/updates/auto"):
enable = bool(payload.get("enable"))
set_auto_updates(enable)
dbg(f"Auto updates set to {enable}")
state = auto_updates_state()
diag_log("info", "Auto updates toggled", {"enabled": enable})
return self._send(200, {"enabled": state.get("enabled", False), "details": state})
if self.path.startswith("/api/updates/config"):
try:
cfg = set_updates_config(payload or {})
dbg(f"Update settings applied: {cfg}")
diag_log("info", "Update settings saved", cfg)
return self._send(200, cfg)
except Exception as e:
dbg(f"Failed to apply updates config: {e}")
diag_log("error", "Update settings save failed", {"error": str(e)})
return self._send(500, {"error": str(e)})
if self.path.startswith("/api/update/check"):
state = check_for_update()
return self._send(200, state)
if self.path.startswith("/api/update/apply"):
# Start background apply to avoid breaking caller during service restart
start_background_task("apply")
state = load_update_state()
state["status"] = "in_progress"
state["message"] = "Starting background apply"
save_update_state(state)
return self._send(202, state)
if self.path.startswith("/api/update/rollback"):
start_background_task("rollback")
state = load_update_state()
state["status"] = "in_progress"
state["message"] = "Starting rollback"
save_update_state(state)
return self._send(202, state)
if self.path.startswith("/api/update/auto"):
state = load_update_state()
state["auto_check"] = bool(payload.get("enable"))
save_update_state(state)
diag_log("info", "Release auto-check toggled", {"enabled": state["auto_check"]})
return self._send(200, state)
if self.path.startswith("/api/update/channel"):
chan = payload.get("channel", "dev")
if chan not in ("dev", "stable"):
return self._send(400, {"error": "channel must be dev or stable"})
state = load_update_state()
state["channel"] = chan
save_update_state(state)
diag_log("info", "Release channel set", {"channel": chan})
return self._send(200, state)
if self.path.startswith("/api/diag/log/level"):
state = _save_diag_state(payload.get("enabled"), payload.get("level"))
diag_log("info", "Diag level updated", state)
return self._send(200, {"state": state})
if self.path.startswith("/api/diag/log/clear"):
try:
DIAG_LOG_FILE.unlink(missing_ok=True)
except Exception:
pass
diag_log("info", "Diag log cleared")
return self._send(200, {"cleared": True, "state": _load_diag_state()})
if self.path.startswith("/api/services/add"):
name = payload.get("name")
port = int(payload.get("port", 0))
if not name or not port:
return self._send(400, {"error": "name and port required"})
if port in CORE_PORTS and name != CORE_NAME:
return self._send(400, {"error": f"Port {port} is reserved for {CORE_NAME}"})
services = load_services()
if any(s.get("port") == port for s in services):
return self._send(400, {"error": "port already exists"})
host = default_host()
scheme = payload.get("scheme")
if scheme not in ("http", "https"):
scheme = "https" if detect_https(host, port) else "http"
notice = (payload.get("notice") or "").strip()
notice_link = (payload.get("notice_link") or "").strip()
self_signed = bool(payload.get("self_signed"))
path = normalize_path(payload.get("path"))
svc = {"name": name, "port": port, "scheme": scheme, "url": f"{scheme}://{host}:{port}{path}"}
if notice:
svc["notice"] = notice
if notice_link:
svc["notice_link"] = notice_link
if self_signed:
svc["self_signed"] = True
if path:
svc["path"] = path
services.append(svc)
save_services(services)
try:
allow_port_lan(port)
except FirewallToolMissing as e:
return self._send(500, {"error": str(e)})
diag_log("info", "Service added", {"name": name, "port": port, "scheme": scheme})
return self._send(200, {"services": services, "message": f"Added {name} on port {port} and opened LAN firewall"})
if self.path.startswith("/api/services/remove"):
port = int(payload.get("port", 0))
if not port:
return self._send(400, {"error": "port required"})
if port in CORE_PORTS:
return self._send(400, {"error": f"Cannot remove core service on port {port}"})
services = [s for s in load_services() if s.get("port") != port]
try:
remove_port_lan(port)
except FirewallToolMissing as e:
return self._send(500, {"error": str(e)})
save_services(services)
diag_log("info", "Service removed", {"port": port})
return self._send(200, {"services": services, "message": f"Removed service on port {port}"})
if self.path.startswith("/api/services/update"):
port = int(payload.get("port", 0))
new_name = payload.get("name")
new_port = payload.get("new_port")
new_scheme = payload.get("scheme")
notice = payload.get("notice")
notice_link = payload.get("notice_link")
new_path = payload.get("path")
self_signed = payload.get("self_signed")
services = load_services()
updated = False
for svc in services:
if svc.get("port") == port:
if new_name:
# Prevent renaming core service to something else if still on core port
if port in CORE_PORTS and new_name != CORE_NAME:
return self._send(400, {"error": f"Core service on port {port} must stay named {CORE_NAME}"})
svc["name"] = new_name
target_port = svc.get("port")
port_changed = False
if new_port is not None:
new_port_int = int(new_port)
if new_port_int != port:
if new_port_int in CORE_PORTS and svc.get("name") != CORE_NAME:
return self._send(400, {"error": f"Port {new_port_int} is reserved for {CORE_NAME}"})
if any(s.get("port") == new_port_int and s is not svc for s in services):
return self._send(400, {"error": "new port already in use"})
try:
remove_port_lan(port)
allow_port_lan(new_port_int)
except FirewallToolMissing as e:
return self._send(500, {"error": str(e)})
svc["port"] = new_port_int
target_port = new_port_int
port_changed = True
host = default_host()
if new_path is not None:
path = normalize_path(new_path)
if path:
svc["path"] = path
elif "path" in svc:
svc.pop("path", None)
else:
path = normalize_path(svc.get("path"))
if path:
svc["path"] = path
if new_scheme:
scheme = new_scheme if new_scheme in ("http", "https") else None
else:
scheme = svc.get("scheme")
if not scheme or scheme == "auto":
scheme = "https" if detect_https(host, target_port) else "http"
svc["scheme"] = scheme
svc["url"] = f"{scheme}://{host}:{target_port}{path}"
if notice is not None:
text = (notice or "").strip()
if text:
svc["notice"] = text
elif "notice" in svc:
svc.pop("notice", None)
if notice_link is not None:
link = (notice_link or "").strip()
if link:
svc["notice_link"] = link
elif "notice_link" in svc:
svc.pop("notice_link", None)
if self_signed is not None:
if bool(self_signed):
svc["self_signed"] = True
else:
svc.pop("self_signed", None)
updated = True
break
if not updated:
return self._send(404, {"error": "service not found"})
save_services(services)
diag_log("info", "Service updated", {"port": target_port, "name": new_name or None, "scheme": scheme})
return self._send(200, {"services": services, "message": "Service updated"})
self._send(404, {"error": "not found"})
def main():
load_services()
server = HTTPServer((HOST, PORT), Handler)
server.serve_forever()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Pi-Kit API / updater")
parser.add_argument("--apply-update", action="store_true", help="Apply latest release (non-HTTP mode)")
parser.add_argument("--check-update", action="store_true", help="Check for latest release (non-HTTP mode)")
parser.add_argument("--rollback-update", action="store_true", help="Rollback to last backup (non-HTTP mode)")
args = parser.parse_args()
if args.apply_update:
apply_update_stub()
sys.exit(0)
if args.check_update:
check_for_update()
sys.exit(0)
if args.rollback_update:
rollback_update_stub()
sys.exit(0)
main()