#! /usr/bin/python3

import os,sys,glob,errno,shutil,time,fnmatch,tempfile
from optparse import OptionParser

def mkdir_p(path):
    """Recursive mkdir"""
    if not os.path.isdir(path):
        os.makedirs(path)

def finddirs(pattern, path):
    """Find all paths which contain a filename matching 'pattern'"""
    result = [
        root
        for root, _, files in os.walk(path)
        if fnmatch.filter(files, pattern)
    ]
    return result

# fill the proc.dat file from BornAmplitudes.dat and VirtAmplitudes.dat.
def fillprocs(model,oras,orew,MGVersion):
    """Fill proc.dat from BornAmplitudes.dat and VirtAmplitudes.dat"""
    oras=int(oras)
    orew=int(orew)

    bornlist=[]
    virtlist=[]
    result = []
    result.append('set fortran_compiler gfortran --no_save')
    result.append("import model "+model)
    borns="BornAmplitudes.dat"
    virts="VirtAmplitudes.dat"
    procnr=0
    virtlines=[]
    bornlines=[]

    # first pass to work out leg numbers
    minlegs=100
    legs=0
    for i in [borns, virts]:
        with open(i, "r") as f:
            for line in f:
                line = line.split(" ")
                N = len(line)
                if N < minlegs:
                    minlegs = N
                    for it in line:
                        if it.replace("-","").isdigit():
                            legs += 1

    #conversion for heft model to go from (2QCD+1QED)->1HIG for each FS HIGGS.
    if model=="heft":
        HIG = oras+orew-legs+2
        if HIG % 2 == 0:
            HIG = HIG//2
        else:
            sys.stderr.write("Warning: No possible coupling power:(int(oras)+int(orew)-legs+2)%2!=0\n")
            exit(1)
    else:
        HIG = 0

    with open(borns, "r") as f:
        for i, line in enumerate(f):
            #this assumes extra QCD emmissions
            addalphas = len(line.split(" ")) - minlegs
            line = line.rstrip()
            procnr += 1
            bornlist.append(str(procnr))
            cmd = "generate   " if i==0 else "add process"
            hig = ""
            qcd = oras+addalphas
            qed = orew
            if HIG:
                hig = "HIG=%s" % HIG
                qcd -= 2*HIG
                qed -= HIG
            bornlines.append("{cmd} {line} {hig} QCD<={qcd} QED<={qed} @{procnr}".format(**locals()))

    result.extend(bornlines)

    with open(virts, "r") as f:
        for i, line in enumerate(f):
            addalphas = len(line.split(" ")) - minlegs
            line = line.rstrip()
            procnr += 1
            virtlist.append(str(procnr))
            cmd = "generate   " if i==0 else "add process"
            qcd = oras+addalphas
            qed = orew
            virtlines.append("{cmd} {line} QCD<={qcd} QED<={qed} [ virt=QCD ] @{procnr}".format(**locals()))

    if virtlines and bornlines:
        result.append("output matchbox MG5 --postpone_model")

    result.extend(virtlines)

    result.append("output matchbox MG5 -f\n")

    with open("proc.dat","w") as f:
        f.write('\n'.join(result))

    return bornlist,virtlist




def build_matchbox_tmp(pwd,buildpath,absolute_links):
    """Create the working directories and link the required files from central MG"""
    pwd = os.path.abspath(pwd)

    if not os.path.isabs(buildpath):
        buildpath = os.path.join(pwd,buildpath)

    resources =glob.glob(os.path.join(buildpath,"MG5","SubProcesses","MadLoop5_resources","*"))
    resources+=glob.glob(os.path.join(buildpath,"MG5","Cards","*"))
    resources+=glob.glob(os.path.join(buildpath,"MG5","Cards","SubProcesses","*"))

    for i in resources:
        f = os.path.join(MG_TMP, os.path.basename(i))
        if not os.path.isfile(f) and not os.path.islink(f):
            if absolute_links:
                os.symlink(i,f)
            else:
                os.symlink(os.path.relpath(i,MG_TMP), f)





