package com.sri.ai.grinder.sgdpllt.library.set.invsupport;

import com.sri.ai.expresso.api.Expression;
import com.sri.ai.expresso.api.IndexExpressionsSet;
import com.sri.ai.expresso.api.IntensionalSet;
import com.sri.ai.expresso.api.Type;
import com.sri.ai.expresso.core.ExtensionalIndexExpressionsSet;
import com.sri.ai.expresso.helper.Expressions;
import com.sri.ai.expresso.helper.SubExpressionsDepthFirstIterator;
import com.sri.ai.expresso.type.FunctionType;
import com.sri.ai.grinder.helper.GrinderUtil;
import com.sri.ai.grinder.sgdpllt.api.Context;
import com.sri.ai.grinder.sgdpllt.library.Disequality;
import com.sri.ai.grinder.sgdpllt.library.Equality;
import com.sri.ai.grinder.sgdpllt.library.FunctorConstants;
import com.sri.ai.grinder.sgdpllt.library.boole.And;
import com.sri.ai.grinder.sgdpllt.library.boole.ForAll;
import com.sri.ai.grinder.sgdpllt.library.boole.Implication;
import com.sri.ai.grinder.sgdpllt.library.indexexpression.IndexExpressions;
import com.sri.ai.grinder.sgdpllt.library.set.Sets;
import com.sri.ai.grinder.sgdpllt.rewriter.api.Simplifier;
import com.sri.ai.util.Util;
import com.sri.ai.util.base.Pair;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;

/* loaded from: input_file:com/sri/ai/grinder/sgdpllt/library/set/invsupport/InversionSimplifier.class */
public class InversionSimplifier implements Simplifier {
    @Override // com.sri.ai.grinder.sgdpllt.rewriter.api.Simplifier
    public Expression applySimplifier(Expression expression, Context context) {
        return simplify(expression, context);
    }

    public static Expression simplify(Expression expression, Context context) {
        Expression expression2 = expression;
        if (isSummationIndexedByFunctionOfQuantifiers(expression, context)) {
            Pair<Expression, FunctionType> indexAndFunctionType = getIndexAndFunctionType(expression, context);
            Expression expression3 = indexAndFunctionType.first;
            FunctionType functionType = indexAndFunctionType.second;
            ArrayList arrayList = new ArrayList();
            collectQuantifiers(expression, arrayList);
            ArrayList arrayList2 = new ArrayList();
            if (isInversionPossible(expression, expression3, functionType, arrayList, arrayList2, context)) {
                expression2 = applyInversion(expression, expression3, functionType, arrayList, arrayList2, context);
            }
        }
        return expression2;
    }

    private static boolean isInversionPossible(Expression expression, Expression expression2, FunctionType functionType, List<Expression> list, List<Expression> list2, Context context) {
        boolean z = false;
        ArrayList<Expression> arrayList = new ArrayList();
        for (int i = 1; i < list.size() && list.get(i).hasFunctor(FunctorConstants.PRODUCT); i++) {
            arrayList.add(list.get(i));
        }
        ArrayList arrayList2 = new ArrayList();
        for (Expression expression3 : arrayList) {
            ArrayList arrayList3 = new ArrayList(arrayList2);
            arrayList3.add(expression3);
            updateInversionOrder(list, arrayList3, list2);
            if (isInvertible(expression, expression2, functionType, list, list2, context)) {
                arrayList2.add(expression3);
            }
        }
        if (arrayList2.size() > 0) {
            updateInversionOrder(list, arrayList2, list2);
            z = true;
        }
        return z;
    }

    private static void updateInversionOrder(List<Expression> list, List<Expression> list2, List<Expression> list3) {
        list3.clear();
        list3.addAll(list2);
        for (Expression expression : list) {
            if (!list2.contains(expression)) {
                list3.add(expression);
            }
        }
    }

