#!/usr/bin/env python3
"""
Exbot / ZED2i minimal live micro-fusion sanity test v23.

Purpose
-------
Keep the raw top-down view that worked in v21, but add a *small* rolling fusion
window to improve missing patches such as half a plank or keyboard edges.

Design rules
------------
- very short live capture window
- no long warmup
- no raw sideview output
- keep left/right/depth/confidence PNGs
- keep the script itself in the same version folder
- separate "low" and "tall" fused maps
"""

import json
import time
from pathlib import Path

import cv2
import numpy as np
import pyzed.sl as sl

OUT_DIR = Path.home() / "exbot_ui" / "zed_floor_sanity_test_output_v23"
OUT_DIR.mkdir(parents=True, exist_ok=True)

# Camera settings
RESOLUTION = sl.RESOLUTION.HD1080
FPS = 30
DEPTH_MODE = sl.DEPTH_MODE.NEURAL_PLUS if hasattr(sl.DEPTH_MODE, "NEURAL_PLUS") else sl.DEPTH_MODE.NEURAL
MIN_DEPTH_M = 0.3
MAX_DEPTH_M = 8.0

# Short rolling fusion only
FUSION_FRAMES = 8
MAX_GRAB_ATTEMPTS = 20

# Zone B
B_FWD_MIN_M = 0.5
B_FWD_MAX_M = 3.5
B_LAT_MIN_M = -2.0
B_LAT_MAX_M = 2.0

# Height ranges
LOW_CLIP_MIN_M = -0.03
LOW_CLIP_MAX_M = 0.20
TALL_CLIP_MIN_M = -0.03
TALL_CLIP_MAX_M = 2.00

# Floor fit candidate region
FLOOR_FIT_FWD_MIN_M = 0.3
FLOOR_FIT_FWD_MAX_M = 4.0
FLOOR_FIT_LAT_MIN_M = -2.5
FLOOR_FIT_LAT_MAX_M = 2.5
FLOOR_FIT_Z_MIN_M = -2.5
FLOOR_FIT_Z_MAX_M = 2.5

# Grid for fused maps
CELL_M = 0.02
NX = int(round((B_LAT_MAX_M - B_LAT_MIN_M) / CELL_M))
NY = int(round((B_FWD_MAX_M - B_FWD_MIN_M) / CELL_M))

TOP_W = 1400
TOP_H = 900


def save_bgr(path: Path, img_bgr: np.ndarray) -> None:
    cv2.imwrite(str(path), img_bgr)


def sl_image_to_bgr(mat: sl.Mat) -> np.ndarray:
    arr = mat.get_data()
    if arr is None:
        raise RuntimeError("No image data returned")
    if arr.ndim == 3 and arr.shape[2] >= 3:
        rgb = arr[:, :, :3]
        if rgb.dtype != np.uint8:
            rgb = np.clip(rgb, 0, 255).astype(np.uint8)
        return cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR)
    raise RuntimeError(f"Unexpected image shape: {arr.shape}")


def finite_xyz_from_point_cloud(mat: sl.Mat) -> np.ndarray:
    pc = mat.get_data()[:, :, :3].reshape(-1, 3)
    mask = np.isfinite(pc).all(axis=1)
    return pc[mask]


def fit_floor_plane_ransac(points: np.ndarray, iterations: int = 200, threshold_m: float = 0.015):
    """Fit floor plane n.x + d = 0 with simple RANSAC + SVD refinement."""
    if len(points) < 100:
        raise RuntimeError("Too few points for floor fit")

    rng = np.random.default_rng(42)
    best_count = -1
    best_inliers = None
    n_points = len(points)

    for _ in range(iterations):
        idx = rng.choice(n_points, 3, replace=False)
        p1, p2, p3 = points[idx]
        n = np.cross(p2 - p1, p3 - p1)
        norm = np.linalg.norm(n)
        if norm < 1e-8:
            continue
        n = n / norm

        # Floor normal should point mostly upward.
        if abs(n[2]) < 0.8:
            continue
        if n[2] < 0:
            n = -n

        d = -float(np.dot(n, p1))
        dist = np.abs(points @ n + d)
        inliers = dist < threshold_m
        count = int(inliers.sum())

        if count > best_count:
            best_count = count
            best_inliers = inliers

    if best_inliers is None or best_count < 100:
        raise RuntimeError("Floor fit failed")

    inlier_pts = points[best_inliers]
    centroid = inlier_pts.mean(axis=0)
    _, _, vh = np.linalg.svd(inlier_pts - centroid, full_matrices=False)
    n = vh[-1, :]
    if n[2] < 0:
        n = -n
    d = -float(np.dot(n, centroid))
    return n, d


