version = 9

import asyncio
import base64
import ipaddress
import json
import os
import ssl
import subprocess
import sys
import time
import traceback
import urllib.parse
from collections import deque
from datetime import datetime
from typing import Any, Optional

# =========================
# HOSTER SETTINGS (edit these)
# =========================
USERNAME = "your_username"
PASSWORD = "your_password"
WEBSITE_DIR = r"C:\path\to\site"  # or "/home/user/site"

# Custom 404 page + toggle
# If CUSTOM_404_ENABLED is True:
# - CUSTOM_404_PATH can be relative to WEBSITE_DIR (recommended), e.g. "404.html"
# - or an absolute path, e.g. r"C:\path\to\404.html"
Enabled = False
CUSTOM_404_ENABLED = Enabled
CUSTOM_404_PATH = r"404.html"

# Simple public IP protection
# Every IP_PROTECTION_REFRESH_SECONDS, the agent refreshes the host public IP.
# Before sending a response, it replaces exact matches of that IP with IP_PROTECTION_REPLACEMENT.
Enabled = False
IP_PROTECTION_ENABLED = Enabled
IP_PROTECTION_REFRESH_SECONDS = 300
IP_PROTECTION_REPLACEMENT = "BLOCKED"

# Auto start local Python scripts when the agent starts (backend helpers, APIs, etc.)
# These are started once per agent process (not on every reconnect).
Enabled = False
AUTO_START_PYTHON_SCRIPTS_ENABLED = Enabled
AUTO_START_PYTHON_SCRIPTS = [
    # r"C:\path\to\backend.py",
    # r"C:\path\to\another_service.py",
]

# Modes:
#   "static" -> only serve files from WEBSITE_DIR
#   "proxy"  -> forward ALL requests to your local server (UPSTREAM_BASE)
#   "hybrid" -> serve files EXCEPT certain prefixes which get proxied (recommended)
MODE = "hybrid"

# Where your backend runs on YOUR PC (bind it to 127.0.0.1 / localhost)
UPSTREAM_BASE = "http://127.0.0.1:5000"

# In hybrid mode, any path starting with one of these prefixes will proxy to UPSTREAM_BASE.
# Paths are relative to your site root on your subdomain, e.g.:
#   https://USERNAME.hostish.site/api/...
PROXY_PREFIXES = [
    "api/",
]

# Hostish VPS WebSocket endpoint
WS_URL = "wss://hostish.site/ws/host"

# Where this agent checks for latest official agent version
LATEST_AGENT_URL = "https://hostish.site/downloads/agent.py"

# Logging / hoster QoL
PRINT_REQUEST_HEADERS = False      # set True if you want header debug
PRINT_REQUEST_BODY_PREVIEW = False # set True for tiny body preview logs
BODY_PREVIEW_BYTES = 120

# QoL: periodic summary + repeated error suppression
Enabled = False
PERIODIC_SUMMARY_ENABLED = Enabled
PERIODIC_SUMMARY_SECONDS = 60
REPEAT_ERROR_SUMMARY_EVERY = 10
# =========================


# NOTE: This script uses the "websockets" package (not in Python stdlib).
# Install once (host machine):
#   python -m pip install websockets
try:
    import websockets
except ImportError:
    print("Missing dependency: websockets")
    print("Install with: python -m pip install websockets")
    sys.exit(1)


HOP_BY_HOP = {
    "connection",
    "keep-alive",
    "proxy-authenticate",
    "proxy-authorization",
    "te",
    "trailers",
    "transfer-encoding",
    "upgrade",
}

MAX_FILE_BYTES = 10 * 1024 * 1024   # 10MB safety cap for static files
MAX_PROXY_BYTES = 20 * 1024 * 1024  # 20MB safety cap for proxied responses
UPSTREAM_TIMEOUT_SECONDS = 20

# Runtime stats (hoster-visible logs)
START_TS = time.time()
CONNECT_COUNT = 0
TOTAL_REQS = 0
TOTAL_BYTES_SENT = 0
RECENT_EVENTS = deque(maxlen=50)
CURRENT_REQS = 0

# Runtime counters (QoL)
STATIC_REQS = 0
PROXY_REQS = 0
STATUS_COUNTS = {}
LAST_SUMMARY_SNAPSHOT = {"reqs": 0, "bytes": 0}

# Repeated error suppression
_REPEAT_ERRORS = {}  # key -> {"count": int, "last": str, "last_ts": float}

# Auto started companion scripts
_MANAGED_SCRIPT_PROCS = []  # list[{"path": str, "proc": Popen, "started_at": float}]
_AUTO_START_ALREADY_RAN = False

# Agent version cache / bandwidth saver (public endpoint uses this)
VERSION_CACHE_COOLDOWN_SECONDS = 30
_version_cache = {
    "last_fetch_ts": 0.0,
    "newest_version": None,   # int | None
    "status": "unknown",      # "up_to_date" | "outdated" | "unknown"
    "error": None,            # str | None
}

# Public IP cache (used by IP protection)
PUBLIC_IP_CHECK_URLS = [
    "https://api.ipify.org",
    "https://checkip.amazonaws.com",
    "https://ipv4.icanhazip.com",
]
_public_ip_cache = {
    "last_fetch_ts": 0.0,
    "ip": None,      # str | None
    "error": None,   # str | None
}


