#!/usr/bin/env python
#
# Copyright (C) 2009  Larne Pekowsky, based on glitch-page.sh by Duncan
# Brown
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
#


from __future__ import print_function
from optparse import OptionParser
import sqlite3
import sys
import os
import glob
import tempfile

from ligo.segments import segmentlist, segment

from glue.ligolw import ligolw
from glue.ligolw import lsctables
from glue.ligolw import utils
from glue.ligolw.utils import ligolw_sqlite
from glue.ligolw import dbtables

from glue.segmentdb import query_engine
from glue.segmentdb import segmentdb_utils

from glue import gpstime

import glob
import time
import datetime
import StringIO

#
# =============================================================================
#
#                                 Command Line
#
# =============================================================================
#

def parse_command_line():
    """
    Parse the command line, return an options object
    """

    def require(flag, message):
        if not flag:
            print(message, file=sys.stdout)
            sys.exit(1)

    parser = OptionParser(
        version     = "%prog CVS $Header$",
        usage       = "%prog --trigger-dir dir --segments url --html-file file --ifo ifo --timestamp-file [other options]",
        description = "Updates or creates the html with recnet glitch summary information"
	)

    parser.add_option("-t", "--trigger-dir",   metavar="dir",  help = "Location of XML files containing sngl_inspiral tables")
    parser.add_option("-e", "--trigger-db",    metavar="database",  help = "Location of an already-created sqlite database")
    parser.add_option("-s", "--segments",      metavar="url",  help = "URL to contact for DQ flags (ldbd: or file:)")
    parser.add_option("-f", "--html-file",     metavar="file", help = "Location of html file to write")
    parser.add_option("-i", "--ifo",           metavar="ifo",  help = "IFO")
    parser.add_option("-g", "--timestamp-file", metavar="file", help = "Location of file storing last run time")
    parser.add_option("-m", "--min-glitch-snr", metavar="snr",  help = "Minimum SNR to be considered a glitch (default=15.0)",  default="15.0")
    parser.add_option("-k", "--known-count",    metavar="known_count", help = "Max. number of triggers with DQ flags to print (default=10)", default="10")
    parser.add_option("-u", "--unknown-count",  metavar="unknown_count", help = "Max. number of triggers without DQ flags to print (default=10)", default="10")
    parser.add_option("-b", "--gps-start-time",  metavar="gps_start_time", help = "Provide an explicit start time, rather than using the timestamp file")
    parser.add_option("-n", "--gps-end-time",  metavar="gps_end_time", help = "Provide an explicit end time, rather than using now")

    parser.add_option("-a", "--omega-conf",       metavar="omega_conf",      help = "Omega configuration file")
    parser.add_option("-c", "--omega-frame-dir",  metavar="omega_frame_dir", help = "Top level directory of frame files for omega")
    parser.add_option("-d", "--omega-out-dir",    metavar="omega_out_dir",   help = "Output directory for omega results")
    parser.add_option("-r", "--statistic",        metavar="statistic",       help = "Ranking statistic, snr|newsnr, default=snr", default="snr")

    options, others = parser.parse_args()

    # Make sure we have all required parameters
    require(options.trigger_dir or options.trigger_db, "Missing required argument [--trigger-dir | --trigger-db]")
    require(options.ifo, "Missing required argument --ifo")
    require(options.html_file, "Missing required argument --html-file")
    require(options.timestamp_file or (options.gps_start_time and options.gps_end_time), "Missing required argument [--timestamp-file | --gps-start-time --gps-end-time]")

    return options



def sigfigs(num):
    if num == 0.0:
        return num

    if num < 10:
        div = 1.0

        while num < 10:
            num *= 10.0
            div *= 10.0

        return '%.1f' % (int(num) / div)

    mul = 1

    while num > 100:
        num /= 10.0
        mul *= 10

    return str(int(num) * mul)


