/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.controlprogram.parfor.opt;

import java.util.ArrayList;
import java.util.Set;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.hops.ipa.InterProceduralAnalysis;
import org.apache.sysds.hops.recompile.Recompiler;
import org.apache.sysds.hops.rewrite.HopRewriteRule;
import org.apache.sysds.hops.rewrite.ProgramRewriteStatus;
import org.apache.sysds.hops.rewrite.ProgramRewriter;
import org.apache.sysds.hops.rewrite.RewriteConstantFolding;
import org.apache.sysds.hops.rewrite.RewriteRemoveUnnecessaryBranches;
import org.apache.sysds.hops.rewrite.StatementBlockRewriteRule;
import org.apache.sysds.parser.DMLProgram;
import org.apache.sysds.parser.ForStatement;
import org.apache.sysds.parser.ParForStatementBlock;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.FunctionProgramBlock;
import org.apache.sysds.runtime.controlprogram.LocalVariableMap;
import org.apache.sysds.runtime.controlprogram.ParForProgramBlock;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.parfor.opt.CostEstimator;
import org.apache.sysds.runtime.controlprogram.parfor.opt.CostEstimatorHops;
import org.apache.sysds.runtime.controlprogram.parfor.opt.CostEstimatorRuntime;
import org.apache.sysds.runtime.controlprogram.parfor.opt.OptTree;
import org.apache.sysds.runtime.controlprogram.parfor.opt.OptTreeConverter;
import org.apache.sysds.runtime.controlprogram.parfor.opt.Optimizer;
import org.apache.sysds.runtime.controlprogram.parfor.opt.OptimizerConstrained;
import org.apache.sysds.runtime.controlprogram.parfor.opt.OptimizerHeuristic;
import org.apache.sysds.runtime.controlprogram.parfor.opt.OptimizerRuleBased;
import org.apache.sysds.runtime.controlprogram.parfor.opt.ProgramRecompiler;
import org.apache.sysds.runtime.util.UtilFunctions;
import org.apache.sysds.utils.stats.InfrastructureAnalyzer;
import org.apache.sysds.utils.stats.ParForStatistics;
import org.apache.sysds.utils.stats.Timing;

public class OptimizationWrapper {
    private static final Log LOG = LogFactory.getLog((String)OptimizationWrapper.class.getName());
    public static final double PAR_FACTOR_INFRASTRUCTURE = 1.0;

    public static void optimize(ParForProgramBlock.POptMode type, ParForStatementBlock sb, ParForProgramBlock pb, ExecutionContext ec, int numRuns) {
        Timing time = new Timing(true);
        LOG.debug((Object)("ParFOR Opt: Running optimization for ParFOR(" + pb.getID() + ")"));
        int ck = UtilFunctions.toInt((double)Math.max(InfrastructureAnalyzer.getCkMaxCP(), InfrastructureAnalyzer.getCkMaxMR()) * 1.0);
        double cm = (double)InfrastructureAnalyzer.getCmMax() * OptimizerUtils.MEM_UTIL_FACTOR;
        OptimizationWrapper.optimize(type, ck, cm, sb, pb, ec, numRuns);
        double timeVal = time.stop();
        LOG.debug((Object)("ParFOR Opt: Finished optimization for PARFOR(" + pb.getID() + ") in " + timeVal + "ms."));
        if (DMLScript.STATISTICS) {
            ParForStatistics.incrementOptimCount();
            ParForStatistics.incrementOptimTime((long)timeVal);
        }
    }

