package com.sri.ai.grinder.sgdpllt.theory.tuple.rewriter;

import com.google.common.base.Predicate;
import com.sri.ai.expresso.api.CountingFormula;
import com.sri.ai.expresso.api.Expression;
import com.sri.ai.expresso.api.IndexExpressionsSet;
import com.sri.ai.expresso.api.IntensionalSet;
import com.sri.ai.expresso.api.LambdaExpression;
import com.sri.ai.expresso.api.QuantifiedExpression;
import com.sri.ai.expresso.core.DefaultCountingFormula;
import com.sri.ai.expresso.core.DefaultLambdaExpression;
import com.sri.ai.expresso.core.ExtensionalIndexExpressionsSet;
import com.sri.ai.expresso.helper.Expressions;
import com.sri.ai.expresso.helper.SubExpressionsDepthFirstIterator;
import com.sri.ai.expresso.type.TupleType;
import com.sri.ai.grinder.sgdpllt.api.Context;
import com.sri.ai.grinder.sgdpllt.library.FunctorConstants;
import com.sri.ai.grinder.sgdpllt.library.boole.ForAll;
import com.sri.ai.grinder.sgdpllt.library.boole.ThereExists;
import com.sri.ai.grinder.sgdpllt.library.indexexpression.IndexExpressions;
import com.sri.ai.grinder.sgdpllt.rewriter.api.Simplifier;
import com.sri.ai.util.Util;
import com.sri.ai.util.base.Pair;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

/* loaded from: input_file:com/sri/ai/grinder/sgdpllt/theory/tuple/rewriter/TupleQuantifierSimplifier.class */
public class TupleQuantifierSimplifier implements Simplifier {
    @Override // com.sri.ai.grinder.sgdpllt.rewriter.api.Simplifier
    public Expression applySimplifier(Expression expression, Context context) {
        return simplify(expression, context);
    }

    public static Expression simplify(Expression expression, Context context) {
        Expression expression2 = expression;
        if (expression instanceof QuantifiedExpression) {
            QuantifiedExpression quantifiedExpression = (QuantifiedExpression) expression;
            LinkedHashMap<Expression, Expression> indexToTypeMapWithDefaultNull = IndexExpressions.getIndexToTypeMapWithDefaultNull(quantifiedExpression);
            List list = (List) indexToTypeMapWithDefaultNull.entrySet().stream().filter(entry -> {
                return entry.getValue() != null && TupleType.isTupleType((Expression) entry.getValue());
            }).collect(Collectors.toList());
            if (list.size() > 0) {
                expression2 = rewriteQuantifiedExpression(quantifiedExpression, indexToTypeMapWithDefaultNull, createTuplesOfVarsForTupleTypes(quantifiedExpression, list), context);
            }
        }
        return expression2;
    }

    private static Map<Expression, Expression> createTuplesOfVarsForTupleTypes(QuantifiedExpression quantifiedExpression, List<Map.Entry<Expression, Expression>> list) {
        HashMap hashMap = new HashMap();
        LinkedHashSet addAllToSet = Util.addAllToSet(new SubExpressionsDepthFirstIterator(quantifiedExpression));
        for (Map.Entry<Expression, Expression> entry : list) {
            ArrayList arrayList = new ArrayList();
            for (int i = 1; i <= entry.getValue().numberOfArguments(); i++) {
                arrayList.add(Expressions.primedUntilUnique(Expressions.makeSymbol(String.valueOf(entry.getKey().toString()) + "_" + i), (Predicate<Expression>) expression -> {
                    return !addAllToSet.contains(expression);
                }));
            }
            hashMap.put(entry.getKey(), Expressions.makeTuple(arrayList));
        }
        return hashMap;
    }

