package com.sri.ai.grinder.sgdpllt.library;

import com.google.common.annotations.Beta;
import com.google.common.base.Predicate;
import com.google.common.base.Predicates;
import com.sri.ai.expresso.api.Expression;
import com.sri.ai.expresso.helper.Expressions;
import com.sri.ai.grinder.sgdpllt.api.Context;
import com.sri.ai.grinder.sgdpllt.library.boole.And;
import com.sri.ai.grinder.sgdpllt.library.boole.Not;
import com.sri.ai.grinder.sgdpllt.rewriter.api.Simplifier;
import com.sri.ai.util.Util;
import com.sri.ai.util.base.BinaryFunction;
import com.sri.ai.util.base.Pair;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Set;

@Beta
/* loaded from: input_file:com/sri/ai/grinder/sgdpllt/library/Equality.class */
public class Equality implements Simplifier {
    public static final Expression FUNCTOR = Expressions.makeSymbol("=");
    public static BinaryFunction<Expression, Expression, Expression> MAKE_PAIR_EQUALITY = new BinaryFunction<Expression, Expression, Expression>() { // from class: com.sri.ai.grinder.sgdpllt.library.Equality.1
        @Override // com.sri.ai.util.base.BinaryFunction
        public Expression apply(Expression expression, Expression expression2) {
            return Equality.make(expression, expression2);
        }
    };

    @Override // com.sri.ai.grinder.sgdpllt.rewriter.api.Simplifier
    public Expression applySimplifier(Expression expression, Context context) {
        return simplify(expression, context);
    }

    public static Expression equalityResultIfItIsKnown(Expression expression, Context context) {
        Boolean equalityResultIfItIsKnownOrNull = equalityResultIfItIsKnownOrNull(expression, context.getIsUniquelyNamedConstantPredicate(), context);
        return equalityResultIfItIsKnownOrNull != null ? Expressions.makeSymbol(equalityResultIfItIsKnownOrNull) : expression;
    }

    private static Boolean equalityResultIfItIsKnownOrNull(Expression expression, Predicate<Expression> predicate, Context context) {
        int numberOfArguments = expression.numberOfArguments() - 1;
        for (int i = 0; i != numberOfArguments; i++) {
            Boolean equalityOfPairIfItIsKnownOrNull = equalityOfPairIfItIsKnownOrNull(expression.get(i), expression.get(i + 1), predicate, context);
            if (equalityOfPairIfItIsKnownOrNull == null) {
                return null;
            }
            if (!equalityOfPairIfItIsKnownOrNull.booleanValue()) {
                return false;
            }
        }
        return true;
    }

    private static Boolean equalityOfPairIfItIsKnownOrNull(Expression expression, Expression expression2, Predicate<Expression> predicate, Context context) {
        if (predicate.apply(expression)) {
            if (predicate.apply(expression2)) {
                return Boolean.valueOf(expression.equals(expression2));
            }
            return null;
        }
        if (predicate.apply(expression2)) {
            return null;
        }
        if (expression.equals(expression2)) {
            return true;
        }
        return (context.isVariable(expression) && context.isVariable(expression2) && Util.notNullAndDistinct(context.getTypeExpressionOfRegisteredSymbol(expression), context.getTypeExpressionOfRegisteredSymbol(expression2))) ? false : null;
    }

    public static Expression makeSureFirstArgumentIsNotAConstant(Expression expression, Context context) {
        Expression expression2 = expression;
        if (context.isUniquelyNamedConstant(expression.get(0))) {
            expression2 = make(expression.get(1), expression.get(0));
        }
        return expression2;
    }

    public static Expression conditionForSubExpressionsEquality(Expression expression, Expression expression2) {
        return And.make((List<? extends Expression>) listOfEqualitiesOfSubExpressions(expression, expression2));
    }

    public static List<Expression> listOfEqualitiesOfSubExpressions(Expression expression, Expression expression2) {
        return Util.zipWith(MAKE_PAIR_EQUALITY, Util.listFrom(expression.getImmediateSubExpressionsIterator()), Util.listFrom(expression2.getImmediateSubExpressionsIterator()));
    }