    private static void optimize(ParForProgramBlock.POptMode otype, int ck, double cm, ParForStatementBlock sb, ParForProgramBlock pb, ExecutionContext ec, int numRuns) {
        Optimizer opt = OptimizationWrapper.createOptimizer(otype);
        Optimizer.CostModelType cmtype = opt.getCostModelType();
        LOG.trace((Object)("ParFOR Opt: Created optimizer (" + otype + "," + opt.getCostModelType()));
        OptTree tree = null;
        if (ConfigurationManager.isDynamicRecompilation()) {
            ForStatement fs = (ForStatement)sb.getStatement(0);
            if (LOG.isDebugEnabled()) {
                try {
                    tree = OptTreeConverter.createOptTree(ck, cm, sb, pb, ec);
                    LOG.debug((Object)("ParFOR Opt: Input plan (before recompilation):\n" + tree.explain(false)));
                }
                catch (Exception ex) {
                    throw new DMLRuntimeException("Unable to create opt tree.", ex);
                }
            }
            try {
                LocalVariableMap constVars = ProgramRecompiler.getReusableScalarVariables(sb.getDMLProg(), sb, ec.getVariables());
                ProgramRecompiler.replaceConstantScalarVariables(sb, constVars);
            }
            catch (Exception ex) {
                throw new DMLRuntimeException(ex);
            }
            try {
                ProgramRewriter rewriter = OptimizationWrapper.createProgramRewriterWithRuleSets();
                ProgramRewriteStatus state = new ProgramRewriteStatus();
                rewriter.rRewriteStatementBlockHopDAGs(sb, state);
                fs.setBody(rewriter.rRewriteStatementBlocks(fs.getBody(), state, true));
                if (state.getRemovedBranches()) {
                    LOG.debug((Object)"ParFOR Opt: Removed branches during program rewrites, rebuilding runtime program");
                    pb.setChildBlocks(ProgramRecompiler.generatePartitialRuntimeProgram(pb.getProgram(), fs.getBody()));
                }
            }
            catch (Exception ex) {
                throw new DMLRuntimeException(ex);
            }
            try {
                InterProceduralAnalysis ipa;
                Set<String> fcand;
                LocalVariableMap tmp = (LocalVariableMap)ec.getVariables().clone();
                Recompiler.ResetType reset = ConfigurationManager.isCodegenEnabled() ? Recompiler.ResetType.RESET_KNOWN_DIMS : Recompiler.ResetType.RESET;
                Recompiler.recompileProgramBlockHierarchy(pb.getChildBlocks(), tmp, 0L, true, reset);
                if (pb.hasFunctions() && !(fcand = (ipa = new InterProceduralAnalysis(sb)).analyzeSubProgram()).isEmpty()) {
                    for (String func : fcand) {
                        String[] funcparts = DMLProgram.splitFunctionKey(func);
                        FunctionProgramBlock fpb = pb.getProgram().getFunctionProgramBlock(funcparts[0], funcparts[1]);
                        Recompiler.ResetType reset2 = fpb.isRecompileOnce() ? reset : Recompiler.ResetType.NO_RESET;
                        Recompiler.recompileProgramBlockHierarchy(fpb.getChildBlocks(), new LocalVariableMap(), 0L, true, reset2);
                    }
                }
            }
            catch (Exception ex) {
                throw new DMLRuntimeException(ex);
            }
        }
        try {
            tree = OptTreeConverter.createOptTree(ck, cm, sb, pb, ec);
            if (LOG.isDebugEnabled()) {
                LOG.debug((Object)("ParFOR Opt: Input plan (before optimization):\n" + tree.explain(false)));
            }
        }
        catch (Exception ex) {
            throw new DMLRuntimeException("Unable to create opt tree.", ex);
        }
        CostEstimator est = OptimizationWrapper.createCostEstimator(cmtype, tree, ec.getVariables());
        LOG.trace((Object)("ParFOR Opt: Created cost estimator (" + cmtype + ")"));
        opt.optimize(sb, pb, tree, est, numRuns, ec);
        if (LOG.isDebugEnabled()) {
            LOG.debug((Object)("ParFOR Opt: Optimized plan (after optimization): \n" + tree.explain(false)));
        }
    }

    private static Optimizer createOptimizer(ParForProgramBlock.POptMode otype) {
        switch (otype) {
            case HEURISTIC: {
                return new OptimizerHeuristic();
            }
            case RULEBASED: {
                return new OptimizerRuleBased();
            }
            case CONSTRAINED: {
                return new OptimizerConstrained();
            }
        }
        throw new DMLRuntimeException("Undefined optimizer: '" + otype + "'.");
    }

    private static CostEstimator createCostEstimator(Optimizer.CostModelType cmtype, OptTree tree, LocalVariableMap vars) {
        switch (cmtype) {
            case STATIC_MEM_METRIC: {
                return new CostEstimatorHops(tree.getPlanMapping());
            }
            case RUNTIME_METRICS: {
                return new CostEstimatorRuntime(tree.getPlanMapping(), (LocalVariableMap)vars.clone());
            }
        }
        throw new DMLRuntimeException("Undefined cost model type: '" + cmtype + "'.");
    }

    private static ProgramRewriter createProgramRewriterWithRuleSets() {
        ArrayList<HopRewriteRule> hRewrites = new ArrayList<HopRewriteRule>();
        hRewrites.add(new RewriteConstantFolding());
        ArrayList<StatementBlockRewriteRule> sbRewrites = new ArrayList<StatementBlockRewriteRule>();
        sbRewrites.add(new RewriteRemoveUnnecessaryBranches());
        ProgramRewriter rewriter = new ProgramRewriter(hRewrites, sbRewrites);
        return rewriter;
    }
}

