/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.table.planner.plan.rules.logical;

import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptRuleOperand;
import org.apache.calcite.rel.core.Calc;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexProgram;
import org.apache.calcite.rex.RexProgramBuilder;
import org.apache.flink.table.functions.python.PythonFunctionKind;
import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalCalc;
import org.apache.flink.table.planner.plan.utils.PythonUtil;

public class PythonMapMergeRule
extends RelOptRule {
    public static final PythonMapMergeRule INSTANCE = new PythonMapMergeRule();

    private PythonMapMergeRule() {
        super(PythonMapMergeRule.operand(FlinkLogicalCalc.class, PythonMapMergeRule.operand(FlinkLogicalCalc.class, PythonMapMergeRule.operand(FlinkLogicalCalc.class, PythonMapMergeRule.none()), new RelOptRuleOperand[0]), new RelOptRuleOperand[0]), "PythonMapMergeRule");
    }

    @Override
    public boolean matches(RelOptRuleCall call) {
        FlinkLogicalCalc topCalc = (FlinkLogicalCalc)call.rel(0);
        FlinkLogicalCalc middleCalc = (FlinkLogicalCalc)call.rel(1);
        FlinkLogicalCalc bottomCalc = (FlinkLogicalCalc)call.rel(2);
        RexProgram topProgram = topCalc.getProgram();
        List topProjects = topProgram.getProjectList().stream().map(topProgram::expandLocalRef).collect(Collectors.toList());
        RexProgram bottomProgram = bottomCalc.getProgram();
        List bottomProjects = bottomProgram.getProjectList().stream().map(bottomProgram::expandLocalRef).collect(Collectors.toList());
        if (topProjects.size() != 1 || !PythonUtil.isPythonCall((RexNode)topProjects.get(0), null) || topProgram.getCondition() != null || bottomProjects.size() != 1 || !PythonUtil.isPythonCall((RexNode)bottomProjects.get(0), null) || bottomProgram.getCondition() != null) {
            return false;
        }
        if (!PythonUtil.takesRowAsInput((RexCall)topProjects.get(0))) {
            return false;
        }
        if (PythonUtil.isPythonCall((RexNode)topProjects.get(0), PythonFunctionKind.GENERAL) ^ PythonUtil.isPythonCall((RexNode)bottomProjects.get(0), PythonFunctionKind.GENERAL)) {
            return false;
        }
        return PythonUtil.isFlattenCalc(middleCalc) && this.isTopCalcTakesWholeMiddleCalcAsInputs((RexCall)topProjects.get(0), middleCalc.getRowType().getFieldCount());
    }

    private boolean isTopCalcTakesWholeMiddleCalcAsInputs(RexCall pythonCall, int inputColumnCount) {
        List<RexNode> pythonCallInputs = pythonCall.getOperands();
        if (pythonCallInputs.size() != inputColumnCount) {
            return false;
        }
        for (int i = 0; i < pythonCallInputs.size(); ++i) {
            RexNode input = pythonCallInputs.get(i);
            if (input instanceof RexInputRef) {
                if (((RexInputRef)input).getIndex() == i) continue;
                return false;
            }
            return false;
        }
        return true;
    }

    @Override
    public void onMatch(RelOptRuleCall call) {
        FlinkLogicalCalc topCalc = (FlinkLogicalCalc)call.rel(0);
        FlinkLogicalCalc middleCalc = (FlinkLogicalCalc)call.rel(1);
        FlinkLogicalCalc bottomCalc = (FlinkLogicalCalc)call.rel(2);
        RexProgram topProgram = topCalc.getProgram();
        List topProjects = topProgram.getProjectList().stream().map(topProgram::expandLocalRef).map(x -> (RexCall)x).collect(Collectors.toList());
        RexCall topPythonCall = (RexCall)topProjects.get(0);
        RexCall newPythonCall = topPythonCall.clone(topPythonCall.getType(), Collections.singletonList(RexInputRef.of(0, bottomCalc.getRowType())));
        List<RexCall> topMiddleMergedProjects = Collections.singletonList(newPythonCall);
        FlinkLogicalCalc topMiddleMergedCalc = new FlinkLogicalCalc(middleCalc.getCluster(), middleCalc.getTraitSet(), bottomCalc, RexProgram.create(bottomCalc.getRowType(), topMiddleMergedProjects, null, Collections.singletonList("f0"), call.builder().getRexBuilder()));
        RexBuilder rexBuilder = call.builder().getRexBuilder();
        RexProgram mergedProgram = RexProgramBuilder.mergePrograms(topMiddleMergedCalc.getProgram(), bottomCalc.getProgram(), rexBuilder);
        Calc newCalc = topMiddleMergedCalc.copy(topMiddleMergedCalc.getTraitSet(), bottomCalc.getInput(), mergedProgram);
        call.transformTo(newCalc);
    }
}

