#! /usr/bin/python3
from __future__ import division
from __future__ import print_function
import os, sys, pprint, argparse, re,copy
from string import Template

# add path to the ufo conversion modules
modulepath = os.path.join("/usr/lib64/Herwig",'python')
sys.path.append(modulepath)

from ufo2peg import *

# set up the option parser for command line input
parser = argparse.ArgumentParser(
    description='Create Herwig model files from Feynrules UFO input.'
)
parser.add_argument(
    'ufodir',
    metavar='UFO_directory',
    help='the UFO model directory'
)
parser.add_argument(
    '-v', '--verbose',
    action="store_true",
    help="print verbose output"
)
parser.add_argument(
    '-n','--name',
    default="FRModel",
    help="set custom nametag for the model"
)
parser.add_argument(
    '--ignore-skipped',
    action="store_true",
    help="ignore skipped vertices and produce output anyway"
)
parser.add_argument(
    '--split-model',
    action="store_true",
    help="Split the model file into pieces to improve compilation for models with many parameters"
)
parser.add_argument(
    '--no-generic-loop-vertices',
    action="store_true",
    help="Don't include the automatically generated generic loop vertices for h->gg and h->gamma gamma"
)
parser.add_argument(
    '--include-generic',
    action="store_true",
    help="Include support for generic spin structures"
)
parser.add_argument(
    '--use-Herwig-Higgs',
    action="store_true",
    help="Use the internal Herwig SM modeling and vertices for Higgs interactions, there may be sign issues but some UFO models have very minimal Higgs interactions"
)
parser.add_argument(
    '--use-generic-for-tensors',
    action="store_true",
    help="Use the generic machinery for all tensor vertices (debugging only)"
)
parser.add_argument(
    '--forbidden-particle-name',
    action="append",
    default=["eta","phi"],
    help="Add particle names not allowed as names in UFO models to avoid conflicts with"+\
    "Herwig internal particles, names will have _UFO appended"
)
parser.add_argument(
    '--convert',
    action="store_true",
    default=False,
    help="Convert the UFO model for python2 to python3 before loading it."
)
# get the arguments
args = parser.parse_args()
# convert model to python 3 if needed
if(args.convert) :
    prepForConversion(args.ufodir)
    convertToPython3(args.ufodir)
# import the model
try :
    import importlib,importlib.util
    try :
        path,mod = os.path.split(os.path.abspath(args.ufodir))
        spec = importlib.util.spec_from_file_location(mod,os.path.join(os.path.abspath(args.ufodir),"__init__.py"))
        FR = spec.loader.load_module()
    except :
        print ("Could not load the UFO python module.\n",
               "This is usually because you are using python3 and the UFO model is in python2.\n",
               "The --convert option can be used to try and convert it, but there is no guarantee.\n",
               "As this modifies the model you should make sure you have a backup copy.\n")
        quit()
except :
    print ("Newer python import did not work, reverting to python 2 imp module")
    import imp
    path,mod = os.path.split(os.path.abspath(args.ufodir))
    FR = imp.load_module(mod,*imp.find_module(mod,[path]))

##################################################
##################################################

# get the Model name from the arguments
modelname = args.name
libname = modelname + '.so'



# define arrays and variables
#allplist = ""
parmdecls = []
parmgetters = []
parmconstr = []
doinit = []

paramstoreplace_ = []
paramstoreplace_expressions_ = []

# get external parameters for printing
parmsubs = dict( [ (p.name, p.value)
                   for p in FR.all_parameters
                   if p.nature == 'external' ] )

# evaluate python cmath
def evaluate(x):
     import cmath
     return eval(x,
                 {'sec':FR.function_library.sec,
                  'csc':FR.function_library.csc,
                  'cmath':cmath,
                  'complexconjugate':FR.function_library.complexconjugate,
                  'im':FR.function_library.im,
                  're':FR.function_library.re},
                 parmsubs)


## get internal params into arrays
internal = ( p for p in FR.all_parameters
             if p.nature == 'internal' )

#paramstoreplaceEW_ = []
#paramstoreplaceEW_expressions_ = []

# calculate internal parameters
for p in internal:
    parmsubs.update( { p.name : evaluate(p.value) } )
#    if 'aS' in p.value and p.name != 'aS':
#        paramstoreplace_.append(p.name)
#        paramstoreplace_expressions_.append(p.value)
#    if 'aEWM1' in p.value and p.name != 'aEWM1':
#        paramstoreplaceEW_.append(p.name)
#        paramstoreplaceEW_expressions_.append(p.value)
parmvalues=copy.copy(parmsubs)