    private static boolean isInvertible(Expression expression, Expression expression2, FunctionType functionType, List<Expression> list, List<Expression> list2, Context context) {
        int indexOf = list2.indexOf(expression);
        Expression head = getHead(list.get(list.size() - 1));
        Context context2 = context;
        Iterator<Expression> it = list.iterator();
        while (it.hasNext()) {
            context2 = context2.extendWith(getIndexExpressions(it.next()));
        }
        Expression expression3 = head;
        List<Expression> subList = list2.subList(indexOf + 1, list2.size());
        for (int size = subList.size() - 1; size >= 0; size--) {
            expression3 = quantifyE(expression3, subList.get(size));
        }
        Expression compute = SetOfArgumentTuplesForFunctionOccurringInExpression.compute(expression2, functionType, expression3);
        List<Expression> subList2 = list2.subList(0, indexOf);
        ArrayList arrayList = new ArrayList();
        Iterator<Expression> it2 = subList2.iterator();
        while (it2.hasNext()) {
            arrayList.add(getIndexAndType(it2.next()).first);
        }
        ArrayList arrayList2 = new ArrayList(arrayList);
        ArrayList arrayList3 = new ArrayList();
        ArrayList arrayList4 = new ArrayList();
        for (int i = 0; i < arrayList.size(); i++) {
            Expression expression4 = (Expression) arrayList.get(i);
            Expression primedUntilUnique = Expressions.primedUntilUnique(expression4, Expressions.makeTuple(arrayList2), context2);
            arrayList3.add(primedUntilUnique);
            arrayList2.add(primedUntilUnique);
            Expression primedUntilUnique2 = Expressions.primedUntilUnique(expression4, Expressions.makeTuple(arrayList2), context2);
            arrayList4.add(primedUntilUnique2);
            arrayList2.add(primedUntilUnique2);
        }
        ArrayList arrayList5 = new ArrayList();
        ArrayList arrayList6 = new ArrayList();
        for (int i2 = 0; i2 < subList2.size(); i2++) {
            Expression condition = getCondition(subList2.get(i2));
            arrayList5.add(replaceAll(condition, arrayList, arrayList3, context2));
            arrayList6.add(replaceAll(condition, arrayList, arrayList4, context2));
        }
        Expression make = Disequality.make(Expressions.makeTuple(arrayList3), Expressions.makeTuple(arrayList4));
        ArrayList arrayList7 = new ArrayList();
        arrayList7.addAll(arrayList5);
        arrayList7.addAll(arrayList6);
        arrayList7.add(make);
        Expression make2 = Implication.make(And.make((List<? extends Expression>) arrayList7), Equality.make(Sets.makeIntersection(replaceAll(compute, arrayList, arrayList3, context2), replaceAll(compute, arrayList, arrayList4, context2)), Sets.EMPTY_SET));
        ArrayList arrayList8 = new ArrayList();
        ArrayList arrayList9 = new ArrayList();
        for (int i3 = 0; i3 < subList2.size(); i3++) {
            Expression expression5 = getIndexAndType(subList2.get(i3)).second;
            arrayList8.add(IndexExpressions.makeIndexExpression((Expression) arrayList3.get(i3), expression5));
            arrayList9.add(IndexExpressions.makeIndexExpression((Expression) arrayList4.get(i3), expression5));
        }
        ArrayList arrayList10 = new ArrayList();
        arrayList10.addAll(arrayList8);
        arrayList10.addAll(arrayList9);
        Expression expression6 = make2;
        for (int size2 = arrayList10.size() - 1; size2 >= 0; size2--) {
            expression6 = ForAll.make((Expression) arrayList10.get(size2), expression6);
        }
        return Expressions.TRUE.equals(context.getTheory().evaluate(expression6, context));
    }

    private static Expression quantifyE(Expression expression, Expression expression2) {
        return Expressions.apply(expression2.getFunctor(), getIntensionalSet(expression2).setHead(expression));
    }

    private static Expression applyInversion(Expression expression, Expression expression2, FunctionType functionType, List<Expression> list, List<Expression> list2, Context context) {
        int indexOf = list2.indexOf(expression);
        List<Expression> subList = list2.subList(0, indexOf);
        Expression head = getHead(list.get(list.size() - 1));
        Context context2 = context;
        Iterator<Expression> it = list.iterator();
        while (it.hasNext()) {
            context2 = context2.extendWith(getIndexExpressions(it.next()));
        }
        Context context3 = context2;
        ArrayList arrayList = new ArrayList();
        Iterator<Expression> it2 = subList.iterator();
        while (it2.hasNext()) {
            arrayList.add(getIndexAndType(it2.next()).first);
        }
        HashSet hashSet = new HashSet();
        HashSet hashSet2 = new HashSet();
        new SubExpressionsDepthFirstIterator(head).forEachRemaining(expression3 -> {
            if (expression3.hasFunctor(expression2)) {
                for (int i = 0; i < expression3.numberOfArguments(); i++) {
                    if (context3.getTheory().isVariable(expression3.get(0), context3)) {
                        hashSet2.add(Integer.valueOf(i));
                    }
                    if (arrayList.contains(expression3.get(i)) || Util.thereExists(new SubExpressionsDepthFirstIterator(expression3.get(i)), expression3 -> {
                        return arrayList.contains(expression3);
                    })) {
                        hashSet.add(Integer.valueOf(i));
                    }
                }
            }
        });
        for (int i = 0; i < functionType.getArity(); i++) {
            if (!hashSet2.contains(Integer.valueOf(i))) {
                hashSet.add(Integer.valueOf(i));
            }
        }
        ArrayList arrayList2 = new ArrayList();
        for (int i2 = 0; i2 < functionType.getArity(); i2++) {
            if (!hashSet.contains(Integer.valueOf(i2))) {
                arrayList2.add(Expressions.parse(functionType.getArgumentTypes().get(i2).getName()));
            }
        }
        Expression parse = Expressions.parse(functionType.getCodomain().getName());
        Expression makeIndexExpression = IndexExpressions.makeIndexExpression(expression2, arrayList2.size() == 0 ? parse : FunctionType.make(parse, arrayList2));
        Expression replaceAllOccurrences = head.replaceAllOccurrences(expression4 -> {
            Expression expression4 = expression4;
            if (expression4.hasFunctor(expression2)) {
                ArrayList arrayList3 = new ArrayList();
                for (int i3 = 0; i3 < expression4.numberOfArguments(); i3++) {
                    if (!hashSet.contains(Integer.valueOf(i3))) {
                        arrayList3.add(expression4.get(i3));
                    }
                }
                expression4 = arrayList3.size() > 0 ? Expressions.apply(expression2, arrayList3) : expression2;
            }
            return expression4;
        }, context3);
        for (int i3 = indexOf + 1; i3 < list2.size(); i3++) {
            Expression expression5 = list2.get(i3);
            replaceAllOccurrences = Expressions.apply(expression5.getFunctor(), IntensionalSet.intensionalMultiSet(getIndexExpressions(expression5), replaceAllOccurrences, getCondition(expression5)));
        }
        Expression apply = Expressions.apply(FunctorConstants.SUM, IntensionalSet.intensionalMultiSet(new ExtensionalIndexExpressionsSet(makeIndexExpression), replaceAllOccurrences, getCondition(expression)));
        for (int size = subList.size() - 1; size >= 0; size--) {
            Expression expression6 = subList.get(size);
            apply = Expressions.apply(FunctorConstants.PRODUCT, IntensionalSet.intensionalMultiSet(getIndexExpressions(expression6), apply, getCondition(expression6)));
        }
        return apply;
    }