def generate_html(outf, triggers, colors, options):
    print('<p>', file=outf)
    print('<table border=1>', file=outf)

    if options.statistic == 'newsnr':
        print('  <tr bgcolor="#9999ff"><th>ifo</th><th>end time UTC</th>', end=' ', file=outf)
        print('<th>end_time</th><th>end_time_ns</th>', end=' ', file=outf)
        print('<th>SNR</th><th>&chi;<sup>2</sup></th><th>New SNR</th><th>eff_distance</th>', file=outf)
        print("<th>M<sub>chirp</sub></th>", end=' ', file=outf)
        print("<th>M<sub>total</sub></th>", end=' ', file=outf)
        print("<th>&Omega; scan</th><th>DQ flags</th></tr>", file=outf)
    else:
        print('  <tr bgcolor="#9999ff"><th>ifo</th><th>end time UTC</th><th>end_time</th><th>end_time_ns</th><th>SNR</th><th>eff_distance</th>', file=outf)
        print("<th>M<sub>chirp</sub></th>", end=' ', file=outf)
        print("<th>M<sub>total</sub></th>", end=' ', file=outf)
        print("<th>&chi;<sup>2</sup></th>", end=' ', file=outf)
        print("<th>Reduced cont. &chi;<sup>2</sup></th>", end=' ', file=outf)
        print("<th>&Omega; scan</th><th>DQ flags</th></tr>", file=outf)

    for count, trig in enumerate(triggers):
        print('  <tr valign="top" bgcolor="%s">' % colors[count % 2], file=outf)

        print('    <td>%s</td>' % (trig[0]), file=outf)

        end_time_utc = os.popen('lalapps_tconvert %d' % trig[1]).next().strip()
        print('    <td>%s</td>' % (end_time_utc), file=outf)

        for res in trig[1:-1]:
            if str(res).find('.') > 0:
                res = sigfigs(float(res))

            print('    <td>%s</td>' % res, file=outf)
        
        print('    <td><a href="omega/%s.%s/index.html">&Omega; scan</a></td>' % (trig[1], trig[2]), file=outf)

        print('    <td>', file=outf)
        for name, value in trig[-1].items():
            if name.find('Light') != -1 or name.find('Up') != -1 or name.find('Calibrated') != -1:
                print('      <b>%s</b><br>' % name, file=outf)
            else:
                print('      %s %s %s<br>' % (name, value[0], value[2]), file=outf)
        print('    </td>', file=outf)

        print('  </tr>', file=outf)
    print('</table>', file=outf)

    print(file=outf)
    print(file=outf)




def setup_files(dir_name, db_name, gps_start_time, gps_end_time):
    if db_name:
        temp_db    = None
        connection = sqlite3.connect(db_name)
    else:
        # Filter out the ones that are outside our time range
        xml_files  = segmentdb_utils.get_all_files_in_range(dir_name, gps_start_time, gps_end_time)

        handle, temp_db = tempfile.mkstemp(suffix='.sqlite')
        os.close(handle)

        target     = dbtables.get_connection_filename(temp_db, None, True, False)
        connection = ligolw_sqlite.setup(target)

        ligolw_sqlite.insert_from_urls(connection, xml_files)

    return temp_db, connection
    

def get_times_with_flags(gps_start_time, gps_end_time):
    # Find times where there were no DQ flags, rather than checking the time of each trigger
    segment_connection = segmentdb_utils.setup_database(os.environ['S6_SEGMENT_SERVER'])
    engine             = query_engine.LdbdQueryEngine(segment_connection)

    sql  = "SELECT start_time, end_time FROM segment "
    sql += " WHERE NOT (%s > segment.end_time OR segment.start_time > %s) " % (gps_start_time, gps_end_time)
    sql += " AND segment.segment_def_id IN (SELECT DISTINCT segment_definer.segment_def_id FROM segment_definer, segment_summary "
    sql += "                                WHERE segment_definer.ifos = '%s' " % options.ifo
    sql += "                                AND segment_definer.segment_def_id = segment_summary.segment_def_id "
    sql += "                                AND segment_summary.segment_def_cdb = segment_definer.creator_db "
    sql += "                                AND segment_definer.name NOT IN ('DMT-SCIENCE', 'DMT-LIGHT', 'DMT-UP', 'DMT-CALIBRATED', 'DMT-INJECTION') "
    sql += "                                AND NOT (%s > segment.end_time OR segment.start_time > %s)) " % (gps_start_time, gps_end_time)

    times_with_flags = segmentlist([segment(row[0], row[1]) for row in engine.query(sql)])
    times_with_flags.coalesce()

    return times_with_flags


def newsnr(snr,chisq,chisq_dof):
    index  = 6.0
    rchisq = chisq/(2 * chisq_dof - 2)
    nhigh  = 2.

    if rchisq > 1.:
        return snr / ((1. + rchisq**(index/nhigh))/2)**(1./index)
    else:
        return snr


def oldsnr(snr,chisq,chisq_dof):
    return snr

def calc_mchirp(mass1,mass2):
    return pow(mass1 * mass2, 3.0/5.0) / pow(mass1 + mass2, 1.0/5.0)