# more arrays used for substitution in templates
paramsforstream = []
parmmodelconstr = []

# loop over parameters and fill in template stuff according to internal/external and complex/real
# WARNING: Complex external parameter input not tested!
if args.verbose:
    print ('verbose mode on: printing all parameters')
    print ('-'*60)
    paramsstuff = ('name', 'expression', 'default value', 'nature')
    pprint.pprint(paramsstuff)



interfacedecl_T = """\
static Parameter<{modelname}, {type}> interface{pname}
  ("{pname}",
   "The interface for parameter {pname}",
   &{modelname}::{pname}_, {value}, 0, 0,
   false, false, Interface::nolimits);
"""

# sort out the couplings
couplingDefns = { "QED" : 99, "QCD" : 99 }
try :
    for coupling in FR.all_orders:
        name = coupling.name.upper()
        couplingDefns[name]= coupling.expansion_order
except:
    for coupling in FR.all_couplings:
        for name,value in coupling.order.items():
            if(name not in couplingDefns) :
                couplingDefns[name]=99

# sort out the particles

massnames = {}
widthnames = {}

for particle in FR.all_particles:
    # skip ghosts and goldstones
    if(isGhost(particle) or isGoldstone(particle)) :
        continue
    if particle.mass != 'ZERO' and particle.mass.name != 'ZERO':
        if(particle.mass in massnames) :
            if(abs(particle.pdg_code) not in massnames[particle.mass]) :
                massnames[particle.mass].append(abs(particle.pdg_code))
        else :
            massnames[particle.mass] = [abs(particle.pdg_code)]

    if particle.width != 'ZERO' and particle.width.name != 'ZERO':
        if(particle.width in widthnames) :
            if(abs(particle.pdg_code) not in widthnames[particle.width]) :
                widthnames[particle.width].append(abs(particle.pdg_code))
        else :
            widthnames[particle.width] = [abs(particle.pdg_code)]

interfaceDecls = []
modelparameters = {}

for p in FR.all_parameters:
    value = parmsubs[p.name]
    if p.type == 'real':
        assert( value.imag < 1.0e-16 )
        value = value.real
        if p.nature == 'external':
            if p not in massnames and p not in widthnames:
                interfaceDecls.append(
                    interfacedecl_T.format(modelname=modelname,
                                           pname=p.name,
                                           value=value,
                                           type=typemap(p.type))
                )
            else:
                interfaceDecls.append('\n// no interface for %s. Use particle definition instead.\n' % p.name)
            if hasattr(p,'lhablock'):
                lhalabel = '{lhablock}_{lhacode}'.format( lhablock=p.lhablock.upper(), lhacode='_'.join(map(str,p.lhacode)) )
                if p not in massnames and p not in widthnames:
                    parmmodelconstr.append('set %s:%s ${%s}' % (modelname, p.name, lhalabel))
                else:
                    parmmodelconstr.append('# %s is taken from the particle setup' % p.name)
                modelparameters[lhalabel] = value
                parmsubs[p.name] = lhalabel
            else:
                if p not in massnames and p not in widthnames:
                    parmmodelconstr.append('set %s:%s %s' % (modelname, p.name, value))
                else:
                    parmmodelconstr.append('# %s is taken from the particle setup' % p.name)
                parmsubs[p.name] = value
            if p not in massnames and p not in widthnames:
                parmconstr.append('%s_(%s)' % (p.name, value))
            else:
                parmconstr.append('%s_()' % p.name)
        else :
            parmconstr.append('%s_()' % p.name)
            parmsubs[p.name] = value
    elif p.type == 'complex':
        value = complex(value)
        if p.nature == 'external':
            parmconstr.append('%s_(%s,%s)' % (p.name, value.real, value.imag))
        else :
            parmconstr.append('%s_(%s,%s)' % (p.name, 0.,0.))
        parmsubs[p.name] = value
    else:
        raise Exception('Unknown data type "%s".' % p.type)

    parmdecls.append('  %s %s_;' % (typemap(p.type), p.name))
    parmgetters.append('  %s %s() const { return %s_; }' % (typemap(p.type),p.name, p.name))
    paramsforstream.append('%s_' % p.name)
    expression, symbols = 'return %s_' % p.name, None
    if p.nature != 'external':
        expression, symbols = py2cpp(p.value)
        text = add_brackets(expression, symbols)
        text=text.replace('()()','()')
        text=text.replace('sec()(','sec(')
        text=text.replace('csc()(','csc(')
        doinit.append('   %s_ = %s;'  % (p.name, text) )
        if(p.type == 'complex') :
            doinit.append('   if(std::isnan(%s_.real()) || std::isnan(%s_.imag()) || std::isinf(%s_.real()) || std::isinf(%s_.imag())) {throw InitException() << "Calculated parameter %s is nan or inf in Feynrules model. Check your input parameters.";}' % (p.name,p.name,p.name,p.name,p.name) )

        else :
            doinit.append('   if(std::isnan(%s_) || std::isinf(%s_)) {throw InitException() << "Calculated parameter %s is nan or inf in Feynrules model. Check your input parameters.";}' % (p.name,p.name,p.name) )
        if p in massnames:
            for idCode in massnames[p] :
                doinit.append('   resetMass(%s,%s_ * GeV);'  % (idCode, p.name) )
        if p in widthnames:
            for idCode in widthnames[p] :
                doinit.append('   getParticleData(%s)->width(%s_ * GeV);' % (idCode, p.name) )
                doinit.append('   getParticleData(%s)->cTau (%s_ == 0.0 ? Length() : hbarc/(%s_*GeV));' % (idCode, p.name, p.name) )
                doinit.append('   getParticleData(%s)->widthCut(10. * %s_ * GeV);' % (idCode, p.name) )


    elif p.nature == 'external':
        if p in massnames:
            for idCode in massnames[p] :
                doinit.append('   %s_ = getParticleData(%s)->mass()  / GeV;'  % (p.name, idCode) )
        if p in widthnames:
            for idCode in widthnames[p] :
                doinit.append('   %s_ = getParticleData(%s)->width() / GeV;'  % (p.name, idCode) )

    if args.verbose:
        pprint.pprint((p.name,p.value, value, p.nature))

