/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.table.planner.plan.rules.physical.stream;

import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelRule;
import org.apache.calcite.plan.RelTraitSet;
import org.apache.calcite.rel.RelDistribution;
import org.apache.calcite.rel.RelNode;
import org.apache.flink.configuration.ReadableConfig;
import org.apache.flink.table.api.TableConfig;
import org.apache.flink.table.api.config.AggregatePhaseStrategy;
import org.apache.flink.table.api.config.ExecutionConfigOptions;
import org.apache.flink.table.planner.calcite.FlinkTypeFactory;
import org.apache.flink.table.planner.plan.metadata.FlinkRelMetadataQuery;
import org.apache.flink.table.planner.plan.nodes.FlinkConventions;
import org.apache.flink.table.planner.plan.nodes.physical.stream.StreamPhysicalExchange;
import org.apache.flink.table.planner.plan.nodes.physical.stream.StreamPhysicalGlobalGroupAggregate;
import org.apache.flink.table.planner.plan.nodes.physical.stream.StreamPhysicalGroupAggregate;
import org.apache.flink.table.planner.plan.nodes.physical.stream.StreamPhysicalLocalGroupAggregate;
import org.apache.flink.table.planner.plan.nodes.physical.stream.StreamPhysicalRel;
import org.apache.flink.table.planner.plan.rules.physical.FlinkExpandConversionRule;
import org.apache.flink.table.planner.plan.rules.physical.stream.ImmutableTwoStageOptimizedAggregateRule;
import org.apache.flink.table.planner.plan.trait.FlinkRelDistribution;
import org.apache.flink.table.planner.plan.trait.FlinkRelDistributionTraitDef;
import org.apache.flink.table.planner.plan.trait.ModifyKindSetTrait;
import org.apache.flink.table.planner.plan.trait.RelModifiedMonotonicity;
import org.apache.flink.table.planner.plan.trait.UpdateKindTrait;
import org.apache.flink.table.planner.plan.utils.AggregateInfoList;
import org.apache.flink.table.planner.plan.utils.AggregateUtil;
import org.apache.flink.table.planner.plan.utils.ChangelogPlanUtils;
import org.apache.flink.table.planner.utils.ShortcutUtils;
import org.apache.flink.table.planner.utils.TableConfigUtils;
import org.immutables.value.Value;
import scala.Option;

