/*
 * Decompiled with CFR 0.152.
 */
package org.apache.druid.segment.join;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.collect.HashMultiset;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import com.google.common.primitives.Ints;
import com.google.inject.Inject;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Function;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.apache.druid.common.guava.GuavaUtils;
import org.apache.druid.java.util.common.Cacheable;
import org.apache.druid.java.util.common.IAE;
import org.apache.druid.java.util.common.Pair;
import org.apache.druid.java.util.common.logger.Logger;
import org.apache.druid.query.Query;
import org.apache.druid.query.cache.CacheKeyBuilder;
import org.apache.druid.query.filter.Filter;
import org.apache.druid.query.filter.InDimFilter;
import org.apache.druid.query.planning.DataSourceAnalysis;
import org.apache.druid.query.planning.PreJoinableClause;
import org.apache.druid.segment.SegmentReference;
import org.apache.druid.segment.filter.FalseFilter;
import org.apache.druid.segment.filter.Filters;
import org.apache.druid.segment.join.Equality;
import org.apache.druid.segment.join.HashJoinSegment;
import org.apache.druid.segment.join.JoinType;
import org.apache.druid.segment.join.Joinable;
import org.apache.druid.segment.join.JoinableClause;
import org.apache.druid.segment.join.JoinableFactory;
import org.apache.druid.segment.join.filter.JoinFilterAnalyzer;
import org.apache.druid.segment.join.filter.JoinFilterPreAnalysis;
import org.apache.druid.segment.join.filter.JoinFilterPreAnalysisKey;
import org.apache.druid.segment.join.filter.JoinableClauses;
import org.apache.druid.segment.join.filter.rewrite.JoinFilterRewriteConfig;
import org.apache.druid.utils.JvmUtils;

public class JoinableFactoryWrapper {
    private static final byte JOIN_OPERATION = 1;
    private static final Logger log = new Logger(JoinableFactoryWrapper.class);
    private final JoinableFactory joinableFactory;

    @Inject
    public JoinableFactoryWrapper(JoinableFactory joinableFactory) {
        this.joinableFactory = (JoinableFactory)Preconditions.checkNotNull((Object)joinableFactory, (Object)"joinableFactory");
    }

    public JoinableFactory getJoinableFactory() {
        return this.joinableFactory;
    }

    public Function<SegmentReference, SegmentReference> createSegmentMapFn(@Nullable Filter baseFilter, List<PreJoinableClause> clauses, AtomicLong cpuTimeAccumulator, Query<?> query) {
        return (Function)JvmUtils.safeAccumulateThreadCpuTime((AtomicLong)cpuTimeAccumulator, () -> {
            List clausesToUse;
            Filter baseFilterToUse;
            if (clauses.isEmpty()) {
                return Function.identity();
            }
            JoinableClauses joinableClauses = JoinableClauses.createClauses(clauses, this.joinableFactory);
            JoinFilterRewriteConfig filterRewriteConfig = JoinFilterRewriteConfig.forQuery(query);
            Set<String> requiredColumns = query.getRequiredColumns();
            if (requiredColumns != null && filterRewriteConfig.isEnableRewriteJoinToFilter()) {
                Pair<List<Filter>, List<JoinableClause>> conversionResult = JoinableFactoryWrapper.convertJoinsToFilters(joinableClauses.getJoinableClauses(), requiredColumns, Ints.checkedCast((long)Math.min(filterRewriteConfig.getFilterRewriteMaxSize(), Integer.MAX_VALUE)));
                baseFilterToUse = Filters.maybeAnd(Lists.newArrayList((Iterable)Iterables.concat(Collections.singleton(baseFilter), (Iterable)((Iterable)conversionResult.lhs)))).orElse(null);
                clausesToUse = (List)conversionResult.rhs;
            } else {
                baseFilterToUse = baseFilter;
                clausesToUse = joinableClauses.getJoinableClauses();
            }
            JoinFilterPreAnalysis joinFilterPreAnalysis = JoinFilterAnalyzer.computeJoinFilterPreAnalysis(new JoinFilterPreAnalysisKey(filterRewriteConfig, clausesToUse, query.getVirtualColumns(), Filters.maybeAnd(Arrays.asList(baseFilterToUse, Filters.toFilter(query.getFilter()))).orElse(null)));
            return baseSegment -> new HashJoinSegment((SegmentReference)baseSegment, baseFilterToUse, (List)GuavaUtils.firstNonNull((Object)clausesToUse, (Object)ImmutableList.of()), joinFilterPreAnalysis);
        });
    }