parser = OptionParser()
parser.add_option("-a", "--buildpath", dest="buildpath",help="Do not use this script. Only for Herwig internal use. ")
parser.add_option("-b", "--build", action="store_true", dest="build", default=True,help="Do not use this script. Only for Herwig internal use.")
parser.add_option("-c", "--madgraph", dest="madgraph",help="Do not use this script. Only for Herwig internal use.")
parser.add_option("-d", "--runpath", dest="runpath",help="Do not use this script. Only for Herwig internal use.")
parser.add_option("-e", "--model", dest="model",help="Do not use this script. Only for Herwig internal use.")
parser.add_option("-f", "--orderas", dest="orderas",help="Do not use this script. Only for Herwig internal use.")
parser.add_option("-g", "--orderew", dest="orderew",help="Do not use this script. Only for Herwig internal use.")
parser.add_option("-i", "--datadir",dest="datadir",help="Do not use this script. Only for Herwig internal use.")
parser.add_option("-I", "--includedir",dest="includedir",help="Do not use this script. Only for Herwig internal use.")
parser.add_option("-l", "--absolute-links",action="store_true", dest="absolute_links", default=False,\
                  help="Do not use this script. Only for Herwig internal use.")
parser.add_option("--cacheprefix", dest="cacheprefix",help="Do not use this script. Only for Herwig internal use.")

(options, args) = parser.parse_args()

if not options.cacheprefix:
    sys.stderr.write("Need a value for --cacheprefix\n")
    exit(1)



cachedir = os.path.abspath(options.cacheprefix)
mkdir_p(cachedir)


#MG_TMP = tempfile.mkdtemp(prefix='MG_tmp_',dir=cachedir)
MG_TMP = os.path.join(cachedir, 'MG_tmp')
mkdir_p(MG_TMP)

pwd=os.getcwd()

param_card=""






if options.model=="loop_sm" or options.model=="heft":

  if options.model=="loop_sm":
    param_card="param_card.dat"
  else:
    param_card="param_card_"+options.model+".dat"

  file = open("%s/MadGraphInterface/%s.in" % (options.datadir,param_card) , "r")
  paramcard = file.read()
  file.close()
  file = open(options.runpath+"/"+param_card, "w")

  params=open(options.runpath+"/MG-Parameter.dat", "r")

  for line in params:
    a=line.split()
    paramcard=paramcard.replace(a[0],a[1])
  params.close()
  file.write(paramcard)
  file.close()
elif  options.model.startswith("/"):

  os.system("python %s/write_param_card.py " % options.model)



else:
  print("---------------------------------------------------------------")
  print("---------------------------------------------------------------")
  print("Warning: The model set for the MadGraph Interface ")
  print("         needs a parameter setting by hand.")
  print("         Please fill the param_card_"+options.model+".dat")
  print("         with your favourite assumptions.")
  print("         And make sure Herwig uses the same parameters.")
  print("---------------------------------------------------------------")
  print("---------------------------------------------------------------")
  if os.path.isfile(options.buildpath +"/MG5/Cards/param_card.dat") and not os.path.isfile(options.runpath+"/"+"param_card_"+options.model+".dat"):
    shutil.copyfile(options.buildpath +"/MG5/Cards/param_card.dat", options.runpath+"/"+"param_card_"+options.model+".dat")
  time.sleep(1)

if not os.path.isdir(options.buildpath):
   print("The MadGraph Install path did not exist. It has been created for you.")
   print("Just start Herwig read again..")
   mkdir_p(options.buildpath)
   exit()

os.chdir(options.buildpath)
if os.path.isfile("InterfaceMadGraph.so"):
  build_matchbox_tmp(pwd,options.buildpath,options.absolute_links)
  exit()

if not options.madgraph and not os.path.isfile("InterfaceMadGraph.so"):
  sys.stderr.write("*** MadGraph build failed, check logfile for details ***")
  exit(1)