@Value.Enclosing
public class TwoStageOptimizedAggregateRule
extends RelRule<TwoStageOptimizedAggregateRuleConfig> {
    public static final TwoStageOptimizedAggregateRule INSTANCE = TwoStageOptimizedAggregateRuleConfig.DEFAULT.toRule();

    private TwoStageOptimizedAggregateRule(TwoStageOptimizedAggregateRuleConfig config) {
        super(config);
    }

    @Override
    public boolean matches(RelOptRuleCall call) {
        TableConfig tableConfig = ShortcutUtils.unwrapTableConfig(call);
        boolean isMiniBatchEnabled = (Boolean)tableConfig.get(ExecutionConfigOptions.TABLE_EXEC_MINIBATCH_ENABLED);
        boolean isTwoPhaseEnabled = TableConfigUtils.getAggPhaseStrategy((ReadableConfig)tableConfig) != AggregatePhaseStrategy.ONE_PHASE;
        return isMiniBatchEnabled && isTwoPhaseEnabled && TwoStageOptimizedAggregateRule.matchesTwoStage((StreamPhysicalGroupAggregate)call.rel(0), call.rel(2));
    }

    public static boolean matchesTwoStage(StreamPhysicalGroupAggregate agg, RelNode realInput) {
        boolean needRetraction = !ChangelogPlanUtils.isInsertOnly((StreamPhysicalRel)realInput);
        FlinkRelMetadataQuery fmq = FlinkRelMetadataQuery.reuseOrCreate(agg.getCluster().getMetadataQuery());
        RelModifiedMonotonicity monotonicity = fmq.getRelModifiedMonotonicity(agg);
        boolean[] needRetractionArray = AggregateUtil.deriveAggCallNeedRetractions(agg.grouping().length, agg.aggCalls(), needRetraction, monotonicity);
        AggregateInfoList aggInfoList = AggregateUtil.transformToStreamAggregateInfoList(ShortcutUtils.unwrapTypeFactory(agg), FlinkTypeFactory.toLogicalRowType(agg.getInput().getRowType()), agg.aggCalls(), needRetractionArray, needRetraction, true, true);
        return AggregateUtil.doAllSupportPartialMerge(aggInfoList.aggInfos()) && !TwoStageOptimizedAggregateRule.isInputSatisfyRequiredDistribution(realInput, agg.grouping());
    }

    private static boolean isInputSatisfyRequiredDistribution(RelNode input, int[] keys) {
        FlinkRelDistribution requiredDistribution = TwoStageOptimizedAggregateRule.createDistribution(keys);
        RelTraitSet traitSet = input.getTraitSet();
        RelDistribution inputDistribution = traitSet.getTrait(FlinkRelDistributionTraitDef.INSTANCE());
        return inputDistribution.satisfies(requiredDistribution);
    }

    private static FlinkRelDistribution createDistribution(int[] keys) {
        if (keys.length > 0) {
            List fields = IntStream.of(keys).boxed().collect(Collectors.toList());
            return FlinkRelDistribution.hash(fields, true);
        }
        return FlinkRelDistribution.SINGLETON();
    }

    @Override
    public void onMatch(RelOptRuleCall call) {
        StreamPhysicalGroupAggregate originalAgg = (StreamPhysicalGroupAggregate)call.rel(0);
        Object realInput = call.rel(2);
        boolean needRetraction = !ChangelogPlanUtils.isInsertOnly((StreamPhysicalRel)realInput);
        FlinkRelMetadataQuery fmq = FlinkRelMetadataQuery.reuseOrCreate(call.getMetadataQuery());
        RelModifiedMonotonicity monotonicity = fmq.getRelModifiedMonotonicity(originalAgg);
        boolean[] aggCallNeedRetractions = AggregateUtil.deriveAggCallNeedRetractions(originalAgg.grouping().length, originalAgg.aggCalls(), needRetraction, monotonicity);
        RelTraitSet localAggTraitSet = realInput.getTraitSet().plus(ModifyKindSetTrait.INSERT_ONLY()).plus(UpdateKindTrait.NONE());
        StreamPhysicalLocalGroupAggregate localHashAgg = new StreamPhysicalLocalGroupAggregate(originalAgg.getCluster(), localAggTraitSet, (RelNode)realInput, originalAgg.grouping(), originalAgg.aggCalls(), aggCallNeedRetractions, needRetraction, originalAgg.partialFinalType());
        int[] globalGrouping = IntStream.range(0, originalAgg.grouping().length).toArray();
        FlinkRelDistribution globalDistribution = TwoStageOptimizedAggregateRule.createDistribution(globalGrouping);
        RelNode newInput = FlinkExpandConversionRule.satisfyDistribution(FlinkConventions.STREAM_PHYSICAL(), localHashAgg, globalDistribution);
        RelTraitSet globalAggProvidedTraitSet = originalAgg.getTraitSet();
        StreamPhysicalGlobalGroupAggregate globalAgg = new StreamPhysicalGlobalGroupAggregate(originalAgg.getCluster(), globalAggProvidedTraitSet, newInput, originalAgg.getRowType(), globalGrouping, originalAgg.aggCalls(), aggCallNeedRetractions, realInput.getRowType(), needRetraction, originalAgg.partialFinalType(), (Option<Object>)Option.empty(), originalAgg.hints());
        call.transformTo(globalAgg);
    }

    @Value.Immutable(singleton=false)
    public static interface TwoStageOptimizedAggregateRuleConfig
    extends RelRule.Config {
        public static final TwoStageOptimizedAggregateRuleConfig DEFAULT = ImmutableTwoStageOptimizedAggregateRule.TwoStageOptimizedAggregateRuleConfig.builder().build().withOperandSupplier(b0 -> b0.operand(StreamPhysicalGroupAggregate.class).oneInput(b1 -> b1.operand(StreamPhysicalExchange.class).oneInput(b2 -> b2.operand(RelNode.class).anyInputs()))).withDescription("TwoStageOptimizedAggregateRule");

        @Override
        default public TwoStageOptimizedAggregateRule toRule() {
            return new TwoStageOptimizedAggregateRule(this);
        }
    }
}

