########################################################################
#
# Copyright (c) 2023, STEREOLABS.
#
# All rights reserved.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
########################################################################

"""
    This sample shows how to track the position of the ZED camera 
    and displays it in a OpenGL window.
"""

import sys

import numpy as np

import ogl_viewer.tracking_viewer as gl
import pyzed.sl as sl
import time
from datetime import datetime, timedelta
import multiprocessing
import threading
import json



# optional libraries : ZED HUB and GPSD
try:
    import pyzed.sl_iot as sliot
    with_zed_hub = True
except ImportError:
    with_zed_hub = False
    print("ZED Hub not detected.")

try:
    from gpsdclient import GPSDClient
    with_gpsd = True
except ImportError:
    with_gpsd = False
    print ("GPSD not detected. GNSS data will be dummy instead.")

global current_gps_value
global gps_stream
global new_gps_data_available


from pyubx2 import UBXReader
from datetime import datetime, timedelta

def convert_nmea_to_timestamp_complete(nmea_sentence):
    parts = nmea_sentence.split(',')

    # Parse the timestamp
    time_part = parts[1]
    hours = int(time_part[0:2])
    minutes = int(time_part[2:4])
    seconds = float(time_part[4:])
    timestamp = datetime.utcnow().replace(month=5, day=21, hour=hours, minute=minutes, second=int(seconds),
                                          microsecond=int((seconds % 1) * 1e6))

    # Parse the latitude
    lat = float(parts[2])
    lat_deg = int(lat / 100)
    lat_min = lat - lat_deg * 100
    lat_sign = 1 if parts[3] == 'N' else -1
    lat = lat_sign * (lat_deg + lat_min / 60)

    # Parse the longitude
    lon = float(parts[4])
    lon_deg = int(lon / 100)
    lon_min = lon - lon_deg * 100
    lon_sign = 1 if parts[5] == 'E' else -1
    lon = lon_sign * (lon_deg + lon_min / 60)

    # Parse the altitude (we're assuming M units, check if different)
    alt = float(parts[9])

    return timestamp, lat, lon, alt





# this loop will run and either retrieve GNSS data from GPSD or no data.
def gps_loop():
    global current_gps_value
    global gps_stream
    global new_gps_data_available

    while True:
        if with_gpsd:
            current_gps_value = next(gps_stream)
            new_gps_data_available = True
        else:
            new_gps_data_available = True
            current_gps_value = True

        # Save raw GNSS data to KML
        latitude = current_gps_value.get("lat", "n/a") if with_gpsd else 0
        longitude = current_gps_value.get("lon", "n/a") if with_gpsd else 0
        altitude = current_gps_value.get("alt", "n/a") if with_gpsd else 0
        coordinates = {
            "latitude": latitude,
            "longitude": longitude,
            "altitude": altitude,
        }
        time.sleep(0.01)