def ts() -> str:
    return datetime.now().strftime("%Y-%m-%d %H:%M:%S")


def log(msg: str) -> None:
    print(f"[{ts()}] {msg}", flush=True)


def log_event(msg: str) -> None:
    RECENT_EVENTS.appendleft(f"[{ts()}] {msg}")
    log(msg)


def fmt_bytes(n: int) -> str:
    n = int(n or 0)
    if n < 1024:
        return f"{n} B"
    if n < 1024 * 1024:
        return f"{n/1024:.1f} KB"
    if n < 1024 * 1024 * 1024:
        return f"{n/(1024*1024):.1f} MB"
    return f"{n/(1024*1024*1024):.2f} GB"


def uptime_s() -> int:
    return int(time.time() - START_TS)


def _status_count_inc(code: int) -> None:
    k = str(int(code))
    STATUS_COUNTS[k] = STATUS_COUNTS.get(k, 0) + 1


def _status_count_top() -> str:
    if not STATUS_COUNTS:
        return "none"
    parts = []
    for k, v in sorted(STATUS_COUNTS.items(), key=lambda kv: (-kv[1], kv[0]))[:6]:
        parts.append(f"{k}:{v}")
    return ", ".join(parts)


def log_repeatable_error(key: str, message: str) -> None:
    """
    Suppress repeated identical errors. Prints the first occurrence,
    then prints summaries every N repeats.
    """
    now_ts = time.time()
    entry = _REPEAT_ERRORS.get(key)
    if entry is None:
        _REPEAT_ERRORS[key] = {"count": 1, "last": message, "last_ts": now_ts}
        log_event(message)
        return

    if entry["last"] == message:
        entry["count"] += 1
        entry["last_ts"] = now_ts
        if entry["count"] % REPEAT_ERROR_SUMMARY_EVERY == 0:
            log_event(f"[repeat x{entry['count']}] {message}")
    else:
        prev_count = entry["count"]
        prev_msg = entry["last"]
        if prev_count > 1:
            log(f"Previous repeated error ended after {prev_count}x: {prev_msg}")
        entry["count"] = 1
        entry["last"] = message
        entry["last_ts"] = now_ts
        log_event(message)


def _normalize_ip_text(raw: str) -> Optional[str]:
    s = (raw or "").strip()
    if not s:
        return None
    s = s.splitlines()[0].strip().split()[0].strip()
    try:
        ipaddress.ip_address(s)
        return s
    except Exception:
        return None


def fetch_public_ip(force: bool = False) -> Optional[str]:
    """
    Fetch and cache the current public IP.
    Uses a simple cooldown to save bandwidth.
    """
    now_ts = time.time()
    age = now_ts - float(_public_ip_cache.get("last_fetch_ts", 0.0) or 0.0)

    if (not force) and age < int(IP_PROTECTION_REFRESH_SECONDS):
        return _public_ip_cache.get("ip")

    import urllib.request

    last_err = None
    for url in PUBLIC_IP_CHECK_URLS:
        try:
            req = urllib.request.Request(
                url,
                headers={"User-Agent": f"HostishAgent/{version}"},
                method="GET",
            )
            with urllib.request.urlopen(req, timeout=8) as resp:
                raw = resp.read(256).decode("utf-8", errors="replace")
            ip = _normalize_ip_text(raw)
            if not ip:
                raise ValueError("invalid IP response")
            old_ip = _public_ip_cache.get("ip")
            _public_ip_cache["last_fetch_ts"] = now_ts
            _public_ip_cache["ip"] = ip
            _public_ip_cache["error"] = None
            if old_ip != ip:
                if old_ip is None:
                    log("IP protection cached public IP")
                else:
                    log("IP protection detected public IP change and refreshed cache")
            return ip
        except Exception as e:
            last_err = f"{type(e).__name__}: {e}"

    _public_ip_cache["last_fetch_ts"] = now_ts
    _public_ip_cache["error"] = last_err
    return _public_ip_cache.get("ip")


def _public_ip_status_text() -> str:
    if not IP_PROTECTION_ENABLED:
        return "IP protection OFF"
    ip = _public_ip_cache.get("ip")
    if ip:
        return "IP protection ON (public IP cached)"
    return "IP protection ON (public IP not cached yet)"


async def periodic_public_ip_refresh_loop() -> None:
    while True:
        await asyncio.sleep(max(30, int(IP_PROTECTION_REFRESH_SECONDS)))
        if not IP_PROTECTION_ENABLED:
            continue
        old_ip = _public_ip_cache.get("ip")
        new_ip = await asyncio.to_thread(fetch_public_ip, True)
        if new_ip is None:
            err = _public_ip_cache.get("error") or "unknown error"
            log_repeatable_error("public_ip_refresh", f"IP protection refresh failed: {err}")
        elif old_ip != new_ip:
            log_event("IP protection cache updated")