    public static Expression make(Object... objArr) {
        if (objArr.length == 1 && (objArr[0] instanceof List)) {
            objArr = ((List) objArr[0]).toArray();
        }
        List<Expression> wrap = Expressions.wrap(objArr);
        return new LinkedHashSet(wrap).size() < 2 ? Expressions.TRUE : Expressions.apply("=", wrap);
    }

    public static Expression makeWithConstantSimplification(Expression expression, Expression expression2, Context context) {
        return expression.equals(expression2) ? Expressions.TRUE : (context.isUniquelyNamedConstant(expression) && context.isUniquelyNamedConstant(expression2)) ? Expressions.FALSE : make(expression, expression2);
    }

    public static Expression makeWithConstantSimplification(Expression expression, Context context) {
        List<Expression> arguments = expression.getArguments();
        Expression expression2 = null;
        LinkedHashSet linkedHashSet = new LinkedHashSet();
        Object obj = null;
        Iterator<Expression> it = arguments.iterator();
        while (true) {
            if (!it.hasNext()) {
                break;
            }
            Expression next = it.next();
            if (!context.isUniquelyNamedConstant(next)) {
                linkedHashSet.add(next);
            } else if (obj == null) {
                obj = next;
            } else if (!next.equals(obj)) {
                expression2 = Expressions.FALSE;
                break;
            }
        }
        if (expression2 == null) {
            if (obj != null) {
                linkedHashSet.add(obj);
            }
            if (linkedHashSet.size() == 1) {
                expression2 = Expressions.TRUE;
            }
        }
        if (expression2 == null) {
            expression2 = arguments.size() == linkedHashSet.size() ? expression : Expressions.apply("=", linkedHashSet.toArray());
        }
        return expression2;
    }

    public static Pair<List<Expression>, Expression> getVariablesListAndConstantOrNullIfNoConstant(Expression expression, Context context) {
        LinkedList linkedList = new LinkedList();
        LinkedList linkedList2 = new LinkedList();
        Util.collectOrReturnFalseIfElementDoesNotFitEither(expression.getArguments(), linkedList, Predicates.not(context.getIsUniquelyNamedConstantPredicate()), linkedList2, context.getIsUniquelyNamedConstantPredicate());
        return linkedList2.isEmpty() ? null : Pair.make(linkedList, (Expression) Util.getFirst(linkedList2));
    }

    public static Pair<Set<Expression>, Set<Expression>> getVariablesListAndConstantsList(Expression expression, Context context) {
        LinkedHashSet linkedHashSet = new LinkedHashSet();
        LinkedHashSet linkedHashSet2 = new LinkedHashSet();
        Util.collectOrReturnFalseIfElementDoesNotFitEither(expression.getArguments(), linkedHashSet, Predicates.not(context.getIsUniquelyNamedConstantPredicate()), linkedHashSet2, context.getIsUniquelyNamedConstantPredicate());
        return Pair.make(linkedHashSet, linkedHashSet2);
    }

    public static boolean isEquality(Expression expression) {
        return expression.hasFunctor("=");
    }

    public static Set<Expression> getSymbolsBoundToSomethingElse(Expression expression) {
        LinkedHashSet linkedHashSet = new LinkedHashSet();
        if (isEquality(expression)) {
            Iterator<Expression> it = expression.getArguments().iterator();
            while (it.hasNext()) {
                linkedHashSet.add(it.next());
            }
            if (linkedHashSet.size() == 1) {
                linkedHashSet.clear();
            }
        }
        return linkedHashSet;
    }

    public static Expression normalize(Expression expression, Context context) {
        return (expression.numberOfArguments() == 2 && context.isUniquelyNamedConstant(expression.get(0)) && !context.isUniquelyNamedConstant(expression.get(1))) ? Expressions.makeExpressionOnSyntaxTreeWithLabelAndSubTrees(expression.getFunctor(), expression.get(1), expression.get(0)) : expression;
    }

    public static Expression makePairwiseEquality(List<Expression> list, List<Expression> list2) {
        return list.size() == list2.size() ? And.make((List<? extends Expression>) Expressions.makePairwiseApplications("=", list, list2)) : Expressions.FALSE;
    }