if __name__ == "__main__":

    global current_gps_value
    global gps_stream
    global new_gps_data_available
    import simplekml

    kml = simplekml.Kml()
    kml_baseline = simplekml.Kml()
    stream = open('C:\\Users\\ben93\\PycharmProjects\\test2\\yolov7\\Bodensee\\AR200\\COM9___921600_230521_100310.ubx',
                  'rb')

    ubr = UBXReader(stream, protfilter=1)
    counter = 0
    start_time = None

    new_gps_data_available = False

    # some variables
    camera_pose = sl.Pose()    
    odometry_pose = sl.Pose()    
    py_translation = sl.Translation()
    pose_data = sl.Transform()
    text_translation = ""
    text_rotation = ""   
    current_gps_value = None

    # connect to ZED Hub
    if with_zed_hub:                    
        sliot.HubClient.connect("Geotracking sample")

    init_params = sl.InitParameters(camera_resolution=sl.RESOLUTION.HD1080,
                                 coordinate_units=sl.UNIT.METER,
                                 coordinate_system=sl.COORDINATE_SYSTEM.RIGHT_HANDED_Y_UP)
                                 
    # If applicable, use the SVO given as parameter
    # Otherwise use ZED live stream
    if len(sys.argv) == 2:
        filepath = sys.argv[1]
        print("Using SVO file: {0}".format(filepath))
        init_params.set_from_svo_file(filepath)

    # step 1
    # create the camera that will input the position from its odometry
    zed = sl.Camera()
    status = zed.open(init_params)
    if status != sl.ERROR_CODE.SUCCESS:
        print(status)
        exit()

    if with_zed_hub:                    
        status = sliot.HubClient.register_camera(zed)
        print("STATUS", status)
    
    communication_parameters = sl.CommunicationParameters()
    communication_parameters.set_for_shared_memory()
    zed.start_publishing(communication_parameters)

    # warmup
    if zed.grab() != sl.ERROR_CODE.SUCCESS:
        print("Unable to initialize the camera.")
        exit(1)
    else:
        zed.get_position(odometry_pose, sl.REFERENCE_FRAME.WORLD)

    tracking_params = sl.PositionalTrackingParameters()
    # These parameters are mandatory to initialize the transformation between GNSS and ZED reference frames.
    tracking_params.enable_imu_fusion = True
    tracking_params.set_gravity_as_origin = True
    zed.enable_positional_tracking(tracking_params)

    camera_info = zed.get_camera_information()
    # Create OpenGL viewer
    viewer = gl.GLViewer()
    viewer.init(camera_info.camera_model)

    # step 2
    # init the fusion module that will input both the camera and the GPS
    fusion = sl.Fusion()
    init_fusion_parameters = sl.InitFusionParameters()
    init_fusion_parameters.coordinate_system = sl.COORDINATE_SYSTEM.RIGHT_HANDED_Y_UP
    init_fusion_parameters.coordinate_units = sl.UNIT.METER
    init_fusion_parameters.output_performance_metrics = False
    init_fusion_parameters.verbose = True
    fusion.init(init_fusion_parameters)
    fusion.enable_positionnal_tracking()
    uuid = sl.CameraIdentifier(camera_info.serial_number)
    print("Subscribing to", uuid.serial_number, communication_parameters.comm_type)
    status = fusion.subscribe(uuid, communication_parameters, sl.Transform(0,0,0))
    if status != sl.FUSION_ERROR_CODE.SUCCESS:
        print("Failed to subscribe to", uuid.serial_number, status)
        exit(1)

    # initialize the GNSS - gpsd https://gpsd.gitlab.io/gpsd/installation.html
    if with_gpsd:
        client = GPSDClient(host="127.0.0.1")
        gps_stream = client.dict_stream(convert_datetime=True, filter=["TPV"])
    else:
        dummy_longitude = 0

    dummy_longitude = 0
    time_to_gps = {}
    for (raw_data, parsed_data) in ubr:
        raw_data = raw_data.decode('utf-8')
        # print(raw_data)
        if '$YDGGA' in raw_data:
            nmea_sentence = raw_data
            timestamp_nmea, lat, lon, alt = convert_nmea_to_timestamp_complete(nmea_sentence)
            if start_time is None:
                start_time = timestamp_nmea

            timestamp_nmea = timestamp_nmea + timedelta(hours=2)
            time_to_gps[str(timestamp_nmea)] = [lat, lon, alt]
            counter += 1

    s = "2023-05-21 12:07:52.464567"
    time_vid = datetime.strptime(s, "%Y-%m-%d %H:%M:%S.%f")

    gps_thread = threading.Thread(target=gps_loop)
    gps_thread.start()
    previous_time_vid = None
    try:
        while (viewer.is_available()):
            # get the odometry information
            if zed.grab() == sl.ERROR_CODE.SUCCESS:
                zed.get_position(odometry_pose, sl.REFERENCE_FRAME.WORLD)

            elif zed.grab() == sl.ERROR_CODE.END_OF_SVOFILE_REACHED:
                break

            # dummy_longitude = dummy_longitude + 0.000001

            # GPS ingest
            if current_gps_value is not None and new_gps_data_available:
                time.sleep(0.01)
                gnss_data = sl.GNSSData()

                if with_gpsd:
                # retrieve the latest value from the GPS
                    longitude = current_gps_value.get("lon", "n/a")
                    # If you want the geotracking for global scale localization on real-world map to work without moving your GNSS, uncomment this line.
                    # longitude = current_gps_value.get("lon", "n/a") + 20* dummy_longitude

                    latitude = current_gps_value.get("lat", "n/a")
                    altitude = current_gps_value.get("alt", "n/a")

                    # retrieve the timestamp and convert it to ZED SDK format
                    date = current_gps_value.get("time")
                    timestamp = sl.Timestamp()
                    timestamp.set_seconds(datetime.datetime.timestamp(date))

                    eph = current_gps_value.get("eph", "n/a")
                    epv = current_gps_value.get("epv", "n/a")                
                else:

                    if previous_time_vid is None:
                        previous_time_vid = sl.get_current_timestamp().get_milliseconds()
                        timestamp = previous_time_vid
                        delta_time = 0
                    else:
                        timestamp = sl.get_current_timestamp()
                        curr_time_stuff = timestamp.get_milliseconds()
                        delta_time = curr_time_stuff - previous_time_vid
                        previous_time_vid = curr_time_stuff

                    time_vid = time_vid + timedelta(seconds=delta_time / 1000.0)

                    timestamp = sl.Timestamp()
                    timestamp.set_seconds(datetime.timestamp(time_vid))

                    vid2gpsdiff = np.Inf
                    corresponding_gps_time = None
                    for times_gps in time_to_gps.keys():
                        times_gps_dt = datetime.strptime(times_gps, "%Y-%m-%d %H:%M:%S.%f")
                        if abs(times_gps_dt - time_vid).total_seconds()*1000 < vid2gpsdiff:
                            vid2gpsdiff = abs(times_gps_dt - time_vid).total_seconds()*1000
                            corresponding_gps_time = times_gps

                    print("closest gps point to ", time_vid, " is ", corresponding_gps_time)
                    if vid2gpsdiff > 1000:
                        print("nonono")
                        raise ValueError("no corresponding gps found")
                    longitude = time_to_gps[corresponding_gps_time][1]
                    latitude = time_to_gps[corresponding_gps_time][0]
                    altitude = time_to_gps[corresponding_gps_time][2]

                    longitude_bas = time_to_gps[corresponding_gps_time][1]
                    latitude_bas = time_to_gps[corresponding_gps_time][0]
                    altitude_bas = time_to_gps[corresponding_gps_time][2]

                    eph = 1
                    epv = 1

                # put your GPS coordinates here : latitude, longitude, altitude
                new_gps_data_available = False
                gnss_data.set_coordinates(latitude, longitude, altitude, False)
                gnss_data.ts = timestamp

                print("raw gnss", latitude, longitude, altitude)

                # put your covariance here if you know it, as an matrix 3x3 in a line
                # in this case
                # [eph * eph   0   0]
                # [0   eph * eph   0]
                # [0   0   epv * epv]
                covariance = [  
                                eph * eph,      0.1,  0.1,
                                0.1,      eph * eph,  0.1,
                                0.1,      0.1,      epv * epv
                            ]

                gnss_data.position_covariances = covariance
                fusion.ingest_gnss_data(gnss_data)

            # get the fused position
            if fusion.process() == sl.FUSION_ERROR_CODE.SUCCESS:
                fused_tracking_state = fusion.get_position(camera_pose, sl.REFERENCE_FRAME.WORLD)
                # you can also retrieve the un-fused position with
                # tracking_state = fusion.get_current_gnss_data(...)

                if fused_tracking_state == sl.POSITIONAL_TRACKING_STATE.OK:
                    rotation = camera_pose.get_rotation_vector()
                    translation = camera_pose.get_translation(py_translation)
                    text_rotation = str((round(rotation[0], 2), round(rotation[1], 2), round(rotation[2], 2)))
                    text_translation = str((round(translation.get()[0], 2), round(translation.get()[1], 2), round(translation.get()[2], 2)))
                    pose_data = camera_pose.pose_data(sl.Transform())

                    viewer.updateData(pose_data, text_translation, text_rotation, fused_tracking_state)

                    # send data to zed hub
                    # visualize it on https://hub.stereolabs.com/workspaces/<workspace_id>/maps
                    geopose = sl.GeoPose()
                    status = fusion.camera_to_geo(camera_pose, geopose)
                   
                    # the fusion will stay in SEARCHING mode until the GNSS has walked at least 5 meters.
                    if status != sl.POSITIONAL_TRACKING_STATE.OK:
                        print(status)
                    else:
                        print("OK")

                        latitude, longitude, altitude = geopose.latlng_coordinates.get_coordinates(False)

                        if with_zed_hub:                    
                            gps = {}
                            gps["layer_type"] = "geolocation"
                            gps["label"] = "GPS_data"
                            gps["position"] = {}
                            gps["position"]["latitude"] = latitude
                            gps["position"]["longitude"] = longitude
                            gps["position"]["altitude"] = altitude
                            status = sliot.HubClient.send_data_to_peers("geolocation", json.dumps(gps))
                        else:
                            # Save computed path into KML
                            coordinates = {
                                "latitude": latitude,
                                "longitude": longitude,
                                "altitude": altitude,
                            }
                                                
                        print("fused gnss", latitude, longitude, altitude)
                        kml.newpoint(coords=[(longitude, latitude, altitude)])
                        kml_baseline.newpoint(coords=[(longitude_bas, latitude_bas, altitude_bas)])

            if with_zed_hub:
                sliot.HubClient.update()

    except KeyboardInterrupt:
        # got a ^C.  Say bye, bye
        kml.save("trajectory.kml")
        kml_baseline.save("trajectory_baseline.kml")
        print("saved kml")
        print('Bye ! (saved trajectory)')
        gps_thread.join()

    viewer.exit()
    zed.close()

