#!/usr/bin/env python
#
# Copyright (C) 2011 Chad Hanna
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 2 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

## @file
# An HTCondor DAG generator to recolor frame data

"""
This program makes a dag to recolor frames
"""

__author__ = 'Chad Hanna <chad.hanna@ligo.org>'


##############################################################################
# import standard modules and append the lalapps prefix to the python path
import sys, os, copy, math
import subprocess, socket, tempfile

##############################################################################
# import the modules we need to build the pipeline
from glue import iterutils
from glue import pipeline
from glue import lal
from glue.ligolw import lsctables
from glue import segments
from glue.ligolw import ligolw
import glue.ligolw.utils as utils
import glue.ligolw.utils.segments as ligolw_segments
from optparse import OptionParser
from gstlal import datasource
from gstlal import dagparts

class ContentHandler(ligolw.LIGOLWContentHandler):
        pass
lsctables.use_in(ContentHandler)

#
# Classes for generating reference psds
#

class gstlal_reference_psd_job(pipeline.CondorDAGJob):
	"""
	A gstlal_reference_psd job
	"""
	def __init__(self, executable=dagparts.which('gstlal_reference_psd'), tag_base='gstlal_reference_psd'):
		"""
		"""
		self.__prog__ = 'gstlal_reference_psd'
		self.__executable = executable
		self.__universe = 'vanilla'
		pipeline.CondorDAGJob.__init__(self,self.__universe,self.__executable)
		self.add_condor_cmd('getenv','True')
		self.add_condor_cmd('requirements', 'Memory > 1999') #FIXME is this enough?
		self.tag_base = tag_base
		self.add_condor_cmd('environment',"KMP_LIBRARY=serial;MKL_SERIAL=yes")
		self.set_sub_file(tag_base+'.sub')
		self.set_stdout_file('logs/'+tag_base+'-$(macroid)-$(process).out')
		self.set_stderr_file('logs/'+tag_base+'-$(macroid)-$(process).err')


class gstlal_median_psd_job(pipeline.CondorDAGJob):
	"""
	A gstlal_median_psd job
	"""
	def __init__(self, executable=dagparts.which('gstlal_ninja_median_of_psds'), tag_base='gstlal_ninja_median_of_psds'):
		"""
		"""
		self.__prog__ = 'gstlal_ninja_median_of_psds'
		self.__executable = executable
		self.__universe = 'vanilla'
		pipeline.CondorDAGJob.__init__(self,self.__universe,self.__executable)
		self.add_condor_cmd('getenv','True')
		self.tag_base = tag_base
		self.add_condor_cmd('environment',"KMP_LIBRARY=serial;MKL_SERIAL=yes")
		self.set_sub_file(tag_base+'.sub')
		self.set_stdout_file('logs/'+tag_base+'-$(macroid)-$(process).out')
		self.set_stderr_file('logs/'+tag_base+'-$(macroid)-$(process).err')


class gstlal_ninja_smooth_reference_psd_job(pipeline.CondorDAGJob):
	"""
	A gstlal_ninja_smooth_reference_psd job
	"""
	def __init__(self, executable=dagparts.which('gstlal_ninja_smooth_reference_psd'), tag_base='gstlal_ninja_smooth_reference_psd'):
		"""
		"""
		self.__prog__ = 'gstlal_ninja_smooth_reference_psd'
		self.__executable = executable
		self.__universe = 'vanilla'
		pipeline.CondorDAGJob.__init__(self,self.__universe,self.__executable)
		self.add_condor_cmd('getenv','True')
		self.tag_base = tag_base
		self.add_condor_cmd('environment',"KMP_LIBRARY=serial;MKL_SERIAL=yes")
		self.set_sub_file(tag_base+'.sub')
		self.set_stdout_file('logs/'+tag_base+'-$(macroid)-$(process).out')
		self.set_stderr_file('logs/'+tag_base+'-$(macroid)-$(process).err')


