#!/usr/bin/env python3
# Import necessary libraries
import sys
import numpy as np
import argparse
import torch
import cv2
import pyzed.sl as sl
from ultralytics import YOLO
import time
from threading import Lock, Thread
from time import sleep
import cv_viewer.tracking_viewer as cv_viewer
from sahi import AutoDetectionModel
from sahi.predict import get_sliced_prediction
from deep_sort_realtime.deepsort_tracker import DeepSort
import open3d as o3d
import ogl_viewer.viewer as gl

# Initialize CUDA and create global variables
torch.cuda.empty_cache()
lock = Lock()
run_signal = False
exit_signal = False

# Initialize Open3D visualizer
vis = o3d.visualization.Visualizer()
vis.create_window("Point Cloud Viewer", width=1280, height=720)

# Function to convert bounding box format
def xywh2abcd(xywh, im_shape):
    # Convert [x, y, width, height] to [top-left, top-right, bottom-right, bottom-left] format
    output = np.zeros((4, 2))
    x_min = max(0, xywh[0] - 0.5 * xywh[2])
    x_max = min(im_shape[1], xywh[0] + 0.5 * xywh[2])
    y_min = max(0, xywh[1] - 0.5 * xywh[3])
    y_max = min(im_shape[0], xywh[1] + 0.5 * xywh[3])
    output[0] = [x_min, y_min]
    output[1] = [x_max, y_min]
    output[2] = [x_max, y_max]
    output[3] = [x_min, y_max]
    return output

# Function to visualize depth map
def visualize_depth_map(depth_map):
     # Convert depth map to a colored visualization
    depth_data = depth_map.get_data()
    depth_data = depth_data.astype(np.float32)
    # Clip the depth values to a reasonable range (e.g., 0 to 2 meters)
    #depth_data = np.clip(depth_data, 0, 2000)
    # Normalize the depth data
    normalized_depth = cv2.normalize(depth_data, None, 0, 255, cv2.NORM_MINMAX)
    # Apply a color map
    colored_depth = cv2.applyColorMap(normalized_depth.astype(np.uint8), cv2.COLORMAP_JET)
    return colored_depth

# Function to calculate and print depth statistics
def depth_statistics(depth_map):
    # Calculate and print average, min, and max depth
    depth_data = depth_map.get_data()
    valid_depths = depth_data[depth_data > 0]
    if len(valid_depths) > 0:
        avg_depth = np.mean(valid_depths)
        min_depth = np.min(valid_depths)
        max_depth = np.max(valid_depths)
        print(f"Depth Stats - Avg: {avg_depth:.2f}mm, Min: {min_depth:.2f}mm, Max: {max_depth:.2f}mm")
    else:
                print("No valid depth data available") 

# Function to convert SAHI results to ZED custom box format
def sahi_results_to_custom_box(object_prediction_list, im0):
     # Convert SAHI detection results to ZED custom box objects
    output = []
    for prediction in object_prediction_list:
        obj = sl.CustomBoxObjectData()
        bbox = prediction.bbox
        xywh = [(bbox.minx + bbox.maxx) / 2, (bbox.miny + bbox.maxy) / 2, bbox.maxx - bbox.minx, bbox.maxy - bbox.miny]
        obj.bounding_box_2d = xywh2abcd(xywh, im0.shape)
        obj.label = prediction.category.id
        obj.probability = prediction.score.value
        obj.is_grounded = False
        output.append(obj)
    return output

# Thread function for object detection
def torch_thread(weights, img_size, conf_thres=0.2, iou_thres=0.45):
     # Initialize and run object detection model in a separate thread
    global image_net, exit_signal, run_signal, detections
    print("Initializing Network...")
    detection_model = AutoDetectionModel.from_pretrained(
        model_type='yolov8',
        model_path=weights,
        confidence_threshold=conf_thres,
        device="cuda:0"
    )
    while not exit_signal:
        if run_signal:
            lock.acquire()
            img = cv2.cvtColor(image_net, cv2.COLOR_BGRA2RGB)
            result = get_sliced_prediction(
                img, detection_model, slice_height=512, slice_width=512,
                overlap_height_ratio=0.2, overlap_width_ratio=0.2
            )
            detections = sahi_results_to_custom_box(result.object_prediction_list, image_net)
            lock.release()
            run_signal = False
        sleep(0.01)

# Function to calculate IoU between two bounding boxes
def bbox_iou(bbox1, bbox2):
    # Calculate Intersection over Union for two bounding boxes
    x1 = max(bbox1[0], bbox2[0][0])
    y1 = max(bbox1[1], bbox2[0][1])
    x2 = min(bbox1[2], bbox2[2][0])
    y2 = min(bbox1[3], bbox2[2][1])
    intersection = max(0, x2 - x1) * max(0, y2 - y1)
    area1 = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1])
    area2 = (bbox2[2][0] - bbox2[0][0]) * (bbox2[2][1] - bbox2[0][1])
    iou = intersection / (area1 + area2 - intersection + 1e-6)
    return iou

