"""
This module defines the country widgets.


Copyright (c) 2023 Proton AG

This file is part of Proton VPN.

Proton VPN is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

Proton VPN is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with ProtonVPN.  If not, see <https://www.gnu.org/licenses/>.
"""

from __future__ import annotations

from dataclasses import dataclass

from typing import List, Tuple, Set
from gi.repository import GLib, GObject

from proton.vpn.app.gtk.utils.accessibility import add_accessibility
from proton.vpn.app.gtk.utils.search import normalize
from proton.vpn.connection.enum import ConnectionStateEnum
from proton.vpn.session.servers import Country
from proton.vpn import logging
from proton.vpn.app.gtk import Gtk
from proton.vpn.app.gtk.controller import Controller
from proton.vpn.app.gtk.widgets.vpn.serverlist.icons import \
    SmartRoutingIcon, P2PIcon, TORIcon, UnderMaintenanceIcon
from proton.vpn.app.gtk.widgets.vpn.serverlist.server import ServerRow
from proton.vpn.session.servers import LogicalServer
from proton.vpn.session.servers import ServerFeatureEnum

logger = logging.getLogger(__name__)


@dataclass
class CountryAnalysis:
    """Contains a summary of the state all the servers in a given country."""
    country_connection_state: ConnectionStateEnum
    smart_routing_country: bool
    under_maintenance: bool
    is_free_country: bool
    country_features: Set[ServerFeatureEnum]


def _analyze_servers(ordered_servers: List[LogicalServer],
                     connected_server_id: str = None) -> CountryAnalysis:
    """
    Iterates over the ordered list of servers and extracts information
    to be displayed for the country.
    """
    # Properties initialized after analysing the servers.
    is_free_country = None

    country_features = set()

    # The country is set under maintenance until the opposite is proven.
    under_maintenance = True

    # The country connection state is set as disconnected until the opposite is proven.
    country_connection_state = ConnectionStateEnum.DISCONNECTED

    # Smart routing is assumed to be used until the opposite is proven.
    smart_routing_country = True

    for server in ordered_servers:
        country_features.update(server.features)

        is_free_country = is_free_country or server.tier == 0

        # The country is under maintenance if (1) that was the case up until now and
        # (2) the current server is also under maintenance (i.e. is not enabled).
        under_maintenance = (under_maintenance and not server.enabled)

        # A country is flagged as a "Smart routing" location if *all* servers are
        # actually physically located in a neighboring country.
        smart_routing_country = \
            (smart_routing_country and server.host_country is not None)

        # If we are currently connected to a server then set its row state to "connected".
        if connected_server_id == server.id:
            country_connection_state = ConnectionStateEnum.CONNECTED

    return CountryAnalysis(
        country_connection_state,
        smart_routing_country,
        under_maintenance,
        is_free_country,
        country_features
    )