def _mask_public_ip_in_bytes(data: bytes, public_ip: str) -> tuple[bytes, int]:
    if not data or not public_ip:
        return data, 0
    src = public_ip.encode("utf-8")
    if not src:
        return data, 0
    count = data.count(src)
    if count <= 0:
        return data, 0
    # Prevent size oracle by perfectly padding the replacement
    repl_str = IP_PROTECTION_REPLACEMENT.ljust(len(public_ip))
    repl = repl_str.encode("utf-8")
    return data.replace(src, repl), count


def _mask_public_ip_in_headers(headers: dict, public_ip: str) -> tuple[dict, int]:
    if not headers or not public_ip:
        return headers or {}, 0

    replaced = 0
    out = {}
    repl_str = IP_PROTECTION_REPLACEMENT.ljust(len(public_ip))
    for k, v in (headers or {}).items():
        ks = str(k)
        vs = str(v)
        c = vs.count(public_ip)
        if c > 0:
            vs = vs.replace(public_ip, repl_str)
            replaced += c
        out[ks] = vs
    return out, replaced


def _apply_ip_protection_to_response(
    req_id: Any,
    status: int,
    headers: dict,
    body: bytes,
    meta: dict,
) -> tuple[int, dict, bytes, dict]:
    if not IP_PROTECTION_ENABLED:
        return status, headers, body, meta

    public_ip = fetch_public_ip(force=False)
    if not public_ip:
        return status, headers, body, meta

    # Skip body replacement for binary files to prevent corruption
    content_type = ""
    for k, v in (headers or {}).items():
        if str(k).lower() == "content-type":
            content_type = str(v).lower()
            break

    is_binary = any(b in content_type for b in ["image", "video", "audio", "octet-stream", "pdf", "zip", "font"])

    new_headers, header_hits = _mask_public_ip_in_headers(headers or {}, public_ip)
    
    if is_binary:
        new_body = body
        body_hits = 0
    else:
        new_body, body_hits = _mask_public_ip_in_bytes(body or b"", public_ip)

    total_hits = header_hits + body_hits
    if total_hits <= 0:
        return status, new_headers, new_body, meta

    for hk in list(new_headers.keys()):
        if hk.lower() == "content-length":
            new_headers[hk] = str(len(new_body))

    meta2 = dict(meta or {})
    meta2["ip_protection_hits"] = total_hits
    meta2["ip_protection_body_hits"] = body_hits
    meta2["ip_protection_header_hits"] = header_hits
    meta2["ip_protection_applied"] = True

    log_event(
        f"IP protection masked public IP in response id={req_id} "
        f"status={status} body_hits={body_hits} header_hits={header_hits}"
    )
    return status, new_headers, new_body, meta2


def _normalize_script_path(p: str) -> str:
    p = os.path.expanduser(str(p).strip())
    return os.path.abspath(p)


def start_configured_python_scripts_once() -> None:
    """
    Starts each Python script in AUTO_START_PYTHON_SCRIPTS one time when the agent boots.
    Does not run again on tunnel reconnects.
    """
    global _AUTO_START_ALREADY_RAN

    if _AUTO_START_ALREADY_RAN:
        return
    _AUTO_START_ALREADY_RAN = True

    if not AUTO_START_PYTHON_SCRIPTS_ENABLED:
        return

    scripts = AUTO_START_PYTHON_SCRIPTS
    if not isinstance(scripts, (list, tuple)):
        log("AUTO_START_PYTHON_SCRIPTS must be a list or tuple")
        return

    if not scripts:
        log("Auto start scripts enabled, but list is empty")
        return

    started_count = 0

    for raw_path in scripts:
        try:
            if raw_path is None:
                continue

            script_path = _normalize_script_path(raw_path)
            if not script_path.lower().endswith(".py"):
                log(f"Auto start skipped (not .py): {script_path}")
                continue

            if not os.path.isfile(script_path):
                log(f"Auto start skipped (missing file): {script_path}")
                continue

            workdir = os.path.dirname(script_path) or None

            popen_kwargs = {
                "cwd": workdir,
            }

            if os.name == "nt":
                popen_kwargs["creationflags"] = getattr(subprocess, "CREATE_NO_WINDOW", 0)

            proc = subprocess.Popen([sys.executable, script_path], **popen_kwargs)

            _MANAGED_SCRIPT_PROCS.append({
                "path": script_path,
                "proc": proc,
                "started_at": time.time(),
                "exit_logged": False,
            })
            started_count += 1
            log_event(f"Auto started Python script pid={proc.pid} path='{script_path}'")

        except Exception as e:
            log_repeatable_error(
                "auto_start_script",
                f"Auto start failed for script '{raw_path}': {type(e).__name__}: {e}"
            )

    log(f"Auto start complete | started={started_count} requested={len(scripts)}")


async def managed_script_monitor_loop() -> None:
    while True:
        await asyncio.sleep(3)

        if not _MANAGED_SCRIPT_PROCS:
            continue

        for entry in list(_MANAGED_SCRIPT_PROCS):
            proc = entry.get("proc")
            if proc is None:
                continue

            try:
                rc = proc.poll()
            except Exception as e:
                log_repeatable_error(
                    "managed_script_poll",
                    f"Managed script poll failed for '{entry.get('path')}': {type(e).__name__}: {e}"
                )
                continue

            if rc is None:
                continue

            if not entry.get("exit_logged"):
                entry["exit_logged"] = True
                log_event(
                    f"Auto started script exited pid={proc.pid} code={rc} path='{entry.get('path')}'"
                )


