/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hadoop.hive.ql.optimizer.calcite.translator.opconventer;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Join;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rex.RexNode;
import org.apache.hadoop.hive.ql.exec.ColumnInfo;
import org.apache.hadoop.hive.ql.exec.JoinOperator;
import org.apache.hadoop.hive.ql.exec.Operator;
import org.apache.hadoop.hive.ql.exec.OperatorFactory;
import org.apache.hadoop.hive.ql.exec.ReduceSinkOperator;
import org.apache.hadoop.hive.ql.exec.RowSchema;
import org.apache.hadoop.hive.ql.exec.Utilities;
import org.apache.hadoop.hive.ql.optimizer.calcite.HiveCalciteUtil;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveAntiJoin;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveJoin;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveMultiJoin;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveSemiJoin;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveSortExchange;
import org.apache.hadoop.hive.ql.optimizer.calcite.translator.opconventer.HiveOpConverter;
import org.apache.hadoop.hive.ql.optimizer.calcite.translator.opconventer.HiveOpConverterUtils;
import org.apache.hadoop.hive.ql.optimizer.calcite.translator.opconventer.HiveRelNodeVisitor;
import org.apache.hadoop.hive.ql.parse.JoinCond;
import org.apache.hadoop.hive.ql.parse.JoinType;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.plan.ExprNodeColumnDesc;
import org.apache.hadoop.hive.ql.plan.ExprNodeDesc;
import org.apache.hadoop.hive.ql.plan.ExprNodeGenericFuncDesc;
import org.apache.hadoop.hive.ql.plan.JoinCondDesc;
import org.apache.hadoop.hive.ql.plan.JoinDesc;
import org.apache.hadoop.hive.ql.plan.OperatorDesc;
import org.apache.hadoop.hive.ql.plan.ReduceSinkDesc;