# find MG version from its address
MGV = options.madgraph.split("/")
MGVersion = 0
for term in MGV :
    if "MG5_aMC_v" in term :
        term = term.split("_")
        MGVersion = term[2][1]

Bornlist,Virtlist=fillprocs(options.model,options.orderas,options.orderew,int(MGVersion))

if sys.version_info[0] == 2 and MGVersion == "3" :
    os.system("python3 "+options.madgraph+"/mg5_aMC proc.dat")
else:
    os.system("python "+options.madgraph+"/mg5_aMC proc.dat")




def make_case_stmts(func):
    """Create Fortran case statements for all processes"""
    result = []
    # add a case option for each item
    for i in Bornlist + list(set(Virtlist) - set(Bornlist)):
        if func == "SLOOPMATRIX(momenta,virt)" and i not in Virtlist:
            continue
        result.append("         CASE(%s)" % i)
        result.append("            CALL MG5_%s_%s" % (i,func))

    # if we have anything, add header and footer
    if result:
        result.insert(0,
                      "         SELECT CASE (proc)")
        result.append("         CASE DEFAULT")
        result.append("             WRITE(*,*) '##W02A WARNING No id found '")
        result.append("         END SELECT    ")
        return "\n".join(result)
    else:
        return ""

def write_interfaceMG():
    """Create the InterfaceMadGraph.f FORTRAN file"""
    subs = {
        "MG_CalculateBORNtxt" : make_case_stmts("BORN(momenta,hel)"),
        "MG_CalculateVIRTtxt" : make_case_stmts("SLOOPMATRIX(momenta,virt)"),
        "MG_Jamptxt"          : make_case_stmts("GET_JAMP(color,Jamp)"),
        "MG_LNJamptxt"        : make_case_stmts("GET_LNJAMP(color,Jamp)"),
        "MG_NColtxt"          : make_case_stmts("GET_NCOL(color)"),
        "MG_ColourMattxt"     : make_case_stmts("GET_NCOLOR(i,j,color)"),
    }

    baseN = 5 if subs["MG_CalculateVIRTtxt"] else 3
    subs["MG_vxxxxxtxt"] = """\
      subroutine  MG_vxxxxx(p, n,inc,VC)
     $  bind(c, name='MG_vxxxxx')
        IMPLICIT NONE
        double precision p(0:3)
        double precision n(0:3)
        INTEGER inc
        double precision VC(0:7)
        double complex  VCtmp({d})
        call vxxxxx(p, 0d0,1,inc ,VCtmp)
        VC(0)= real(VCtmp({a}))
        VC(1)=aimag(VCtmp({a}))
        VC(2)= real(VCtmp({b}))
        VC(3)=aimag(VCtmp({b}))
        VC(4)= real(VCtmp({c}))
        VC(5)=aimag(VCtmp({c}))
        VC(6)= real(VCtmp({d}))
        VC(7)=aimag(VCtmp({d}))
      END
    """.format(a=baseN, b=baseN+1, c=baseN+2, d=baseN+3)

    fname = os.path.join(
        options.datadir,
        "MadGraphInterface",
        "InterfaceMadGraph.f.in"
    )
    with open(fname, 'r') as f:
        tmp = f.read()
    with open("InterfaceMadGraph.f", 'w') as f:
        f.write(tmp.format(**subs))

write_interfaceMG()


make=" "
fortanfiles=glob.glob('*/*/*.f')+glob.glob('*/*/*/*.f')

for i in fortanfiles:
  if "check_sa" not in i and "f2py_wrapper" not in i:
    if not os.path.islink(i):
       make += " "+i+"\\\n                  "

incfiles=glob.glob('*/*/*.inc')+glob.glob('*/*/*/*.inc')

coefdir=""
for i in incfiles:
  if "nexternal.inc" in i:
   coefdir+=" -I"+i.replace("nexternal.inc"," ")


