#----------------------------------#
# -- Author: V.Garonne
# -- Mail: garonne@lal.in2p3.fr
# -- Date: 08/25/2006
# -- Name: tbroadcast
# -- Description: main class
#----------------------------------#

# 21-Jul-2009 compile most used packages first; protect critical sections

import os
import sys
import time
import string
import os.path
# import commands
import traceback
import exceptions
from threading import BoundedSemaphore

from threadpool  import WorkRequest
from threadpool  import ThreadPool

from  subprocess import Popen

class Scheduler:

    def __init__(self, num_workers=20, file=None, ignore_cycles=False, local=False, sort=False,
                 output=None, error=None, silent = False, perf=False, keep_going=True):
        self.pool            = ThreadPool(num_workers=num_workers, poll_timeout=3)
        self.num_workers     = num_workers
        self.current_project = {'name': None, 'path': None, 'version': None}
        self.packages        = {}
        self.counter         = 0
        self.semaphore       = BoundedSemaphore(1)
        self.local           = local
        self.sort            = sort
        self.ignore_cycles   = ignore_cycles
        self.output          = output
        self.error           = error
        self.silent          = silent
        self.perf            = perf
        self.keep_going      = keep_going
        if self.sort:
            print "Compile packages sorted according to use count"
        if self.perf is not False:
            f = open (self.perf, 'w+')
            f.close()
        if output is not None:
            if not os.path.exists (output):
                print "path",output,"does not exists"
                sys.exit(-1)
            if not os.path.isdir(output):
                print "path",output,"is not a valid directory"
                sys.exit(-1)

        # init cmt stuff
        self.get_current_project()
        self.current_package = self.get_current_package()
        self.instanciate_packages (file)
        if self.local: self.get_local_graph()
        self.check_cycles()
        self.get_use_count()

    def get_current_project(self):
        cmd = 'cmt show projects | grep current'
        status, output = getstatusoutput (cmd)
        #status, output = commands.getstatusoutput (cmd)
        if status != 0:
            print output
            sys.exit(-1)
        lines = string.split(output, '\n')
        for line in lines:
            if line!='' and line [0] != '#':
                item  = string.split (line, ' ')
                self.current_project ['name']    = item[0]
                self.current_project ['version'] = item[1]
                self.current_project ['path']    = item[3][:-1]
                version =  self.current_project ['path'][len(self.current_project ['path'])-len(self.current_project ['version'] ):]
                if  self.current_project ['version'] == version:
                    self.current_project ['path'] =  os.path.normpath(self.current_project ['path'][:-len(self.current_project ['version'] )])
                return

    def get_counter(self):
        self.semaphore.acquire ()
        self.counter = self.counter + 1
        value = self.counter
        self.semaphore.release()
        return value

    def check_cycles (self):
        cmd = 'cmt -private show cycles'
        cycle_found = False
        status, output = getstatusoutput (cmd)
        #status, output = commands.getstatusoutput (cmd)
        if status != 0:
            print output
            sys.exit(-1)
        lines = string.split(output, '\n')
        cycles = list()
        for line in lines:
            if line!='' and line [0] != '#':
               cycles.append (string.split(line))
        cercles =list()
        for cycle in cycles:
            cycleInProject = True
            for package in cycle:
                if not self.packages.has_key(package):
                    cycleInProject = False
            if cycleInProject:
              cercles.append(cycle)
        if len(cercles):
            if not self.ignore_cycles:
                print "# Error: cycles found, not possible to execute broadcast with threads. Please correct the following cycles:"
                for cycle in cercles:
                    loop = ""
                    for package in cycle:
                        loop = loop + package + ' -> '
                    print loop + '...'
                sys.exit(-1)
            else:
                print "# Warning: There are cycles and you have selected the automatic suppress cycles mode"
                for cycle in cercles:
                    loop = ""
                    for package in cycle:
                        loop = loop + package + ' -> '
                    if cycle[0] in self.packages[cycle[len(cycle)-1]]['uses']:
                        print '## In cycle: '+loop + '..., we suppress the dependency '+ cycle[len(cycle)-1]+'->'+cycle[0]
                        self.packages[cycle[len(cycle)-1]]['uses'].remove(cycle[0])
