package com.sri.ai.grinder.sgdpllt.theory.tuple.rewriter;

import com.sri.ai.expresso.api.Expression;
import com.sri.ai.expresso.helper.Expressions;
import com.sri.ai.expresso.type.TupleType;
import com.sri.ai.grinder.helper.GrinderUtil;
import com.sri.ai.grinder.sgdpllt.api.Context;
import com.sri.ai.grinder.sgdpllt.library.FunctorConstants;
import com.sri.ai.grinder.sgdpllt.rewriter.api.Simplifier;
import com.sri.ai.util.base.Pair;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

/* loaded from: input_file:com/sri/ai/grinder/sgdpllt/theory/tuple/rewriter/TupleValuedFreeVariablesSimplifier.class */
public class TupleValuedFreeVariablesSimplifier implements Simplifier {
    private Expression lastSimplifiedExpression;

    @Override // com.sri.ai.grinder.sgdpllt.rewriter.api.Simplifier
    public Expression applySimplifier(Expression expression, Context context) {
        Expression expression2 = expression;
        if (expression != this.lastSimplifiedExpression && (this.lastSimplifiedExpression == null || !Expressions.isSubExpressionOf(expression, this.lastSimplifiedExpression))) {
            expression2 = simplify(expression, context);
            if (expression2 != expression) {
                this.lastSimplifiedExpression = expression2;
            }
        }
        return expression2;
    }

    public static Expression simplify(Expression expression, Context context) {
        Expression expression2 = expression;
        Map<Expression, Expression> freeVariablesAndTypes = Expressions.freeVariablesAndTypes(expression, context);
        if (freeVariablesAndTypes.size() > 0) {
            Map map = (Map) freeVariablesAndTypes.entrySet().stream().filter(entry -> {
                return entry.getValue() != null && TupleType.isTupleType((Expression) entry.getValue());
            }).collect(Collectors.toMap(entry2 -> {
                return (Expression) entry2.getKey();
            }, entry3 -> {
                return (TupleType) GrinderUtil.fromTypeExpressionToItsIntrinsicMeaning((Expression) entry3.getValue(), context);
            }));
            if (map.size() > 0) {
                Map<Expression, List<Pair<Expression, Integer>>> constructComponentMap = constructComponentMap(map, expression, context);
                Expression replaceAllOccurrences = expression.replaceAllOccurrences(expression3 -> {
                    Expression expression3 = expression3;
                    List list = (List) constructComponentMap.get(expression3);
                    if (list != null) {
                        expression3 = constructComponentTuple(list);
                    }
                    return expression3;
                }, context);
                Context extendContextWithComponentVariables = extendContextWithComponentVariables(context, map, constructComponentMap);
                Expression evaluate = context.getTheory().evaluate(replaceAllOccurrences, extendContextWithComponentVariables);
                Map<Expression, Pair<Expression, Integer>> createReverseLookupMap = createReverseLookupMap(constructComponentMap);
                expression2 = evaluate.replaceAllOccurrences(expression4 -> {
                    Expression expression4 = expression4;
                    Pair pair = (Pair) createReverseLookupMap.get(expression4);
                    if (pair != null) {
                        expression4 = Expressions.apply(FunctorConstants.GET, pair.first, pair.second);
                    }
                    return expression4;
                }, extendContextWithComponentVariables);
            }
        }
        return expression2;
    }

    private static Map<Expression, List<Pair<Expression, Integer>>> constructComponentMap(Map<Expression, TupleType> map, Expression expression, Context context) {
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        for (Map.Entry<Expression, TupleType> entry : map.entrySet()) {
            Expression key = entry.getKey();
            TupleType value = entry.getValue();
            ArrayList arrayList = new ArrayList();
            int arity = value.getArity();
            for (int i = 1; i <= arity; i++) {
                arrayList.add(new Pair(Expressions.makeUniqueVariable(String.valueOf(key.toString()) + i, expression, context), Integer.valueOf(i)));
            }
            linkedHashMap.put(key, arrayList);
        }
        return linkedHashMap;
    }

    private static Expression constructComponentTuple(List<Pair<Expression, Integer>> list) {
        return Expressions.makeTuple((List<Expression>) list.stream().map(pair -> {
            return (Expression) pair.first;
        }).collect(Collectors.toList()));
    }

    private static Context extendContextWithComponentVariables(Context context, Map<Expression, TupleType> map, Map<Expression, List<Pair<Expression, Integer>>> map2) {
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        LinkedHashSet linkedHashSet = new LinkedHashSet();
        for (Map.Entry<Expression, TupleType> entry : map.entrySet()) {
            Expression key = entry.getKey();
            TupleType value = entry.getValue();
            linkedHashSet.addAll(value.getElementTypes());
            for (Pair<Expression, Integer> pair : map2.get(key)) {
                linkedHashMap.put(pair.first.toString(), value.getElementTypes().get(pair.second.intValue() - 1).getName());
            }
        }
        return (Context) GrinderUtil.extendRegistryWith(linkedHashMap, linkedHashSet, context);
    }

    private static Map<Expression, Pair<Expression, Integer>> createReverseLookupMap(Map<Expression, List<Pair<Expression, Integer>>> map) {
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        for (Map.Entry<Expression, List<Pair<Expression, Integer>>> entry : map.entrySet()) {
            Expression key = entry.getKey();
            for (Pair<Expression, Integer> pair : entry.getValue()) {
                linkedHashMap.put(pair.first, new Pair(key, pair.second));
            }
        }
        return linkedHashMap;
    }

    @Override // com.sri.ai.grinder.sgdpllt.rewriter.api.Simplifier, com.sri.ai.grinder.sgdpllt.rewriter.api.Rewriter, com.sri.ai.util.base.BinaryFunction
    public /* bridge */ /* synthetic */ Expression apply(Expression expression, Context context) {
        return apply(expression, context);
    }
}
