########################################################################
#
# Copyright (c) 2023, STEREOLABS.
#
# All rights reserved.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
########################################################################

"""
    This sample shows how to fuse the position of the ZED camera with an external GNSS Sensor
"""

import sys
import pyzed.sl as sl
import gnss_tools
import datetime
from scipy import interpolate
import exporter.KMLExporter as export
import gnss_tools as gt
import numpy as np
import time


class Geolocation:
    def __init__(self, filepath: str):
        nmea_data = gnss_tools.read_nmea_file(file_path=filepath)
        self.gnss_nmea_gga = nmea_data["GGA"]

        t_gga = (
            self.gnss_nmea_gga["OS_Datetime"] - datetime.datetime(1970, 1, 1, 0, 0, 0)
        ).dt.total_seconds()

        self.gnss_nmea_gga.insert(
            len(self.gnss_nmea_gga.columns),
            "OS_ts_millisec",
            t_gga * 1e3,
            allow_duplicates=True,
        )

        self.interp_fun_lat = interpolate.interp1d(
            t_gga, self.gnss_nmea_gga["Lat_decdeg"].to_numpy()
        )
        self.interp_fun_lon = interpolate.interp1d(
            t_gga, self.gnss_nmea_gga["Lon_decdeg"].to_numpy()
        )
        self.interp_fun_alt = interpolate.interp1d(
            t_gga, self.gnss_nmea_gga["Altitude"].to_numpy()
        )

        self.gnss_nmea_gst = nmea_data["GST"]

        t_gst = (
            self.gnss_nmea_gst["OS_Datetime"] - datetime.datetime(1970, 1, 1, 0, 0, 0)
        ).dt.total_seconds()
        self.interp_fun_lat_err = interpolate.interp1d(
            t_gst, self.gnss_nmea_gst["LatitudeError"].to_numpy()
        )
        self.interp_fun_lon_err = interpolate.interp1d(
            t_gst, self.gnss_nmea_gst["LongitudeError"].to_numpy()
        )
        self.interp_fun_alt_err = interpolate.interp1d(
            t_gst, self.gnss_nmea_gst["AltitudeError"].to_numpy()
        )

        self.gnss_nmea_gga.insert(
            len(self.gnss_nmea_gga.columns),
            "Lat_err",
            self.interp_fun_lat_err(t_gga),
            allow_duplicates=True,
        )
        self.gnss_nmea_gga.insert(
            len(self.gnss_nmea_gga.columns),
            "Lon_err",
            self.interp_fun_lon_err(t_gga),
            allow_duplicates=True,
        )
        self.gnss_nmea_gga.insert(
            len(self.gnss_nmea_gga.columns),
            "Alt_err",
            self.interp_fun_alt_err(t_gga),
            allow_duplicates=True,
        )


