"""Wrappers around ARC client commands and related utilities."""

from datetime import datetime
import logging
import os
import re
import time
from typing import \
        Any, Callable, Dict, Generic, List, NewType, Optional, Tuple, TypeVar
from subprocess import Popen, PIPE, CalledProcessError
from arcnagios.utils import map_option, host_of_uri
from arcnagios.utils import Result, ResultOk, ResultError
from arcnagios.nagutils import NagiosPerflog

Alpha = TypeVar("Alpha")

class ParseError(Exception):
    """Exception raised on unrecognized command output."""

# Job States
#

JobStage = NewType("JobStage", int)

S_UNKNOWN   = JobStage(0)
S_ENTRY     = JobStage(1)
S_PRERUN    = JobStage(2)
S_INLRMS    = JobStage(3)
S_POSTRUN   = JobStage(4)
S_FINAL     = JobStage(5)

class JobState:

    def __init__(self, name: str, stage: JobStage):
        self.name = name
        self.stage = stage

    def __str__(self) -> str:
        return self.name

    def is_final(self) -> bool:
        return self.stage == S_FINAL

class InlrmsJobState(JobState):

    def __init__(self, name: str):
        JobState.__init__(self, name, S_INLRMS)

class PendingJobState(JobState):

    def __init__(self, name: str, stage: JobStage, pending: JobState):
        JobState.__init__(self, name, stage)
        self.pending = pending

_JOB_STATE_OF_STR = {}

def _add_job_state(name: str, stage: JobStage = S_UNKNOWN) -> JobState:
    job_state: JobState
    if name.startswith('INLRMS:'):
        assert stage in (S_UNKNOWN, S_INLRMS)
        job_state = InlrmsJobState(name)
    elif name.startswith('PENDING:'):
        pending = jobstate_of_str(name[8:])
        job_state = PendingJobState(name, stage, pending)
    else:
        job_state = JobState(name, stage)
    _JOB_STATE_OF_STR[name] = job_state
    return job_state

def jobstate_of_str(name: str) -> JobState:
    if not name in _JOB_STATE_OF_STR:
        _JOB_STATE_OF_STR[name] = _add_job_state(name)
    return _JOB_STATE_OF_STR[name]

J_NOT_SEEN   = _add_job_state("NOT_SEEN",   stage = S_ENTRY)
J_ACCEPTED   = _add_job_state("Accepted",   stage = S_ENTRY)
J_PREPARING  = _add_job_state("Preparing",  stage = S_PRERUN)
J_SUBMITTING = _add_job_state("Submitting", stage = S_PRERUN)
J_HOLD       = _add_job_state("Hold",       stage = S_PRERUN)
J_QUEUING    = _add_job_state("Queuing",    stage = S_INLRMS)
J_RUNNING    = _add_job_state("Running",    stage = S_INLRMS)
J_FINISHING  = _add_job_state("Finishing",  stage = S_POSTRUN)
J_FINISHED   = _add_job_state("Finished",   stage = S_FINAL)
J_KILLED     = _add_job_state("Killed",     stage = S_FINAL)
J_FAILED     = _add_job_state("Failed",     stage = S_FINAL)
J_DELETED    = _add_job_state("Deleted",    stage = S_FINAL)
J_UNDEFINED  = _add_job_state("Undefined",  stage = S_UNKNOWN)
J_OTHER      = _add_job_state("Other",      stage = S_UNKNOWN)


# ARC Commands
#

def _time_arg(x: float) -> str:
    return str(int(x + 0.5))

class Arcstat:
    def __init__(
            self, *,
            state: JobState, specific_state: str,
            submitted: Optional[str],
            job_error: Optional[str],
            exit_code: Optional[int]):
        self.state = state
        self.specific_state = specific_state
        self.submitted = submitted
        self.job_error = job_error
        self.exit_code = exit_code

