package com.sri.ai.grinder.sgdpllt.core.solver;

import com.google.common.base.Function;
import com.sri.ai.expresso.api.Expression;
import com.sri.ai.expresso.helper.Expressions;
import com.sri.ai.expresso.helper.SubExpressionsDepthFirstIterator;
import com.sri.ai.grinder.sgdpllt.api.Context;
import com.sri.ai.grinder.sgdpllt.api.MultiIndexQuantifierEliminator;
import com.sri.ai.grinder.sgdpllt.group.AssociativeCommutativeGroup;
import com.sri.ai.grinder.sgdpllt.group.AssociativeCommutativeSemiRing;
import com.sri.ai.grinder.sgdpllt.library.boole.And;
import com.sri.ai.grinder.sgdpllt.library.controlflow.IfThenElse;
import com.sri.ai.util.Util;
import com.sri.ai.util.base.PairOf;
import java.util.Collection;
import java.util.LinkedHashSet;
import java.util.List;
import org.apache.commons.lang3.StringUtils;

/* loaded from: input_file:com/sri/ai/grinder/sgdpllt/core/solver/AbstractSGVET.class */
public class AbstractSGVET extends AbstractMultiIndexQuantifierEliminator {
    protected MultiIndexQuantifierEliminator subSolver;
    public boolean basicOutput = false;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/sri/ai/grinder/sgdpllt/core/solver/AbstractSGVET$Partition.class */
    public static class Partition {
        private List<Expression> index;
        private List<Expression> remainingIndices;
        private PairOf<List<Expression>> expressionsOnIndexAndNot;

        public Partition(Expression expression, List<Expression> list, PairOf<List<Expression>> pairOf) {
            this.index = Util.list(expression);
            this.remainingIndices = list;
            this.expressionsOnIndexAndNot = pairOf;
        }

        public boolean isTrivial() {
            return ((List) this.expressionsOnIndexAndNot.first).isEmpty() || ((List) this.expressionsOnIndexAndNot.second).isEmpty();
        }
    }

    public AbstractSGVET(MultiIndexQuantifierEliminator multiIndexQuantifierEliminator) {
        this.subSolver = multiIndexQuantifierEliminator;
    }

    public boolean isVariable(Expression expression, Context context) {
        return context.getTheory().isVariable(expression, context);
    }

    @Override // com.sri.ai.grinder.sgdpllt.core.solver.AbstractMultiIndexQuantifierEliminator, com.sri.ai.grinder.sgdpllt.api.MultiIndexQuantifierEliminator
    public void interrupt() {
        super.interrupt();
        this.subSolver.interrupt();
    }

    @Override // com.sri.ai.grinder.sgdpllt.api.MultiIndexQuantifierEliminator
    public Expression solve(AssociativeCommutativeGroup associativeCommutativeGroup, List<Expression> list, Expression expression, Expression expression2, Context context) {
        Expression multiply;
        checkInterrupted();
        Expression evaluate = context.getTheory().evaluate(expression2, context);
        if (getDebug()) {
            System.out.println("SGVE(T) input: " + evaluate);
            System.out.println("Width        : " + width(evaluate, context));
        }
        AssociativeCommutativeSemiRing associativeCommutativeSemiRing = (AssociativeCommutativeSemiRing) associativeCommutativeGroup;
        Partition pickPartition = list.size() < 1 ? null : pickPartition(associativeCommutativeSemiRing, factoredConditionalsWithAbsorbingElseClause(associativeCommutativeSemiRing, evaluate, context), list, context);
        if (pickPartition == null) {
            if (this.basicOutput) {
                System.out.println("No partition");
            }
            multiply = this.subSolver.solve(associativeCommutativeGroup, list, expression, evaluate, context);
        } else {
            Expression product = product(associativeCommutativeSemiRing, (Collection) pickPartition.expressionsOnIndexAndNot.first, context);
            if (this.basicOutput) {
                System.out.println("Eliminating: " + Util.getFirst(pickPartition.index));
                System.out.println("From       : " + product);
                System.out.println("Width      : " + width(product, context) + " out of " + list.size() + " indices");
            }
            Expression solve = this.subSolver.solve(associativeCommutativeGroup, pickPartition.index, IfThenElse.make(expression, product, associativeCommutativeSemiRing.multiplicativeAbsorbingElement()), context);
            if (this.basicOutput) {
                System.out.println("Solution   : " + solve + StringUtils.LF);
            }
            ((List) pickPartition.expressionsOnIndexAndNot.second).add(solve);
            multiply = associativeCommutativeSemiRing.multiply(solve(associativeCommutativeGroup, pickPartition.remainingIndices, context, product(associativeCommutativeSemiRing, (Collection) pickPartition.expressionsOnIndexAndNot.second, context), context), context);
        }
        return multiply;
    }