if __name__ == "__main__":
    gl = Geolocation("gnss_nmea_gnss_nmea_tcp_20230512_083519.176194908.txt")
    # some variables
    camera_pose = sl.Pose()
    odometry_pose = sl.Pose()
    py_translation = sl.Translation()
    pose_data = sl.Transform()
    text_translation = ""
    text_rotation = ""

    init_params = sl.InitParameters(
        camera_resolution=sl.RESOLUTION.HD720,
        coordinate_units=sl.UNIT.METER,
        coordinate_system=sl.COORDINATE_SYSTEM.RIGHT_HANDED_Y_UP,
        sdk_verbose=1,
    )

    # If applicable, use the SVO given as parameter
    # Otherwise use ZED live stream
    if True:  # len(sys.argv) == 2:
        filepath = r"C:\my_desk\zed_files\ZED1002\20230512\ZED1002_20230512_092806.434704-20230512_093304.873144_HD2K_15fps_300s.svo"  # sys.argv[1]
        print("Using SVO file: {0}".format(filepath))
        init_params.set_from_svo_file(filepath)

    # step 1
    # create the camera that will input the position from its odometry
    zed = sl.Camera()
    status = zed.open(init_params)
    if status != sl.ERROR_CODE.SUCCESS:
        print(status)
        exit()

    communication_parameters = sl.CommunicationParameters()
    communication_parameters.set_for_shared_memory()
    zed.start_publishing(communication_parameters)

    # warmup
    if zed.grab() != sl.ERROR_CODE.SUCCESS:
        print("Unable to initialize the camera.")
        exit(1)
    else:
        zed.get_position(odometry_pose, sl.REFERENCE_FRAME.WORLD)

    tracking_params = sl.PositionalTrackingParameters()
    # These parameters are mandatory to initialize the transformation between GNSS and ZED reference frames.
    tracking_params.enable_imu_fusion = True
    tracking_params.set_gravity_as_origin = True
    zed.enable_positional_tracking(tracking_params)

    camera_info = zed.get_camera_information()

    # step 2
    # init the fusion module that will input both the camera and the GPS
    fusion = sl.Fusion()
    init_fusion_parameters = sl.InitFusionParameters()
    init_fusion_parameters.coordinate_system = sl.COORDINATE_SYSTEM.RIGHT_HANDED_Y_UP
    init_fusion_parameters.coordinate_units = sl.UNIT.METER

    fusion.init(init_fusion_parameters)
    fusion.enable_positionnal_tracking()
    fusion.disable_body_tracking()

    uuid = sl.CameraIdentifier(camera_info.serial_number)
    print("Subscribing to", uuid.serial_number, communication_parameters.comm_type)
    status = fusion.subscribe(uuid, communication_parameters, sl.Transform(0, 0, 0))
    if status != sl.FUSION_ERROR_CODE.SUCCESS:
        print("Failed to subscribe to", uuid.serial_number, status)
        exit(1)

    x = 0

    i = 0
    for i in range(500):
        # get the odometry information
        if zed.grab() == sl.ERROR_CODE.SUCCESS:
            zed.get_position(odometry_pose, sl.REFERENCE_FRAME.WORLD)

        elif zed.grab() == sl.ERROR_CODE.END_OF_SVOFILE_REACHED:
            break

        # dummy GPS value
        # x = x + 0.000000001
        # get the GPS information
        time.sleep(0.01)
        gnss_data = sl.GNSSData()
        # gnss_data.ts = zed.get_timestamp(sl.TIME_REFERENCE.IMAGE)

        ts = odometry_pose.timestamp
        ts_sec = ts.get_milliseconds() / 1e3
        lat = gl.interp_fun_lat(ts_sec)
        lon = gl.interp_fun_lon(ts_sec)
        alt = gl.interp_fun_alt(ts_sec)
        print(
            ts.get_milliseconds(),
            datetime.datetime.fromtimestamp(ts_sec),
            gt.decdeg2ddmm(lat),
            gt.decdeg2dddmm(lon),
            alt,
        )
        # put your GPS corrdinates here : latitude, longitude, altitude
        gnss_data.ts = ts
        gnss_data.set_coordinates(lat, lon, alt, in_radian=False)

        # put your covariance here if you know it, as an matrix 3x3 in a line
        # This is the default value
        # covariance = [1, 0.1, 0.1, 0.1, 1, 0.1, 0.1, 0.1, 1]

        eplat = gl.interp_fun_lat_err(ts_sec)
        eplon = gl.interp_fun_lon_err(ts_sec)
        eph = np.sqrt(eplat * eplon)
        epv = gl.interp_fun_alt_err(ts_sec)

        # covariance = [eph * eph, 0, 0, 0, eph * eph, 0, 0, 0, epv * epv]
        covariance = [1, 0, 0, 0, 1, 0, 0, 0, 1]

        # covariance = [eplat**2, 0.1, 0.1, 0.1, eplon**2, 0.1, 0.1, 0.1, epv**2]

        gnss_data.position_covariances = covariance
        ingest_status = fusion.ingest_gnss_data(gnss_data)
        print(ingest_status)

        # get the fused position
        if fusion.process() == sl.FUSION_ERROR_CODE.SUCCESS:
            fused_tracking_state = fusion.get_position(
                camera_pose, sl.REFERENCE_FRAME.WORLD
            )
            if fused_tracking_state == sl.POSITIONAL_TRACKING_STATE.OK:
                rotation = camera_pose.get_rotation_vector()
                translation = camera_pose.get_translation(py_translation)
                text_rotation = str(
                    (
                        round(rotation[0], 2),
                        round(rotation[1], 2),
                        round(rotation[2], 2),
                    )
                )
                text_translation = str(
                    (
                        round(translation.get()[0], 2),
                        round(translation.get()[1], 2),
                        round(translation.get()[2], 2),
                    )
                )
                pose_data = camera_pose.pose_data(sl.Transform())
                geopose = sl.GeoPose()
                status = fusion.camera_to_geo(camera_pose, geopose)

                # the fusion will stay in SEARCHING mode until the GNSS has walked at least 5 meters.
                if status != sl.POSITIONAL_TRACKING_STATE.OK:
                    print(status)
                else:
                    print("OK")
                    (
                        latitude,
                        longitude,
                        altitude,
                    ) = geopose.latlng_coordinates.get_coordinates(False)
                    # Save computed path into KML
                    (
                        latitude,
                        longitude,
                        altitude,
                    ) = geopose.latlng_coordinates.get_coordinates(False)
                    coordinates = {
                        "latitude": latitude,
                        "longitude": longitude,
                        "altitude": altitude,
                    }
                    export.saveKMLData("computed_geoposition.kml", coordinates)

                    print(latitude, longitude, altitude)

        i = i + 1

    zed.close()