def arcstat(
        jobids: Optional[List[int]] = None,
        log: Optional[logging.Logger] = None,
        timeout: int = 5,
        show_unavailable: bool = False) -> Dict[str, Arcstat]:

    cmd = ['arcstat', '-l', '--timeout', str(timeout)]
    if jobids is None:
        cmd.append('-a')
    else:
        cmd.extend(map(str, jobids))
    if show_unavailable:
        cmd.append('-u')

    with Popen(cmd, stdout = PIPE, encoding = 'utf-8') as process:
        jobstats = {}
        line_number = 0

        def parse_error(msg):
            if log:
                log.error('Unexpected output from arcstat at line %d: %s'
                          % (line_number, msg))
            else:
                raise ParseError('Unexpected output from arcstat at line %d: %s'
                                 % (line_number, msg))
        def convert(jobid, jobstat):
            if jobstat['State'] == 'Undefined':
                state = J_UNDEFINED
                specific_state = None
            elif 'Specific state' in jobstat:
                state = jobstate_of_str(jobstat['State'])
                specific_state = jobstat['Specific state']
            else:
                raise ParseError('Missing "State" or "Specific state" for %s.'
                                 % jobid)
            return Arcstat(state = state, specific_state = specific_state,
                           exit_code = map_option(int, jobstat.get('Exit code')),
                           submitted = jobstat.get('Submitted'),
                           job_error = jobstat.get('Job Error'))

        jobid: Optional[str] = None
        jobfield: Optional[str] = None
        jobstat: Dict[str, str] = {}
        assert process.stdout
        for line in process.stdout:
            line_number += 1
            if line.endswith('\n'):
                line = line[0:-1]

            if line.startswith('No jobs') or line.startswith('Status of '):
                break

            if line.startswith('Job:'):
                if not jobid is None:
                    jobstats[jobid] = convert(jobid, jobstat)
                jobid = line[4:].strip()
                jobstat = {}
                jobfield = None
            elif line.startswith('Warning:'):
                if log:
                    log.warning(line)
            elif line == '':
                pass
            elif line.startswith('  '):
                if jobfield is None:
                    parse_error('Continuation line %r before job field.')
                    continue
                jobstat[jobfield] += '\n' + line
            elif line.startswith(' '):
                kv = line.strip()
                try:
                    jobfield, v = kv.split(':', 1)
                    if jobid is None:
                        parse_error('Missing "Job: ..." header before %r' % kv)
                        continue
                    jobstat[jobfield] = v.strip()
                except ValueError:
                    parse_error('Expecting "<key>: <value>", got %r' % line)
                    continue
            else:
                parse_error('Unrecognized output %r' % line)

    if not jobid is None:
        jobstats[jobid] = convert(jobid, jobstat)
    return jobstats

FileType = NewType('FileType', int)
DIR = FileType(0)
FILE = FileType(1)
NEITHER_DIR_NOR_FILE = FileType(2)

def file_type_of_str(s) -> FileType:
    if s == 'dir':
        return DIR
    if s == 'file':
        return FILE
    return NEITHER_DIR_NOR_FILE

class ArclsEntry:
    def __init__(
            self, *, name: str, type_: FileType, size: int,
            modified: Optional[datetime] = None,
            checksum: Optional[str] = None,
            latency: Optional[str] = None):
        self.filename = name
        self.entry_type = type_
        self.size = size
        self.modified = modified
        self.checksum = checksum
        self.latency = latency

class PerfProcess(Generic[Alpha]):
    program: str

    def __init__(self, args: List[str],
                 perflog: Optional[NagiosPerflog], perfindex: Optional[str]):
        # pylint: disable=R1732
        self._command = [self.program] + list(map(str, args))
        self._perflog = perflog
        self._perfindex = perfindex
        self._start_time = time.time()
        env = os.environ.copy()
        env["TZ"] = "UTC"
        self._popen = \
            Popen(self._command, stdout=PIPE, stderr=PIPE, encoding='utf-8',
                  env=env)
        self._result: Optional[Result[Alpha, CalledProcessError]] = None

    def __enter__(self):
        self._popen = self._popen.__enter__()
        return self

    def __exit__(self, exc_type, exc_value, exc_tb):
        self._popen.__exit__(exc_type, exc_value, exc_tb)

    def _ok(self, stdout: str) \
            -> Result[Alpha, CalledProcessError]:
        return ResultOk(stdout)

    def _error(self, returncode: int, stderr: str) \
            -> Result[Alpha, CalledProcessError]:
        exn = CalledProcessError(returncode, self._command, stderr)
        return ResultError(exn)

    def communicate(self) -> Result[Alpha, CalledProcessError]:
        if self._result is None:
            stdout, stderr = self._popen.communicate()
            returncode = self._popen.returncode
            run_time = time.time() - self._start_time
            if self._perflog:
                label = self.program + '_time'
                if self._perfindex:
                    self._perflog.addi(
                            label, self._perfindex, run_time,
                            uom = 's', limit_min = 0)
                else:
                    self._perflog.add(label, run_time, uio = 's', limit_min = 0)
            if returncode == 0:
                self._result = self._ok(stdout)
            else:
                self._result = self._error(returncode, stderr)
        return self._result

