/*
 * Decompiled with CFR 0.152.
 */
package org.apache.druid.msq.logical.stages;

import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.List;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rex.RexNode;
import org.apache.druid.error.DruidException;
import org.apache.druid.frame.key.ClusterBy;
import org.apache.druid.frame.key.KeyColumn;
import org.apache.druid.math.expr.Expr;
import org.apache.druid.msq.exec.StageProcessor;
import org.apache.druid.msq.kernel.HashShuffleSpec;
import org.apache.druid.msq.kernel.ShuffleSpec;
import org.apache.druid.msq.logical.LogicalInputSpec;
import org.apache.druid.msq.logical.StageMaker;
import org.apache.druid.msq.logical.stages.AbstractFrameProcessorStage;
import org.apache.druid.msq.logical.stages.AbstractShuffleStage;
import org.apache.druid.msq.logical.stages.LogicalStage;
import org.apache.druid.msq.logical.stages.SegmentMapStage;
import org.apache.druid.msq.querykit.QueryKitUtils;
import org.apache.druid.msq.querykit.common.SortMergeJoinStageProcessor;
import org.apache.druid.query.JoinAlgorithm;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.segment.join.JoinConditionAnalysis;
import org.apache.druid.segment.join.JoinType;
import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.expression.Expressions;
import org.apache.druid.sql.calcite.planner.Calcites;
import org.apache.druid.sql.calcite.planner.ExpressionParser;
import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.planner.querygen.DruidQueryGenerator;
import org.apache.druid.sql.calcite.planner.querygen.SourceDescProducer;
import org.apache.druid.sql.calcite.rel.DruidJoinQueryRel;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import org.apache.druid.sql.calcite.rel.logical.DruidJoin;

public class JoinStage {
    public static LogicalStage buildJoinStage(List<LogicalStage> inputStages, DruidQueryGenerator.DruidNodeStack stack) {
        DruidJoin join = (DruidJoin)stack.getNode();
        if (join.getJoinAlgorithm(stack.getPlannerContext()) == JoinAlgorithm.SORT_MERGE) {
            return JoinStage.buildMergeJoin(inputStages, stack, join);
        }
        return JoinStage.buildBroadcastJoin(inputStages, stack, join);
    }

    private static LogicalStage buildBroadcastJoin(List<LogicalStage> inputStages, DruidQueryGenerator.DruidNodeStack stack, DruidJoin join) {
        PlannerContext plannerContext = stack.getPlannerContext();
        ArrayList<LogicalInputSpec> inputDescs = new ArrayList<LogicalInputSpec>();
        inputDescs.add(LogicalInputSpec.of(inputStages.get(0)));
        for (int i = 1; i < inputStages.size(); ++i) {
            inputDescs.add(LogicalInputSpec.of(inputStages.get(i), i, LogicalInputSpec.InputProperty.BROADCAST));
        }
        SourceDescProducer.SourceDesc unnestSD = join.getSourceDesc(plannerContext, Lists.transform(inputDescs, LogicalInputSpec::getSourceDesc));
        return new SegmentMapStage(unnestSD, inputDescs);
    }

    private static LogicalStage buildMergeJoin(List<LogicalStage> inputStages, DruidQueryGenerator.DruidNodeStack stack, DruidJoin join) {
        String prefix = JoinStage.findUnusedJoinPrefix(inputStages.get(0).getRowSignature());
        RowSignature signature = RowSignature.builder().addAll(inputStages.get(0).getLogicalRowSignature()).addAll(inputStages.get(1).getLogicalRowSignature().withPrefix(prefix)).build();
        PlannerContext plannerContext = stack.getPlannerContext();
        VirtualColumnRegistry virtualColumnRegistry = VirtualColumnRegistry.create((RowSignature)signature, (ExpressionParser)plannerContext.getExpressionParser(), (boolean)plannerContext.getPlannerConfig().isForceExpressionVirtualColumns());
        plannerContext.setJoinExpressionVirtualColumnRegistry(virtualColumnRegistry);
        DruidExpression condition = Expressions.toDruidExpression((PlannerContext)plannerContext, (RowSignature)signature, (RexNode)join.getCondition());
        plannerContext.setJoinExpressionVirtualColumnRegistry(null);
        if (!virtualColumnRegistry.isEmpty()) {
            throw DruidException.defensive((String)"Not sure how to handle this right now - it should be fixed", (Object[])new Object[0]);
        }
        JoinConditionAnalysis analysis = JoinConditionAnalysis.forExpression((String)condition.getExpression(), (Expr)plannerContext.parseExpression(condition.getExpression()), (String)prefix);
        List<List<KeyColumn>> partitionKeys = SortMergeJoinStageProcessor.toKeyColumns(SortMergeJoinStageProcessor.validateCondition(analysis));
        ArrayList<ShuffleStage> shuffleStages = new ArrayList<ShuffleStage>();
        for (int i = 0; i < inputStages.size(); ++i) {
            LogicalStage inputStage = inputStages.get(i);
            shuffleStages.add(new ShuffleStage(inputStage, partitionKeys.get(i)));
        }
        return new SortMergeStage(signature, Lists.transform(shuffleStages, LogicalInputSpec::of), prefix, analysis, DruidJoinQueryRel.toDruidJoinType((JoinRelType)join.getJoinType()));
    }

    private static String findUnusedJoinPrefix(RowSignature rowSignature) {
        List leftColumnNames = rowSignature.getColumnNames();
        return Calcites.findUnusedPrefixForDigits((String)"j", (Iterable)leftColumnNames) + "0";
    }

    static class ShuffleStage
    extends AbstractShuffleStage {
        protected final List<KeyColumn> keyColumns;

        public ShuffleStage(LogicalStage inputStage, List<KeyColumn> keyColumns) {
            super(QueryKitUtils.sortableSignature(inputStage.getLogicalRowSignature(), keyColumns), LogicalInputSpec.of(inputStage));
            this.keyColumns = keyColumns;
        }

        @Override
        public RowSignature getLogicalRowSignature() {
            return ((LogicalInputSpec)this.inputSpecs.get(0)).getRowSignature();
        }

        @Override
        public ShuffleSpec buildShuffleSpec() {
            ClusterBy clusterBy = new ClusterBy(this.keyColumns, 0);
            return new HashShuffleSpec(clusterBy, 1);
        }

        @Override
        public LogicalStage extendWith(DruidQueryGenerator.DruidNodeStack stack) {
            return null;
        }
    }

    public static class SortMergeStage
    extends AbstractFrameProcessorStage {
        private String rightPrefix;
        private JoinConditionAnalysis conditionAnalysis;
        private JoinType joinType;

        public SortMergeStage(RowSignature signature, List<LogicalInputSpec> inputs, String rightPrefix, JoinConditionAnalysis conditionAnalysis, JoinType joinType) {
            super(signature, inputs);
            this.rightPrefix = rightPrefix;
            this.conditionAnalysis = conditionAnalysis;
            this.joinType = joinType;
        }

        @Override
        public LogicalStage extendWith(DruidQueryGenerator.DruidNodeStack stack) {
            return null;
        }

        @Override
        public StageProcessor<?, ?> buildStageProcessor(StageMaker stageMaker) {
            return new SortMergeJoinStageProcessor(this.rightPrefix, this.conditionAnalysis, this.joinType);
        }
    }
}