    private static void collectQuantifiers(Expression expression, List<Expression> list) {
        if (isFunctionOnIntensionalSetWithSingleIndex(null, expression)) {
            list.add(expression);
            collectQuantifiers(getHead(expression), list);
        }
    }

    private static boolean isSummationIndexedByFunctionOfQuantifiers(Expression expression, Context context) {
        boolean z = false;
        if (isFunctionOnIntensionalSetWithSingleIndex(FunctorConstants.SUM, expression)) {
            Pair<Expression, Expression> indexAndType = getIndexAndType(expression);
            if (indexAndType.second != null && indexAndType.second.hasFunctor(FunctorConstants.FUNCTION_TYPE)) {
                z = isFunctionOnIntensionalSetWithSingleIndex(FunctorConstants.PRODUCT, getHead(expression));
            }
        }
        return z;
    }

    private static boolean isFunctionOnIntensionalSetWithSingleIndex(Object obj, Expression expression) {
        boolean z = false;
        if (((obj == null && Expressions.isFunctionApplicationWithArguments(expression)) || expression.hasFunctor(obj)) && expression.numberOfArguments() == 1) {
            Expression expression2 = expression.get(0);
            if (Sets.isIntensionalSet(expression2) && IndexExpressions.getIndices(((IntensionalSet) expression2).getIndexExpressions()).size() == 1) {
                z = true;
            }
        }
        return z;
    }

    private static IntensionalSet getIntensionalSet(Expression expression) {
        return (IntensionalSet) expression.get(0);
    }

    private static IndexExpressionsSet getIndexExpressions(Expression expression) {
        return getIntensionalSet(expression).getIndexExpressions();
    }

    private static Expression getHead(Expression expression) {
        return getIntensionalSet(expression).getHead();
    }

    private static Expression getCondition(Expression expression) {
        return getIntensionalSet(expression).getCondition();
    }

    private static Pair<Expression, Expression> getIndexAndType(Expression expression) {
        List<Expression> indexExpressionsWithType = IndexExpressions.getIndexExpressionsWithType(getIndexExpressions(expression));
        if (indexExpressionsWithType.size() != 1) {
            throw new UnsupportedOperationException("Currently only support singular indices");
        }
        return IndexExpressions.getIndexAndDomain(indexExpressionsWithType.get(0));
    }

    private static Pair<Expression, FunctionType> getIndexAndFunctionType(Expression expression, Context context) {
        IndexExpressionsSet indexExpressions = getIndexExpressions(expression);
        List<Expression> indices = IndexExpressions.getIndices(indexExpressions);
        if (indices.size() != 1) {
            throw new UnsupportedOperationException("Currently only support singular indices");
        }
        Expression expression2 = indices.get(0);
        Type typeOfExpression = GrinderUtil.getTypeOfExpression(expression2, context.extendWith(indexExpressions));
        FunctionType functionType = null;
        if (typeOfExpression instanceof FunctionType) {
            functionType = (FunctionType) typeOfExpression;
        }
        return new Pair<>(expression2, functionType);
    }

    private static Expression replaceAll(Expression expression, List<Expression> list, List<Expression> list2, Context context) {
        return expression.replaceAllOccurrences(expression2 -> {
            Expression expression2 = expression2;
            int indexOf = list.indexOf(expression2);
            if (indexOf >= 0) {
                expression2 = (Expression) list2.get(indexOf);
            }
            return expression2;
        }, context);
    }

    @Override // com.sri.ai.grinder.sgdpllt.rewriter.api.Simplifier, com.sri.ai.grinder.sgdpllt.rewriter.api.Rewriter, com.sri.ai.util.base.BinaryFunction
    public /* bridge */ /* synthetic */ Expression apply(Expression expression, Context context) {
        return apply(expression, context);
    }
}
