/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hive.druid.org.apache.calcite.rel.rules;

import java.util.ArrayList;
import java.util.IdentityHashMap;
import java.util.List;
import java.util.Map;
import org.apache.hive.druid.com.google.common.collect.ImmutableList;
import org.apache.hive.druid.org.apache.calcite.linq4j.Ord;
import org.apache.hive.druid.org.apache.calcite.plan.RelOptRuleCall;
import org.apache.hive.druid.org.apache.calcite.plan.RelRule;
import org.apache.hive.druid.org.apache.calcite.rel.RelNode;
import org.apache.hive.druid.org.apache.calcite.rel.core.Aggregate;
import org.apache.hive.druid.org.apache.calcite.rel.core.AggregateCall;
import org.apache.hive.druid.org.apache.calcite.rel.core.RelFactories;
import org.apache.hive.druid.org.apache.calcite.rel.core.Union;
import org.apache.hive.druid.org.apache.calcite.rel.logical.LogicalAggregate;
import org.apache.hive.druid.org.apache.calcite.rel.logical.LogicalUnion;
import org.apache.hive.druid.org.apache.calcite.rel.metadata.RelMdUtil;
import org.apache.hive.druid.org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.hive.druid.org.apache.calcite.rel.rules.TransformationRule;
import org.apache.hive.druid.org.apache.calcite.rel.type.RelDataType;
import org.apache.hive.druid.org.apache.calcite.sql.SqlAggFunction;
import org.apache.hive.druid.org.apache.calcite.sql.fun.SqlAnyValueAggFunction;
import org.apache.hive.druid.org.apache.calcite.sql.fun.SqlBitOpAggFunction;
import org.apache.hive.druid.org.apache.calcite.sql.fun.SqlCountAggFunction;
import org.apache.hive.druid.org.apache.calcite.sql.fun.SqlMinMaxAggFunction;
import org.apache.hive.druid.org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.hive.druid.org.apache.calcite.sql.fun.SqlSumAggFunction;
import org.apache.hive.druid.org.apache.calcite.sql.fun.SqlSumEmptyIsZeroAggFunction;
import org.apache.hive.druid.org.apache.calcite.tools.RelBuilder;
import org.apache.hive.druid.org.apache.calcite.tools.RelBuilderFactory;
import org.apache.hive.druid.org.apache.calcite.util.ImmutableBitSet;

public class AggregateUnionTransposeRule
extends RelRule<Config>
implements TransformationRule {
    private static final Map<Class<? extends SqlAggFunction>, Boolean> SUPPORTED_AGGREGATES = new IdentityHashMap<Class<? extends SqlAggFunction>, Boolean>();

    protected AggregateUnionTransposeRule(Config config) {
        super(config);
    }

    @Deprecated
    public AggregateUnionTransposeRule(Class<? extends Aggregate> aggregateClass, Class<? extends Union> unionClass, RelBuilderFactory relBuilderFactory) {
        this(Config.DEFAULT.withRelBuilderFactory(relBuilderFactory).as(Config.class).withOperandFor(aggregateClass, unionClass));
    }

    @Deprecated
    public AggregateUnionTransposeRule(Class<? extends Aggregate> aggregateClass, RelFactories.AggregateFactory aggregateFactory, Class<? extends Union> unionClass, RelFactories.SetOpFactory setOpFactory) {
        this(aggregateClass, unionClass, RelBuilder.proto(aggregateFactory, setOpFactory));
    }

    @Override
    public void onMatch(RelOptRuleCall call) {
        Aggregate aggRel = (Aggregate)call.rel(0);
        Union union = (Union)call.rel(1);
        if (!union.all) {
            return;
        }
        int groupCount = aggRel.getGroupSet().cardinality();
        List<AggregateCall> transformedAggCalls = this.transformAggCalls(aggRel.copy(aggRel.getTraitSet(), aggRel.getInput(), aggRel.getGroupSet(), null, aggRel.getAggCallList()), groupCount, aggRel.getAggCallList());
        if (transformedAggCalls == null) {
            return;
        }
        RelBuilder relBuilder = call.builder();
        int transformCount = 0;
        RelMetadataQuery mq = call.getMetadataQuery();
        for (RelNode input : union.getInputs()) {
            boolean alreadyUnique = RelMdUtil.areColumnsDefinitelyUnique(mq, input, aggRel.getGroupSet());
            relBuilder.push(input);
            if (alreadyUnique) continue;
            ++transformCount;
            relBuilder.aggregate(relBuilder.groupKey(aggRel.getGroupSet()), aggRel.getAggCallList());
        }
        if (transformCount == 0) {
            return;
        }
        relBuilder.union(true, union.getInputs().size());
        relBuilder.aggregate(relBuilder.groupKey(aggRel.getGroupSet(), (Iterable<? extends ImmutableBitSet>)aggRel.getGroupSets()), transformedAggCalls);
        call.transformTo(relBuilder.build());
    }

    private List<AggregateCall> transformAggCalls(RelNode input, int groupCount, List<AggregateCall> origCalls) {
        ArrayList<AggregateCall> newCalls = new ArrayList<AggregateCall>();
        for (Ord<AggregateCall> ord : Ord.zip(origCalls)) {
            RelDataType aggType;
            SqlAggFunction aggFun;
            AggregateCall origCall = (AggregateCall)ord.e;
            if (origCall.isDistinct() || !SUPPORTED_AGGREGATES.containsKey(origCall.getAggregation().getClass())) {
                return null;
            }
            if (origCall.getAggregation() == SqlStdOperatorTable.COUNT) {
                aggFun = SqlStdOperatorTable.SUM0;
                aggType = null;
            } else {
                aggFun = origCall.getAggregation();
                aggType = origCall.getType();
            }
            AggregateCall newCall = AggregateCall.create(aggFun, origCall.isDistinct(), origCall.isApproximate(), origCall.ignoreNulls(), ImmutableList.of(Integer.valueOf(groupCount + ord.i)), -1, origCall.collation, groupCount, input, aggType, origCall.getName());
            newCalls.add(newCall);
        }
        return newCalls;
    }

    static {
        SUPPORTED_AGGREGATES.put(SqlMinMaxAggFunction.class, true);
        SUPPORTED_AGGREGATES.put(SqlCountAggFunction.class, true);
        SUPPORTED_AGGREGATES.put(SqlSumAggFunction.class, true);
        SUPPORTED_AGGREGATES.put(SqlSumEmptyIsZeroAggFunction.class, true);
        SUPPORTED_AGGREGATES.put(SqlAnyValueAggFunction.class, true);
        SUPPORTED_AGGREGATES.put(SqlBitOpAggFunction.class, true);
    }

    public static interface Config
    extends RelRule.Config {
        public static final Config DEFAULT = EMPTY.as(Config.class).withOperandFor(LogicalAggregate.class, LogicalUnion.class);

        @Override
        default public AggregateUnionTransposeRule toRule() {
            return new AggregateUnionTransposeRule(this);
        }

        default public Config withOperandFor(Class<? extends Aggregate> aggregateClass, Class<? extends Union> unionClass) {
            return this.withOperandSupplier(b0 -> b0.operand(aggregateClass).oneInput(b1 -> b1.operand(unionClass).anyInputs())).as(Config.class);
        }
    }
}