#                sys.exit(-1)

    def format_uses (self, content):
        # format variables
        lignes  = string.split(content, '\n')
        lines   = list()
        for ligne in lignes:
           if ligne [0] == '#' and ligne[:5] != "#CMT>" and ligne[:10] != "# Required" and ligne not in ['# Selection :','#']:
               lines.append(ligne)
        lines.reverse()
        return lines

    def format_paths (self, content):
        # format variables
        lignes  = string.split(content, '\n')
        lines   = list()
        for ligne in lignes:
            if ligne[:4] == "use ":
               lines.append(ligne)
        return lines

    def get_paths (self, content):
        lines = self.format_paths(content)
        for line in lines:
                result  = string.split (line[4:len(line)], ' ')
                if  self.packages.has_key(result[0]):
                    if len(result)==4:
                        name, version, offset, path = string.split (line[4:len(line)], " ")
                        #print name, version, offset, path
                        #print path[1:-1] + '/' + offset + '/' +name + '/' + version + '/cmt'
                        if path == '(no_auto_imports)':
                            path   = offset
                            offset = ''
                        if os.path.exists(path[1:-1] + '/' + offset + '/' +name + '/' + version + '/cmt'):
                            full_path = path[1:-1] + '/' + offset + '/' +name + '/' + version + '/cmt'
                        elif os.path.exists(path[1:-1] + '/' + offset + '/' +name + '/cmt'):
                            full_path = path[1:-1] + '/' + offset + '/' +name + '/cmt'
                        else:
                            print '# error path not found for', name
                            sys.exit(-1)
                    elif len(result)==5:
                        name, version, offset, path, importation = string.split (line[4:len(line)], " ")
                        if os.path.exists(path[1:-1] + '/' + offset + '/' +name + '/' + version + '/cmt'):
                            full_path = path[1:-1] + '/' + offset + '/' +name + '/' + version + '/cmt'
                        elif os.path.exists(path[1:-1] + '/' + offset + '/' +name + '/cmt'):
                            full_path = path[1:-1] + '/' + offset + '/' +name + '/cmt'
                        else:
                            print '# error path not found for', name
                            sys.exit(-1)
                    elif len(result)==3:
                        name, version, path = string.split (line[4:len(line)], " ")
                        if os.path.exists(path[1:-1] + '/' +name + '/' + version + '/cmt'):
                            full_path = path[1:-1] + '/' +name + '/' + version + '/cmt'
                        elif os.path.exists(path[1:-1] + '/' +name + + '/cmt'):
                            full_path = path[1:-1] + '/' +name + + '/cmt'
                        else:
                            print '# error path not found for', name
                            sys.exit(-1)
                    else:
                        print "error:",line
                        print str(result)
                        sys.exit(-1)
                    self.packages[result[0]]['path'] = os.path.normpath(full_path)
                    commonprefix = os.path.commonprefix([self.packages[result[0]]['path'], self.current_project ['path']])
                    if os.path.normpath(commonprefix) == self.current_project ['path']:
                        #print result[0], ' belong to project', self.current_project ['name']
                        self.packages[result[0]]['current_project'] = True

    def get_uses(self, content):
        # initiates variables
        lignes = self.format_uses(content)
        if not len(lignes): return
        self.packages [self.current_package] = {'version': '*', 'use_count': 0,
                                                'uses': list(), 'status': 'waiting',
                                                'current_project': True, 'path': os.getcwd()}
        previous_client = self.current_package
        previous_level  = 0
        level_stack    = [{'name':previous_client,'level':previous_level},]
        ligne = lignes.pop()
        while len(lignes)!=0:
            current_level = string.find(ligne, 'use')
            while current_level > previous_level:
                name    = string.split (ligne)[2]
                version = string.split (ligne)[3]
                if not self.packages.has_key (name):
                  self.packages [name] = {'version': version, 'use_count': 0,
                                          'uses': list(), 'status': 'waiting',
                                          'current_project': False, 'path': None}
                if name not in self.packages[previous_client]['uses']:# and name != previous_client:
                   self.packages[previous_client]['uses'].append (name)
                level_stack.append({'name':previous_client,'level':previous_level})
                previous_client = name
                previous_level = current_level
                if len(lignes):
                    ligne = lignes.pop()
                    #print ligne
                    current_level = string.find(ligne, 'use')

            # restore the level
            if len(lignes):
                if len(level_stack):
                    item = level_stack.pop()
                    while item['level'] >= current_level and len(level_stack):
                             item = level_stack.pop()
                    previous_client = item['name']
                    previous_level  = item['level']
            #print previous_client, '-->',string.split (ligne)[2]

    def instanciate_packages(self, file=None):
        # We create the schedule of the work units
        print '# First, we initialize the DAG by parsing "cmt show uses"'
        if file is None:
            cmd  = 'cmt show uses'
        else:
            cmd = 'cat ' + file
        status, output = getstatusoutput (cmd)
        #status, output = commands.getstatusoutput (cmd)
        if status != 0:
            print output
            sys.exit(-1)
        self.get_uses(output)
        self.get_paths(output)
        #self.simulate_execution()

    def get_use_count(self):
        for key in self.packages:
            count = 0
            for parent in self.packages:
                if key in self.packages[parent]['uses']: count += 1
            self.packages[key]['use_count'] = count
            # print "Package",key,"use_count",count

    def get_local_graph(self):
        To_remove = list()
        for key in self.packages:
            if self.packages[key]['current_project']== False:
                for selected in self.packages:
                    if key in self.packages[selected]['uses']:
                       self.packages[selected]['uses'].remove(key)
                To_remove.append (key)
        for item in To_remove:
            del self.packages[item]

    def simulate_execution(self):
        while True:
            ndone = self.simulate_requests()
            if ndone == 0: break

    def simulate_requests(self):
        runnable = self.get_next_work_units()
        if len(runnable):
            print '\n#--------------------------------------------------------------'
            print "# Execute parallel actions within packages - total", len(runnable)
            print '#--------------------------------------------------------------'
        for selected in runnable:
            use_count = self.packages[selected]['use_count']
            path = self.packages[selected]['path']
            print '#--------------------------------------------------------------'
            print '# (%d/%d %d) Now trying [] in %s' % (self.get_counter(), len(self.packages), use_count, path)
            print '#--------------------------------------------------------------'
            self.suppress_work_unit(selected)
        return len(runnable)

    def get_current_package(self):
        cmd            = 'cmt show macro package'
        status, output = getstatusoutput (cmd)
        #status, output = commands.getstatusoutput (cmd)
        if status != 0:
            print output
            sys.exit(-1)
        lines = string.split(output, '\n')
        for line in lines:
            if line [0] != '#':
                start = string.find(line,"'")
                end   = string.find(line[start+1:len(line)],"'")
                return line [start+1:start+end+1]

    def print_dependencies(self):
        print '# -------------------------------------------'
        print '# package --> dependencies, status, use_count'
        print '# -------------------------------------------'
        for key in self.packages.keys():
            print key, '-->', self.packages[key] ['uses'],',', self.packages[key] ['status'],',', self.packages[key] ['use_count']

    def print_status(self, status):
        print '# ------------------------'
        print '# package --> dependencies'
        print '# ------------------------'
        i = 1
        for key in self.packages.keys():
            if self.packages[key] ['status'] == status:
                print i , key, '-->', self.packages[key] ['uses'],',', self.packages[key] ['status']
                i = i + 1

    def is_work_unit_waiting (self, name):
        return self.packages[name] ['status'] == 'waiting'

    def set_work_unit_status (self, name, status):
        self.packages[name] ['status'] = status

    def get_next_work_units (self):
        # by default returned list is in arbitrary order - may be this is better
        # if self.sort is set returned list is sorted - most used packages first
        runnable = list()
        for key in self.packages:
            if len(self.packages[key]['uses']) == 0 and self.packages[key]['status'] == 'waiting':
                use_count = self.packages[key]['use_count']
                runnable.append((use_count,key))
        if self.sort:
            runnable.sort()
            runnable.reverse()
        result = [ pair[1] for pair in runnable ]
        return result

    def suppress_work_unit (self, name):
        #print '# remove', name, 'from schedule'
        self.semaphore.acquire()
        self.packages[name]['status']='done'
        for key in self.packages.keys():
            if name in self.packages[key]['uses']:
                self.packages[key]['uses'].remove(name)
        self.semaphore.release()

    def add_work_unit (self, name, cmd):
        if self.is_work_unit_waiting (name):
            # we create requests
            arg = {'cmd': cmd , 'package': name}
            req = WorkRequest(self.do_execute, [arg] , None, callback=self.result_callback)