pcwriter = ParamCardWriter(FR.all_parameters)
paramcard_output = '\n'.join(pcwriter.output)

### special treatment
#    if p.name == 'aS':
#        expression = '0.25 * sqr(strongCoupling(q2)) / Constants::pi'
#    elif p.name == 'aEWM1':
#        expression = '4.0 * Constants::pi / sqr(electroMagneticCoupling(q2))'
#    elif p.name == 'Gf':
#        expression = 'generator()->standardModel()->fermiConstant() * GeV2'

paramconstructor=': '

for ncount in range(0,len(parmconstr)) :
    paramconstructor += parmconstr[ncount]
    if(ncount != len(parmconstr) -1) :
        paramconstructor += ','
    if(ncount%5 == 0 ) :
        paramconstructor += "\n"

paramout=""
paramin =""
for ncount in range(0,len(paramsforstream)) :
    if(ncount !=0 ) :
        paramout += "<< " + paramsforstream[ncount]
        paramin  += ">> " + paramsforstream[ncount]
    else :
        paramout += paramsforstream[ncount]
        paramin  += paramsforstream[ncount]
    if(ncount%5 == 0 ) :
        paramout += "\n"
        paramin  += "\n"

parmtextsubs = { 'parmgetters' : '\n'.join(parmgetters),
                 'parmdecls' : '\n'.join(parmdecls),
                 'parmconstr' : paramconstructor,
                 'getters' : '',
                 'decls' : '',
                 'addVertex' : '',
                 'doinit' : '\n'.join(doinit),
                 'ostream' : paramout,
                 'istream' : paramin ,
                 'refs' : '',
                 'parmextinter': ''.join(interfaceDecls),
                 'ModelName': modelname,
                 'calcfunctions': '',
                 'param_card_data': paramcard_output
                 }

##################################################
##################################################
##################################################




# set up the conversion of the vertices
vertexConverter = VertexConverter(FR,parmvalues,couplingDefns)
vertexConverter.readArgs(args)
# convert the vertices
vertexConverter.convert()
cdefs=""
couplingOrders=""
ncount=2
for name,value in couplingDefns.items() :
    if(name=="QED") :
        couplingOrders+="  setCouplings(\"%s\",make_pair(%s,%s));\n" %(name,1,value)
    elif (name=="QCD") :
        couplingOrders+="  setCouplings(\"%s\",make_pair(%s,%s));\n" %(name,2,value)
    else :
        ncount+=1
        cdefs +="  const T %s = %s;\n" % (name,ncount)
        couplingOrders+="  setCouplings(\"%s\",make_pair(%s,%s));\n" % (name,ncount,value)