class CountryHeader(Gtk.Box):  # pylint: disable=too-many-instance-attributes
    """Header with the country name shown at the beginning of each CountryRow."""
    # pylint: disable=too-many-arguments
    def __init__(
            self,
            country: Country,
            under_maintenance: bool,
            upgrade_required: bool,
            server_features: Set[ServerFeatureEnum],
            smart_routing: bool,
            connection_state: ConnectionStateEnum,
            controller: Controller,
            show_country_servers: bool = False
    ):
        super().__init__(orientation=Gtk.Orientation.HORIZONTAL)
        self._country = country
        self._under_maintenance = under_maintenance
        self._upgrade_required = upgrade_required
        self._server_features = server_features
        self._smart_routing = smart_routing
        self._controller = controller

        self._country_name_label = None
        self._under_maintenance_icon = None
        self._country_details = None

        self._connect_button = None
        self._connect_button_handler_id = None
        self._toggle_button = None
        self._toggle_button_handler_id = None

        self._collapsed_img = Gtk.Image.new_from_icon_name("pan-down-symbolic")
        self._expanded_img = Gtk.Image.new_from_icon_name("pan-up-symbolic")

        self._build_ui(connection_state)

        # The following setters needs to be called after the UI has been built
        # as they need to modify some UI widgets.
        self.show_country_servers = show_country_servers
        self._connection_state = connection_state

    def _build_ui(self, connection_state: ConnectionStateEnum):
        self._country_name_label = Gtk.Label(label=self.country_name)
        self._country_name_label.set_halign(Gtk.Align.START)
        self.prepend(self._country_name_label)
        self.set_spacing(10)

        self._show_under_maintenance_icon_or_country_details()

        self._toggle_button = Gtk.Button()
        self._toggle_button.add_css_class("secondary")
        self._toggle_button_handler_id = self._toggle_button.connect(
            "clicked", self._on_toggle_button_clicked
        )
        self._country_name_label.set_halign(Gtk.Align.END)
        self.append(self._toggle_button)

        self.connection_state = connection_state

    def update_under_maintenance_status(self, under_maintenance: bool):
        """Shows or hides the under maintenance status for the country."""
        self._under_maintenance = under_maintenance
        self._show_under_maintenance_icon_or_country_details()

    def _show_under_maintenance_icon_or_country_details(self):
        if self._under_maintenance:
            self._show_under_maintenance_icon()
        else:
            self._show_country_details()

    def _show_under_maintenance_icon(self):
        if self._country_details:
            self._country_details.set_visible(False)

        if not self._under_maintenance_icon:
            self._under_maintenance_icon = UnderMaintenanceIcon(self.country_name)
            self._under_maintenance_icon.set_halign(Gtk.Align.END)
            self.append(self._under_maintenance_icon)

        self._country_name_label.set_property("sensitive", False)

    def _show_country_details(self):
        if self._under_maintenance_icon:
            self._under_maintenance_icon.set_visible(False)

        if not self._country_details:
            self._country_details = self._build_country_details()
            self.append(self._country_details)

        self._country_details.set_visible(True)
        self._country_name_label.set_property("sensitive", True)

    def _build_country_details(self):
        country_details = Gtk.Box(orientation=Gtk.Orientation.HORIZONTAL)
        country_details.set_halign(Gtk.Align.END)
        country_details.set_hexpand(True)
        country_details.set_spacing(10)

        if self._upgrade_required:
            button = self._build_upgrade_required_link_button()
            country_details.append(button)
        else:
            button = self._build_connect_button()
            self._connect_button = button
            country_details.append(self._connect_button)

        add_accessibility(button, Gtk.AccessibleRelation.LABELLED_BY, self._country_name_label)

        country_row_icons = []
        if self._smart_routing:
            country_row_icons.append(SmartRoutingIcon())

        server_feature_icons = self._build_server_feature_icons()
        country_row_icons.extend(server_feature_icons)
        for icon in country_row_icons:
            country_details.prepend(icon)

        if country_row_icons:
            add_accessibility(button, Gtk.AccessibleRelation.DESCRIBED_BY, country_row_icons)

        return country_details

    @property
    def under_maintenance(self) -> bool:
        """Indicates whether all the servers for this country are under maintenance or not."""
        return self._under_maintenance

    @property
    def upgrade_required(self):
        """Indicates whether the user needs to upgrade to have access to this country or not."""
        return self._upgrade_required

    def _build_upgrade_required_link_button(self) -> Gtk.LinkButton:
        upgrade_button = Gtk.LinkButton.new_with_label("Upgrade")
        upgrade_button.set_uri("https://account.protonvpn.com/")
        return upgrade_button

    def _build_connect_button(self) -> Gtk.Button:
        connect_button = Gtk.Button()
        self._connect_button_handler_id = connect_button.connect(
            "clicked", self._on_connect_button_clicked
        )
        connect_button.add_css_class("secondary")
        return connect_button

    def _build_server_feature_icons(self) -> List[Gtk.Image]:
        server_feature_icons = []
        if ServerFeatureEnum.P2P in self._server_features:
            server_feature_icons.append(P2PIcon())
        if ServerFeatureEnum.TOR in self._server_features:
            server_feature_icons.append(TORIcon())
        return server_feature_icons

    @property
    def server_features(self) -> Set[ServerFeatureEnum]:
        """Returns the set of features supported by the servers in this country."""
        return self._server_features

    @GObject.Signal(name="toggle-country-servers")
    def toggle_country_servers(self):
        """Signal when the user clicks the button to expand/collapse the servers
        from a country."""

    @property
    def country_code(self):
        """Returns the code of the country this header is for."""
        return self._country.code

    @property
    def country_name(self):
        """Returns the name of the country this header is for."""
        return self._country.name

    @property
    def show_country_servers(self):
        """Returns whether the country servers should be shown or not."""
        return self._show_country_servers

    @show_country_servers.setter
    def show_country_servers(self, show_country_servers: bool):
        """Sets whether the country servers should be shown or not."""
        self._show_country_servers = show_country_servers
        self._toggle_button.set_child(
            self._expanded_img if self.show_country_servers else self._collapsed_img
        )
        self._toggle_button.set_tooltip_text(
            f"Hide all servers from {self.country_name}" if self.show_country_servers else
            f"Show all servers from {self.country_name}"
        )

    @property
    def available(self) -> bool:
        """Returns True if the country is available, meaning the user can
        connect to one of its servers. Otherwise, it returns False."""
        return not self.upgrade_required and not self.under_maintenance

    @property
    def connection_state(self):
        """Returns the connection state of the server shown in this row."""
        return self._connection_state

    @connection_state.setter
    def connection_state(self, connection_state: ConnectionStateEnum):
        """Sets the connection state, modifying the row depending on the state."""
        # pylint: disable=duplicate-code
        self._connection_state = connection_state

        if self.available:
            # Update the server row according to the connection state.
            method = f"_on_connection_state_{connection_state.name.lower()}"
            if hasattr(self, method):
                getattr(self, method)()

    def _on_toggle_button_clicked(self, _toggle_button: Gtk.Button):
        self.show_country_servers = not self.show_country_servers
        self.emit("toggle-country-servers")

    def _on_connect_button_clicked(self, _connect_button: Gtk.Button):
        future = self._controller.connect_to_country(self.country_code)
        future.add_done_callback(lambda f: GLib.idle_add(f.result))  # bubble up exceptions if any.

    def _on_connection_state_disconnected(self):
        """Flags this server as "not connected"."""
        self._connect_button.set_sensitive(True)
        self._connect_button.set_label("Connect")

    def _on_connection_state_connecting(self):
        """Flags this server as "connecting"."""
        self._connect_button.set_label("Connecting...")
        self._connect_button.set_sensitive(False)

    def _on_connection_state_connected(self):
        """Flags this server as "connected"."""
        self._connect_button.set_sensitive(False)
        self._connect_button.set_label("Connected")

    def _on_connection_state_disconnecting(self):
        pass

    def _on_connection_state_error(self):
        """Flags this server as "error"."""
        self._on_connection_state_disconnected()

    def click_toggle_country_servers_button(self):
        """Clicks the button to toggle the country servers.
        This method was made available for tests."""
        self._toggle_button.emit("clicked")

    def click_connect_button(self):
        """Clicks the button to connect to the country.
        This method was made available for tests."""
        self._connect_button.emit("clicked")

    def grab_focus(self):  # pylint: disable=arguments-differ
        """Focuses on the connect button if available, otherwise on the toggle button."""
        if self._connect_button and not self._under_maintenance:
            self._connect_button.grab_focus()
        elif self._toggle_button:
            self._toggle_button.grab_focus()

    def cleanup(self):
        """Clean up signal connections to allow garbage collection."""
        if self._toggle_button_handler_id:
            self._toggle_button.disconnect(self._toggle_button_handler_id)
            self._toggle_button_handler_id = None
            self._toggle_button = None

        if self._connect_button_handler_id:
            self._connect_button.disconnect(self._connect_button_handler_id)
            self._connect_button_handler_id = None
            self._connect_button = None