    private Partition pickPartition(AssociativeCommutativeSemiRing associativeCommutativeSemiRing, Expression expression, Collection<Expression> collection, Context context) {
        Partition partition;
        if (collection.isEmpty()) {
            partition = null;
        } else {
            partition = (Partition) Util.argmin(Util.mapIntoList(collection, makePartition(collection, associativeCommutativeSemiRing.getFactors(expression))), width(associativeCommutativeSemiRing, context));
            if (partition.isTrivial()) {
                partition = null;
            }
        }
        return partition;
    }

    public Function<Expression, Partition> makePartition(Collection<Expression> collection, List<Expression> list) {
        return expression -> {
            return pickPartitionForIndex(expression, collection, list);
        };
    }

    public Partition pickPartitionForIndex(Expression expression, Collection<Expression> collection, List<Expression> list) {
        return new Partition(expression, Util.removeNonDestructively(collection, expression), Util.collectToLists(list, expression2 -> {
            return Expressions.isSubExpressionOf(expression, expression2);
        }));
    }

    public Function<Partition, Integer> width(AssociativeCommutativeSemiRing associativeCommutativeSemiRing, Context context) {
        return partition -> {
            return Integer.valueOf(width(associativeCommutativeSemiRing, partition, context));
        };
    }

    private int width(AssociativeCommutativeSemiRing associativeCommutativeSemiRing, Partition partition, Context context) {
        return width(product(associativeCommutativeSemiRing, (Collection) partition.expressionsOnIndexAndNot.first, context), context);
    }

    private int width(Expression expression, Context context) {
        LinkedHashSet linkedHashSet = new LinkedHashSet();
        SubExpressionsDepthFirstIterator subExpressionsDepthFirstIterator = new SubExpressionsDepthFirstIterator(expression);
        while (subExpressionsDepthFirstIterator.hasNext()) {
            Expression next = subExpressionsDepthFirstIterator.next();
            if (isVariable(next, context)) {
                linkedHashSet.add(next);
            }
        }
        return linkedHashSet.size();
    }

    public Expression factoredConditionalsWithAbsorbingElseClause(AssociativeCommutativeSemiRing associativeCommutativeSemiRing, Expression expression, Context context) {
        List<Expression> factors = associativeCommutativeSemiRing.getFactors(expression);
        List<Expression> factoredConditionalsWithAbsorbingElseClause = factoredConditionalsWithAbsorbingElseClause(associativeCommutativeSemiRing, factors);
        return factoredConditionalsWithAbsorbingElseClause == factors ? expression : product(associativeCommutativeSemiRing, factoredConditionalsWithAbsorbingElseClause, context);
    }

    private List<Expression> factoredConditionalsWithAbsorbingElseClause(AssociativeCommutativeSemiRing associativeCommutativeSemiRing, List<Expression> list) {
        return Util.nonDestructivelyExpandElementsIfFunctionReturnsNonNullCollection(list, expression -> {
            return factorConditionalIfPossible(associativeCommutativeSemiRing, expression);
        });
    }

    private List<Expression> factorConditionalIfPossible(AssociativeCommutativeSemiRing associativeCommutativeSemiRing, Expression expression) {
        Expression nthRoot;
        List<Expression> list = null;
        if (IfThenElse.isIfThenElse(expression) && elseBranchIsAbsorbing(associativeCommutativeSemiRing, expression) && conditionIsConjunction(expression) && (nthRoot = associativeCommutativeSemiRing.getNthRoot(numberOfConjuncts(expression), IfThenElse.thenBranch(expression))) != null) {
            list = Util.mapIntoList(And.getConjuncts(IfThenElse.condition(expression)), expression2 -> {
                return IfThenElse.make(expression2, nthRoot, associativeCommutativeSemiRing.multiplicativeAbsorbingElement());
            });
        }
        return list;
    }

    public boolean conditionIsConjunction(Expression expression) {
        return And.isConjunction(IfThenElse.condition(expression));
    }

    public boolean elseBranchIsAbsorbing(AssociativeCommutativeSemiRing associativeCommutativeSemiRing, Expression expression) {
        return IfThenElse.elseBranch(expression).equals(associativeCommutativeSemiRing.multiplicativeAbsorbingElement());
    }

    public int numberOfConjuncts(Expression expression) {
        return IfThenElse.condition(expression).numberOfArguments();
    }

    private Expression product(AssociativeCommutativeSemiRing associativeCommutativeSemiRing, Collection<Expression> collection, Context context) {
        return associativeCommutativeSemiRing.multiply(Expressions.apply(associativeCommutativeSemiRing.multiplicativeFunctor(), collection), context);
    }

    public String toString() {
        return "SGVE(T)";
    }
}