def meters_to_topdown_pixels(lat_m: np.ndarray, fwd_m: np.ndarray):
    # Keep cupboard on the right and near points at the bottom.
    x_norm = (lat_m - B_LAT_MIN_M) / (B_LAT_MAX_M - B_LAT_MIN_M)
    x_px = ((1.0 - x_norm) * (TOP_W - 1)).astype(np.int32)

    y_norm = (fwd_m - B_FWD_MIN_M) / (B_FWD_MAX_M - B_FWD_MIN_M)
    y_px = ((1.0 - y_norm) * (TOP_H - 1)).astype(np.int32)

    return x_px, y_px


def heights_to_colors(height_m: np.ndarray, clip_min: float, clip_max: float) -> np.ndarray:
    h = np.clip(height_m, clip_min, clip_max)
    h01 = (h - clip_min) / (clip_max - clip_min)
    gray = np.clip(h01 * 255.0, 0, 255).astype(np.uint8)
    colors = cv2.applyColorMap(gray, cv2.COLORMAP_TURBO)
    return colors.reshape(-1, 3)


def draw_scatter_topdown(points_b: np.ndarray, heights_b: np.ndarray, out_path: Path, title: str, clip_min: float, clip_max: float) -> None:
    canvas = np.zeros((TOP_H, TOP_W, 3), dtype=np.uint8)
    canvas[:] = (16, 16, 16)

    draw_grid_and_labels(canvas, title, clip_min, clip_max)

    x_px, y_px = meters_to_topdown_pixels(points_b[:, 1], points_b[:, 0])
    colors = heights_to_colors(heights_b, clip_min, clip_max)

    inside = (x_px >= 0) & (x_px < TOP_W) & (y_px >= 0) & (y_px < TOP_H)
    for x, y, c in zip(x_px[inside], y_px[inside], colors[inside]):
        canvas[y, x] = c.tolist()

    save_bgr(out_path, canvas)