class ArcsubProcess(PerfProcess):
    program = 'arcsub'

    def __init__(
            self,
            jobdesc_files: List[str],
            *,
            cluster: Optional[str] = None,
            jobids_to_file: Optional[str] = None,
            timeout: Optional[float] = None,
            perflog: NagiosPerflog,
            loglevel: Optional[str] = None,
            submissioninterface: Optional[str] = None,
            infointerface: Optional[str] = None):
        args = jobdesc_files
        if cluster:
            args += ['-C', cluster]
            if ':' in cluster:
                perfindex = host_of_uri(cluster)
            else:
                perfindex = cluster
        else:
            perfindex = None
        if jobids_to_file:
            args += ['-o', jobids_to_file]
        if timeout:
            args += ['-t', _time_arg(timeout)]
        if loglevel:
            args += ['-d', loglevel]
        if submissioninterface:
            args += ['-S', submissioninterface]
        if infointerface:
            args += ['-I', infointerface]
        PerfProcess.__init__(self, args, perflog, perfindex)

class ArcgetProcess(PerfProcess):
    program = 'arcget'

    def __init__(
            self, job_id: str,
            top_output_dir: Optional[str] = None,
            timeout: Optional[float] = None,
            perflog: Optional[NagiosPerflog] = None):
        ce_host = host_of_uri(job_id)
        args = [job_id]
        if not timeout is None:
            args += ['-t', _time_arg(timeout)]
        if not top_output_dir is None:
            args += ['-D', top_output_dir]
        PerfProcess.__init__(self, args, perflog, ce_host)

class ArckillProcess(PerfProcess):
    program = 'arckill'

    def __init__(
            self, job_id: str, force: bool = False,
            timeout: Optional[float] = None,
            perflog: Optional[NagiosPerflog] = None):
        # pylint: disable=unused-argument
        ce_host = host_of_uri(job_id)
        args = [job_id]
        if not timeout is None:
            args += ['-t', _time_arg(timeout)]
        PerfProcess.__init__(self, args, perflog, ce_host)

class ArccleanProcess(PerfProcess):
    program = 'arcclean'

    def __init__(
            self, job_id: str, force: bool = False,
            timeout: Optional[float] = None,
            perflog: Optional[NagiosPerflog] = None):
        ce_host = host_of_uri(job_id)
        args = [job_id]
        if not timeout is None:
            args += ['-t', _time_arg(timeout)]
        if force:
            args.append('-f')
        PerfProcess.__init__(self, args, perflog, ce_host)

class ArcrmProcess(PerfProcess):
    program = 'arcrm'

    def __init__(
            self, url: str, force: bool = False,
            timeout: Optional[float] = None,
            perflog: Optional[NagiosPerflog] = None):
        se_host = host_of_uri(url)
        args = [url]
        if not timeout is None:
            args += ['-t', _time_arg(timeout)]
        if force:
            args.append('-f')
        PerfProcess.__init__(self, args, perflog, se_host)

class ArccpProcess(PerfProcess):
    program = 'arccp'

    def __init__(
            self, src_url: str, dst_url: str,
            timeout: Optional[float] = 20, transfer: bool = True,
            perflog: Optional[NagiosPerflog] = None):
        se_host = None
        if ':' in src_url:
            se_host = host_of_uri(src_url)
        elif ':' in dst_url:
            se_host = host_of_uri(dst_url)
        args = [src_url, dst_url]
        if not timeout is None:
            args += ['-t', _time_arg(timeout)]
        if not transfer:
            args.append('-T')
        PerfProcess.__init__(self, args, perflog, se_host)