def get_triggers_in_times(options, gps_start_time, gps_end_time, valid_times, count):
    if not abs(valid_times):
        return []

    # Note, the S5 version of this script had the condition
    #    search = 'FindChirpSPtwoPN' 
    # The triggers from MBTA don't set this, alough it could if desirable.
    # 
    # When hitting DB2 we could use the following query:
    #
    #  SELECT sngl_inspiral.ifo, 
    #         sngl_inspiral.end_time,
    #         sngl_inspiral.end_time_ns,
    #         sngl_inspiral.snr, 
    #         sngl_inspiral.eff_distance,
    #         sngl_inspiral.f_final, 
    #         sngl_inspiral.ttotal
    #  FROM sngl_inspiral 
    #  WHERE (sngl_inspiral.end_time, sngl_inspiral.snr) IN 
    #         (select end_time, MAX(snr) AS snr 
    #          FROM sngl_inspiral
    #          WHERE end_time >= ? AND
    #                end_time <  ? AND
    #                ifo = ? AND
    #                snr >= 15.0 
    #          GROUP BY end_time
    #          ORDER BY snr desc)
    #  AND sngl_inspiral.ifo = ? 
    #  ORDER BY snr DESC""",  (start_time, end_time, ifo, ifo) )
    #
    # But sqlite doesn't allow this, since:
    #   SQL error: only a single result allowed for a SELECT that is part of an expression
    #
    # So we have to be a little trickier...
    #

    connection.create_function("mchirp", 2, calc_mchirp)

    if options.statistic == 'newsnr':
        connection.create_function("newsnr", 3, newsnr)
        fields = """sngl_inspiral.ifo, 
             sngl_inspiral.end_time,
             sngl_inspiral.end_time_ns,
             sngl_inspiral.snr,
             sngl_inspiral.chisq,
             newsnr(sngl_inspiral.snr,sngl_inspiral.chisq,sngl_inspiral.chisq_dof), 
             sngl_inspiral.eff_distance,
             mchirp(sngl_inspiral.mass1,sngl_inspiral.mass2),
             sngl_inspiral.mass1 + sngl_inspiral.mass2
        """
    else:
        connection.create_function("newsnr", 3, oldsnr)
        fields = """sngl_inspiral.ifo, 
             sngl_inspiral.end_time,
             sngl_inspiral.end_time_ns,
             sngl_inspiral.snr,
             sngl_inspiral.eff_distance,
             sngl_inspiral.chisq,
             sngl_inspiral.mass1,
             sngl_inspiral.mass2,
             sngl_inspiral.bank_chisq / (2.0 * sngl_inspiral.bank_chisq_dof),
             sngl_inspiral.cont_chisq / sngl_inspiral.cont_chisq_dof
        """

    #
    # The join above was meant to ensure that we only pick
    # the loudest trigger from each second, so we don't
    # fill the table with triggers from the same glitch.
    # Bu since we only ever look at 16-second clustered
    # triggers anyway, we don't need that.
    #
    rows = connection.cursor().execute("SELECT " + fields + """
      FROM sngl_inspiral 
      WHERE sngl_inspiral.end_time >= ? AND
            sngl_inspiral.end_time <  ? AND
            sngl_inspiral.ifo = ? AND
            newsnr(sngl_inspiral.snr, sngl_inspiral.chisq, sngl_inspiral.chisq_dof) >= ?
      ORDER BY newsnr(sngl_inspiral.snr,sngl_inspiral.chisq,sngl_inspiral.chisq_dof) DESC""",
      (gps_start_time, gps_end_time, options.ifo, float(options.min_glitch_snr)) )

    ret = []
    i   = 0

    # These variables names are incorrect if we're using new snr: snr appears before new_snr and
    # cont_chisq is eleimated.  But we only really need the time fields, and there's the same 
    # number of values.
    for ifo, end_time, end_time_ns, snr, chisq, new_snr, eff_dist, mchirp, mtotal in rows:
        trig_time = end_time
        if end_time_ns >= 500000000:
            trig_time += 1
    
        if trig_time not in valid_times:
            continue

        # Find the flags on at this time.  We need this even for "unknown" times, since
        # we need to check for DMT-SCIENCE etc
        flags = {}

        pipe  = os.popen('ligolw_dq_query --segment=%s --include-segments %s --in-segments-only --report %d' % (options.segments, ifo, end_time))
        for line in pipe:
            flag, beforet, timet, aftert = [x for x in line.split() if x != '']
    
            ifo, name, version = flag.split(':')
            flags[name] = (beforet, timet, aftert)

        pipe.close()
    
        ifo_status = ''

        for flag_name in ['Light','Up','Calibrated','Science','Injection']:
            flag = 'DMT-' + flag_name.upper()

            if flag in flags:
                ifo_status += flag_name + ','
                del flags[flag]

        if len(ifo_status) > 0:
            ifo_status        = ifo_status[:-1]
            flags[ifo_status] = True

        ret.append((ifo, end_time, end_time_ns, snr, chisq, new_snr, eff_dist, mchirp, mtotal, flags))


        i += 1

        if i > count:
            break

    return ret