def stop_managed_scripts() -> None:
    if not _MANAGED_SCRIPT_PROCS:
        return

    log("Stopping auto started scripts...")

    for entry in _MANAGED_SCRIPT_PROCS:
        proc = entry.get("proc")
        if not proc:
            continue
        try:
            if proc.poll() is None:
                proc.terminate()
        except Exception:
            pass

    deadline = time.time() + 3.0
    while time.time() < deadline:
        all_done = True
        for entry in _MANAGED_SCRIPT_PROCS:
            proc = entry.get("proc")
            if not proc:
                continue
            try:
                if proc.poll() is None:
                    all_done = False
                    break
            except Exception:
                pass
        if all_done:
            break
        time.sleep(0.1)

    for entry in _MANAGED_SCRIPT_PROCS:
        proc = entry.get("proc")
        if not proc:
            continue
        try:
            if proc.poll() is None:
                proc.kill()
        except Exception:
            pass


def print_startup_banner() -> None:
    log("=" * 72)
    log("Hostish agent starting")
    log(f"Version         : {version}")
    log(f"Username        : {USERNAME}")
    log(f"Mode            : {MODE}")
    log(f"Website dir     : {WEBSITE_DIR}")
    log(f"Custom 404      : {'ON' if CUSTOM_404_ENABLED else 'OFF'}")
    log(f"Custom 404 path : {CUSTOM_404_PATH}")
    log(f"IP protection   : {_public_ip_status_text()}")
    log(
        "Auto start py   : "
        + ("ON" if AUTO_START_PYTHON_SCRIPTS_ENABLED else "OFF")
        + f" ({len(AUTO_START_PYTHON_SCRIPTS) if isinstance(AUTO_START_PYTHON_SCRIPTS, (list, tuple)) else 0} configured)"
    )
    if MODE in ("proxy", "hybrid"):
        log(f"Upstream base   : {UPSTREAM_BASE}")
    if MODE == "hybrid":
        log(f"Proxy prefixes  : {PROXY_PREFIXES}")
    log(f"Public site URL : https://{USERNAME}.hostish.site/")
    log(version_status_text())
    log("=" * 72)


def print_summary() -> None:
    log(
        "Summary | "
        f"uptime={uptime_s()}s | "
        f"connections={CONNECT_COUNT} | "
        f"requests={TOTAL_REQS} | "
        f"static={STATIC_REQS} | "
        f"proxy={PROXY_REQS} | "
        f"bytes_sent={fmt_bytes(TOTAL_BYTES_SENT)} | "
        f"inflight={CURRENT_REQS} | "
        f"statuses=[{_status_count_top()}]"
    )


def print_periodic_delta_summary() -> None:
    prev_reqs = LAST_SUMMARY_SNAPSHOT.get("reqs", 0)
    prev_bytes = LAST_SUMMARY_SNAPSHOT.get("bytes", 0)
    delta_reqs = TOTAL_REQS - prev_reqs
    delta_bytes = TOTAL_BYTES_SENT - prev_bytes

    LAST_SUMMARY_SNAPSHOT["reqs"] = TOTAL_REQS
    LAST_SUMMARY_SNAPSHOT["bytes"] = TOTAL_BYTES_SENT

    log(
        "Periodic summary | "
        f"window={PERIODIC_SUMMARY_SECONDS}s | "
        f"reqs={delta_reqs} | "
        f"bytes={fmt_bytes(delta_bytes)} | "
        f"total_reqs={TOTAL_REQS} | "
        f"inflight={CURRENT_REQS} | "
        f"statuses=[{_status_count_top()}]"
    )


async def periodic_summary_loop() -> None:
    while True:
        await asyncio.sleep(max(5, int(PERIODIC_SUMMARY_SECONDS)))
        if PERIODIC_SUMMARY_ENABLED:
            print_periodic_delta_summary()


def _extract_version_from_first_line(line: str) -> Optional[int]:
    """
    expects first line like: version = 6
    """
    if not isinstance(line, str):
        return None
    s = line.strip()
    if not s.lower().startswith("version"):
        return None
    parts = s.split("=", 1)
    if len(parts) != 2:
        return None
    rhs = parts[1].strip()
    rhs = rhs.strip("\"' ")
    if rhs.isdigit():
        return int(rhs)
    return None