def draw_grid_and_labels(canvas: np.ndarray, title: str, clip_min: float, clip_max: float) -> None:
    # Grid every 0.5 m
    for lat in np.arange(B_LAT_MIN_M, B_LAT_MAX_M + 1e-6, 0.5):
        x_px, _ = meters_to_topdown_pixels(np.array([lat]), np.array([B_FWD_MIN_M]))
        cv2.line(canvas, (int(x_px[0]), 0), (int(x_px[0]), TOP_H - 1), (60, 60, 60), 1)

    for fwd in np.arange(B_FWD_MIN_M, B_FWD_MAX_M + 1e-6, 0.5):
        _, y_px = meters_to_topdown_pixels(np.array([B_LAT_MIN_M]), np.array([fwd]))
        cv2.line(canvas, (0, int(y_px[0])), (TOP_W - 1, int(y_px[0])), (60, 60, 60), 1)

    font = cv2.FONT_HERSHEY_SIMPLEX
    cv2.putText(canvas, title, (20, 30), font, 0.8, (255, 255, 255), 2, cv2.LINE_AA)
    cv2.putText(canvas, "right", (20, TOP_H - 15), font, 0.6, (220, 220, 220), 1, cv2.LINE_AA)
    cv2.putText(canvas, "left", (TOP_W - 80, TOP_H - 15), font, 0.6, (220, 220, 220), 1, cv2.LINE_AA)
    cv2.putText(canvas, f"near {B_FWD_MIN_M:.1f} m", (TOP_W - 180, TOP_H - 15), font, 0.55, (220, 220, 220), 1, cv2.LINE_AA)
    cv2.putText(canvas, f"far {B_FWD_MAX_M:.1f} m", (TOP_W - 160, 28), font, 0.55, (220, 220, 220), 1, cv2.LINE_AA)

    for fwd in np.arange(B_FWD_MIN_M, B_FWD_MAX_M + 1e-6, 0.5):
        _, y_px = meters_to_topdown_pixels(np.array([B_LAT_MIN_M]), np.array([fwd]))
        cv2.putText(canvas, f"{fwd:.1f}m", (6, max(18, int(y_px[0]) - 4)), font, 0.45, (200, 200, 200), 1, cv2.LINE_AA)

    # Legend
    legend_h = 220
    legend_w = 24
    legend_x0 = TOP_W - 70
    legend_y0 = 70
    legend = np.linspace(255, 0, legend_h, dtype=np.uint8).reshape(-1, 1)
    legend = cv2.applyColorMap(legend, cv2.COLORMAP_TURBO)
    legend = cv2.resize(legend, (legend_w, legend_h), interpolation=cv2.INTER_NEAREST)
    canvas[legend_y0:legend_y0 + legend_h, legend_x0:legend_x0 + legend_w] = legend
    cv2.rectangle(canvas, (legend_x0, legend_y0), (legend_x0 + legend_w, legend_y0 + legend_h), (255, 255, 255), 1)
    cv2.putText(canvas, f"{clip_max:.2f}m", (legend_x0 - 5, legend_y0 - 6), font, 0.45, (255, 255, 255), 1, cv2.LINE_AA)
    cv2.putText(canvas, f"{clip_min:.2f}m", (legend_x0 - 5, legend_y0 + legend_h + 18), font, 0.45, (255, 255, 255), 1, cv2.LINE_AA)


def draw_grid_map(height_map: np.ndarray, valid_mask: np.ndarray, out_path: Path, title: str, clip_min: float, clip_max: float) -> None:
    canvas = np.zeros((TOP_H, TOP_W, 3), dtype=np.uint8)
    canvas[:] = (16, 16, 16)
    draw_grid_and_labels(canvas, title, clip_min, clip_max)

    if np.any(valid_mask):
        ys, xs = np.where(valid_mask)
        # Cell center positions in meters
        fwd = B_FWD_MIN_M + (ys + 0.5) * CELL_M
        lat = B_LAT_MIN_M + (xs + 0.5) * CELL_M
        x_px, y_px = meters_to_topdown_pixels(lat, fwd)
        colors = heights_to_colors(height_map[valid_mask], clip_min, clip_max)
        for x, y, c in zip(x_px, y_px, colors):
            cv2.circle(canvas, (int(x), int(y)), 2, tuple(int(v) for v in c), thickness=-1)

    save_bgr(out_path, canvas)


def draw_support_map(support_map: np.ndarray, out_path: Path, title: str) -> None:
    canvas = np.zeros((TOP_H, TOP_W, 3), dtype=np.uint8)
    canvas[:] = (16, 16, 16)

    # Grid and simple labels
    for lat in np.arange(B_LAT_MIN_M, B_LAT_MAX_M + 1e-6, 0.5):
        x_px, _ = meters_to_topdown_pixels(np.array([lat]), np.array([B_FWD_MIN_M]))
        cv2.line(canvas, (int(x_px[0]), 0), (int(x_px[0]), TOP_H - 1), (60, 60, 60), 1)

    for fwd in np.arange(B_FWD_MIN_M, B_FWD_MAX_M + 1e-6, 0.5):
        _, y_px = meters_to_topdown_pixels(np.array([B_LAT_MIN_M]), np.array([fwd]))
        cv2.line(canvas, (0, int(y_px[0])), (TOP_W - 1, int(y_px[0])), (60, 60, 60), 1)

    font = cv2.FONT_HERSHEY_SIMPLEX
    cv2.putText(canvas, title, (20, 30), font, 0.8, (255, 255, 255), 2, cv2.LINE_AA)

    if np.any(support_map > 0):
        max_support = int(np.max(support_map))
        ys, xs = np.where(support_map > 0)
        val = (support_map[ys, xs].astype(np.float32) / max_support * 255.0).astype(np.uint8)
        colors = cv2.applyColorMap(val, cv2.COLORMAP_BONE).reshape(-1, 3)
        fwd = B_FWD_MIN_M + (ys + 0.5) * CELL_M
        lat = B_LAT_MIN_M + (xs + 0.5) * CELL_M
        x_px, y_px = meters_to_topdown_pixels(lat, fwd)
        for x, y, c in zip(x_px, y_px, colors):
            cv2.circle(canvas, (int(x), int(y)), 2, tuple(int(v) for v in c), thickness=-1)

    save_bgr(out_path, canvas)