# coupling definitions
couplingTemplate= """\
namespace ThePEG {{
namespace Helicity {{
namespace CouplingType {{
  typedef unsigned T;
  /**
   *  Enums for couplings
   */
{coup}
}}
}}
}}
"""
if(cdefs!="") :
    cdefs = couplingTemplate.format(coup=cdefs)
parmtextsubs['couplings'] = cdefs
parmtextsubs['couplingOrders']  = couplingOrders

# particles

plist, names = thepeg_particles(FR,parmsubs,modelname,modelparameters,args.forbidden_particle_name,args.use_Herwig_Higgs)

particlelist = [
    "# insert HPConstructor:Outgoing 0 /Herwig/{n}/Particles/{p}".format(n=modelname,p=p)
    for p in names
]
# make the first one active to have something runnable in the example .in file
particlelist[0] = particlelist[0][2:]
particlelist = '\n'.join(particlelist)

modelfilesubs = { 'plist' : plist,
                  'vlist' : vertexConverter.get_vertices(libname),
                  'setcouplings': '\n'.join(parmmodelconstr),
                  'ModelName': modelname
                  }

# write the files from templates according to the above subs
if vertexConverter.should_print():
    MODEL_HWIN = getTemplate('LHC-FR.in')
    if(not args.split_model) :
        MODEL_CC = [getTemplate('Model.cc')]
    else :
        MODEL_EXTRA_CC=getTemplate('Model6.cc')
        extra_names=[]
        extra_calcs=[]
        parmtextsubs['doinit']=""
        for i in range(0,len(doinit)) :
            if( i%20 == 0 ) :
                function_name = "initCalc" +str(int(i/20))
                parmtextsubs['doinit'] += function_name +"();\n"
                parmtextsubs['calcfunctions'] += "void " + function_name + "();\n"
                extra_names.append(function_name)
                extra_calcs.append("")
            extra_calcs[-1] += doinit[i] + "\n"
        for i in range(0,len(extra_names)) :
            ccname =  '%s_extra_%s.cc' % (modelname,i)
            writeFile( ccname,     MODEL_EXTRA_CC.substitute({'ModelName' : modelname,
                                                              'functionName' : extra_names[i],
                                                              'functionCalcs' : extra_calcs[i] }) )




        MODEL_CC = [getTemplate('Model1.cc'),getTemplate('Model2.cc'),getTemplate('Model3.cc'),
                    getTemplate('Model4.cc'),getTemplate('Model5.cc')]
    MODEL_H  = getTemplate('Model.h')
    print ('LENGTH',len(MODEL_CC))
    MODELINFILE = getTemplate('FR.model')
    writeFile( 'LHC-%s.in' % modelname,
               MODEL_HWIN.substitute({ 'ModelName' : modelname,
                                       'Particles' : particlelist })
    )

    modeltemplate = Template( MODELINFILE.substitute(modelfilesubs) )

    writeFile( '%s.h' % modelname,      MODEL_H.substitute(parmtextsubs) )
    for i in range(0,len(MODEL_CC)) :
        if(len(MODEL_CC)==1) :
            ccname = '%s.cc' % modelname
        else :
            ccname = '%s.cc' % (modelname + str(i))
        writeFile( ccname,     MODEL_CC[i].substitute(parmtextsubs) )
    writeFile( modelname +'.template', modeltemplate.template )
    writeFile( modelname +'.model', modeltemplate.substitute( modelparameters ) )
    # copy the Makefile-FR to current directory,
    # replace with the modelname for compilation
    with open(os.path.join(modulepath,'Makefile-FR'),'r') as orig:
        with open('Makefile','w') as dest:
            dest.write(orig.read().replace("FeynrulesModel.so", libname))


print('finished generating model:\t', modelname)
print('model directory:\t\t', args.ufodir)
print('generated:\t\t\t', len(FR.all_vertices), 'vertices')
print('='*60)
print('library:\t\t\t', libname)
print('input file:\t\t\t', 'LHC-' + modelname +'.in')
print('model file:\t\t\t', modelname +'.model')
print('='*60)
print("""\
To complete the installation, compile by typing "make".
An example input file is provided as LHC-FRModel.in,
you'll need to change the required particles in there.
""")
print('DONE!')
print('='*60)