file=open("makefile","w")
file.write("include MG5/Source/make_opts  ")
if Virtlist!=[]:

    collierlib=options.madgraph
    collierlib=collierlib.replace("bin","HEPTools/collier")
    onelooplib=options.madgraph
    onelooplib=onelooplib.replace("bin","HEPTools/oneloop")

    file.write("\nLIBDIR = MG5/lib")
    file.write("\nLINKLIBS =-L$(LIBDIR) -lcts -liregi ")
    if (os.path.exists("MG5/lib/ninja_lib")):
        file.write("\nLINKLIBS += -L$(LIBDIR)/ninja_lib -lninja   ")
    if (os.path.exists(collierlib)):
        file.write("\nLINKLIBS += -L%s -lcollier "%(collierlib))
    if (os.path.exists(onelooplib)):
        file.write("\nLINKLIBS += -L%s -lavh_olo "%(onelooplib))
    if (os.path.exists("MG5/lib/golem95_lib")):
        file.write("\nLINKLIBS += -L$(LIBDIR)/golem95_lib -lgolem")


file.write("\nPROCESS= InterfaceMadGraph.f "+make+"\n\nall:  \n\t gfortran  -O2 -Wall -U_FORTIFY_SOURCE -D_FORTIFY_SOURCE=3 -fstack-protector-strong -funwind-tables -fasynchronous-unwind-tables -fstack-clash-protection -Werror=return-type -flto=auto -g  -fdefault-integer-8 -std=legacy -ffixed-line-length-none -w -fbounds-check -ffixed-line-length-none -fPIC -fno-f2c -shared -s -o  InterfaceMadGraph.so -IMG5/SubProcesses/" )
if Virtlist!=[]:
  if (os.path.exists("MG5/lib/golem95_include")):
      file.write(" -IMG5/lib/golem95_include ")
  if (os.path.exists("MG5/lib/ninja_include")):
      file.write(" -IMG5/lib/ninja_include ")


  incldir=options.madgraph.replace("bin","HEPTools/include")

  file.write(" -I%s -I./ "%incldir)
  # Find all .mod files also in /usr/include if golem was build there.
  # There can be an error message in the MadGraph output to add the golem include path to the makefiles.
  # Usually MadGraph finds the path if its Golem was build in an separate dictionary.
  # Our bootstrap script installs golem with gosam beside boost. Here MadGraph creates a  link (->errormessage).
  # If we can find the modfiles easily the user doesn't need to change the makefiles.
  moddirs=finddirs('*.mod',options.includedir)
  for moddir in moddirs:
    file.write(" -I%s " % moddir)
  if os.path.isdir("/usr/include"):
    moddirs=finddirs('*.mod',"/usr/include")
    for moddir in moddirs:
      file.write(" -I%s " % moddir)

if coefdir != "":
   file.write(coefdir)
file.write("   $(PROCESS) $(LINKLIBS) ")
file.close()


def replacetext(filename, sourceText, replaceText):
    """Replace text in existing file."""
    with open(filename, "r") as f:
        text = f.read()
    with open(filename, "w") as f:
        f.write(text.replace(sourceText, replaceText))

os.chdir(pwd)
os.chdir(options.buildpath)

replacetext("MG5/Source/MODEL/lha_read.f", "ident_card.dat", os.path.join(options.cacheprefix,
                                                                          "MG_tmp","ident_card.dat"))
replacetext("MG5/Source/MODEL/lha_read.f", "param.log", os.path.join(options.cacheprefix,
                                                                     "MG_tmp","param.log"))
if Virtlist:
  replacetext("MG5/SubProcesses/MadLoopCommons.f",
              "PREFIX='./'",
              "PREFIX='%s/'" % MG_TMP)

os.system("make")
if not os.path.isfile("InterfaceMadGraph.so"):
   print("Second trial to make MadGraph Interface. ")
   print("Needed if new .mod files are produced by make.")
   os.system("make")

build_matchbox_tmp(pwd,options.buildpath,options.absolute_links)