    private static Expression rewriteQuantifiedExpression(Expression expression, Map<Expression, Expression> map, Map<Expression, Expression> map2, Context context) {
        Expression rewriteIntensionalSet;
        if (ForAll.isForAll(expression)) {
            rewriteIntensionalSet = rewriteForAll(expression, map, map2, context);
        } else if (ThereExists.isThereExists(expression)) {
            rewriteIntensionalSet = rewriteThereExists(expression, map, map2, context);
        } else if (expression instanceof CountingFormula) {
            rewriteIntensionalSet = rewriteCountingFormula((CountingFormula) expression, map, map2, context);
        } else if (expression instanceof LambdaExpression) {
            rewriteIntensionalSet = rewriteLambdaExpression((LambdaExpression) expression, map, map2, context);
        } else {
            if (!(expression instanceof IntensionalSet)) {
                throw new UnsupportedOperationException("Quantifer currently not supported : " + expression);
            }
            rewriteIntensionalSet = rewriteIntensionalSet((IntensionalSet) expression, map, map2, context);
        }
        return rewriteIntensionalSet;
    }

    private static Expression rewriteForAll(Expression expression, Map<Expression, Expression> map, Map<Expression, Expression> map2, Context context) {
        if (map.size() > 1) {
            throw new IllegalStateException("We have a Universal Quantifier with > 1 index : " + expression);
        }
        Pair<IndexExpressionsSet, Expression> update = update(ForAll.getBody(expression), map, map2, context);
        return ForAll.make(update.first, update.second);
    }

    private static Expression rewriteThereExists(Expression expression, Map<Expression, Expression> map, Map<Expression, Expression> map2, Context context) {
        if (map.size() > 1) {
            throw new IllegalStateException("We have an Existential Quantifier with > 1 index : " + expression);
        }
        Pair<IndexExpressionsSet, Expression> update = update(ThereExists.getBody(expression), map, map2, context);
        return ThereExists.make(update.first, update.second);
    }

    private static Expression rewriteCountingFormula(CountingFormula countingFormula, Map<Expression, Expression> map, Map<Expression, Expression> map2, Context context) {
        Pair<IndexExpressionsSet, Expression> update = update(countingFormula.getBody(), map, map2, context);
        return new DefaultCountingFormula(update.first, update.second);
    }

    private static Expression rewriteLambdaExpression(LambdaExpression lambdaExpression, Map<Expression, Expression> map, Map<Expression, Expression> map2, Context context) {
        Pair<IndexExpressionsSet, Expression> update = update(lambdaExpression.getBody(), map, map2, context);
        return new DefaultLambdaExpression(update.first, update.second);
    }

    private static Expression rewriteIntensionalSet(IntensionalSet intensionalSet, Map<Expression, Expression> map, Map<Expression, Expression> map2, Context context) {
        Pair<IndexExpressionsSet, Expression> update = update(Expressions.makeTuple(intensionalSet.getHead(), intensionalSet.getCondition()), map, map2, context);
        Expression expression = update.second.get(0);
        Expression expression2 = update.second.get(1);
        return intensionalSet.isUniSet() ? IntensionalSet.intensionalUniSet(update.first, expression, expression2) : IntensionalSet.intensionalMultiSet(update.first, expression, expression2);
    }

    private static Pair<IndexExpressionsSet, Expression> update(Expression expression, Map<Expression, Expression> map, Map<Expression, Expression> map2, Context context) {
        ArrayList arrayList = new ArrayList();
        Expression expression2 = expression;
        for (Map.Entry<Expression, Expression> entry : map.entrySet()) {
            Expression key = entry.getKey();
            Expression value = entry.getValue();
            Expression expression3 = map2.get(key);
            if (expression3 != null) {
                expression2 = expression2.replaceAllOccurrences(key, expression3, context);
                for (int i = 0; i < value.numberOfArguments(); i++) {
                    arrayList.add(Expressions.apply(FunctorConstants.IN, expression3.get(i), value.get(i)));
                }
            } else if (value == null) {
                arrayList.add(key);
            } else {
                arrayList.add(Expressions.apply(FunctorConstants.IN, key, value));
            }
        }
        return new Pair<>(new ExtensionalIndexExpressionsSet(arrayList), expression2);
    }

    @Override // com.sri.ai.grinder.sgdpllt.rewriter.api.Simplifier, com.sri.ai.grinder.sgdpllt.rewriter.api.Rewriter, com.sri.ai.util.base.BinaryFunction
    public /* bridge */ /* synthetic */ Expression apply(Expression expression, Context context) {
        return apply(expression, context);
    }
}