    public static Pair<Expression, Expression> separateVariableLiteral(Expression expression, Expression expression2, Context context) {
        Pair<Expression, Expression> make;
        if (expression2.equals(Expressions.TRUE) || expression2.equals(Expressions.FALSE)) {
            make = Pair.make(expression2, expression2);
        } else if (expression2.hasFunctor(FunctorConstants.DISEQUALITY)) {
            make = !expression2.getArguments().contains(expression) ? Pair.make(Expressions.TRUE, expression2) : Pair.make(expression2, Expressions.TRUE);
        } else {
            if (!expression2.hasFunctor("=")) {
                throw new Error(expression2 + " is not an equality literal as required by Equality.separateVariableLiteral");
            }
            if (expression2.getArguments().contains(expression)) {
                Pair<Set<Expression>, Set<Expression>> variablesListAndConstantsList = getVariablesListAndConstantsList(expression2, context);
                Set<Expression> set = variablesListAndConstantsList.second;
                if (set.size() > 1) {
                    make = Pair.make(Expressions.FALSE, Expressions.FALSE);
                } else {
                    Set<Expression> set2 = variablesListAndConstantsList.first;
                    set2.remove(expression);
                    LinkedList linkedList = new LinkedList(set2);
                    linkedList.addAll(set);
                    make = linkedList.isEmpty() ? Pair.make(Expressions.TRUE, Expressions.TRUE) : Pair.make(make(expression, (Expression) Util.getLast(linkedList)), make(linkedList));
                }
            } else {
                make = Pair.make(Expressions.TRUE, expression2);
            }
        }
        return make;
    }

    public static Expression getWhatExpressionIsComparedToIfUniqueOrNull(Expression expression, Expression expression2) {
        return expression.numberOfArguments() == 2 ? expression.get(0).equals(expression2) ? expression.get(1) : expression.get(1).equals(expression2) ? expression.get(0) : null : null;
    }

    public static Expression simplifyIfEqualityOrDisequality(Expression expression, Context context) {
        return isEquality(expression) ? simplify(expression, context) : Disequality.isDisequality(expression) ? Disequality.simplify(expression, context) : expression;
    }

    public static Expression simplify(Expression expression, Context context) {
        Expression make;
        if (Util.allEqual(expression.getArguments())) {
            make = Expressions.TRUE;
        } else {
            LinkedHashSet linkedHashSet = new LinkedHashSet();
            LinkedHashSet linkedHashSet2 = new LinkedHashSet();
            Util.collect(expression.getArguments(), linkedHashSet, context.getIsUniquelyNamedConstantPredicate(), linkedHashSet2);
            make = linkedHashSet.size() > 1 ? Expressions.FALSE : (linkedHashSet.size() == 1 && linkedHashSet.contains(Expressions.TRUE)) ? And.make((List<? extends Expression>) new ArrayList(linkedHashSet2)) : (linkedHashSet.size() == 1 && linkedHashSet.contains(Expressions.FALSE)) ? And.make((List<? extends Expression>) new ArrayList(Util.mapIntoArrayList(linkedHashSet2, expression2 -> {
                return Not.make(expression2);
            }))) : expression;
        }
        return make;
    }

    public static Expression simplifyGivenEquality(Expression expression, Expression expression2, Expression expression3) {
        return (expression.getArguments().contains(expression2) && expression.getArguments().contains(expression3)) ? Expressions.TRUE : expression;
    }

    public static Expression simplifyGivenDisequality(Expression expression, Expression expression2, Expression expression3) {
        return (expression.getArguments().contains(expression2) && expression.getArguments().contains(expression3)) ? Expressions.FALSE : expression;
    }

    @Override // com.sri.ai.grinder.sgdpllt.rewriter.api.Simplifier, com.sri.ai.grinder.sgdpllt.rewriter.api.Rewriter, com.sri.ai.util.base.BinaryFunction
    public /* bridge */ /* synthetic */ Expression apply(Expression expression, Context context) {
        return apply(expression, context);
    }
}
