"""Convenience classes for paramiko"""
import stat
from typing import List
import paramiko
from paramiko.channel import Channel


class SSHError(paramiko.ssh_exception.SSHException):
    """Raised by ``SSHConn`` on errors"""


class CmdResult:
    """Return type for ``SSHConn.run()``

    Attributes:
        stdout (str)
        stderr (str)
        rcode (int)
    """

    def __init__(self, conn, cmd: str):
        chan: Channel
        chan = conn.get_transport().open_session()
        chan.set_environment_variable('shell', 'xterm')
        chan.exec_command(cmd)
        stdout = chan.makefile('rb', -1)
        stderr = chan.makefile_stderr('rb', -1)
        self.stdout = stdout.read()
        self.stderr = stderr.read()
        if self.stderr.startswith(b'stdin: is not a tty\n'):
            # this is caused by 'mesg y' in /etc/bashrc. Using a pty would
            # avoid that, but then mix stdout and stderr which is not wanted.
            self.stderr = self.stderr[20:]
        self.stdout = str(self.stdout, 'utf-8', errors='surrogateescape')
        self.stderr = str(self.stderr, 'utf-8', errors='surrogateescape')
        self.rcode = chan.recv_exit_status()
        chan.close()


class SSHConn(paramiko.SSHClient):
    """Context manager that handles SSH connections

    **Arguments are the same as paramiko.SSHClient.connect, but with one extra
    argument, "sftp"**

    Args:
        sftp (bool): whether to enable SFTP and listdir
            features on connect
    Raises:
        SSHError: on failure to connect
    Attributes:
        sftp: A paramiko.SFTPClient object, or None
    """

    def __init__(self, sftp: bool, **kwargs):
        """Args for the context manager are read from here.
        If you don't need sftp functionality, set sftp=False"""
        super().__init__()
        self.set_missing_host_key_policy(paramiko.AutoAddPolicy())
        self._kwargs = kwargs
        self._use_sftp = sftp
        self.sftp = None

    def __enter__(self):
        """Connects to SSH using the context manager 'with' statement"""
        try:
            self.connect(**self._kwargs)
            if self._use_sftp:
                self.sftp = self.open_sftp()
        except Exception as exc:
            raise SSHError(exc) from exc
        return self

    def __exit__(self, exc_type, exc_val, traceback):
        """Disconnects from SSH on un-indent, even if an exception is raised"""
        if self._use_sftp:
            self.sftp.close()
        self.close()

    def listdir(self, path: str) -> List[str]:
        """List contents of a remote dir similar to os.listdir()

        Args:
            path: remote path to browse
        Raises:
            SSHError
        """
        return self._listdir(path, _type=None)

    def ls_dirs(self, path: str) -> List[str]:
        """listdir() but only return directories

        Args:
            path: remote path to browse
        Raises:
            SSHError
        """
        return self._listdir(path, _type=stat.S_ISDIR)

    def ls_files(self, path: str) -> List[str]:
        """listdir() but only return regular files

        Args:
            path: remote path to browse
        Raises:
            SSHError
        """
        return self._listdir(path, _type=stat.S_ISREG)

    def _listdir(self, path, _type):
        try:
            items = []
            for lstat in self.sftp.listdir_iter(path):
                if _type is None or _type(lstat.st_mode):
                    items.append(lstat.filename)
            return items
        except Exception as exc:
            raise SSHError(exc) from exc

    def run(self, cmd: str) -> CmdResult:
        """Execute a remote command over SSH

        Args:
            cmd: command to run
        Raises:
            SSHError
        """
        try:
            return CmdResult(self, cmd)
        except Exception as exc:
            raise SSHError(exc) from exc