image_left = sl.Mat()
# image_left_ocv = np.full((display_resolution.height, display_resolution.width, 4), [245, 239, 239, 255], np.uint8)

def main():
    global image_net, exit_signal, run_signal, detections
    # Start the object detection thread
    capture_thread = Thread(target=torch_thread, kwargs={'weights': opt.weights, 'img_size': opt.img_size, "conf_thres": opt.conf_thres})
    capture_thread.start()

    print("Initializing Camera...") # Initialize ZED camera
    zed = sl.Camera()
    input_type = sl.InputType()
    if opt.svo is not None:
        input_type.set_from_svo_file(opt.svo)

    # Initialize DeepSORT tracker
    deep_sort = DeepSort(
        max_age=30, n_init=3, nms_max_overlap=1.0, max_cosine_distance=0.3,
        nn_budget=None, override_track_class=None, embedder="mobilenet",
        half=True, bgr=True, embedder_gpu=True
    )

    # Set up ZED camera parameters
    init_params = sl.InitParameters(input_t=input_type, svo_real_time_mode=True)
    init_params.coordinate_units = sl.UNIT.MILLIMETER
    init_params.depth_mode = sl.DEPTH_MODE.ULTRA
    init_params.coordinate_system = sl.COORDINATE_SYSTEM.RIGHT_HANDED_Y_UP
    init_params.depth_maximum_distance = 1000
    init_params.depth_stabilization = True

    status = zed.open(init_params)
    if status != sl.ERROR_CODE.SUCCESS:
        print(repr(status))
        exit()

    # Create Open3D visualizer
    vis = o3d.visualization.Visualizer() #06.11.2024
    vis.create_window() #06.11.2024


    positional_tracking_parameters = sl.PositionalTrackingParameters()
    positional_tracking_parameters.set_as_static = True
    zed.enable_positional_tracking(positional_tracking_parameters)


    # Enable object detection and tracking
    obj_param = sl.ObjectDetectionParameters()
    obj_param.detection_model = sl.OBJECT_DETECTION_MODEL.CUSTOM_BOX_OBJECTS
    obj_param.enable_tracking = True
    obj_param.enable_segmentation = False
    zed.enable_object_detection(obj_param)

    objects = sl.Objects()
    obj_runtime_param = sl.ObjectDetectionRuntimeParameters()

    # camera initialization  
    camera_infos = zed.get_camera_information()
    camera_res = camera_infos.camera_configuration.resolution

    # Define point_cloud_res
    point_cloud_res = sl.Resolution(min(camera_res.width, 720), min(camera_res.height, 404))

    # Rest of your initialization code
    viewer = gl.GLViewer()
    viewer.init(camera_infos.camera_model, point_cloud_res, obj_param.enable_tracking)
    point_cloud = sl.Mat(point_cloud_res.width, point_cloud_res.height, sl.MAT_TYPE.F32_C4, sl.MEM.CPU)
    point_cloud_render = sl.Mat()

    display_resolution = sl.Resolution(min(camera_res.width, 1280), min(camera_res.height, 720))
    image_scale = [display_resolution.width / camera_res.width, display_resolution.height / camera_res.height]
    image_left_ocv = np.full((display_resolution.height, display_resolution.width, 4), [245, 239, 239, 255], np.uint8)

    camera_config = camera_infos.camera_configuration
    tracks_resolution = sl.Resolution(400, display_resolution.height)
    track_view_generator = cv_viewer.TrackingViewer(tracks_resolution, camera_config.fps, init_params.depth_maximum_distance)
    track_view_generator.set_camera_calibration(camera_config.calibration_parameters)
    image_track_ocv = np.zeros((tracks_resolution.height, tracks_resolution.width, 4), np.uint8)

    cam_w_pose = sl.Pose()


    # Main loop
    while not exit_signal:
        if zed.grab(sl.RuntimeParameters()) == sl.ERROR_CODE.SUCCESS:
            # Retrieve and process image
            lock.acquire()
            zed.retrieve_image(image_left, sl.VIEW.LEFT, sl.MEM.CPU, display_resolution)
            image_left_ocv = image_left.get_data()
            image_left_rgb = cv2.cvtColor(image_left_ocv, cv2.COLOR_RGBA2RGB)
            image_net = image_left_ocv
            lock.release()
            run_signal = True

            while run_signal:
                sleep(0.001)

            # Perform object detection
            zed.ingest_custom_box_objects(detections)
            zed.retrieve_objects(objects, obj_runtime_param)

            # Convert detections to DeepSORT format and update tracks
            detections = []
            for obj in objects.object_list:
                bbox = obj.bounding_box_2d
                x1, y1 = bbox[0]
                x2, y2 = bbox[2]
                w = x2 - x1
                h = y2 - y1
                confidence = obj.confidence
                class_id = obj.label
                detections.append(([x1, y1, w, h], confidence, class_id))

            tracks = deep_sort.update_tracks(detections, frame=image_left_rgb)

            
            # Update object IDs based on tracking results
            for track in tracks:
                if not track.is_confirmed():
                    continue
                track_id = track.track_id
                ltrb = track.to_ltrb()
                for obj in objects.object_list:
                    if bbox_iou(ltrb, obj.bounding_box_2d) > 0.5:
                        obj.unique_object_id = str(track_id)

            
            # Retrieve and visualize depth map
            depth_map = sl.Mat()
            zed.retrieve_measure(depth_map, sl.MEASURE.DEPTH)
            colored_depth_map = visualize_depth_map(depth_map)
            cv2.imshow("Depth Map", colored_depth_map)
            depth_statistics(depth_map)

            # Point cloud visualization
            point_cloud = sl.Mat()
            zed.retrieve_measure(point_cloud, sl.MEASURE.XYZRGBA, sl.MEM.CPU, point_cloud_res)

            # Convert ZED point cloud to Open3D format
            pc_data = point_cloud.get_data()
            xyz = pc_data[:, :, :3].reshape(-1, 3)
            rgb = pc_data[:, :, 3:6].reshape(-1, 3) / 255.0  # Normalize RGB values

            pcd = o3d.geometry.PointCloud()
            pcd.points = o3d.utility.Vector3dVector(xyz)
            pcd.colors = o3d.utility.Vector3dVector(rgb)

            # Update Open3D visualizer
            vis.clear_geometries()
            vis.add_geometry(pcd)
            vis.update_geometry(pcd)
            vis.poll_events()
            vis.update_renderer()

            
            # Process and display object information
            for obj in objects.object_list:
                if obj.tracking_state == sl.OBJECT_TRACKING_STATE.OK:
                    center_x = int((obj.bounding_box_2d[0][0] + obj.bounding_box_2d[1][0]) / 2)
                    center_y = int((obj.bounding_box_2d[0][1] + obj.bounding_box_2d[2][1]) / 2)
                    err, depth = depth_map.get_value(center_x, center_y)
                    if err == sl.ERROR_CODE.SUCCESS and np.isfinite(depth):
                        confidence_percent = obj.confidence * 100
                        print(f"Object ID: {obj.unique_object_id}, Depth: {depth:.2f}mm, Confidence: {confidence_percent:.2f}%")
                    else:
                        print(f"Object ID: {obj.unique_object_id}, Depth: Invalid, Confidence: {obj.confidence * 100:.2f}%")

                position = obj.position
                dimensions = obj.dimensions
                velocity = obj.velocity
                print(f"Object ID: {obj.unique_object_id}")
                print(f"Position: X: {position[0]:.2f}, Y: {position[1]:.2f}, Z: {position[2]:.2f}")
                print(f"Dimensions: W: {dimensions[0]:.2f}, H: {dimensions[1]:.2f}, L: {dimensions[2]:.2f}")
                print(f"Velocity: X: {velocity[0]:.2f}, Y: {velocity[1]:.2f}, Z: {velocity[2]:.2f}")

            zed.retrieve_image(image_left, sl.VIEW.LEFT, sl.MEM.CPU, display_resolution)
            zed.get_position(cam_w_pose, sl.REFERENCE_FRAME.WORLD)

            # Visualize results
            cv_viewer.render_2D(image_left_ocv, image_scale, objects, obj_param.enable_tracking)
            global_image = cv2.hconcat([image_left_ocv, image_track_ocv])
            track_view_generator.generate_view(objects, cam_w_pose, image_track_ocv, objects.is_tracked)
            cv2.imshow("ZED | 2D View and Birds View", global_image)

            # Handle user input
            key = cv2.waitKey(10)
            if key == 27 or key == ord('q') or key == ord('Q'):
                exit_signal = True
        # Check if window is closed
        if not vis.poll_events():
            break

    # Clean up
    zed.close()
    vis.destroy_window()

# Entry point of the script
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--weights', type=str, default='C:/Program Files (x86)/ZED SDK/zed-sdk-master/runs/detect/train32/weights/best.pt', help='model.pt path(s)')
    parser.add_argument('--svo', type=str, default=None, help='optional svo file')
    parser.add_argument('--img_size', type=int, default=640, help='inference size (pixels)')
    parser.add_argument('--conf_thres', type=float, default=0.4, help='object confidence threshold')
    parser.add_argument('--depth_conf_thres', type=int, default=2, help='ZED depth confidence threshold')
    opt = parser.parse_args()

    # Run the main function
    with torch.no_grad():
        main()