#!/usr/bin/python3.13
# -*- coding: utf-8 -*-
#
#       lalpulsar_knope_collate_results.py
#
#       Copyright 2013, 2016
#       Matthew Pitkin <matthew.pitkin@ligo.org>,
#
#       This program is free software; you can redistribute it and/or modify
#       it under the terms of the GNU General Public License as published by
#       the Free Software Foundation; either version 2 of the License, or
#       (at your option) any later version.
#
#       This program is distributed in the hope that it will be useful,
#       but WITHOUT ANY WARRANTY; without even the implied warranty of
#       MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#       GNU General Public License for more details.
#
#       You should have received a copy of the GNU General Public License
#       along with this program; if not, write to the Free Software
#       Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
#       MA 02110-1301, USA.

from __future__ import print_function, division

# standard library imports
import argparse
import ast
import json
import datetime
import os
import re
import sys
from configparser import ConfigParser

# related third party imports
import numpy as np

import matplotlib

matplotlib.use("Agg")

# local application/library specific imports
from lalpulsar import git_version
from lalpulsar.pulsarpputils import *
from lalpulsar.pulsarhtmlutils import *

__author__ = "Matthew Pitkin <matthew.pitkin@ligo.org>"
__version__ = "git id {}".format(git_version.id)
__date__ = git_version.date


# create format for the output page
htmlpage = """
<!DOCTYPE html>
<html lang="en">

<head>
  <meta http-equiv="Content-Type" content="text/html; charset=utf-8"/>
  <meta name="description" content="Known pulsar search results"/>
  <meta charset="UTF-8">

  <title>Known pulsar search results</title>

  <link rel="stylesheet" type="text/css" href="{cssfile}"/>
</head>

<body>

<h1>Results table {nsources}</h1>

<!-- The table of pulsars and results -->
<div class="tablediv">
{resultstable}
</div>
<br />

<!-- Footnotes from the table -->
{footnotes}
<br />

<!-- a footer -->
<div id="footer">
{footer}
</div>

</body>
</html>
"""