def fetch_latest_agent_version(force: bool = False) -> Optional[int]:
    """
    Reads only the first line of LATEST_AGENT_URL and parses version.
    Cached for 30 seconds by default to save bandwidth.
    """
    now_ts = time.time()
    age = now_ts - _version_cache["last_fetch_ts"]

    if (not force) and age < VERSION_CACHE_COOLDOWN_SECONDS and _version_cache["newest_version"] is not None:
        return _version_cache["newest_version"]

    import urllib.request

    try:
        req = urllib.request.Request(
            LATEST_AGENT_URL,
            headers={"User-Agent": f"HostishAgent/{version}"},
            method="GET",
        )
        with urllib.request.urlopen(req, timeout=8) as resp:
            chunk = resp.read(256)
            first_line = chunk.splitlines()[0].decode("utf-8", errors="replace") if chunk else ""
            newest = _extract_version_from_first_line(first_line)

        _version_cache["last_fetch_ts"] = now_ts
        _version_cache["newest_version"] = newest
        _version_cache["error"] = None

        if newest is None:
            _version_cache["status"] = "unknown"
        elif newest > version:
            _version_cache["status"] = "outdated"
        else:
            _version_cache["status"] = "up_to_date"

        return newest
    except Exception as e:
        _version_cache["last_fetch_ts"] = now_ts
        _version_cache["error"] = f"{type(e).__name__}: {e}"
        if _version_cache["newest_version"] is None:
            _version_cache["status"] = "unknown"
        return _version_cache["newest_version"]


def version_status_text() -> str:
    newest = _version_cache.get("newest_version")
    status = _version_cache.get("status")
    if newest is None or status == "unknown":
        return f"Agent version check unavailable (current={version})"
    if newest > version:
        return f"Agent outdated: current={version}, newest={newest}"
    return f"Agent up to date (version {version})"


def startup_version_check_or_prompt() -> None:
    newest = fetch_latest_agent_version(force=True)

    if newest is None:
        log("Could not verify latest agent version right now (continuing)")
        return

    if newest > version:
        prompt = (
            'Your agent is outdated, you can either use the newest version at '
            'https://hostish.site/downloads/agent.py or you can respond "c" to continue with this outdated script\n> '
        )
        try:
            ans = input(prompt).strip().lower()
        except EOFError:
            ans = ""
        if ans != "c":
            log("Exiting because outdated agent was not approved to continue.")
            sys.exit(0)
        log(f"Continuing with outdated agent (current={version}, newest={newest})")
    else:
        log("Agent up to date")


def _norm_prefix(p: str) -> str:
    p = (p or "").lstrip("/")
    if p and not p.endswith("/"):
        p = p + "/"
    return p


PROXY_PREFIXES = [_norm_prefix(x) for x in PROXY_PREFIXES if x]


def safe_join(root: str, rel_path: str) -> str:
    """
    Prevent path traversal. Only allow files inside WEBSITE_DIR.
    rel_path is like "index.html" or "assets/app.js"
    """
    rel_path = (rel_path or "").replace("\\", "/").lstrip("/")
    full = os.path.realpath(os.path.join(root, rel_path))
    root_abs = os.path.realpath(root)

    if os.path.commonpath([root_abs, full]) != root_abs:
        raise ValueError("path traversal blocked")

    return full


def should_proxy(rel_path: str) -> bool:
    rel_path = (rel_path or "").lstrip("/").lower()
    if MODE == "proxy":
        return True
    if MODE == "static":
        return False
    for pref in PROXY_PREFIXES:
        if rel_path.startswith(pref.lower()):
            return True
    return False


def guess_content_type(path: str) -> str:
    import mimetypes
    ct, _ = mimetypes.guess_type(path)
    return ct or "application/octet-stream"


def _preview_bytes(b: bytes, n: int = BODY_PREVIEW_BYTES) -> str:
    if not b:
        return ""
    sample = b[:n]
    try:
        txt = sample.decode("utf-8", errors="replace")
        txt = txt.replace("\n", "\\n").replace("\r", "\\r")
        if len(b) > n:
            txt += "..."
        return txt
    except Exception:
        return f"<{len(sample)} bytes binary>"


def _metadata_version_response() -> tuple[int, dict, bytes, dict]:
    """
    Public lightweight endpoint (served by the agent itself):
      /__hostish/version
    Includes current version + newest version status.
    Bandwidth saver: newest-version fetch is cached for 30s.
    """
    newest = fetch_latest_agent_version(force=False)
    status = _version_cache.get("status", "unknown")
    err = _version_cache.get("error")
    payload = {
        "ok": True,
        "username": USERNAME,
        "current_version": version,
        "newest_version": newest,
        "status": status,
        "checked_at_unix": int(_version_cache.get("last_fetch_ts", 0) or 0),
        "cooldown_seconds": VERSION_CACHE_COOLDOWN_SECONDS,
        "error": err,
        "ip_protection_enabled": bool(IP_PROTECTION_ENABLED),
        "ip_protection_checked_at_unix": int(_public_ip_cache.get("last_fetch_ts", 0) or 0),
    }
    body = json.dumps(payload, separators=(",", ":")).encode("utf-8")
    headers = {
        "content-type": "application/json; charset=utf-8",
        "cache-control": "no-store",
        "access-control-allow-origin": "*",
    }
    return 200, headers, body, {
        "source": "agent-meta",
        "endpoint": "/__hostish/version",
        "size": len(body),
    }


def _hostish_default_404_redirect() -> tuple[int, dict, bytes, dict]:
    body = b""
    headers = {
        "location": "https://hostish.site/404/",
        "cache-control": "no-store",
    }
    return 302, headers, body, {
        "source": "agent-404-redirect",
        "target": "https://hostish.site/404/",
    }


