/*
 * Decompiled with CFR 0.152.
 */
package org.apache.calcite.rel.rules;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptRuleOperand;
import org.apache.calcite.plan.hep.HepRelVertex;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Join;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.logical.LogicalFilter;
import org.apache.calcite.rel.logical.LogicalJoin;
import org.apache.calcite.rel.logical.LogicalProject;
import org.apache.calcite.rel.logical.LogicalTableScan;
import org.apache.calcite.rel.rules.QueryOptimizationRules;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.tools.RelBuilderFactory;
import org.apache.calcite.util.mapping.Mappings;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class OuterJoinOptViaNullRejectionRule
extends QueryOptimizationRules {
    public static Set<String> visitedJoinMemo = new HashSet<String>();
    static final Logger HEAVYDBLOGGER = LoggerFactory.getLogger(OuterJoinOptViaNullRejectionRule.class);

    public OuterJoinOptViaNullRejectionRule(RelBuilderFactory relBuilderFactory) {
        super(OuterJoinOptViaNullRejectionRule.operand(RelNode.class, OuterJoinOptViaNullRejectionRule.operand(Join.class, null, OuterJoinOptViaNullRejectionRule.any()), new RelOptRuleOperand[0]), relBuilderFactory, "OuterJoinOptViaNullRejectionRule");
        this.clearMemo();
    }

    void clearMemo() {
        visitedJoinMemo.clear();
    }

    @Override
    public void onMatch(RelOptRuleCall call) {
        Object parentNode = call.rel(0);
        LogicalJoin join = (LogicalJoin)call.rel(1);
        String condString = join.getCondition().toString();
        if (visitedJoinMemo.contains(condString)) {
            return;
        }
        visitedJoinMemo.add(condString);
        if (!(join.getCondition() instanceof RexCall)) {
            return;
        }
        if (join.getJoinType() == JoinRelType.INNER || join.getJoinType() == JoinRelType.SEMI || join.getJoinType() == JoinRelType.ANTI) {
            return;
        }
        RelNode joinLeftChild = ((HepRelVertex)join.getLeft()).getCurrentRel();
        RelNode joinRightChild = ((HepRelVertex)join.getRight()).getCurrentRel();
        if (joinLeftChild instanceof LogicalProject) {
            return;
        }
        if (!(joinRightChild instanceof LogicalTableScan)) {
            return;
        }
        RexCall joinCond = (RexCall)join.getCondition();
        HashSet<Integer> leftJoinCols = new HashSet<Integer>();
        HashSet<Integer> rightJoinCols = new HashSet<Integer>();
        HashMap<Integer, String> leftJoinColToColNameMap = new HashMap<Integer, String>();
        HashMap<Integer, String> rightJoinColToColNameMap = new HashMap<Integer, String>();
        HashSet<Integer> originalLeftJoinCols = new HashSet<Integer>();
        HashSet<Integer> originalRightJoinCols = new HashSet<Integer>();
        HashMap<Integer, String> originalLeftJoinColToColNameMap = new HashMap<Integer, String>();
        HashMap<Integer, String> originalRightJoinColToColNameMap = new HashMap<Integer, String>();
        ArrayList<RexCall> capturedFilterPredFromJoin = new ArrayList<RexCall>();
        if (joinCond.getKind() == SqlKind.EQUALS) {
            this.addJoinCols(joinCond, join, leftJoinCols, rightJoinCols, leftJoinColToColNameMap, rightJoinColToColNameMap, originalLeftJoinCols, originalRightJoinCols, originalLeftJoinColToColNameMap, originalRightJoinColToColNameMap);
        } else if (joinCond.getKind() == SqlKind.AND) {
            for (RexNode n : joinCond.getOperands()) {
                if (!(n instanceof RexCall)) continue;
                RexCall op = (RexCall)n;
                if (op.getOperands().size() > 2 && op.getOperands().get(1) instanceof RexLiteral) {
                    capturedFilterPredFromJoin.add(op);
                    continue;
                }
                this.addJoinCols(op, join, leftJoinCols, rightJoinCols, leftJoinColToColNameMap, rightJoinColToColNameMap, originalLeftJoinCols, originalRightJoinCols, originalLeftJoinColToColNameMap, originalRightJoinColToColNameMap);
            }
        }
        if (leftJoinCols.isEmpty() || rightJoinCols.isEmpty()) {
            return;
        }
        RelNode root = call.getPlanner().getRoot();
        ArrayList<LogicalFilter> collectedFilterNodes = new ArrayList<LogicalFilter>();
        RelNode curNode = root;
        RelBuilder relBuilder = call.builder();
        this.collectFilterCondition(curNode, collectedFilterNodes);
        if (collectedFilterNodes.isEmpty()) {
            return;
        }
        HashSet<Integer> nullRejectedLeftJoinCols = new HashSet<Integer>();
        HashSet<Integer> nullRejectedRightJoinCols = new HashSet<Integer>();
        boolean hasExprsConnectedViaOR = false;
        for (LogicalFilter filter : collectedFilterNodes) {
            RexNode node = filter.getCondition();
            if (!(node instanceof RexCall)) continue;
            RexCall curExpr = (RexCall)node;
            if (curExpr.getKind() == SqlKind.OR) {
                hasExprsConnectedViaOR = true;
                break;
            }
            if (curExpr.getKind() == SqlKind.AND) {
                for (RexNode n : curExpr.getOperands()) {
                    RexCall c;
                    if (!(n instanceof RexCall) || !this.isCandidateFilterPred(c = (RexCall)n) || !(c.getOperands().get(0) instanceof RexInputRef)) continue;
                    RexInputRef col = (RexInputRef)c.getOperands().get(0);
                    int colId = col.getIndex();
                    boolean leftFilter = leftJoinCols.contains(colId);
                    boolean rightFilter = rightJoinCols.contains(colId);
                    if (leftFilter && rightFilter) {
                        return;
                    }
                    this.addNullRejectedJoinCols(c, filter, nullRejectedLeftJoinCols, nullRejectedRightJoinCols, leftJoinColToColNameMap, rightJoinColToColNameMap);
                }
                continue;
            }
            if (!(curExpr instanceof RexCall) || !this.isCandidateFilterPred(curExpr) || !(curExpr.getOperands().get(0) instanceof RexInputRef)) continue;
            RexInputRef col = (RexInputRef)curExpr.getOperands().get(0);
            int colId = col.getIndex();
            boolean leftFilter = leftJoinCols.contains(colId);
            boolean rightFilter = rightJoinCols.contains(colId);
            if (leftFilter && rightFilter) {
                return;
            }
            this.addNullRejectedJoinCols(curExpr, filter, nullRejectedLeftJoinCols, nullRejectedRightJoinCols, leftJoinColToColNameMap, rightJoinColToColNameMap);
        }
        if (hasExprsConnectedViaOR) {
            return;
        }
        if (!capturedFilterPredFromJoin.isEmpty()) {
            for (RexCall c : capturedFilterPredFromJoin) {
                if (!(c.getOperands().get(0) instanceof RexInputRef)) continue;
                RexInputRef col = (RexInputRef)c.getOperands().get(0);
                int colId = col.getIndex();
                String colName = join.getRowType().getFieldNames().get(colId);
                Boolean l = false;
                Boolean r = false;
                if (originalLeftJoinColToColNameMap.containsKey(colId) && ((String)originalLeftJoinColToColNameMap.get(colId)).equals(colName)) {
                    l = true;
                }
                if (originalRightJoinColToColNameMap.containsKey(colId) && ((String)originalRightJoinColToColNameMap.get(colId)).equals(colName)) {
                    r = true;
                }
                if (l.booleanValue() && !r.booleanValue()) {
                    nullRejectedLeftJoinCols.add(colId);
                    continue;
                }
                if (r.booleanValue() && !l.booleanValue()) {
                    nullRejectedRightJoinCols.add(colId);
                    continue;
                }
                if (!r.booleanValue() || !l.booleanValue()) continue;
                return;
            }
        }
        Boolean leftNullRejected = false;
        Boolean rightNullRejected = false;
        if (!nullRejectedLeftJoinCols.isEmpty() && leftJoinCols.equals(nullRejectedLeftJoinCols)) {
            leftNullRejected = true;
        }
        if (!nullRejectedRightJoinCols.isEmpty() && rightJoinCols.equals(nullRejectedRightJoinCols)) {
            rightNullRejected = true;
        }
        LogicalJoin newJoinNode = null;
        Boolean needTransform = false;
        if (join.getJoinType() == JoinRelType.FULL) {
            if (leftNullRejected.booleanValue() && !rightNullRejected.booleanValue()) {
                newJoinNode = join.copy(join.getTraitSet(), join.getCondition(), join.getLeft(), join.getRight(), JoinRelType.LEFT, join.isSemiJoinDone());
                needTransform = true;
            }
            if (leftNullRejected.booleanValue() && rightNullRejected.booleanValue()) {
                newJoinNode = join.copy(join.getTraitSet(), join.getCondition(), join.getLeft(), join.getRight(), JoinRelType.INNER, join.isSemiJoinDone());
                needTransform = true;
            }
        } else if (join.getJoinType() == JoinRelType.LEFT && rightNullRejected.booleanValue()) {
            newJoinNode = join.copy(join.getTraitSet(), join.getCondition(), join.getLeft(), join.getRight(), JoinRelType.INNER, join.isSemiJoinDone());
            needTransform = true;
        }
        if (needTransform.booleanValue()) {
            relBuilder.push(newJoinNode);
            parentNode.replaceInput(0, newJoinNode);
            call.transformTo((RelNode)parentNode);
        }
    }

    void addJoinCols(RexCall joinCond, LogicalJoin joinOp, Set<Integer> leftJoinCols, Set<Integer> rightJoinCols, Map<Integer, String> leftJoinColToColNameMap, Map<Integer, String> rightJoinColToColNameMap, Set<Integer> originalLeftJoinCols, Set<Integer> originalRightJoinCols, Map<Integer, String> originalLeftJoinColToColNameMap, Map<Integer, String> originalRightJoinColToColNameMap) {
        if (joinCond.getOperands().size() != 2 || !(joinCond.getOperands().get(0) instanceof RexInputRef) || !(joinCond.getOperands().get(1) instanceof RexInputRef)) {
            return;
        }
        RexInputRef leftJoinCol = (RexInputRef)joinCond.getOperands().get(0);
        RexInputRef rightJoinCol = (RexInputRef)joinCond.getOperands().get(1);
        originalLeftJoinCols.add(leftJoinCol.getIndex());
        originalRightJoinCols.add(rightJoinCol.getIndex());
        originalLeftJoinColToColNameMap.put(leftJoinCol.getIndex(), joinOp.getRowType().getFieldNames().get(leftJoinCol.getIndex()));
        originalRightJoinColToColNameMap.put(rightJoinCol.getIndex(), joinOp.getRowType().getFieldNames().get(rightJoinCol.getIndex()));
        if (leftJoinCol.getIndex() > rightJoinCol.getIndex()) {
            leftJoinCol = (RexInputRef)joinCond.getOperands().get(1);
            rightJoinCol = (RexInputRef)joinCond.getOperands().get(0);
        }
        int originalLeftColOffset = this.traceColOffset(joinOp.getLeft(), leftJoinCol, 0);
        int originalRightColOffset = this.traceColOffset(joinOp.getRight(), rightJoinCol, joinOp.getLeft().getRowType().getFieldCount());
        if (originalLeftColOffset != -1) {
            return;
        }
        int leftColOffset = originalLeftColOffset == -1 ? leftJoinCol.getIndex() : originalLeftColOffset;
        int rightColOffset = originalRightColOffset == -1 ? rightJoinCol.getIndex() : originalRightColOffset;
        String leftJoinColName = joinOp.getRowType().getFieldNames().get(leftColOffset);
        String rightJoinColName = joinOp.getRowType().getFieldNames().get(rightJoinCol.getIndex());
        leftJoinCols.add(leftColOffset);
        rightJoinCols.add(rightColOffset);
        leftJoinColToColNameMap.put(leftColOffset, leftJoinColName);
        rightJoinColToColNameMap.put(rightColOffset, rightJoinColName);
    }

    void addNullRejectedJoinCols(RexCall call, LogicalFilter targetFilter, Set<Integer> nullRejectedLeftJoinCols, Set<Integer> nullRejectedRightJoinCols, Map<Integer, String> leftJoinColToColNameMap, Map<Integer, String> rightJoinColToColNameMap) {
        if (this.isCandidateFilterPred(call) && call.getOperands().get(0) instanceof RexInputRef) {
            RexInputRef col = (RexInputRef)call.getOperands().get(0);
            int colId = col.getIndex();
            String colName = targetFilter.getRowType().getFieldNames().get(colId);
            Boolean l = false;
            Boolean r = false;
            if (leftJoinColToColNameMap.containsKey(colId) && leftJoinColToColNameMap.get(colId).equals(colName)) {
                l = true;
            }
            if (rightJoinColToColNameMap.containsKey(colId) && rightJoinColToColNameMap.get(colId).equals(colName)) {
                r = true;
            }
            if (l.booleanValue() && !r.booleanValue()) {
                nullRejectedLeftJoinCols.add(colId);
                return;
            }
            if (r.booleanValue() && !l.booleanValue()) {
                nullRejectedRightJoinCols.add(colId);
                return;
            }
        }
    }

    void collectFilterCondition(RelNode curNode, List<LogicalFilter> collectedFilterNodes) {
        if (curNode instanceof HepRelVertex) {
            curNode = ((HepRelVertex)curNode).getCurrentRel();
        }
        if (curNode instanceof LogicalFilter) {
            collectedFilterNodes.add((LogicalFilter)curNode);
        }
        if (curNode.getInputs().size() == 0) {
            return;
        }
        for (int i = 0; i < curNode.getInputs().size(); ++i) {
            this.collectFilterCondition(curNode.getInput(i), collectedFilterNodes);
        }
    }

    void collectProjectNode(RelNode curNode, List<LogicalProject> collectedProject) {
        if (curNode instanceof HepRelVertex) {
            curNode = ((HepRelVertex)curNode).getCurrentRel();
        }
        if (curNode instanceof LogicalProject) {
            collectedProject.add((LogicalProject)curNode);
        }
        if (curNode.getInputs().size() == 0) {
            return;
        }
        for (int i = 0; i < curNode.getInputs().size(); ++i) {
            this.collectProjectNode(curNode.getInput(i), collectedProject);
        }
    }

    int traceColOffset(RelNode curNode, RexInputRef colRef, int startOffset) {
        int colOffset = -1;
        ArrayList<LogicalProject> collectedProjectNodes = new ArrayList<LogicalProject>();
        this.collectProjectNode(curNode, collectedProjectNodes);
        if (!collectedProjectNodes.isEmpty()) {
            int base_offset;
            LogicalProject projectNode = collectedProjectNodes.get(0);
            Mappings.TargetMapping targetMapping = projectNode.getMapping();
            if (null != colRef && null != targetMapping && (base_offset = colRef.getIndex() - startOffset) >= 0 && base_offset < targetMapping.getSourceCount()) {
                colOffset = targetMapping.getSourceOpt(base_offset);
            }
        }
        return colOffset;
    }

    boolean isComparisonOp(RexCall c) {
        SqlKind opKind = c.getKind();
        return SqlKind.BINARY_COMPARISON.contains((Object)opKind) || SqlKind.BINARY_EQUALITY.contains((Object)opKind);
    }

    boolean isNotNullFilter(RexCall c) {
        return c.op.kind == SqlKind.IS_NOT_NULL && c.operands.size() == 1;
    }

    boolean isCandidateFilterPred(RexCall c) {
        return this.isNotNullFilter(c) || c.operands.size() == 2 && this.isComparisonOp(c) && c.operands.get(0) instanceof RexInputRef && c.operands.get(1) instanceof RexLiteral;
    }
}