def _str_or_na(arg: str) -> Optional[str]:
    if arg == '(n/a)':
        return None
    return arg

def _parse_modified(arg: str) -> Optional[datetime]:
    try:
        return datetime.fromisoformat(arg)
    except ValueError:
        return None

# Maps arcls headers to ArclsEntry.__init__ arguments.
_ARCLS_COLUMNS: Dict[str, Tuple[str, Callable[[str], Any]]] = {
    '<Name>': ('name', str),
    '<Type>': ('type_', file_type_of_str),
    '<Size>': ('size', int),
    '<Modified>': ('modified', _parse_modified),
    '<CheckSum>': ('checksum', _str_or_na),
    '<Latency>': ('latency', _str_or_na)
}

_ARCLS_DATE_RE = re.compile(r'(\d{4}-\d{2}-\d{2}) (\d{2}:\d{2}:\d{2})')
_ARCLS_DATE_REPL = r'\1T\2+00:00' # datetime.fromisoformat does not support Z

class ArclsProcess(PerfProcess[List[ArclsEntry]]):
    program = 'arcls'

    def __init__(self, url: str, timeout: Optional[float] = 20,
                 perflog: Optional[NagiosPerflog] = None):
        se_host = host_of_uri(url)
        args = ['-l', url]
        if not timeout is None:
            args += ['-t', _time_arg(timeout)]
        PerfProcess.__init__(self, args, perflog, se_host)

    def _ok(self, stdout: str) -> Result[List[ArclsEntry], CalledProcessError]:
        entries = []
        lines = stdout.split('\n')[:-1]
        header_line = lines[0]
        header_fields = header_line.split()

        columns = []
        for column, header_field in enumerate(header_fields):
            if header_field in _ARCLS_COLUMNS:
                key = _ARCLS_COLUMNS[header_field][0]
                convert = _ARCLS_COLUMNS[header_field][1]
                columns.append((key, column, convert))

        for line in lines[1:]:
            # The <Modified> column contains a space, so convert it to RFC 3339.
            line = _ARCLS_DATE_RE.sub(_ARCLS_DATE_REPL, line)
            # Limit the split in case the first field, which should be the file
            # name, contains spaces.
            fields = line.rsplit(None, len(header_fields) - 1)
            if len(fields) != len(header_fields):
                raise RuntimeError(
                        'Line %r does not match header %r in output from %s'
                        % (line, header_line, ' '.join(self._command)))
            kwargs = dict((key, convert(fields[column]))
                          for key, column, convert in columns)
            entries.append(ArclsEntry(**kwargs))
        return ResultOk(entries)

class ArcClient:

    def __init__(self, perflog: Optional[NagiosPerflog] = None):
        self._perflog = perflog

    def arcsub(self, *args, **kwargs) -> str:
        kwargs['perflog'] = self._perflog
        with ArcsubProcess(*args, **kwargs) as process:
            return process.communicate()

    def arcget(self, *args, **kwargs) -> str:
        kwargs['perflog'] = self._perflog
        with ArcgetProcess(*args, **kwargs) as process:
            return process.communicate()

    def arcrm(self, *args, **kwargs) -> str:
        kwargs['perflog'] = self._perflog
        with ArcrmProcess(*args, **kwargs) as process:
            return process.communicate()

    def arcclean(self, *args, **kwargs) -> str:
        kwargs['perflog'] = self._perflog
        with ArccleanProcess(*args, **kwargs) as process:
            return process.communicate()

    def arckill(self, *args, **kwargs) -> str:
        kwargs['perflog'] = self._perflog
        with ArckillProcess(*args, **kwargs) as process:
            return process.communicate()

    def arccp(self, *args, **kwargs) -> str:
        kwargs['perflog'] = self._perflog
        with ArccpProcess(*args, **kwargs) as process:
            return process.communicate()

    def arcls(self, *args, **kwargs) \
            -> Result[List[ArclsEntry], CalledProcessError]:
        kwargs['perflog'] = self._perflog
        with ArclsProcess(*args, **kwargs) as process:
            return process.communicate()