# main function
if __name__ == "__main__":
    description = """This script will collate results pages from multiple pulsars into a signal table."""
    epilog = """An example configuration file could contain the following:

# a section containing output information
[output]
path = path_to_output_directory  # the path to the output directory to contain the page [required]

# a section containing input directory information
[input]
path = path_to_input_directory   # the path to the directory in which the individual results page directories live [required]

# a section containing general information for the output table
[general]
detectors = ['H1', 'L1', 'Joint']   # a list of detctors (including 'Joint') whose results will be output
sort_value = name                   # the pulsar parameter on which to sort the results (default: 'name') [Allowed: 'name', 'freq', 'ra', 'dec', 'sdlim', 'hul', 'ell', 'sdrat', 'dist']
sort_direction = ascending          # sort in ascending or decending order (default: 'ascending')
results = ['h0ul', 'ell']            # a list of result parameters to output (default: ['h0ul'] - the h0 upper limit) [Allowed: 'h0ul', 'ell', 'sdrat', 'q22', 'bsn', 'bci', 'bcin', 'C21ul', 'C22ul']
parameters = ['f0rot', 'ra', 'dec'] # a list of pulsar parameters to output (default: ['f0rot'] - the pulsar's rotation frequency) [Allowed: 'f0rot', 'f0gw', 'f1rot', 'f1gw', 'sdlim', 'ra', 'dec', 'dist']
"""

    parser = argparse.ArgumentParser(
        description=description,
        epilog=epilog,
        formatter_class=argparse.RawDescriptionHelpFormatter,
    )
    parser.add_argument("inifile", help="The configuration (.ini) file")

    # parse input options
    opts = parser.parse_args()

    inifile = opts.inifile

    cp = ConfigParser()
    try:
        cp.read(inifile)
    except:
        print(
            "Error... cannot parse configuration file '%s'" % inifile, file=sys.stderr
        )
        sys.exit(-1)

    # get the output path
    try:
        outpath = cp.get("output", "path")
    except:
        print(
            "Error... no output directory 'path' specified in [output] section.",
            file=sys.stderr,
        )
        sys.exit(1)

    # try making directory
    if not os.access(outpath, os.W_OK) and not os.path.isdir(
        outpath
    ):  # check if directory already exists
        try:
            os.makedirs(outpath)
        except:
            print(
                "Error... cannot make output directory '%s'." % outpath, file=sys.stderr
            )
            sys.exit(1)

    # get the directory containing all the individual pulsar result directories
    try:
        inpath = cp.get("input", "path")
    except:
        print(
            "Error... no input directory 'path' specified in [input] section.",
            file=sys.stderr,
        )
        sys.exit(1)

    if not os.path.isdir(inpath):
        print("Error... input path '%s' does not exist." % inpath, file=sys.stderr)
        sys.exit(1)

    # check how to sort the results (e.g. by name)
    try:
        sorttype = cp.get("general", "sort_value")
    except:
        sorttype = "name"  # default to sort by pulsar name

    # sorting can be by name (equivalently right ascension), frequency, declination, spin-down limit, h0 upper limit, ellipticity upper limit, spin-down ratio or distance
    if sorttype not in [
        "name",
        "freq",
        "ra",
        "dec",
        "sdlim",
        "h0ul",
        "ell",
        "sdrat",
        "dist",
    ]:
        print(
            "Error... sorting can only be on either 'name', 'freq', 'ra', 'dec', 'sdlim', 'hul', 'ell', 'sdrat' or 'dist'.",
            file=sys.stderr,
        )
        sys.exit(1)

    # get the direction of the sorting ('ascending' or 'descending')
    try:
        sortdirection = cp.get("general", "sort_direction")
        if sortdirection == "ascending":
            reverse = False
        elif sortdirection == "descending":
            reverse = True
        else:  # default to ascending
            reverse = False
    except:
        reverse = False  # default to ascending

    # get list of detectors (and/or 'Joint') to use in the output table
    try:
        ifos = ast.literal_eval(cp.get("general", "detectors"))
    except:
        print("Error... could not parse list of detectors to include.", file=sys.stderr)
        sys.exit(1)

    # get list of value to output (e.g. ['h0ul', 'ell'] for the h0 upper limit and ellipticity limit
    try:
        outputlims = [
            ol.upper() for ol in ast.literal_eval(cp.get("general", "results"))
        ]  # covert to upper case
    except:  # as default output the h0 upper limit
        outputlims = ["H0UL"]

    # get a list of the pulsar parameters to output
    try:
        outputvals = [
            ov.upper() for ov in ast.literal_eval(cp.get("general", "parameters"))
        ]  # convert to upper case
    except:  # as default output the rotational frequency
        outputvals = ["F0ROT"]

    # dictionary for output html and LaTeX pages
    htmlinput = {}
    latexinput = {}

    # get directories containing results pages
    resdirs = os.listdir(inpath)
    if len(resdirs) == 0:
        print(
            "Error... the input directory '%s' contains no other directories." % inpath,
            file=sys.stderr,
        )
        sys.exit(1)

    resultsdata = {}  # dictionary to hold results data
    sourcedirs = [
        os.path.join(inpath, rd)
        for rd in resdirs
        if os.path.isfile(os.path.join(inpath, rd, "{}.ini".format(rd)))
    ]
    totalsources = len(sourcedirs)
    cursources = 0  # currently number of completed sources
    for d in sourcedirs:
        ld = os.listdir(d)
        jsonfile = None
        for fd in ld:
            if ".json" == os.path.splitext(fd)[-1]:
                jsonfile = os.path.join(d, fd)
                break
        if jsonfile == None:  # no file found, so move on to next directory
            continue

        # read in json file
        try:
            fp = open(jsonfile, "r")
            pdata = json.load(fp)
            fp.close()
        except:
            print(
                "Warning... could not read JSON file '%s'. Skipping this directory."
                % jsonfile
            )
            continue

        # check dictionary contains a 'Pulsar data' key
        if "Pulsar data" not in pdata:
            print(
                "Warning... no 'Pulsar data' field in JSON file '%s'. Skipping this directory."
                % jsonfile
            )
            continue

        # check dictionary contains 'PSR' key
        if "PSR" not in pdata:
            print(
                "Warning... no 'PSR' pulsar name field in JSON file '%s'. Skipping this directory."
                % jsonfile
            )
            continue

        # check that the request detectors are within the dictionary
        ifopresent = True
        for ifo in ifos:
            if ifo not in pdata:
                print(
                    "Warning... no key for detector '%s' in JSON file '%s'. Skipping this directory."
                    % (ifo, jsonfile)
                )
                ifopresent = False
                break
        if not ifopresent:
            continue

        # add data into dictionary
        resultsdata[pdata["PSR"]] = pdata
        resultsdata[pdata["PSR"]]["path"] = os.path.relpath(
            d, outpath
        )  # relative path to result directory
        cursources += 1

    if len(resultsdata) == 0:
        print("Error... no reults were found!", file=sys.stderr)
        sys.exit(1)

    # perform sorting
    if "freq" == sorttype:  # sort by frequency
        sortlist = [
            (resultsdata[pname]["Pulsar data"]["F0"], pname) for pname in resultsdata
        ]
    elif "dec" == sorttype:  # sort by declination
        sortlist = [
            (resultsdata[pname]["Pulsar data"]["DEC"], pname) for pname in resultsdata
        ]
    elif "dist" == sorttype:  # sort by distance
        sortlist = [
            (resultsdata[pname]["Pulsar data"]["DIST"], pname) for pname in resultsdata
        ]
    elif "sdlim" == sorttype:  # sort by spin-down limit
        sortlist = [
            (resultsdata[pname]["Pulsar data"]["spin-down limit"], pname)
            for pname in resultsdata
        ]
    elif (
        "h0ul" == sorttype
    ):  # sort on h0 upper limit (get minimum h0 upper limit from all requested IFOs)
        sortlist = [
            (
                min([resultsdata[pname][ifo]["Upper limits"]["H0"] for ifo in ifos]),
                pname,
            )
            for pname in resultsdata
        ]
    elif (
        "ell" == sorttype
    ):  # sort on ellipticity upper limit (get minimum upper limit from all requested IFOs)
        sortlist = [
            (
                min([resultsdata[pname][ifo]["Upper limits"]["ELL"] for ifo in ifos]),
                pname,
            )
            for pname in resultsdata
        ]
    elif (
        "sdrat" == sorttype
    ):  # sort on ratio of result to spin-down limit (get minimum upper limit from all requested IFOs)
        sortlist = [
            (
                min(
                    [
                        resultsdata[pname][ifo]["Upper limits"]["spin-down ratio"]
                        for ifo in ifos
                    ]
                ),
                pname,
            )
            for pname in resultsdata
        ]
    else:  # sort by name (equivalent to sorting by RA) by default
        sortlist = [(pname, pname) for pname in resultsdata]

    sortedlist = [
        p[1] for p in sorted(sortlist, reverse=reverse)
    ]  # return a list of sorted names

    nifos = len(ifos)  # number of IFOs (including "Joint")
    numprepars = 1 + len(outputvals)  # number of parameters in start of table
    numlims = len(outputlims)  # number of limit parameter to output

    # check if outputing coherent vs incoherent Bayes factors (Bci or Bcin) - these will only be used for if 'Joint' is present in the detectors
    showbci = showbcin = 0
    if "BCI" in outputlims:
        if "Joint" in ifos:
            showbci = 1
        outputlims.pop(outputlims.index("BCI"))
        numlims -= 1
    if "BCIN" in outputlims:
        if "Joint" in ifos:
            showbcin = 1
        outputlims.pop(outputlims.index("BCIN"))
        numlims -= 1

    # start creating the results table
    restable = htmltable()
    ltable = latextable(ncolumns=(numprepars + numlims * nifos + showbci + showbcin))
    ltable.set_caption("Limits on the gravitational wave amplitude for known pulsars")

    # first row
    restable.addrow()
    ltable.addrow(underline=True)
    restable.adddata("", dataclass="bottomborder", header=True, colspan=numprepars)
    ltable.adddata("~", multicolumn=numprepars)
    for ifo in ifos:
        ncolspan = numlims
        if ifo == "Joint":
            ncolspan += showbci + showbcin
        restable.adddata(
            ifo,
            dataclass=ifo,
            header=True,
            datastyle="text-align:center; border-left:1px solid #000; border-bottom:1px solid #000",
            colspan=ncolspan,
        )
        ltable.adddata(ifo, multicolumn=ncolspan)

    # second row
    restable.addrow()
    ltable.addrow(underline=True)
    restable.adddata("Pulsar", dataclass="bottomborder", header=True)
    ltable.adddata("Pulsar")
    for prepar in outputvals:
        restable.adddata(paramhtmldict[prepar], dataclass="bottomborder", header=True)
        ltable.adddata(paramlatexdict[prepar])

    for ifo in ifos:
        dataclass = "leftborder"  # class for first value
        datastyle = "border-bottom:1px solid #000"  # style for first value
        for ol in outputlims:
            if ol[-2:] == "UL":  # if value is an upper limit add credible region
                cr = list(resultsdata.values())[0][ifo]["Upper limits"][
                    "credible region"
                ]
                restable.adddata(
                    paramhtmldict[ol].format(cr),
                    dataclass=dataclass,
                    datastyle=datastyle,
                    header=True,
                )
                ltable.adddata(paramlatexdict[ol].format(cr))
            else:
                restable.adddata(
                    paramhtmldict[ol],
                    dataclass=dataclass,
                    datastyle=datastyle,
                    header=True,
                )
                ltable.adddata(paramlatexdict[ol])
            dataclass = "bottomborder"
            datastyle = ""
        if (
            ifo == "Joint"
        ):  # check whether to show additional coherent vs incoherent Bayes factor values
            if showbci:
                restable.adddata(
                    paramhtmldict["BCI"],
                    dataclass=dataclass,
                    datastyle=datastyle,
                    header=True,
                )
                ltable.adddata(paramlatexdict["BCI"])
            if showbcin:
                restable.adddata(
                    paramhtmldict["BCIN"],
                    dataclass=dataclass,
                    datastyle=datastyle,
                    header=True,
                )
                ltable.adddata(paramlatexdict["BCIN"])

    # a dictionary to convert between parameter names
    convdict = {
        "RA": "RA",
        "DEC": "DEC",
        "H0UL": "H0",
        "C21UL": "C21",
        "C22UL": "C22",
        "I21UL": "I21",
        "I31UL": "I31",
        "SDLIM": "spin-down limit",
        "SDRAT": "spin-down ratio",
        "BSN": "Signal vs Noise",
        "BCI": "Coherent vs Incoherent",
        "BCIN": "Coherent vs Incoherent or Noise",
    }

    # checks for whegther footnotes are required
    dagger = False
    ddagger = False

    # loop through par files to produce plots
    for pname in sortedlist:
        restable.addrow()
        ltable.addrow()

        pulsar = resultsdata[pname]  # the dictionary of information for this pulsar

        # output pulsar names and pulsar parameters (include link to results page)
        restable.adddata(
            atag(os.path.join(pulsar["path"], pname + ".html"), linktext=pname).text
        )
        ltable.adddata(
            re.sub(r"\-", "\\\\textminus", pname)
        )  # substitute - ('hyphen') for \textminus in names

        for prepar in outputvals:
            htmlsdtag = ""
            latexsdtag = ""

            # set footnotes for pulsar's that have had a "corrected" spin-down
            if "SDLIM" in prepar:
                if (
                    pulsar["Pulsar data"]["P1_I"] != None
                ):  # spin-down has been corrected for intrinsic motion effects
                    if (
                        pulsar["Pulsar data"]["P1_I"] > 0.0
                    ):  # double check that intrinsic period derivative is positive
                        htmlsdtag = htmltag("sup", tagtext="&dagger;").text
                        latexsdtag = r"$\dagger$"
                        dagger = True
                elif pulsar["Pulsar data"]["ASSOC"] != None:
                    if "GC" in pulsar["Pulsar data"]["ASSOC"]:
                        htmlsdtag = htmltag("sup", tagtext="&Dagger;").text
                        latexsdtag = r"$\ddagger$"
                        ddagger = True

            if prepar in convdict:
                pn = convdict[prepar]
            else:
                pn = prepar
            if pn in pulsar["Pulsar data"]:
                preval = pulsar["Pulsar data"][pn]
                if preval == None:
                    prevalhtml = "*"
                    prevallatex = "*"
                else:
                    disphtmlfunc = paramhtmldispfunc.__dict__[prepar]
                    displatexfunc = paramlatexdispfunc.__dict__[prepar]
                    prevalhtml = disphtmlfunc(preval)
                    prevallatex = displatexfunc(preval)
            else:
                prevalhtml = "*"
                prevallatex = "*"

            restable.adddata(prevalhtml + htmlsdtag)
            ltable.adddata(prevallatex + latexsdtag)

        for ifo in ifos:
            dataclass = "leftborder"

            for limpar in outputlims:
                if limpar in convdict:
                    ln = convdict[limpar]
                else:
                    ln = limpar

                section = "Upper limits"  # generally values will be in the 'Upper limits' section
                if (
                    limpar == "BSN"
                ):  # if getting the Bayes factor look in 'Bayes factors' section
                    section = "Bayes factors"

                if ln in pulsar[ifo][section]:
                    limval = pulsar[ifo][section][ln]

                    if limval == None:
                        limvalhtml = "*"
                        limvallatex = "*"
                    else:
                        if limpar == "BSN":  # convert to log base 10
                            limval = limval / np.log(10.0)
                        disphtmlfunc = paramhtmldispfunc.__dict__[limpar]
                        displatexfunc = paramlatexdispfunc.__dict__[limpar]
                        limvalhtml = disphtmlfunc(limval)
                        limvallatex = displatexfunc(limval)
                else:
                    limvalhtml = "*"
                    limvallatex = "*"

                restable.adddata(limvalhtml, dataclass=dataclass)
                dataclass = ""
                ltable.adddata(limvallatex)

            # check if needing to output coherent vs incoherent Bayes factors
            if ifo == "Joint":
                for bu, show in [("BCI", showbci), ("BCIN", showbcin)]:
                    if show:
                        bn = convdict[bu]
                        if bn in pulsar[ifo]["Bayes factors"]:
                            bval = pulsar[ifo]["Bayes factors"][bn]
                            if bval == None:
                                bvalhtml = "*"
                                bvallatex = "*"
                            else:
                                bval = bval / np.log(10.0)
                                disphtmlfunc = paramhtmldispfunc.__dict__[bu]
                                displatexfunc = paramlatexdispfunc.__dict__[bu]
                                bvalhtml = disphtmlfunc(bval)
                                bvallatex = displatexfunc(bval)
                        else:
                            bvalhtml = "*"
                            bvallatex = "*"

                        restable.adddata(bvalhtml, dataclass=dataclass)
                        dataclass = ""
                        ltable.adddata(bvallatex)

    htmlinput["nsources"] = "(%d/%d sources)" % (cursources, totalsources)
    htmlinput["resultstable"] = restable.tabletext

    # add footnotes
    if dagger or ddagger:
        htmlfootnotes = htmltag("div", tagclass="footnotes")
        if dagger:
            htmlfootnotes += htmltag("sup", tagtext="&dagger;").text
            htmlfootnotes += " The spin-down limit was calculated using a spin-down "
            htmlfootnotes += atag(
                "http://www.atnf.csiro.au/research/pulsar/psrcat/psrcat_help.html?type=normal&highlight=p1_i#p1_i",
                linktext="corrected",
            ).text
            htmlfootnotes += " for proper motion effects.\n"
            htmlfootnotes += "<br />\n"
            ltable.set_postamble(
                "$\\dagger$ The pulsar's spin-down is corrected for proper motion effects. \\\\\n"
            )
        if ddagger:
            htmlfootnotes += htmltag("sup", tagtext="&ddagger;").text
            htmlfootnotes += " The spin-down limit was calculated using a characteristic spin-down age of 10<sup>9</sup> years.\n"
            ltable.set_postamble(
                ltable.postamble
                + "$\\ddagger$ The pulsar's spin-down is calculated using a characteristic spin-down age of $10^9$ years.\n"
            )
        htmlinput["footnotes"] = htmlfootnotes.text
    else:
        htmlinput["footnotes"] = ""

    # create CSS
    cssfile = os.path.join(outpath, "table.css")
    fp = open(cssfile, "w")
    fp.write(results_table_css)
    fp.close()

    htmlinput["cssfile"] = os.path.basename(cssfile)

    # get time/date for file creation
    now = datetime.datetime.now()

    # add footer containing author, date and command lines used for page
    htmlinput["footer"] = "{} - {}<br><br>Command lines used:<br>{}<br>{}<br>".format(
        __author__, now.strftime("%a %d %b %Y"), " ".join(sys.argv), __version__
    )

    # create page
    try:
        htmlfile = os.path.join(outpath, "index.html")
        fp = open(htmlfile, "w")
        fp.write(htmlpage.format(**htmlinput))
        fp.close()
    except:
        print("Error... there was a problem outputting the html page.", file=sys.stderr)
        sys.exit(1)

    # output the LaTeX table
    try:
        latexfile = os.path.join(outpath, "resultstable.tex")
        fp = open(latexfile, "w")
        fp.write(ltable.tabletext)
        fp.close()
    except:
        print(
            "Error... there was a problem outputting the LaTeX table.", file=sys.stderr
        )
        sys.exit(1)