class DeferredCountryRow(Gtk.Box):  # pylint: disable=too-many-instance-attributes
    """Row containing all servers, servers are loaded lazily"""

    # pylint: disable=too-many-arguments
    def __init__(
            self,
            country: Country,
            user_tier: int,
            controller: Controller,
            connected_server_id: str = None,
            show_country_servers: bool = False,
    ):
        super().__init__(orientation=Gtk.Orientation.VERTICAL)

        self._controller = controller
        self._indexed_server_rows = {}

        free_servers, plus_servers =\
            self._group_servers_by_tier(country.servers)
        is_free_user = user_tier == 0

        # Properties initialized after building all server rows.
        self._is_free_country = None
        self._upgrade_required = None
        self._country_features = set()
        self._under_maintenance = None

        self._server_rows_revealer = Gtk.Revealer()
        server_rows_container = Gtk.Box(orientation=Gtk.Orientation.VERTICAL)
        server_rows_container.set_spacing(10)
        self._server_rows_revealer.set_margin_top(15)
        self._server_rows_revealer.set_child(server_rows_container)

        ordered_servers = []
        if is_free_user:
            ordered_servers.extend(free_servers)
            ordered_servers.extend(plus_servers)
        else:
            ordered_servers.extend(plus_servers)
            ordered_servers.extend(free_servers)

        analysis = _analyze_servers(ordered_servers, connected_server_id)

        self._under_maintenance = analysis.under_maintenance  # noqa: E501 # pylint: disable=line-too-long # nosemgrep: python.lang.maintainability.is-function-without-parentheses.is-function-without-parentheses
        self._is_free_country = analysis.is_free_country  # noqa: E501 # pylint: disable=line-too-long # nosemgrep: python.lang.maintainability.is-function-without-parentheses.is-function-without-parentheses
        self._country_features = analysis.country_features  # noqa: E501 # pylint: disable=line-too-long # nosemgrep: python.lang.maintainability.is-function-without-parentheses.is-function-without-parentheses
        self._connected_server_id = connected_server_id

        def add_servers_to_country():
            for server in ordered_servers:
                server_row = ServerRow(
                    server=server,
                    user_tier=user_tier,
                    controller=self._controller
                )
                server_rows_container.append(server_row)

                self._indexed_server_rows[server.id] = server_row

                # If we are currently connected to a server then set its row
                # state to "connected".
                #
                # We use self._connected_server_id instead of
                # connected_server_id because there's a chance it might change
                # before this function is called.
                if self._connected_server_id == server.id:
                    server_row.connection_state = ConnectionStateEnum.CONNECTED

        self._add_servers_to_country = add_servers_to_country

        self._upgrade_required = is_free_user and not self._is_free_country

        self._country_header = CountryHeader(
            country=country,
            under_maintenance=self._under_maintenance,
            upgrade_required=self._upgrade_required,
            server_features=self._country_features,
            smart_routing=analysis.smart_routing_country,
            connection_state=analysis.country_connection_state,
            controller=controller,
            show_country_servers=show_country_servers
        )
        self._toggle_signal_handler_id = self._country_header.connect(
            "toggle-country-servers", self._on_toggle_country_servers
        )

        self.append(self._country_header)
        self.append(self._server_rows_revealer)

        if show_country_servers:
            self._server_rows_revealer.set_reveal_child(True)

    def _generate_servers_if_needed(self, country_header: CountryHeader):
        if country_header.show_country_servers:
            if self._add_servers_to_country:
                self._add_servers_to_country()
                self._add_servers_to_country = None

    def toggle_row(self):
        """Toggles the view of the children of the country row."""
        self._country_header.click_toggle_country_servers_button()

    @property
    def country_code(self):
        """Returns the code of the country. A short unique string"""
        return self._country_header.country_code

    @property
    def country_name(self):
        """Returns the name of the country.
        This method was made available for tests."""
        return self._country_header.country_name

    @property
    def upgrade_required(self):
        """Returns True if this country is not in the currently logged-in
        user tier, and therefore it requires a plan upgrade. Otherwise, it
        returns False."""
        return self._upgrade_required

    @property
    def is_free_country(self) -> bool:
        """Returns True if this country has any servers available to
        users with a free account. Otherwise, it returns False."""
        return self._is_free_country

    @property
    def showing_servers(self):
        """Returns True if the servers are being showed and False otherwise.
        This method was made available for tests."""
        return self._server_rows_revealer.get_reveal_child()

    def click_toggle_country_servers_button(self):
        """
        Clicks the button to toggle the visibility of the country servers.
        This method was made available for tests.
        """
        self._country_header.click_toggle_country_servers_button()

    def grab_focus(self):  # pylint: disable=arguments-differ
        """Focuses on the country row."""
        self._country_header.grab_focus()

    @property
    def server_rows(self) -> List[ServerRow]:
        """Returns the list of server rows for this server.
        This method was made available for tests."""
        server_rows = []
        revealer_child = self._server_rows_revealer.get_child()
        if revealer_child:
            child = revealer_child.get_first_child()
            while child:
                server_rows.append(child)
                child = child.get_next_sibling()
        return server_rows

    @property
    def connection_state(self):
        """Returns the connection state for this row."""
        return self._country_header.connection_state

    @property
    def header_searchable_content(self) -> str:
        """Returns the normalized searchable content for the country header."""
        return normalize(self.country_name)

    @staticmethod
    def _group_servers_by_tier(country_servers) -> Tuple[List[LogicalServer]]:
        free_servers = []
        plus_servers = []
        for server in country_servers:
            if server.tier == 0:
                free_servers.append(server)
            else:
                plus_servers.append(server)

        return free_servers, plus_servers

    def _on_toggle_country_servers(self, country_header: CountryHeader):
        self._server_rows_revealer.set_reveal_child(
            country_header.show_country_servers
        )
        self._generate_servers_if_needed(country_header)

    def set_servers_visibility(self, visible: bool):
        """Country servers will be shown if set to True. Otherwise, they'll be hidden."""
        self._country_header.show_country_servers = visible
        self._server_rows_revealer.set_reveal_child(visible)

    def connection_status_update(self, connection_state):
        """This method is called by VPNWidget whenever the VPN connection status changes."""
        self._country_header.connection_state = connection_state.type
        server_id = connection_state.context.connection.server_id
        server = self._indexed_server_rows.get(server_id, None)
        if server:
            server.connection_state = connection_state.type

        # maintain connected server id only when connected
        if self._controller.is_connection_active:  # noqa: E501 # pylint: disable=line-too-long # nosemgrep: python.lang.maintainability.is-function-without-parentheses.is-function-without-parentheses
            self._connected_server_id = server_id
        else:
            self._connected_server_id = None

    def click_connect_button(self):
        """Clicks the button to connect to the country.
        This method was made available for tests."""
        self._country_header.click_connect_button()

    def update_server_loads(self, new_country: Country):
        """Refreshes the UI after new server loads were retrieved."""
        # Start by setting the country under maintenance until the opposite is proven.
        for server_row in self._indexed_server_rows.values():
            server_row.update_server_load()

        if new_country is not None:
            self._under_maintenance = _analyze_servers(
                new_country.servers,
                self._connected_server_id
            ).under_maintenance

            self._country_header.update_under_maintenance_status(
                self._under_maintenance
            )

    def cleanup(self):
        """Clean up signal connections and references to allow garbage collection."""
        # Disconnect country header signal
        self._country_header.disconnect(self._toggle_signal_handler_id)

        # Clean up country header's signal connections
        self._country_header.cleanup()
        self._country_header = None

        # Clean up all server rows and their signal connections
        for server_row in self._indexed_server_rows.values():
            server_row.cleanup()
        self._indexed_server_rows.clear()
        self._indexed_server_rows = {}

        # Clear the deferred server creation function and its captured variables
        self._add_servers_to_country = None

        # Clear other references
        # Remove child to break widget tree references
        child = self._server_rows_revealer.get_child()
        if child:
            self._server_rows_revealer.set_child(None)
        self._server_rows_revealer = None

        # Clear controller reference
        self._controller = None