def _resolve_custom_404_full_path() -> Optional[str]:
    p = (CUSTOM_404_PATH or "").strip()
    if not p:
        return None

    if os.path.isabs(p):
        return os.path.abspath(p)

    # relative to WEBSITE_DIR
    return safe_join(WEBSITE_DIR, p)


def _serve_custom_404_if_configured() -> Optional[tuple[int, dict, bytes, dict]]:
    if not CUSTOM_404_ENABLED:
        return None

    try:
        full = _resolve_custom_404_full_path()
    except Exception:
        return None

    if not full or not os.path.isfile(full):
        return None

    try:
        size = os.path.getsize(full)
        if size > MAX_FILE_BYTES:
            return None
        with open(full, "rb") as f:
            body = f.read()
        headers = {
            "content-type": guess_content_type(full),
            "cache-control": "no-store",
        }
        return 404, headers, body, {
            "source": "custom-404",
            "served_rel": CUSTOM_404_PATH,
            "full_path": full,
            "size": len(body),
        }
    except Exception:
        return None


def read_static_file(rel_path: str) -> tuple[int, dict, bytes, dict]:
    """
    Return (status, headers, body_bytes, meta)
    meta keys may include served_path, source
    """
    requested_rel = rel_path

    # directory default
    if rel_path == "" or rel_path.endswith("/"):
        rel_path = rel_path + "index.html"

    try:
        full = safe_join(WEBSITE_DIR, rel_path)
    except ValueError:
        return 403, {"content-type": "text/plain; charset=utf-8"}, b"Forbidden", {
            "source": "static",
            "served_rel": rel_path,
            "requested_rel": requested_rel,
            "full_path": None,
            "error": "path traversal blocked",
        }

    if not os.path.exists(full) or not os.path.isfile(full):
        return 404, {"content-type": "text/plain; charset=utf-8"}, b"Not Found", {
            "source": "static",
            "served_rel": rel_path,
            "requested_rel": requested_rel,
            "full_path": full,
            "error": "not found",
        }

    size = os.path.getsize(full)
    if size > MAX_FILE_BYTES:
        return 413, {"content-type": "text/plain; charset=utf-8"}, b"File too large", {
            "source": "static",
            "served_rel": rel_path,
            "requested_rel": requested_rel,
            "full_path": full,
            "error": "too large",
        }

    with open(full, "rb") as f:
        body = f.read()

    headers = {
        "content-type": guess_content_type(full),
        "cache-control": "no-store",
    }
    return 200, headers, body, {
        "source": "static",
        "served_rel": rel_path,
        "requested_rel": requested_rel,
        "full_path": full,
        "size": size,
    }


def parse_upstream_base(up: str):
    from urllib.parse import urlparse
    u = urlparse(up)
    if u.scheme not in ("http", "https"):
        raise ValueError("UPSTREAM_BASE must start with http:// or https://")
    host = u.hostname or "127.0.0.1"
    port = u.port or (443 if u.scheme == "https" else 80)
    base_path = u.path or ""
    if base_path.endswith("/"):
        base_path = base_path[:-1]
    return u.scheme, host, port, base_path


def proxy_request(method: str, rel_path: str, query: str, headers: dict, body: bytes):
    """
    Forward to local server (UPSTREAM_BASE). Return (status, headers, body_bytes, meta)
    Uses only stdlib HTTP client.
    """
    import http.client
    import urllib.parse

    scheme, host, port, base_path = parse_upstream_base(UPSTREAM_BASE)

    # Security: only allow localhost by default (prevents turning into open proxy).
    if host not in ("127.0.0.1", "localhost", "::1"):
        return 400, {"content-type": "text/plain; charset=utf-8"}, b"UPSTREAM must be localhost", {
            "source": "proxy",
            "error": "UPSTREAM must be localhost",
            "upstream": f"{scheme}://{host}:{port}",
        }

    # Defend against HTTP Request Smuggling via URI
    rel_path = urllib.parse.quote((rel_path or "").lstrip("/"), safe="/")
    query = urllib.parse.quote(query or "", safe="&=")

    path = "/" + rel_path
    if base_path:
        path = base_path + path
    if query:
        path = path + "?" + query

    out_headers = {}
    for k, v in (headers or {}).items():
        lk = str(k).lower()
        if lk in HOP_BY_HOP:
            continue
        if lk == "host":
            continue
        
        # Defend against HTTP Request Smuggling via Headers
        safe_k = str(k).replace("\r", "").replace("\n", "")
        safe_v = str(v).replace("\r", "").replace("\n", "")
        out_headers[safe_k] = safe_v

    out_headers["Host"] = host
    out_headers["Content-Length"] = str(len(body))

    if scheme == "https":
        conn = http.client.HTTPSConnection(host, port, timeout=UPSTREAM_TIMEOUT_SECONDS)
    else:
        conn = http.client.HTTPConnection(host, port, timeout=UPSTREAM_TIMEOUT_SECONDS)

    try:
        conn.request(method, path, body=body, headers=out_headers)
        resp = conn.getresponse()
        status = resp.status

        resp_body = resp.read(MAX_PROXY_BYTES + 1)
        if len(resp_body) > MAX_PROXY_BYTES:
            return 502, {"content-type": "text/plain; charset=utf-8"}, b"Upstream response too large", {
                "source": "proxy",
                "error": "response too large",
                "upstream_path": path,
                "status": status,
            }

        resp_headers = {}
        for (hk, hv) in resp.getheaders():
            lk = hk.lower()
            if lk in HOP_BY_HOP:
                continue
            resp_headers[hk] = hv

        return status, resp_headers, resp_body, {
            "source": "proxy",
            "upstream_path": path,
            "status": status,
            "resp_len": len(resp_body),
        }
    except Exception as e:
        msg = f"PROXY fail repeated upstream='{path}' error={type(e).__name__}: {e}"
        log_repeatable_error("proxy_request", msg)
        return 502, {"content-type": "text/plain; charset=utf-8"}, b"Upstream error", {
            "source": "proxy",
            "error": f"{type(e).__name__}: {e}",
            "upstream_path": path,
        }
    finally:
        try:
            conn.close()
        except Exception:
            pass


