/*
 * Decompiled with CFR 0.152.
 */
package org.apache.druid.sql.calcite.rule;

import com.google.common.collect.ImmutableList;
import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.List;
import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptRuleOperand;
import org.apache.calcite.plan.RelOptRuleOperandChildren;
import org.apache.calcite.rel.RelCollation;
import org.apache.calcite.rel.RelCollations;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.rules.SubstitutionRule;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.SqlPostfixOperator;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.tools.RelBuilder;
import org.checkerframework.checker.nullness.qual.Nullable;

public class DruidAggregateCaseToFilterRule
extends RelOptRule
implements SubstitutionRule {
    private boolean extendedFilteredSumRewrite;

    public DruidAggregateCaseToFilterRule(boolean extendedFilteredSumRewrite) {
        super(DruidAggregateCaseToFilterRule.operand(Aggregate.class, (RelOptRuleOperand)DruidAggregateCaseToFilterRule.operand(Project.class, (RelOptRuleOperandChildren)DruidAggregateCaseToFilterRule.any()), (RelOptRuleOperand[])new RelOptRuleOperand[0]));
        this.extendedFilteredSumRewrite = extendedFilteredSumRewrite;
    }

    public boolean matches(RelOptRuleCall call) {
        Aggregate aggregate = (Aggregate)call.rel(0);
        Project project = (Project)call.rel(1);
        for (AggregateCall aggregateCall : aggregate.getAggCallList()) {
            int singleArg = DruidAggregateCaseToFilterRule.soleArgument(aggregateCall);
            if (singleArg < 0 || !DruidAggregateCaseToFilterRule.isThreeArgCase((RexNode)project.getProjects().get(singleArg))) continue;
            return true;
        }
        return false;
    }

    public void onMatch(RelOptRuleCall call) {
        Aggregate aggregate = (Aggregate)call.rel(0);
        Project project = (Project)call.rel(1);
        ArrayList<AggregateCall> newCalls = new ArrayList<AggregateCall>(aggregate.getAggCallList().size());
        ArrayList<RexNode> newProjects = new ArrayList<RexNode>(project.getProjects());
        for (AggregateCall aggregateCall : aggregate.getAggCallList()) {
            AggregateCall newCall = this.transform(aggregateCall, project, newProjects);
            if (newCall == null) {
                newCalls.add(aggregateCall);
                continue;
            }
            newCalls.add(newCall);
        }
        if (newCalls.equals(aggregate.getAggCallList())) {
            return;
        }
        RelBuilder relBuilder = call.builder().push(project.getInput()).project(newProjects);
        RelBuilder.GroupKey groupKey = relBuilder.groupKey(aggregate.getGroupSet(), (Iterable)aggregate.getGroupSets());
        relBuilder.aggregate(groupKey, newCalls).convert(aggregate.getRowType(), false);
        call.transformTo(relBuilder.build());
        call.getPlanner().prune((RelNode)aggregate);
    }

    private @Nullable AggregateCall transform(AggregateCall call, Project project, List<RexNode> newProjects) {
        int singleArg = DruidAggregateCaseToFilterRule.soleArgument(call);
        if (singleArg < 0) {
            return null;
        }
        RexNode rexNode = (RexNode)project.getProjects().get(singleArg);
        if (!DruidAggregateCaseToFilterRule.isThreeArgCase(rexNode)) {
            return null;
        }
        RelOptCluster cluster = project.getCluster();
        RexBuilder rexBuilder = cluster.getRexBuilder();
        RexCall caseCall = (RexCall)rexNode;
        boolean flip = RexLiteral.isNullLiteral((RexNode)((RexNode)caseCall.operands.get(1))) && !RexLiteral.isNullLiteral((RexNode)((RexNode)caseCall.operands.get(2)));
        RexNode arg1 = (RexNode)caseCall.operands.get(flip ? 2 : 1);
        RexNode arg2 = (RexNode)caseCall.operands.get(flip ? 1 : 2);
        SqlPostfixOperator op = flip ? SqlStdOperatorTable.IS_NOT_TRUE : SqlStdOperatorTable.IS_TRUE;
        RexNode filterFromCase = rexBuilder.makeCall((SqlOperator)op, new RexNode[]{(RexNode)caseCall.operands.get(0)});
        RexNode filter = call.filterArg >= 0 ? rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.AND, new RexNode[]{(RexNode)project.getProjects().get(call.filterArg), filterFromCase}) : filterFromCase;
        RelDataTypeFactory typeFactory = rexBuilder.getTypeFactory();
        SqlKind kind = call.getAggregation().getKind();
        if (call.isDistinct()) {
            if (kind == SqlKind.COUNT && RexLiteral.isNullLiteral((RexNode)arg2)) {
                newProjects.add(arg1);
                newProjects.add(filter);
                return AggregateCall.create((SqlAggFunction)SqlStdOperatorTable.COUNT, (boolean)true, (boolean)false, (boolean)false, (List)call.rexList, (List)ImmutableList.of((Object)(newProjects.size() - 2)), (int)(newProjects.size() - 1), null, (RelCollation)RelCollations.EMPTY, (RelDataType)call.getType(), (String)call.getName());
            }
            return null;
        }
        if (kind == SqlKind.COUNT && arg1.isA(SqlKind.LITERAL) && !RexLiteral.isNullLiteral((RexNode)arg1) && RexLiteral.isNullLiteral((RexNode)arg2)) {
            newProjects.add(filter);
            return AggregateCall.create((SqlAggFunction)SqlStdOperatorTable.COUNT, (boolean)false, (boolean)false, (boolean)false, (List)call.rexList, (List)ImmutableList.of(), (int)(newProjects.size() - 1), null, (RelCollation)RelCollations.EMPTY, (RelDataType)call.getType(), (String)call.getName());
        }
        if (kind == SqlKind.SUM0 && DruidAggregateCaseToFilterRule.isIntLiteral(arg1, BigDecimal.ONE) && DruidAggregateCaseToFilterRule.isIntLiteral(arg2, BigDecimal.ZERO)) {
            newProjects.add(filter);
            RelDataType dataType = typeFactory.createTypeWithNullability(typeFactory.createSqlType(SqlTypeName.BIGINT), false);
            return AggregateCall.create((SqlAggFunction)SqlStdOperatorTable.COUNT, (boolean)false, (boolean)false, (boolean)false, (List)call.rexList, (List)ImmutableList.of(), (int)(newProjects.size() - 1), null, (RelCollation)RelCollations.EMPTY, (RelDataType)dataType, (String)call.getName());
        }
        if (RexLiteral.isNullLiteral((RexNode)arg2) && call.getAggregation().allowsFilter() || kind == SqlKind.SUM0 && DruidAggregateCaseToFilterRule.isIntLiteral(arg2, BigDecimal.ZERO)) {
            newProjects.add(arg1);
            newProjects.add(filter);
            return AggregateCall.create((SqlAggFunction)call.getAggregation(), (boolean)false, (boolean)false, (boolean)false, (List)call.rexList, (List)ImmutableList.of((Object)(newProjects.size() - 2)), (int)(newProjects.size() - 1), null, (RelCollation)RelCollations.EMPTY, (RelDataType)call.getType(), (String)call.getName());
        }
        if (this.extendedFilteredSumRewrite && kind == SqlKind.SUM && DruidAggregateCaseToFilterRule.isIntLiteral(arg2, BigDecimal.ZERO)) {
            if (DruidAggregateCaseToFilterRule.isIntLiteral(arg1, BigDecimal.ONE)) {
                newProjects.add(filter);
                RelDataType dataType = typeFactory.createTypeWithNullability(typeFactory.createSqlType(SqlTypeName.BIGINT), false);
                return AggregateCall.create((SqlAggFunction)SqlStdOperatorTable.COUNT, (boolean)false, (boolean)false, (boolean)false, (List)call.rexList, (List)ImmutableList.of(), (int)(newProjects.size() - 1), null, (RelCollation)RelCollations.EMPTY, (RelDataType)dataType, (String)call.getName());
            }
            newProjects.add(arg1);
            newProjects.add(filter);
            RelDataType newType = typeFactory.createTypeWithNullability(call.getType(), true);
            return AggregateCall.create((SqlAggFunction)call.getAggregation(), (boolean)false, (boolean)false, (boolean)false, (List)call.rexList, (List)ImmutableList.of((Object)(newProjects.size() - 2)), (int)(newProjects.size() - 1), null, (RelCollation)RelCollations.EMPTY, (RelDataType)newType, (String)call.getName());
        }
        return null;
    }

    private static int soleArgument(AggregateCall aggregateCall) {
        return aggregateCall.getArgList().size() == 1 ? (Integer)aggregateCall.getArgList().get(0) : -1;
    }

    private static boolean isThreeArgCase(RexNode rexNode) {
        return rexNode.getKind() == SqlKind.CASE && ((RexCall)rexNode).operands.size() == 3;
    }

    private static boolean isIntLiteral(RexNode rexNode, BigDecimal value) {
        return rexNode instanceof RexLiteral && SqlTypeName.INT_TYPES.contains(rexNode.getType().getSqlTypeName()) && value.equals(((RexLiteral)rexNode).getValueAs(BigDecimal.class));
    }
}