#
# =============================================================================
#
#                                     Main
#
# =============================================================================
#

if __name__ == '__main__':
    options = parse_command_line()    

    # Do we have everything we need for omega scans?
    omega_ok = options.omega_conf and options.omega_frame_dir and options.omega_out_dir

    if not omega_ok:
        print("Warning: Missing omega options, omega scans will not be run", file=sys.stderr)


    kcount = int(options.known_count)
    ucount = int(options.unknown_count)

    # Run up to now or the user-provided time
    if options.gps_end_time:
        gps_end_time = int(options.gps_end_time)
    else:
        gps_end_time = gpstime.GpsSecondsFromPyUTC(time.time())

    # Find the last time we ran (in principle we could get this from the time
    # stamp on the HTML file, but we should allow for the possibility that 
    # someone might hand-edit that file)
    if options.gps_start_time:
        gps_start_time = int(options.gps_start_time)
    elif os.path.exists(options.timestamp_file):
        f = open(options.timestamp_file)
        gps_start_time = int(next(f))
        f.close()
    else:
        # Looks like we've never run before, start a day ago
        gps_start_time = gps_end_time - (60 * 60 * 24)


    # Load the relevant trigger XML files into a sqlite DB and
    # get a connection
    temp_db, connection = setup_files(options.trigger_dir, options.trigger_db, gps_start_time, gps_end_time)

    # Did we find any triggers?  If not there won't even be a sngl_inspiral
    # table, so the normal query will fail

    have_triggers = connection.cursor().execute("SELECT COUNT(*) FROM sqlite_master WHERE name = 'sngl_inspiral'").fetchone()[0]

    if have_triggers: 
        times_with_flags    = get_times_with_flags(gps_start_time, gps_end_time)
        times_without_flags = segmentlist([segment(gps_start_time,gps_end_time)]) - times_with_flags

        known_trigs   = get_triggers_in_times(options, gps_start_time, gps_end_time, times_with_flags, kcount)
        unknown_trigs = get_triggers_in_times(options, gps_start_time, gps_end_time, times_without_flags, ucount)
    else:
        # Use an empty array so the rest of the code will flow through and
        # update the page accordingly
        known_trigs    = []
        unknown_trigs  = []

    # ligolw_sicluster --cluster-window ${clusterwindow} --sort-descending-snr ${xmlfile}

    # Do the omega scans if requested
    if omega_ok:
        for trig in (known_trigs + unknown_trigs):
            ifo         = trig[0]
            end_time    = trig[1]
            end_time_ns = trig[2]

            # The frames may live in a directory with subdirectories by time.  If the given directory
            # contains a '%d' populate it with the first 4 digits of the time.
            full_path = options.omega_frame_dir

            if full_path.find('%d') != -1:
                full_path = full_path.replace('%d', str(end_time / 100000))

            if full_path.find('%s') != -1:
                full_path = full_path.replace('%s', ifo[0])

            conf = options.omega_conf.replace('%s',ifo[0])

            os.system("nohup condor_run ~omega/opt/omega/bin/wpipeline scan -r -c %s -f %s -o %s/%d.%d  %d.%d < /dev/null &>/dev/null &" % ( conf, full_path, options.omega_out_dir, end_time, end_time_ns, end_time, end_time_ns ))


    # Convert to html, prepend to the file.
    out_tmp = StringIO.StringIO()
   
    # Could these be done in Python?
    start_time_str = os.popen('lalapps_tconvert %d' % gps_start_time).next().strip()
    end_time_str   = os.popen('lalapps_tconvert %d' % gps_end_time).next().strip()

    print("<h3>%s through %s</h3>" % (start_time_str, end_time_str), file=out_tmp)

    print('<h3>Glitches without associated DQ flag</h3>', file=out_tmp)
    generate_html(out_tmp, unknown_trigs, ['#ffdddd', '#ffcccc'], options)
    

    print('<h3>Glitches with associated DQ flag</h3>', file=out_tmp)
    generate_html(out_tmp, known_trigs, ['#ddffdd', '#ccffcc'], options)

    print("<hr>", file=out_tmp)

    if os.path.exists(options.html_file):
        in_tmp = open(options.html_file)
        for l in in_tmp:
            print(l, end=' ', file=out_tmp)
        in_tmp.close()

    out_html = open(options.html_file,'w')
    print(out_tmp.getvalue(), file=out_html)
    out_html.close()

    # All done, if we weren't given an ending time update the timestamp
    if not options.gps_end_time:
        f = open(options.timestamp_file, 'w')
        print(gps_end_time, file=f)
        f.close()

    # If we created our own database, clean up after ourselves.
    if temp_db:
        os.remove(temp_db)