class gstlal_reference_psd_node(pipeline.CondorDAGNode):
	"""
	A gstlal_reference_psd node
	"""
	def __init__(self, job, dag, frame_cache, gps_start_time, gps_end_time, instrument, channel, injections=None, p_node=[]):

		pipeline.CondorDAGNode.__init__(self,job)
		self.add_var_opt("frame-cache", frame_cache)
		self.add_var_opt("gps-start-time", gps_start_time)
		self.add_var_opt("gps-end-time", gps_end_time)
		self.add_var_opt("data-source", "frames")
		self.add_var_arg("--channel-name=%s=%s" % (instrument, channel))
		if injections:
			self.add_var_opt("injections", injections)
		path = os.getcwd()
		output_name = self.output_name = '%s/%s-%d-%d-reference_psd.xml.gz' % (path, instrument, gps_start_time, gps_end_time)
		self.add_var_opt("write-psd",output_name)
		dag.output_cache.append(lal.CacheEntry(instrument, "-", segments.segment(gps_start_time, gps_end_time), "file://localhost/%s" % (output_name,)))
		for p in p_node:
			self.add_parent(p)
		dag.add_node(self)


class gstlal_ninja_smooth_reference_psd_node(pipeline.CondorDAGNode):
	"""
	A gstlal_ninja_smooth_reference_psd node
	"""
	def __init__(self, job, dag, instrument, input_psd, p_node=[]):
		pipeline.CondorDAGNode.__init__(self,job)
		path = os.getcwd()
		#FIXME shouldn't be hardcoding stuff like this
		output_name = self.output_name = input_psd.replace('reference_psd', 'smoothed_reference_psd')
		self.add_var_opt("instrument", instrument)
		self.add_var_opt("input-psd", input_psd)
		self.add_var_opt("output-psd", output_name)
		for p in p_node:
			self.add_parent(p)
		dag.add_node(self)


class gstlal_median_psd_node(pipeline.CondorDAGNode):
	"""
	A gstlal_median_psd node
	"""
	def __init__(self, job, dag, instrument, input_psds, output, p_node=[]):
		pipeline.CondorDAGNode.__init__(self,job)
		path = os.getcwd()
		#FIXME shouldn't be hardcoding stuff like this
		output_name = self.output_name = output
		self.add_var_opt("instrument", instrument)
		self.add_var_opt("output-name", output_name)
		for psd in input_psds:
			self.add_file_arg(psd)
		for p in p_node:
			self.add_parent(p)
		dag.add_node(self)


#
# classes for generating recolored frames
#

class gstlal_fake_frames_job(pipeline.CondorDAGJob):
	"""
	A gstlal_fake_frames job
	"""
	def __init__(self, executable=dagparts.which('gstlal_fake_frames'), tag_base='gstlal_fake_frames'):
		"""
		"""
		self.__prog__ = 'gstlal_fake_frames'
		self.__executable = executable
		self.__universe = 'vanilla'
		pipeline.CondorDAGJob.__init__(self,self.__universe,self.__executable)
		self.add_condor_cmd('getenv','True')
		self.add_condor_cmd('requirements', 'Memory > 1999') #FIXME is this enough?
		self.tag_base = tag_base
		self.add_condor_cmd('environment',"KMP_LIBRARY=serial;MKL_SERIAL=yes")
		self.set_sub_file(tag_base+'.sub')
		self.set_stdout_file('logs/'+tag_base+'-$(macroid)-$(process).out')
		self.set_stderr_file('logs/'+tag_base+'-$(macroid)-$(process).err')