    public Optional<byte[]> computeJoinDataSourceCacheKey(DataSourceAnalysis dataSourceAnalysis) {
        List<PreJoinableClause> clauses = dataSourceAnalysis.getPreJoinableClauses();
        if (clauses.isEmpty()) {
            throw new IAE("No join clauses to build the cache key for data source [%s]", new Object[]{dataSourceAnalysis.getDataSource()});
        }
        CacheKeyBuilder keyBuilder = new CacheKeyBuilder(1);
        if (dataSourceAnalysis.getJoinBaseTableFilter().isPresent()) {
            keyBuilder.appendCacheable((Cacheable)dataSourceAnalysis.getJoinBaseTableFilter().get());
        }
        for (PreJoinableClause clause : clauses) {
            Optional<byte[]> bytes = this.joinableFactory.computeJoinCacheKey(clause.getDataSource(), clause.getCondition());
            if (!bytes.isPresent()) {
                log.debug("skipping caching for join since [%s] does not support caching", new Object[]{clause.getDataSource()});
                return Optional.empty();
            }
            keyBuilder.appendByteArray(bytes.get());
            keyBuilder.appendString(clause.getCondition().getOriginalExpression());
            keyBuilder.appendString(clause.getPrefix());
            keyBuilder.appendString(clause.getJoinType().name());
        }
        return Optional.of(keyBuilder.build());
    }

    @VisibleForTesting
    static Pair<List<Filter>, List<JoinableClause>> convertJoinsToFilters(List<JoinableClause> clauses, Set<String> requiredColumns, int maxNumFilterValues) {
        ArrayList<Filter> filterList = new ArrayList<Filter>();
        ArrayList<JoinableClause> clausesToUse = new ArrayList<JoinableClause>();
        HashMultiset columnsRequiredByJoinClauses = HashMultiset.create();
        for (JoinableClause clause : clauses) {
            for (String column : clause.getCondition().getRequiredColumns()) {
                columnsRequiredByJoinClauses.add((Object)column, 1);
            }
        }
        Set<String> rightPrefixes = clauses.stream().map(JoinableClause::getPrefix).collect(Collectors.toSet());
        boolean atStart = true;
        for (JoinableClause clause : clauses) {
            if (atStart) {
                for (String column : clause.getCondition().getRequiredColumns()) {
                    columnsRequiredByJoinClauses.remove((Object)column, 1);
                }
                JoinClauseToFilterConversion joinClauseToFilterConversion = JoinableFactoryWrapper.convertJoinToFilter(clause, (Set<String>)Sets.union(requiredColumns, (Set)columnsRequiredByJoinClauses.elementSet()), maxNumFilterValues, rightPrefixes);
                if (joinClauseToFilterConversion.getConvertedFilter() != null) {
                    filterList.add(joinClauseToFilterConversion.getConvertedFilter());
                }
                if (joinClauseToFilterConversion.isJoinClauseFullyConverted()) continue;
                clausesToUse.add(clause);
                atStart = false;
                continue;
            }
            clausesToUse.add(clause);
        }
        return Pair.of(filterList, clausesToUse);
    }

    @VisibleForTesting
    static JoinClauseToFilterConversion convertJoinToFilter(JoinableClause clause, Set<String> requiredColumns, int maxNumFilterValues, Set<String> rightPrefixes) {
        if (clause.getJoinType() == JoinType.INNER && clause.getCondition().getNonEquiConditions().isEmpty() && clause.getCondition().getEquiConditions().size() > 0) {
            ArrayList<Filter> filters = new ArrayList<Filter>();
            int numValues = maxNumFilterValues;
            boolean joinClauseFullyConverted = requiredColumns.stream().noneMatch(clause::includesColumn);
            for (Equality condition : clause.getCondition().getEquiConditions()) {
                String leftColumn = condition.getLeftExpr().getBindingIfIdentifier();
                if (leftColumn == null) {
                    return new JoinClauseToFilterConversion(null, false);
                }
                if (rightPrefixes.stream().anyMatch(leftColumn::startsWith)) {
                    joinClauseFullyConverted = false;
                    continue;
                }
                Joinable.ColumnValuesWithUniqueFlag columnValuesWithUniqueFlag = clause.getJoinable().getNonNullColumnValues(condition.getRightColumn(), numValues);
                if (columnValuesWithUniqueFlag.getColumnValues().isEmpty()) {
                    if (columnValuesWithUniqueFlag.isAllUnique()) {
                        return new JoinClauseToFilterConversion(FalseFilter.instance(), true);
                    }
                    joinClauseFullyConverted = false;
                    continue;
                }
                numValues -= columnValuesWithUniqueFlag.getColumnValues().size();
                filters.add(Filters.toFilter(new InDimFilter(leftColumn, columnValuesWithUniqueFlag.getColumnValues())));
                if (columnValuesWithUniqueFlag.isAllUnique()) continue;
                joinClauseFullyConverted = false;
            }
            return new JoinClauseToFilterConversion(Filters.maybeAnd(filters).orElse(null), joinClauseFullyConverted);
        }
        return new JoinClauseToFilterConversion(null, false);
    }

    private static class JoinClauseToFilterConversion {
        @Nullable
        private final Filter convertedFilter;
        private final boolean joinClauseFullyConverted;

        public JoinClauseToFilterConversion(@Nullable Filter convertedFilter, boolean joinClauseFullyConverted) {
            this.convertedFilter = convertedFilter;
            this.joinClauseFullyConverted = joinClauseFullyConverted;
        }

        @Nullable
        public Filter getConvertedFilter() {
            return this.convertedFilter;
        }

        public boolean isJoinClauseFullyConverted() {
            return this.joinClauseFullyConverted;
        }
    }
}