def accumulate_into_grid(points_b: np.ndarray, heights_b: np.ndarray,
                         low_sum: np.ndarray, low_count: np.ndarray,
                         tall_max: np.ndarray, support_frames: np.ndarray) -> None:
    """Accumulate one frame into fused low/tall/support grids."""
    if len(points_b) == 0:
        return

    xi = np.floor((points_b[:, 1] - B_LAT_MIN_M) / CELL_M).astype(np.int32)
    yi = np.floor((points_b[:, 0] - B_FWD_MIN_M) / CELL_M).astype(np.int32)

    inside = (xi >= 0) & (xi < NX) & (yi >= 0) & (yi < NY)
    xi = xi[inside]
    yi = yi[inside]
    h = heights_b[inside]

    # Low map: average clipped low heights
    h_low = np.clip(h, LOW_CLIP_MIN_M, LOW_CLIP_MAX_M)
    np.add.at(low_sum, (yi, xi), h_low)
    np.add.at(low_count, (yi, xi), 1)

    # Tall map: cellwise maximum height
    h_tall = np.clip(h, TALL_CLIP_MIN_M, TALL_CLIP_MAX_M)
    linear = yi * NX + xi
    order = np.argsort(linear)
    linear_sorted = linear[order]
    h_sorted = h_tall[order]
    if len(linear_sorted):
        uniq, idx_start = np.unique(linear_sorted, return_index=True)
        idx_end = np.r_[idx_start[1:], len(linear_sorted)]
        cell_max = np.array([np.max(h_sorted[s:e]) for s, e in zip(idx_start, idx_end)], dtype=np.float32)
        uy = uniq // NX
        ux = uniq % NX
        tall_max[uy, ux] = np.maximum(tall_max[uy, ux], cell_max)

        # Support map: count this cell once for this frame
        np.add.at(support_frames, (uy, ux), 1)