class gstlal_fake_frames_node(pipeline.CondorDAGNode):
	"""
	A gstlal_fake_frames node
	"""
	def __init__(self, job, dag, frame_cache, gps_start_time, gps_end_time, channel, reference_psd, color_psd, sample_rate, injections=None, output_channel_name = None, duration = 4096, output_path = None, frame_type = None, shift = None, whiten_track_psd = False, frames_per_file = 1, p_node=[]):

		pipeline.CondorDAGNode.__init__(self,job)
		self.add_var_opt("frame-cache", frame_cache)
		self.add_var_opt("gps-start-time",gps_start_time)
		self.add_var_opt("gps-end-time",gps_end_time)
		self.add_var_opt("data-source", "frames")
		self.add_var_arg("--channel-name=%s=%s" % (instrument, channel))
		self.add_var_opt("whiten-reference-psd",reference_psd)
		self.add_var_opt("color-psd", color_psd)
		self.add_var_opt("sample-rate", sample_rate)
		if injections is not None:
			self.add_var_opt("injections", injections)
		self.add_var_opt("output-channel-name", output_channel_name)
		self.add_var_opt("frame-duration", duration)
		if output_path is not None:
			self.add_var_opt("output-path", output_path)
		self.add_var_opt("frame-type", frame_type)
		if whiten_track_psd:
			self.add_var_opt("whiten-track-psd",reference_psd)
		if shift:
			self.add_var_opt("shift", shift)
		self.add_var_opt("frames-per-file", frames_per_file)
		for p in p_node:
			self.add_parent(p)
		dag.add_node(self)


def choosesegs(seglists, min_segment_length):
	for instrument, seglist in seglists.iteritems():
		newseglist = segments.segmentlist()
		for seg in seglist:
			if abs(seg) > min_segment_length:
				newseglist.append(segments.segment(seg))
		seglists[instrument] = newseglist


def parse_command_line():
	parser = OptionParser(description = __doc__)
	
	parser.add_option("--frame-cache", metavar = "filename", help = "Set the name of the LAL cache listing the LIGO-Virgo .gwf frame files (optional)")
	parser.add_option("--injections", metavar = "filename", help = "Set the name of the LIGO light-weight XML file from which to load injections (optional).")
	parser.add_option("--channel-name", metavar = "name", action = "append", help = "Set the name of the channels to process.  Can be given multiple times as --channel-name=IFO=CHANNEL-NAME")
	parser.add_option("--frame-segments-file", metavar = "filename", help = "Set the name of the LIGO light-weight XML file from which to load frame segments. Required")
	parser.add_option("--frame-segments-name", metavar = "name", help = "Set the name of the segments to extract from the segment tables. Required")
 
	parser.add_option("--min-segment-length", metavar = "SECONDS", help = "Set the minimum segment length to process (required)", type="float")
	parser.add_option("--shift", metavar = "NANOSECONDS", help = "Number of nanoseconds to delay (negative) or advance (positive) the time stream", type = "int")
	parser.add_option("--sample-rate", metavar = "HZ", default = 16384, type = "int", help = "Sample rate at which to generate the data, should be less than or equal to the sample rate of the measured psds provided, default = 16384 Hz")
	parser.add_option("--whiten-type", metavar="psdperseg|medianofpsdperseg|FILE", help = "Whiten whatever data is coming out of datasource either from the data or from a fixed reference psd if a file is given")
	parser.add_option("--whiten-track-psd", action = "store_true", help = "Calculate PSD from input data and track with time.")
	parser.add_option("--color-psd", metavar = "FILE", help = "Set the name of psd xml file to color the data with")
	parser.add_option("--output-path", metavar = "IFO=PATH", action = "append", help = "Set the instrument dependent output path for frames, defaults to current working directory. eg H1=/path/to/H1/frames. Can be given more than once.")
	parser.add_option("--output-channel-name", metavar = "IFO=NAME", action="append", help = "The name of the channel in the output frames. The default is the same as the channel name. can be given more than once. Required ")
	parser.add_option("--frame-type", metavar = "IFO=TYPE", action = "append", help = "Set the instrument dependent frame type, H1=TYPE. Can be given more than once and is required for each instrument processed.")
	parser.add_option("--frame-duration", metavar = "SECONDS", default = 16, type = "int", help = "Set the duration of the output frames.  The duration of the frame file will be multiplied by --frames-per-file.  Default: 16s")
	parser.add_option("--frames-per-file", metavar = "INT", default = 256, type = "int", help = "Set the number of frames per file.  Default: 256")
	parser.add_option("--verbose", action = "store_true", help = "Be verbose")
	
	options, filenames = parser.parse_args()

	fail = ""
	for option in ("min_segment_length", "frame_type", "frame_segments_file", "frame_segments_name"):
		if getattr(options, option) is None:
			fail += "must provide option %s\n" % (option)
	if fail:
		raise ValueError(fail)

	inchannels = datasource.channel_dict_from_channel_list(options.channel_name)
	outchannels = datasource.channel_dict_from_channel_list(options.output_channel_name)
	frametypes = datasource.channel_dict_from_channel_list(options.frame_type)
	outpaths = datasource.channel_dict_from_channel_list(options.output_path)

	if not (set(frametypes) == set(inchannels) == set(outchannels)):
		raise ValueError('--frame-type, --channel-name and --output-channel-name must contain same instruments')

	return options, inchannels, outchannels, outpaths, frametypes, filenames


