#!/usr/bin/env python

from datetime import datetime as dt
from datetime import datetime, timedelta
import numpy as np
import pytz
import random
import sys
import time

from aman.sys.aco.Ant import Ant
from aman.sys.aco.Configuration import Configuration
from aman.sys.aco.Node import Node
from aman.sys.aco.RunwayManager import RunwayManager
from aman.types.Inbound import Inbound

# This class implements the ant colony of the following paper:
# https://sci-hub.mksa.top/10.1109/cec.2019.8790135
class Colony:
    def associateInbound(rwyManager : RunwayManager, node : Node, earliestArrivalTime : datetime):
        rwy, eta, _ = rwyManager.selectArrivalRunway(node, earliestArrivalTime)
        eta = max(earliestArrivalTime, eta)

        node.Inbound.PlannedRunway = rwy
        node.Inbound.PlannedStar = node.ArrivalCandidates[rwy.Name].Star
        node.Inbound.PlannedArrivalRoute = node.ArrivalCandidates[rwy.Name].ArrivalRoute
        node.Inbound.PlannedArrivalTime = eta
        node.Inbound.InitialArrivalTime = node.ArrivalCandidates[rwy.Name].InitialArrivalTime
        node.Inbound.PlannedTrackmiles = node.ArrivalCandidates[rwy.Name].Trackmiles
        rwyManager.RunwayInbounds[rwy.Name] = node

    def calculateInitialCosts(rwyManager : RunwayManager, nodes, earliestArrivalTime : datetime):
        overallDelay = timedelta(seconds = 0)

        # assume that the nodes are sorted in FCFS order
        for node in nodes:
            Colony.associateInbound(rwyManager, node, earliestArrivalTime)
            overallDelay += node.Inbound.PlannedArrivalTime - node.Inbound.InitialArrivalTime

        return overallDelay

    def __init__(self, inbounds, configuration : Configuration):
        self.Configuration = configuration
        self.ResultDelay = None
        self.Result = None
        self.Nodes = []

        # create the new planning instances
        currentTime = dt.utcfromtimestamp(int(time.time())).replace(tzinfo = pytz.UTC)
        for inbound in inbounds:
            self.Nodes.append(Node(inbound, currentTime, self.Configuration.WeatherModel, self.Configuration.AirportConfiguration, self.Configuration.RunwayConstraints))

        rwyManager = RunwayManager(self.Configuration)
        delay = Colony.calculateInitialCosts(rwyManager, self.Nodes, self.Configuration.EarliestArrivalTime)
        self.FcfsDelay = delay

        # run the optimization in every cycle to ensure optimal spacings based on TTG
        if 0.0 >= delay.total_seconds():
            delay = timedelta(seconds = 1.0)

        # initial value for the optimization
        self.Configuration.ThetaZero = 1.0 / (len(self.Nodes) * (delay.total_seconds() / 60.0))
        self.PheromoneMatrix = np.ones(( len(self.Nodes), len(self.Nodes) ), dtype=float) * self.Configuration.ThetaZero

    def sequenceAndPredictInbound(self, rwyManager : RunwayManager, node : Node):
        self.Result.append(node)
        Colony.associateInbound(rwyManager, node, self.Configuration.EarliestArrivalTime)

        reqTimeDelta = self.Result[-1].Inbound.InitialArrivalTime - self.Result[-1].Inbound.PlannedArrivalTime
        self.Result[-1].Inbound.PlannedArrivalRoute[0].PTA = self.Result[-1].Inbound.PlannedArrivalRoute[0].ETA - reqTimeDelta
        for i in range(1, len(self.Result[-1].Inbound.PlannedArrivalRoute)):
            prev = self.Result[-1].Inbound.PlannedArrivalRoute[i - 1]
            current = self.Result[-1].Inbound.PlannedArrivalRoute[i]
            current.PTA = prev.PTA + (current.ETA - prev.ETA)

    def optimize(self):
        # FCFS is the best solution
        if None != self.Result:
            return

        # define the tracking variables
        bestSequence = None

        # run the optimization loops
        for _ in range(0, self.Configuration.ExplorationRuns):
            # select the first inbound
            index = random.randint(1, len(self.Nodes)) - 1
            candidates = []

            for _ in range(0, self.Configuration.AntCount):
                # let the ant find a solution
                ant = Ant(self.PheromoneMatrix, self.Configuration, self.Nodes)
                ant.findSolution(index)

                # fallback to check if findSolution was successful
                if None == ant.SequenceDelay or None == ant.Sequence:
                    sys.stderr.write('Invalid ANT run detected!')
                    sys.exit(-1)

                candidates.append(
                    [
                        ant.SequenceDelay,
                        ant.Sequence
                    ]
                )

            # find the best solution in all candidates of this generation
            bestCandidate = None
            for candidate in candidates:
                if None == bestCandidate or candidate[0] < bestCandidate[0]:
                    bestCandidate = candidate

                if None != bestSequence:
                    dTheta = 1.0 / ((bestSequence[0].total_seconds() / 60.0) or 1.0)
                    for i in range(1, len(bestSequence[1])):
                        update = (1.0 - self.Configuration.Epsilon) * self.PheromoneMatrix[bestSequence[1][i - 1], bestSequence[1][i]] +  self.Configuration.Epsilon * dTheta
                        self.PheromoneMatrix[bestSequence[1][i - 1], bestSequence[1][i]] = max(update, self.Configuration.ThetaZero)

            # check if we find a new best candidate
            if None != bestCandidate:
                if None == bestSequence or bestCandidate[0] < bestSequence[0]:
                    bestSequence = bestCandidate

        # create the final sequence
        self.Result = []
        rwyManager = RunwayManager(self.Configuration)

        # use the optimized sequence
        if None != bestSequence and self.FcfsDelay >= bestSequence[0]:
            # create the resulting sequence
            self.ResultDelay = bestSequence[0]

            # finalize the sequence
            for idx in bestSequence[1]:
                self.sequenceAndPredictInbound(rwyManager, self.Nodes[idx])
        # use the FCFS sequence
        else:
            self.ResultDelay = self.FcfsDelay
            for node in self.Nodes:
                self.sequenceAndPredictInbound(node)