#           req = WorkRequest(self.do_execute, [arg] , None,
#                 callback=self.result_callback, exc_callback=self.handle_exception)
            # then we put the work request in the queue...
            self.set_work_unit_status (name, 'queued')
            self.pool.putRequest(req)
            # print "# Work request #%s added on %s." % (req.requestID, str(arg['package']))

    def execute (self, command):
        #self.print_dependencies ()
        self.semaphore.acquire()
        packages = self.get_next_work_units()
        if len(packages):
            print '\n#--------------------------------------------------------------'
            print '# Execute parallel actions within packages (total',len(packages),')',packages
            print '\n#--------------------------------------------------------------'
            for package in packages:
                self.add_work_unit (package, command)
            sys.stdout.flush()
        self.semaphore.release()

    def execute_all(self,command):
        #self.print_dependencies ()
        self.execute (command)
        self.wait()
        if self.counter != len(self.packages):
            print 'tbroadcast: warning: compiled',self.counter,'out of',len(self.packages),'packages'
        self.pool.dismissWorkers(self.num_workers, do_join=True)
        #self.print_dependencies ()
        #self.print_status (status='waiting')

    def wait (self):
       self.pool.wait()

    # this will be called each time a result is available
    def result_callback(self, request, result):
      #print "**Result: %s from request #%s" % (str(result), request.requestID)
      #print "# Result: %s from request #%s" % (result['package'], request.requestID)
      self.execute (result['cmd'])

    # the work the threads will have to do
    def do_execute(self, arg):
      package = arg['package']
      path = self.packages[package]['path']
      if path == None or not os.path.exists(path):
          raise RuntimeError('Path to package '+ package +' not found')
      self.set_work_unit_status(package, 'running')
      header =          '#--------------------------------------------------------------\n'
      header = header + '# ('+str(self.get_counter())+'/'+str(len(self.packages))+') Now trying ['+arg['cmd']+'] in '+path+'\n'
      header = header + '#--------------------------------------------------------------\n'
      print header
      sys.stdout.flush()
      project_path = self.current_project['path']+'/'+self.current_project['version']+'/'
      log_name     = string.replace(path, project_path, '')
      log_name     = string.replace(log_name, '/cmt', '')
      log_name     = string.replace(log_name, '/', '_')
      log_name     = log_name+'.loglog'
      # arg['log']   = log_name
      cmd = "cd "+ path +";"+ arg['cmd']
      # status, output = commands.getstatusoutput(cmd)
      # init output file

      self.packages[package] ['startTime'] = time.time()

      if self.output is not None:
           f1 = open (self.output+'/'+ log_name, 'w+')
           f1.write (header)
           f1.flush()
           if self.error is not None:
               f2 = open (self.error+'/error'+log_name, 'w+')
               fp = Popen(cmd, shell=True, stdout=f1, stderr=f2)
               fp.communicate()
               status = fp.wait()
               f2.close()
           else:
               fp = Popen(cmd, shell=True, stdout=f1, stderr=f1)
               fp.communicate()
               status = fp.wait()
           f1.close()
      else:
           fp = Popen(cmd, shell=True)
           fp.communicate()
           status = fp.wait()
      sys.stdout.flush()
      sys.stderr.flush()

      # Error is not handled - exit() is forbidden here
      if not self.keep_going and status > 0:
          print 'Error',status,'for package',package
          # sys.exit(status)

      self.packages[package] ['endTime'] = time.time()
      if self.perf:
          self.semaphore.acquire()
          f = open (self.perf, 'a')
          f.write (package+" "+str(self.packages[package]['startTime'])+" "+str(self.packages[package]['endTime'] )+'\n')
          f.close()
          self.semaphore.release()
      self.suppress_work_unit(package)
      return {'cmd': arg['cmd'], 'package':arg['package']}


    # this will be called when an exception occurs within a thread
    def handle_exception(self, request, exc_info):
      #traceback.print_stack()
      print '#--------------------------------------------------------------'
      #print "# Exception occured in request #%s: %s" %(request.requestID, exc_info[1])
      if exc_info[0]== exceptions.SystemExit:
        print "Stop execution (No_keep_going option enabled): exit code == %s " %(exc_info[1])
        print '#--------------------------------------------------------------'
        sys.exit(exc_info[1])
      print "# Exception occured: %s" %(exc_info[1])
      print exc_info
      print '#--------------------------------------------------------------'
      #sys.exit(-1)


    def generate_make (self, file, command):
        makefile = open (file, 'w+')
        makefile.write ('MAKE=make\n')
        #MFLAGS= -j10
        self.counter = len(self.packages)
        self.recursive_make (self.current_package, command, makefile, len(self.packages))
        makefile.close ()

    def recursive_make (self, package, command, makefile, indice,actions=list()):
        lines = self.generate_action_make (package, command, indice)
        makefile.write (lines)
        #print lines
        for pkg in self.packages[package] ['uses']:
            if pkg not in actions:
                actions.append(pkg)
                indice = indice - 1
                self.counter = self.counter - 1
                self.recursive_make(pkg, command,makefile, indice, actions)

    def generate_action_make (self, package, command, indice):
        lines = package + ' :: '
        # add dependencies
        for pkg in self.packages[package] ['uses']:
            lines = lines + ' ' + pkg

        # add the action itself
        newcommand = string.replace (command, '<package>', package)
        if command =='':
            newcommand='$(MAKE)'
        lines = lines + '\n'
        lines = lines +  '\t@echo "#--------------------------------------------------------------"\n'
        lines = lines +  '\t@echo "# ('+str(self.counter)+'/'+str(len(self.packages))+') Now trying ['+newcommand+'] in '+ self.packages[package]['path']+'"\n'
        lines = lines +  '\t@echo "#--------------------------------------------------------------"\n'
        lines = lines +  'ifdef LOCATION\n'
        lines = lines +  '\t@echo "#--------------------------------------------------------------"> $(LOCATION)/'+ package +'.loglog\n'
        lines = lines +  '\t@echo "# ('+str(self.counter)+'/'+str(len(self.packages))+') Now trying ['+newcommand+'] in '+ self.packages[package]['path']+'">> $(LOCATION)/'+ package +'.loglog\n'
        lines = lines +  '\t@echo "#--------------------------------------------------------------">> $(LOCATION)/'+ package +'.loglog\n'
        lines = lines + '\t+@cd ' + self.packages[package]['path']
        lines = lines + ' && ' + newcommand + ' >> $(LOCATION)/'+ package +'.loglog 2>&1\n'
        lines = lines + 'else\n'
        lines = lines + '\t+@cd ' + self.packages[package]['path']
        lines = lines + ' && ' + newcommand + '\n'
        lines = lines + 'endif\n\n'
        return lines


# own copy of getstatusoutput
def getstatusoutput(cmd):
    """Return (status, stdout) of executing cmd in a shell.

    A trailing line separator is removed from the output string.
    The exit status of the command is encoded in the format specified for wait(),
    when the exit status is zero (termination without errors), 0 is returned.
    """
    import os
    p = os.popen(cmd, 'r')
    out = p.read()
    sts = p.close()
    if sts is None: sts = 0
    if out.endswith(os.linesep):
        out = out[:out.rindex(os.linesep)]
    return sts, out

#--------- EoF --------#