class JoinVisitor
extends HiveRelNodeVisitor<RelNode> {
    JoinVisitor(HiveOpConverter hiveOpConverter) {
        super(hiveOpConverter);
    }

    @Override
    HiveOpConverter.OpAttr visit(RelNode joinRel) throws SemanticException {
        Object joinFilters;
        int i;
        String[] baseSrc = new String[joinRel.getInputs().size()];
        String tabAlias = this.hiveOpConverter.getHiveDerivedTableAlias();
        HiveOpConverter.OpAttr[] inputs = new HiveOpConverter.OpAttr[joinRel.getInputs().size()];
        ArrayList children = new ArrayList(joinRel.getInputs().size());
        for (int i2 = 0; i2 < inputs.length; ++i2) {
            inputs[i2] = this.hiveOpConverter.dispatch(joinRel.getInput(i2));
            children.add((Operator)inputs[i2].inputs.get(0));
            baseSrc[i2] = inputs[i2].tabAlias;
        }
        for (int tag = 0; tag < children.size(); ++tag) {
            ReduceSinkOperator reduceSinkOp = (ReduceSinkOperator)children.get(tag);
            ((ReduceSinkDesc)reduceSinkOp.getConf()).setTag(tag);
        }
        HashSet<Integer> newVcolsInCalcite = new HashSet<Integer>();
        newVcolsInCalcite.addAll((Collection<Integer>)inputs[0].vcolsInCalcite);
        if (joinRel instanceof HiveMultiJoin || !(joinRel instanceof Join) || !((Join)joinRel).isSemiJoin() && ((Join)joinRel).getJoinType() != JoinRelType.ANTI) {
            int shift = ((Operator)inputs[0].inputs.get(0)).getSchema().getSignature().size();
            for (i = 1; i < inputs.length; ++i) {
                newVcolsInCalcite.addAll((Collection<Integer>)HiveCalciteUtil.shiftVColsSet(inputs[i].vcolsInCalcite, shift));
                shift += ((Operator)inputs[i].inputs.get(0)).getSchema().getSignature().size();
            }
        }
        if (LOG.isDebugEnabled()) {
            LOG.debug("Translating operator rel#" + joinRel.getId() + ":" + joinRel.getRelTypeName() + " with row type: [" + joinRel.getRowType() + "]");
        }
        ExprNodeDesc[][] joinExpressions = new ExprNodeDesc[inputs.length][];
        for (i = 0; i < inputs.length; ++i) {
            joinExpressions[i] = ((HiveSortExchange)joinRel.getInput(i)).getKeyExpressions();
        }
        if (joinRel instanceof HiveJoin) {
            joinFilters = ImmutableList.of((Object)((HiveJoin)joinRel).getJoinFilter());
        } else if (joinRel instanceof HiveMultiJoin) {
            joinFilters = ((HiveMultiJoin)joinRel).getJoinFilters();
        } else if (joinRel instanceof HiveSemiJoin) {
            joinFilters = ImmutableList.of((Object)((HiveSemiJoin)joinRel).getJoinFilter());
        } else if (joinRel instanceof HiveAntiJoin) {
            joinFilters = ImmutableList.of((Object)((HiveAntiJoin)joinRel).getJoinFilter());
        } else {
            throw new SemanticException("Can't handle join type: " + joinRel.getClass().getName());
        }
        ArrayList filterExpressions = Lists.newArrayList();
        for (int i3 = 0; i3 < joinFilters.size(); ++i3) {
            ArrayList<ExprNodeDesc> filterExpressionsForInput = new ArrayList<ExprNodeDesc>();
            if (joinFilters.get(i3) != null) {
                for (RexNode conj : RelOptUtil.conjunctions((RexNode)((RexNode)joinFilters.get(i3)))) {
                    ExprNodeDesc expr = HiveOpConverterUtils.convertToExprNode(conj, joinRel, null, newVcolsInCalcite);
                    filterExpressionsForInput.add(expr);
                }
            }
            filterExpressions.add(filterExpressionsForInput);
        }
        JoinOperator joinOp = this.genJoin(joinRel, joinExpressions, filterExpressions, children, baseSrc, tabAlias);
        return new HiveOpConverter.OpAttr(tabAlias, newVcolsInCalcite, joinOp);
    }

    private JoinOperator genJoin(RelNode join, ExprNodeDesc[][] joinExpressions, List<List<ExprNodeDesc>> filterExpressions, List<Operator<?>> children, String[] baseSrc, String tabAlias) throws SemanticException {
        int i;
        Object tag;
        boolean noOuterJoin;
        boolean semiJoin;
        JoinCondDesc[] joinCondns;
        if (join instanceof HiveMultiJoin) {
            HiveMultiJoin hmj = (HiveMultiJoin)join;
            joinCondns = new JoinCondDesc[hmj.getJoinInputs().size()];
            for (int i2 = 0; i2 < hmj.getJoinInputs().size(); ++i2) {
                joinCondns[i2] = new JoinCondDesc(new JoinCond((Integer)hmj.getJoinInputs().get((int)i2).left, (Integer)hmj.getJoinInputs().get((int)i2).right, this.transformJoinType(hmj.getJoinTypes().get(i2))));
            }
            semiJoin = false;
            noOuterJoin = !hmj.isOuterJoin();
        } else {
            JoinType joinType;
            joinCondns = new JoinCondDesc[1];
            JoinRelType joinRelType = JoinRelType.INNER;
            if (join instanceof Join) {
                joinRelType = ((Join)join).getJoinType();
            }
            switch (joinRelType) {
                case SEMI: {
                    joinType = JoinType.LEFTSEMI;
                    semiJoin = true;
                    break;
                }
                case ANTI: {
                    joinType = JoinType.ANTI;
                    semiJoin = true;
                    break;
                }
                default: {
                    assert (join instanceof Join);
                    joinType = this.transformJoinType(((Join)join).getJoinType());
                    semiJoin = false;
                }
            }
            joinCondns[0] = new JoinCondDesc(new JoinCond(0, 1, joinType));
            noOuterJoin = joinType != JoinType.FULLOUTER && joinType != JoinType.LEFTOUTER && joinType != JoinType.RIGHTOUTER;
        }
        ArrayList<ColumnInfo> outputColumns = new ArrayList<ColumnInfo>();
        ArrayList<String> outputColumnNames = new ArrayList<String>(join.getRowType().getFieldNames());
        Operator[] childOps = new Operator[children.size()];
        HashMap<String, Byte> reversedExprs = new HashMap<String, Byte>();
        HashMap<Byte, List<ExprNodeDesc>> exprMap = new HashMap<Byte, List<ExprNodeDesc>>();
        HashMap<Byte, List<ExprNodeDesc>> filters = new HashMap<Byte, List<ExprNodeDesc>>();
        HashMap<String, ExprNodeDesc> colExprMap = new HashMap<String, ExprNodeDesc>();
        HashMap<Integer, Set<String>> posToAliasMap = new HashMap<Integer, Set<String>>();
        int outputPos = 0;
        for (int pos = 0; pos < children.size(); ++pos) {
            ReduceSinkOperator inputRS = (ReduceSinkOperator)children.get(pos);
            if (inputRS.getNumParent() != 1) {
                throw new SemanticException("RS should have single parent");
            }
            Operator<OperatorDesc> parent = inputRS.getParentOperators().get(0);
            ReduceSinkDesc rsDesc = (ReduceSinkDesc)inputRS.getConf();
            int[] index = inputRS.getValueIndex();
            tag = (byte)rsDesc.getTag();
            if (semiJoin && pos != 0) {
                exprMap.put((Byte)tag, new ArrayList());
                childOps[pos] = inputRS;
                continue;
            }
            posToAliasMap.put(pos, new HashSet<String>(inputRS.getSchema().getTableNames()));
            List<String> keyColNames = rsDesc.getOutputKeyColumnNames();
            List<String> valColNames = rsDesc.getOutputValueColumnNames();
            Map<String, ExprNodeDesc> descriptors = this.buildBacktrackFromReduceSinkForJoin(outputPos, outputColumnNames, keyColNames, valColNames, index, parent, baseSrc[pos]);
            List<ColumnInfo> parentColumns = parent.getSchema().getSignature();
            for (int i3 = 0; i3 < index.length; ++i3) {
                ColumnInfo info = new ColumnInfo(parentColumns.get(i3));
                info.setInternalName(outputColumnNames.get(outputPos));
                info.setTabAlias(tabAlias);
                outputColumns.add(info);
                reversedExprs.put(outputColumnNames.get(outputPos), (Byte)tag);
                ++outputPos;
            }
            exprMap.put((Byte)tag, new ArrayList<ExprNodeDesc>(descriptors.values()));
            colExprMap.putAll(descriptors);
            childOps[pos] = inputRS;
        }
        ArrayList filtersPerInput = Lists.newArrayList();
        int[][] filterMap = new int[children.size()][];
        for (i = 0; i < children.size(); ++i) {
            filtersPerInput.add(new ArrayList());
        }
        for (i = 0; i < filterExpressions.size(); ++i) {
            int leftPos = joinCondns[i].getLeft();
            int rightPos = joinCondns[i].getRight();
            for (ExprNodeDesc expr : filterExpressions.get(i)) {
                int inputPos = this.updateExprNode(expr, reversedExprs, colExprMap);
                if (inputPos == -1) {
                    inputPos = leftPos;
                }
                ((List)filtersPerInput.get(inputPos)).add(expr);
                if (joinCondns[i].getType() != 3 && joinCondns[i].getType() != 1 && joinCondns[i].getType() != 2) continue;
                if (inputPos == leftPos) {
                    this.updateFilterMap(filterMap, leftPos, rightPos);
                    continue;
                }
                this.updateFilterMap(filterMap, rightPos, leftPos);
            }
        }
        for (int pos = 0; pos < children.size(); ++pos) {
            ReduceSinkOperator inputRS = (ReduceSinkOperator)children.get(pos);
            ReduceSinkDesc rsDesc = (ReduceSinkDesc)inputRS.getConf();
            tag = (byte)rsDesc.getTag();
            filters.put((Byte)tag, (List)filtersPerInput.get(pos));
        }
        JoinDesc desc = new JoinDesc(exprMap, outputColumnNames, noOuterJoin, joinCondns, filters, joinExpressions, null);
        desc.setReversedExprs(reversedExprs);
        desc.setFilterMap(filterMap);
        JoinOperator joinOp = (JoinOperator)OperatorFactory.getAndMakeChild(childOps[0].getCompilationOpContext(), desc, new RowSchema(outputColumns), childOps);
        joinOp.setColumnExprMap(colExprMap);
        joinOp.setPosToAliasMap(posToAliasMap);
        ((JoinDesc)joinOp.getConf()).setBaseSrc(baseSrc);
        if (LOG.isDebugEnabled()) {
            LOG.debug("Generated " + joinOp + " with row schema: [" + joinOp.getSchema() + "]");
        }
        return joinOp;
    }

    private Map<String, ExprNodeDesc> buildBacktrackFromReduceSinkForJoin(int initialPos, List<String> outputColumnNames, List<String> keyColNames, List<String> valueColNames, int[] index, Operator<?> inputOp, String tabAlias) {
        LinkedHashMap<String, ExprNodeDesc> columnDescriptors = new LinkedHashMap<String, ExprNodeDesc>();
        for (int i = 0; i < index.length; ++i) {
            ColumnInfo info = new ColumnInfo(inputOp.getSchema().getSignature().get(i));
            String field = index[i] >= 0 ? Utilities.ReduceField.KEY + "." + keyColNames.get(index[i]) : Utilities.ReduceField.VALUE + "." + valueColNames.get(-index[i] - 1);
            ExprNodeColumnDesc desc = new ExprNodeColumnDesc(info.getType(), field, tabAlias, info.getIsVirtualCol());
            columnDescriptors.put(outputColumnNames.get(initialPos + i), desc);
        }
        return columnDescriptors;
    }

    private int updateExprNode(ExprNodeDesc expr, Map<String, Byte> reversedExprs, Map<String, ExprNodeDesc> colExprMap) throws SemanticException {
        byte inputPos = -1;
        if (expr instanceof ExprNodeGenericFuncDesc) {
            ExprNodeGenericFuncDesc func = (ExprNodeGenericFuncDesc)expr;
            ArrayList<ExprNodeDesc> newChildren = new ArrayList<ExprNodeDesc>();
            for (ExprNodeDesc functionChild : func.getChildren()) {
                if (functionChild instanceof ExprNodeColumnDesc) {
                    String colRef = functionChild.getExprString();
                    byte pos = reversedExprs.get(colRef);
                    if (pos != -1) {
                        if (inputPos == -1) {
                            inputPos = pos;
                        } else if (inputPos != pos) {
                            throw new SemanticException("UpdateExprNode is expecting only one position for join operator convert. But there are more than one.");
                        }
                    }
                    newChildren.add(colExprMap.get(colRef));
                    continue;
                }
                byte pos = this.updateExprNode(functionChild, reversedExprs, colExprMap);
                if (pos != -1) {
                    if (inputPos == -1) {
                        inputPos = pos;
                    } else if (inputPos != pos) {
                        throw new SemanticException("UpdateExprNode is expecting only one position for join operator convert. But there are more than one.");
                    }
                }
                newChildren.add(functionChild);
            }
            func.setChildren(newChildren);
        }
        return inputPos;
    }

    private void updateFilterMap(int[][] filterMap, int inputPos, int joinPos) {
        int[] map = filterMap[inputPos];
        if (map == null) {
            filterMap[inputPos] = new int[2];
            filterMap[inputPos][0] = joinPos;
            int[] nArray = filterMap[inputPos];
            nArray[1] = nArray[1] + 1;
        } else {
            boolean inserted = false;
            for (int j = 0; j < map.length / 2 && !inserted; ++j) {
                if (map[j * 2] != joinPos) continue;
                int n = j * 2 + 1;
                map[n] = map[n] + 1;
                inserted = true;
            }
            if (!inserted) {
                int[] newMap = new int[map.length + 2];
                System.arraycopy(map, 0, newMap, 0, map.length);
                newMap[map.length] = joinPos;
                int n = map.length + 1;
                newMap[n] = newMap[n] + 1;
                filterMap[inputPos] = newMap;
            }
        }
    }

    private JoinType transformJoinType(JoinRelType type) {
        return switch (type) {
            case JoinRelType.FULL -> JoinType.FULLOUTER;
            case JoinRelType.LEFT -> JoinType.LEFTOUTER;
            case JoinRelType.RIGHT -> JoinType.RIGHTOUTER;
            default -> JoinType.INNER;
        };
    }
}