options, inchannels, outchannels, outpaths, frametypes, filenames = parse_command_line()

try:
	os.mkdir("logs")
except:
	pass

dag = dagparts.CondorDAG("gstlal_fake_frames_pipe")

seglists = ligolw_segments.segmenttable_get_by_name(utils.load_filename(options.frame_segments_file, verbose = options.verbose, contenthandler = ContentHandler), options.frame_segments_name).coalesce()
choosesegs(seglists, options.min_segment_length)

psdJob = gstlal_reference_psd_job()
smoothJob = gstlal_ninja_smooth_reference_psd_job()
medianJob = gstlal_median_psd_job()
colorJob = gstlal_fake_frames_job()

smoothnode = {}
mediannode = {}
p_node = dict([(i, []) for i in seglists])

if options.whiten_type in ("psdperseg", "medianofpsdperseg"):
	psd = {}
	for instrument, seglist in seglists.iteritems():
		mediannode[instrument] = {}
		smoothnode[instrument] = {}
		psd[instrument] = {}
		for seg in seglist:
			#FIXME if there are sements without frame caches this will barf
			psdnode = gstlal_reference_psd_node(psdJob, dag, options.frame_cache, int(seg[0]), int(seg[1]), instrument, inchannels[instrument], injections=None, p_node=[])
			smoothnode[instrument][seg] = gstlal_ninja_smooth_reference_psd_node(smoothJob, dag, instrument, psdnode.output_name,  p_node=[psdnode])
			if options.whiten_type == "psdperseg":
				psd[instrument][seg] = smoothnode[instrument][seg].output_name

		mediannode[instrument] = gstlal_median_psd_node(medianJob, dag, instrument, [v.output_name for v in smoothnode[instrument].values()], "%s_median_psd.xml.gz" % instrument, p_node=smoothnode[instrument].values())
		p_node[instrument] = [mediannode[instrument]]
		if options.whiten_type == "medianofpsdperseg":
			psd[instrument] = mediannode[instrument].output_name

elif options.whiten_type is not None:
	psd = lalseries.read_psd_xmldoc(utils.load_filename(options.whiten_reference_psd, verbose = options.verbose, contenthandler = ligolw.LIGOLWContentHandler))
else:
	psd = dict([(i, None) for i in seglists])

for instrument, seglist in seglists.iteritems():
	try:
		output_path = outpaths[instrument]
	except KeyError:
		output_path = None
	for seg in seglist:
		try:
			reference_psd = psd[instrument][seg]
		except TypeError:
			reference_psd = psd[instrument]
		gstlal_fake_frames_node(colorJob, dag, options.frame_cache, int(seg[0]), int(seg[1]), inchannels[instrument], reference_psd, color_psd=options.color_psd, sample_rate = options.sample_rate, injections=options.injections, output_channel_name = outchannels[instrument], output_path = output_path, duration = options.frame_duration, frame_type = frametypes[instrument], shift = options.shift, whiten_track_psd = options.whiten_track_psd, frames_per_file = options.frames_per_file, p_node=p_node[instrument])
		
dag.write_sub_files()
dag.write_dag()
dag.write_script()
dag.write_cache()