async def handle_req(msg: dict) -> dict:
    global TOTAL_REQS, TOTAL_BYTES_SENT, CURRENT_REQS, STATIC_REQS, PROXY_REQS

    req_id = msg.get("id")
    method = msg.get("method", "GET")
    rel_path = msg.get("path", "")
    query = msg.get("query", "")
    headers = msg.get("headers", {}) or {}
    body_b64 = msg.get("body_b64", "") or ""

    req_start = time.time()
    CURRENT_REQS += 1

    try:
        try:
            body = base64.b64decode(body_b64.encode("ascii")) if body_b64 else b""
        except Exception:
            body = b""

        host_hdr = headers.get("host", "")
        route_mode = "proxy" if should_proxy(rel_path) else "static"

        normalized_rel = (rel_path or "").lstrip("/")
        if normalized_rel == "__hostish/version":
            status, out_headers, out_body, meta = _metadata_version_response()
            route_mode = "agent-meta"
        else:
            log_event(
                f"REQ start  id={req_id} method={method} path='/{rel_path}'"
                + (f" query='{query}'" if query else "")
                + (f" host='{host_hdr}'" if host_hdr else "")
                + f" body={fmt_bytes(len(body))} mode={route_mode}"
            )

            if PRINT_REQUEST_HEADERS and headers:
                log("  headers: " + json.dumps(headers, ensure_ascii=False))
            if PRINT_REQUEST_BODY_PREVIEW and body:
                log("  body preview: " + _preview_bytes(body))

            if route_mode == "proxy":
                status, out_headers, out_body, meta = await asyncio.to_thread(
                    proxy_request, method, rel_path, query, headers, body
                )
            else:
                status, out_headers, out_body, meta = await asyncio.to_thread(
                    read_static_file, rel_path
                )

        # Custom 404 handling (applies to static/proxy 404s, not metadata endpoint)
        if status == 404 and meta.get("source") not in ("agent-meta", "agent-404-redirect"):
            if CUSTOM_404_ENABLED:
                custom_404 = await asyncio.to_thread(_serve_custom_404_if_configured)
                if custom_404 is not None:
                    status, out_headers, out_body, meta = custom_404
            else:
                status, out_headers, out_body, meta = _hostish_default_404_redirect()

        # IP protection (simple string replacement in headers/body)
        status, out_headers, out_body, meta = await asyncio.to_thread(
            _apply_ip_protection_to_response, req_id, status, out_headers, out_body, meta
        )

        elapsed_ms = int((time.time() - req_start) * 1000)

        TOTAL_REQS += 1
        TOTAL_BYTES_SENT += len(out_body)
        _status_count_inc(status)

        if route_mode == "proxy":
            PROXY_REQS += 1
        elif route_mode == "static":
            STATIC_REQS += 1

        src = meta.get("source")

        if src == "static":
            served_rel = meta.get("served_rel", rel_path)
            full_path = meta.get("full_path")
            if full_path:
                log_event(
                    f"FILE served id={req_id} status={status} rel='{served_rel}' "
                    f"bytes={fmt_bytes(len(out_body))} type='{out_headers.get('content-type','?')}' "
                    f"time={elapsed_ms}ms"
                )
                log(f"  local file: {full_path}")
            else:
                log_event(
                    f"FILE fail  id={req_id} status={status} rel='{served_rel}' "
                    f"reason='{meta.get('error','unknown')}' time={elapsed_ms}ms"
                )

        elif src == "proxy":
            log_event(
                f"PROXY done  id={req_id} status={status} upstream='{meta.get('upstream_path','?')}' "
                f"bytes={fmt_bytes(len(out_body))} time={elapsed_ms}ms"
            )
            if meta.get("error"):
                log(f"  proxy note: {meta['error']}")

        elif src == "custom-404":
            log_event(
                f"404 custom id={req_id} served='{meta.get('served_rel')}' "
                f"bytes={fmt_bytes(len(out_body))} time={elapsed_ms}ms"
            )
            log(f"  custom 404 file: {meta.get('full_path')}")

        elif src == "agent-404-redirect":
            log_event(
                f"404 redirect id={req_id} -> {meta.get('target')} time={elapsed_ms}ms"
            )

        else:
            # agent-meta
            log(
                f"META served id={req_id} endpoint={meta.get('endpoint')} "
                f"bytes={fmt_bytes(len(out_body))} time={elapsed_ms}ms"
            )

        print_summary()

        return {
            "type": "resp",
            "id": req_id,
            "status": int(status),
            "headers": out_headers,
            "body_b64": base64.b64encode(out_body).decode("ascii"),
        }

    except Exception as e:
        elapsed_ms = int((time.time() - req_start) * 1000)
        msg_txt = f"REQ error id={req_id} {type(e).__name__}: {e} time={elapsed_ms}ms"
        log_repeatable_error("handle_req", msg_txt)
        traceback.print_exc()
        return {
            "type": "resp",
            "id": req_id,
            "status": 500,
            "headers": {"content-type": "text/plain; charset=utf-8"},
            "body_b64": base64.b64encode(b"Internal agent error").decode("ascii"),
        }
    finally:
        CURRENT_REQS = max(0, CURRENT_REQS - 1)