def main() -> int:
    start_time = time.time()

    zed = sl.Camera()
    init = sl.InitParameters()
    init.camera_resolution = RESOLUTION
    init.camera_fps = FPS
    init.depth_mode = DEPTH_MODE
    init.coordinate_units = sl.UNIT.METER
    if hasattr(sl, "COORDINATE_SYSTEM") and hasattr(sl.COORDINATE_SYSTEM, "RIGHT_HANDED_Z_UP_X_FWD"):
        init.coordinate_system = sl.COORDINATE_SYSTEM.RIGHT_HANDED_Z_UP_X_FWD
    init.depth_minimum_distance = MIN_DEPTH_M
    init.depth_maximum_distance = MAX_DEPTH_M

    err = zed.open(init)
    if err != sl.ERROR_CODE.SUCCESS:
        print(f"Open failed: {err}")
        return 1

    runtime = sl.RuntimeParameters()
    runtime.enable_depth = True
    if hasattr(runtime, "measure3D_reference_frame") and hasattr(sl.REFERENCE_FRAME, "CAMERA"):
        runtime.measure3D_reference_frame = sl.REFERENCE_FRAME.CAMERA

    # Mats for the final snapshot
    left = sl.Mat()
    right = sl.Mat()
    depth_view = sl.Mat()
    confidence_view = sl.Mat()
    point_cloud = sl.Mat()

    low_sum = np.zeros((NY, NX), dtype=np.float32)
    low_count = np.zeros((NY, NX), dtype=np.int32)
    tall_max = np.full((NY, NX), np.nan, dtype=np.float32)
    support_frames = np.zeros((NY, NX), dtype=np.int32)

    last_points_b = None
    last_heights_b = None
    floor_ok_count = 0
    floor_candidate_total = 0
    last_floor_n = np.array([0.0, 0.0, 1.0], dtype=np.float32)
    last_floor_d = 0.0

    successful_frames = 0
    attempts = 0

    while successful_frames < FUSION_FRAMES and attempts < MAX_GRAB_ATTEMPTS:
        attempts += 1
        if zed.grab(runtime) != sl.ERROR_CODE.SUCCESS:
            continue

        successful_frames += 1

        # Save image views from the latest successful frame
        zed.retrieve_image(left, sl.VIEW.LEFT)
        zed.retrieve_image(right, sl.VIEW.RIGHT)
        zed.retrieve_image(depth_view, sl.VIEW.DEPTH)
        if hasattr(sl.VIEW, "CONFIDENCE"):
            zed.retrieve_image(confidence_view, sl.VIEW.CONFIDENCE)
        zed.retrieve_measure(point_cloud, sl.MEASURE.XYZRGBA, sl.MEM.CPU)

        xyz = finite_xyz_from_point_cloud(point_cloud)

        floor_candidates = xyz[
            (xyz[:, 0] >= FLOOR_FIT_FWD_MIN_M) & (xyz[:, 0] <= FLOOR_FIT_FWD_MAX_M) &
            (xyz[:, 1] >= FLOOR_FIT_LAT_MIN_M) & (xyz[:, 1] <= FLOOR_FIT_LAT_MAX_M) &
            (xyz[:, 2] >= FLOOR_FIT_Z_MIN_M) & (xyz[:, 2] <= FLOOR_FIT_Z_MAX_M)
        ]
        floor_candidate_total += int(len(floor_candidates))

        try:
            floor_n, floor_d = fit_floor_plane_ransac(floor_candidates)
            floor_ok_count += 1
            last_floor_n = floor_n
            last_floor_d = floor_d
        except Exception:
            floor_n = last_floor_n
            floor_d = last_floor_d

        heights = xyz @ floor_n + floor_d

        b_mask = (
            (xyz[:, 0] >= B_FWD_MIN_M) & (xyz[:, 0] <= B_FWD_MAX_M) &
            (xyz[:, 1] >= B_LAT_MIN_M) & (xyz[:, 1] <= B_LAT_MAX_M) &
            np.isfinite(heights) &
            (heights >= LOW_CLIP_MIN_M - 0.05) & (heights <= TALL_CLIP_MAX_M + 0.10)
        )

        points_b = xyz[b_mask]
        heights_b = heights[b_mask]
        last_points_b = points_b
        last_heights_b = heights_b

        accumulate_into_grid(points_b, heights_b, low_sum, low_count, tall_max, support_frames)

    if successful_frames == 0:
        print("Grab failed")
        zed.close()
        return 1

    # Save raw views from the latest good frame
    save_bgr(OUT_DIR / "zed_left.png", sl_image_to_bgr(left))
    save_bgr(OUT_DIR / "zed_right.png", sl_image_to_bgr(right))
    save_bgr(OUT_DIR / "zed_depth_view.png", sl_image_to_bgr(depth_view))
    if confidence_view.get_width() > 0 and confidence_view.get_height() > 0:
        save_bgr(OUT_DIR / "confidence_view.png", sl_image_to_bgr(confidence_view))

    # Raw scatter from the last frame (v21-style)
    if last_points_b is not None and len(last_points_b) > 0:
        draw_scatter_topdown(
            last_points_b, last_heights_b,
            OUT_DIR / "raw_topdown_points_B_low.png",
            "Zone B raw top-down points - low range",
            LOW_CLIP_MIN_M, LOW_CLIP_MAX_M
        )

    # Fused low map
    low_valid = low_count > 0
    low_map = np.zeros((NY, NX), dtype=np.float32)
    low_map[low_valid] = low_sum[low_valid] / low_count[low_valid]
    draw_grid_map(
        low_map, low_valid,
        OUT_DIR / "fused_topdown_points_B_low.png",
        "Zone B fused top-down - low range",
        LOW_CLIP_MIN_M, LOW_CLIP_MAX_M
    )

    # Fused tall map
    tall_valid = np.isfinite(tall_max)
    tall_map = np.zeros((NY, NX), dtype=np.float32)
    tall_map[tall_valid] = tall_max[tall_valid]
    draw_grid_map(
        tall_map, tall_valid,
        OUT_DIR / "fused_topdown_points_B_tall.png",
        "Zone B fused top-down - tall range",
        TALL_CLIP_MIN_M, TALL_CLIP_MAX_M
    )

    # Fused support
    draw_support_map(
        support_frames,
        OUT_DIR / "fused_support_B.png",
        "Zone B fused support"
    )

    settings = {
        "version": "v23",
        "resolution": "HD1080",
        "fps": FPS,
        "depth_mode": str(DEPTH_MODE).split(".")[-1],
        "min_depth_m": MIN_DEPTH_M,
        "max_depth_m": MAX_DEPTH_M,
        "fusion_frames_requested": FUSION_FRAMES,
        "fusion_frames_successful": successful_frames,
        "max_grab_attempts": MAX_GRAB_ATTEMPTS,
        "b_zone_forward_m": [B_FWD_MIN_M, B_FWD_MAX_M],
        "b_zone_lateral_m": [B_LAT_MIN_M, B_LAT_MAX_M],
        "cell_m": CELL_M,
        "low_height_clip_m": [LOW_CLIP_MIN_M, LOW_CLIP_MAX_M],
        "tall_height_clip_m": [TALL_CLIP_MIN_M, TALL_CLIP_MAX_M],
        "right_left_reversed_for_display": True,
        "front_back_reversed_for_display": False,
        "outputs_kept": [
            "zed_left.png",
            "zed_right.png",
            "zed_depth_view.png",
            "confidence_view.png",
            "raw_topdown_points_B_low.png",
            "fused_topdown_points_B_low.png",
            "fused_topdown_points_B_tall.png",
            "fused_support_B.png",
            "settings.json",
            "summary.json",
            "zed_floor_sanity_test_v23.py",
        ],
    }

    summary = {
        "timestamp_unix": time.time(),
        "processing_seconds": round(time.time() - start_time, 3),
        "candidate_floor_points_total": floor_candidate_total,
        "floor_fit_successful_frames": floor_ok_count,
        "last_floor_plane_normal": [float(v) for v in last_floor_n],
        "last_floor_plane_d": float(last_floor_d),
        "last_frame_b_points_used": int(len(last_points_b)) if last_points_b is not None else 0,
        "fused_low_valid_cells": int(np.sum(low_valid)),
        "fused_tall_valid_cells": int(np.sum(tall_valid)),
        "fused_support_nonzero_cells": int(np.sum(support_frames > 0)),
        "output_files": sorted(p.name for p in OUT_DIR.glob("*") if p.is_file()),
    }

    (OUT_DIR / "settings.json").write_text(json.dumps(settings, indent=2))
    (OUT_DIR / "summary.json").write_text(json.dumps(summary, indent=2))

    print("Done.")
    print(f"Output folder: {OUT_DIR}")
    print(f"Processing seconds: {summary['processing_seconds']}")
    print(f"Fusion frames: {successful_frames}")
    print(f"Last frame B points used: {summary['last_frame_b_points_used']}")
    print("Main PNGs: raw_topdown_points_B_low.png, fused_topdown_points_B_low.png, fused_topdown_points_B_tall.png, fused_support_B.png")

    zed.close()
    return 0


if __name__ == "__main__":
    raise SystemExit(main())