async def main():
    global CONNECT_COUNT

    if MODE not in ("static", "proxy", "hybrid"):
        print("MODE must be: static | proxy | hybrid")
        return

    if MODE in ("static", "hybrid") and not os.path.isdir(WEBSITE_DIR):
        print("WEBSITE_DIR not found:", WEBSITE_DIR)
        return

    # Warm public IP cache if enabled (before startup banner)
    if IP_PROTECTION_ENABLED:
        ip = fetch_public_ip(force=True)
        if ip:
            log("IP protection initialized")
        else:
            err = _public_ip_cache.get("error") or "unknown error"
            log(f"IP protection could not fetch public IP at startup (will retry): {err}")

    # Version check before startup banner
    startup_version_check_or_prompt()
    print_startup_banner()

    # Auto start companion Python scripts once
    start_configured_python_scripts_once()

    # Start periodic tasks
    if PERIODIC_SUMMARY_ENABLED:
        asyncio.create_task(periodic_summary_loop())
    if IP_PROTECTION_ENABLED:
        asyncio.create_task(periodic_public_ip_refresh_loop())
    if AUTO_START_PYTHON_SCRIPTS_ENABLED:
        asyncio.create_task(managed_script_monitor_loop())

    # TLS context (default verification)
    ssl_ctx = ssl.create_default_context()

    backoff = 2

    while True:
        try:
            log(f"Connecting to Hostish VPS WebSocket: {WS_URL}")
            async with websockets.connect(
                WS_URL,
                ssl=ssl_ctx,
                max_size=25 * 1024 * 1024,
                ping_interval=20,
                ping_timeout=20,
            ) as ws:
                CONNECT_COUNT += 1
                log_event(f"WebSocket connected (connection #{CONNECT_COUNT})")

                # send auth
                auth_payload = {"username": USERNAME, "password": PASSWORD}
                await ws.send(json.dumps(auth_payload))
                log("Sent authentication request")

                # wait auth response
                raw = await ws.recv()
                auth = json.loads(raw)
                if auth.get("type") != "auth" or not auth.get("ok"):
                    log_repeatable_error(
                        "auth_failed",
                        "Authentication failed (check USERNAME/PASSWORD in agent.py)",
                    )
                    await asyncio.sleep(3)
                    continue

                backoff = 2
                log_event(f"Authenticated as '{USERNAME}' | mode={MODE}")
                log(f"Public URL is live (while this agent stays connected): https://{USERNAME}.hostish.site/")
                log(version_status_text())
                log(_public_ip_status_text())

                # Task wrapper for handling requests non-blocking
                async def handle_and_send(msg_payload):
                    try:
                        resp = await handle_req(msg_payload)
                        await ws.send(json.dumps(resp))
                        try:
                            resp_len = len(base64.b64decode(resp["body_b64"].encode("ascii")))
                        except Exception:
                            resp_len = 0
                        log(f"RESP sent   id={resp.get('id')} status={resp.get('status')} bytes={fmt_bytes(resp_len)}")
                    except Exception as e:
                        log_repeatable_error("task_error", f"Error in background task: {e}")

                # serve loop
                while True:
                    raw = await ws.recv()
                    msg = json.loads(raw)

                    mtype = msg.get("type")
                    if mtype in ("ping", "pong"):
                        log(f"Tunnel keepalive: {mtype}")
                        continue

                    if mtype != "req":
                        log(f"Ignoring unknown message type: {mtype!r}")
                        continue

                    # Hand off request immediately to avoid freezing the WebSocket listener
                    asyncio.create_task(handle_and_send(msg))

        except KeyboardInterrupt:
            raise
        except Exception as e:
            log_repeatable_error("main_disconnect", f"Disconnected: {type(e).__name__}: {e}")
            log(f"Reconnecting in {backoff}s...")
            await asyncio.sleep(backoff)
            backoff = min(backoff * 2, 15)


if __name__ == "__main__":
    try:
        asyncio.run(main())
    except KeyboardInterrupt:
        log("Stopped by user")
    except Exception:
        traceback.print_exc()
    finally:
        stop_managed_scripts()
