package IncrementalAnytimeExactBeliefPropagation;

import IncrementalAnytimeExactBeliefPropagation.Model.Model;
import IncrementalAnytimeExactBeliefPropagation.Model.Node.FactorNode;
import IncrementalAnytimeExactBeliefPropagation.Model.Node.VariableNode;
import com.sri.ai.expresso.core.DefaultExtensionalMultiSet;
import com.sri.ai.grinder.sgdpllt.library.bounds.Bound;
import com.sri.ai.grinder.sgdpllt.library.bounds.Bounds;
import com.sri.ai.util.Util;
import com.sri.ai.util.base.PairOf;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Set;

/* loaded from: input_file:IncrementalAnytimeExactBeliefPropagation/IncrementalAnytimeBeliefPropagationWithSeparatorConditioning.class */
public class IncrementalAnytimeBeliefPropagationWithSeparatorConditioning {
    private Model model;
    private boolean allExplored = false;
    public PartitionTree partitionTree;
    private Iterator<PartitionTree> partitionTreeIterator;

    public IncrementalAnytimeBeliefPropagationWithSeparatorConditioning(Model model, Iterator<PartitionTree> it) {
        this.model = model;
        this.partitionTreeIterator = it;
        if (it.hasNext()) {
            this.partitionTree = it.next();
        } else {
            this.partitionTree = null;
        }
    }

    public boolean isAllExplored() {
        return !this.partitionTreeIterator.hasNext();
    }

    public Bound expandAndComputeInferenceByRebuildingPartitionTree() {
        if (!this.partitionTreeIterator.hasNext()) {
            return null;
        }
        this.model.ExpandModel((FactorNode) this.partitionTreeIterator.next().node);
        return inference().normalize(this.model.getTheory(), this.model.getContext());
    }

    public Bound expandAndComputeInference() {
        if (!this.partitionTreeIterator.hasNext()) {
            return null;
        }
        PartitionTree next = this.partitionTreeIterator.next();
        this.model.ExpandModel((FactorNode) next.node);
        updatePartitionTree(next);
        return this.partitionTree.node.getBound().normalize(this.model.getTheory(), this.model.getContext());
    }

    private void updatePartitionTree(PartitionTree partitionTree) {
        FactorNode factorNode = (FactorNode) partitionTree.node;
        Collection<VariableNode> variablesOfAFactor = this.model.getVariablesOfAFactor(factorNode);
        updateSetOfFactorsInPartitionTree(partitionTree, factorNode);
        updateSetOfVariablesInPartitionTree(partitionTree, variablesOfAFactor);
        updateCutSet(partitionTree, factorNode);
        updateBounds();
    }

    private void updateSetOfFactorsInPartitionTree(PartitionTree partitionTree, FactorNode factorNode) {
        while (partitionTree != null) {
            partitionTree.setOfFactors.add(factorNode);
            partitionTree = partitionTree.parent;
        }
    }

    private void updateSetOfVariablesInPartitionTree(PartitionTree partitionTree, Collection<VariableNode> collection) {
        while (partitionTree != null) {
            partitionTree.setOfVariables.addAll(collection);
            partitionTree.setOfVariables.remove(partitionTree.node);
            partitionTree = partitionTree.parent;
        }
    }

    private void updateCutSet(PartitionTree partitionTree, FactorNode factorNode) {
        Collection<VariableNode> variablesOfAFactor = this.model.getVariablesOfAFactor(factorNode);
        Iterator<PartitionTree> it = partitionTree.children.iterator();
        while (it.hasNext()) {
            variablesOfAFactor.remove(it.next().node);
        }
        variablesOfAFactor.remove(partitionTree.parent.node);
        addingToCutSet(partitionTree, variablesOfAFactor, null);
    }

    private void addingToCutSet(PartitionTree partitionTree, Collection<VariableNode> collection, PartitionTree partitionTree2) {
        if (partitionTree != null && partitionTree.parent != null) {
            addingToCutSet(partitionTree.parent, collection, partitionTree);
            partitionTree.cutsetOfAllLevelsAbove.addAll(partitionTree.parent.separator);
            partitionTree.cutsetOfAllLevelsAbove.addAll(partitionTree.parent.cutsetOfAllLevelsAbove);
            partitionTree.separator.removeAll(partitionTree.cutsetOfAllLevelsAbove);
        }
        ArrayList arrayList = new ArrayList();
        for (PartitionTree partitionTree3 : partitionTree.children) {
            if (!partitionTree3.equals(partitionTree2)) {
                HashSet hashSet = new HashSet();
                hashSet.addAll(collection);
                hashSet.retainAll(partitionTree3.setOfVariables);
                arrayList.addAll(hashSet);
            }
        }
        collection.removeAll(arrayList);
        partitionTree.separator.addAll(arrayList);
        partitionTree.recomputeBound = true;
        for (PartitionTree partitionTree4 : partitionTree.children) {
            if (!partitionTree4.equals(partitionTree2)) {
                updateLASandSeparator(partitionTree4);
            }
        }
    }

    private void updateLASandSeparator(PartitionTree partitionTree) {
        partitionTree.recomputeBound = true;
        HashSet hashSet = new HashSet();
        if (partitionTree.parent == null) {
            return;
        }
        hashSet.addAll(partitionTree.parent.cutsetOfAllLevelsAbove);
        hashSet.addAll(partitionTree.parent.separator);
        if (thisSetIncreasestheLAS(hashSet, partitionTree)) {
            partitionTree.cutsetOfAllLevelsAbove.addAll(hashSet);
            partitionTree.separator.removeAll(hashSet);
            Iterator<PartitionTree> it = partitionTree.children.iterator();
            while (it.hasNext()) {
                updateLASandSeparator(it.next());
            }
        }
    }

    private boolean thisSetIncreasestheLAS(Collection<VariableNode> collection, PartitionTree partitionTree) {
        HashSet hashSet = new HashSet();
        hashSet.addAll(collection);
        hashSet.removeAll(partitionTree.cutsetOfAllLevelsAbove);
        return !hashSet.isEmpty();
    }

    private void updateBounds() {
        updateBounds(this.partitionTree);
    }

    private void updateBounds(PartitionTree partitionTree) {
        if (partitionTree.recomputeBound) {
            partitionTree.recomputeBound = false;
            if (partitionTree.node.isVariable() && !this.model.isExhausted((VariableNode) partitionTree.node)) {
                partitionTree.node.setBound(Bounds.simplex(Util.arrayList(partitionTree.node.getValue()), this.model.getTheory(), this.model.getContext(), this.model.isExtensional()));
                return;
            }
            Iterator<PartitionTree> it = partitionTree.children.iterator();
            while (it.hasNext()) {
                updateBounds(it.next());
            }
            if (partitionTree.node.isFactor()) {
                partitionTree.node.setBound(factorMessage(partitionTree));
            }
            if (partitionTree.node.isVariable()) {
                partitionTree.node.setBound(variableMessage(partitionTree));
            }
        }
    }

    private Bound factorMessage(PartitionTree partitionTree) {
        HashSet hashSet = new HashSet();
        Iterator<VariableNode> it = partitionTree.separator.iterator();
        while (it.hasNext()) {
            hashSet.add(it.next().getValue());
        }
        Bound[] boundArr = new Bound[partitionTree.children.size()];
        int i = 0;
        for (PartitionTree partitionTree2 : partitionTree.children) {
            boundArr[i] = partitionTree2.node.getBound();
            hashSet.add(partitionTree2.node.getValue());
            i++;
        }
        Iterator<VariableNode> it2 = partitionTree.cutsetOfAllLevelsAbove.iterator();
        while (it2.hasNext()) {
            hashSet.remove(it2.next().getValue());
        }
        Bound boundProduct = Bounds.boundProduct(this.model.getTheory(), this.model.getContext(), this.model.isExtensional(), boundArr);
        ArrayList arrayList = new ArrayList();
        arrayList.addAll(hashSet);
        return boundProduct.summingPhiTimesBound(new DefaultExtensionalMultiSet(arrayList), partitionTree.node.getValue(), this.model.getContext(), this.model.getTheory());
    }

    private Bound variableMessage(PartitionTree partitionTree) {
        HashSet hashSet = new HashSet();
        Iterator<VariableNode> it = partitionTree.separator.iterator();
        while (it.hasNext()) {
            hashSet.add(it.next().getValue());
        }
        Bound[] boundArr = new Bound[partitionTree.children.size()];
        int i = 0;
        Iterator<PartitionTree> it2 = partitionTree.children.iterator();
        while (it2.hasNext()) {
            boundArr[i] = it2.next().node.getBound();
            i++;
        }
        Bound boundProduct = Bounds.boundProduct(this.model.getTheory(), this.model.getContext(), this.model.isExtensional(), boundArr);
        ArrayList arrayList = new ArrayList();
        arrayList.addAll(hashSet);
        return boundProduct.summingBound(new DefaultExtensionalMultiSet(arrayList), this.model.getContext(), this.model.getTheory());
    }

    public Bound inferenceOverEntireModel() {
        this.model.SetExploredGraphToEntireGraph();
        return inference().normalize(this.model.getTheory(), this.model.getContext());
    }

    public Bound inference() {
        this.partitionTree = new PartitionTree(this.model.getQuery(), this.model);
        this.allExplored = this.model.AllExplored();
        return variableMessage(this.partitionTree, new HashSet());
    }

    private Bound variableMessage(PartitionTree partitionTree, Set<VariableNode> set) {
        if (!partitionTree.node.isVariable()) {
            Util.println("error in S-BP!!!");
            return null;
        }
        PairOf<Set<VariableNode>> computeSeparatorOnThisLevelAndSeparatorOnLevelsBelow = computeSeparatorOnThisLevelAndSeparatorOnLevelsBelow(partitionTree, set);
        Set set2 = (Set) computeSeparatorOnThisLevelAndSeparatorOnLevelsBelow.first;
        Set<VariableNode> set3 = (Set) computeSeparatorOnThisLevelAndSeparatorOnLevelsBelow.second;
        Bound[] boundArr = new Bound[partitionTree.children.size()];
        HashSet hashSet = new HashSet();
        Iterator it = set2.iterator();
        while (it.hasNext()) {
            hashSet.add(((VariableNode) it.next()).getValue());
        }
        if (!this.allExplored && !this.model.isExhausted((VariableNode) partitionTree.node)) {
            return Bounds.simplex(Util.arrayList(partitionTree.node.getValue()), this.model.getTheory(), this.model.getContext(), this.model.isExtensional());
        }
        int i = 0;
        Iterator<PartitionTree> it2 = partitionTree.children.iterator();
        while (it2.hasNext()) {
            boundArr[i] = factorMessage(it2.next(), set3);
            i++;
        }
        Bound boundProduct = Bounds.boundProduct(this.model.getTheory(), this.model.getContext(), this.model.isExtensional(), boundArr);
        ArrayList arrayList = new ArrayList();
        arrayList.addAll(hashSet);
        return boundProduct.summingBound(new DefaultExtensionalMultiSet(arrayList), this.model.getContext(), this.model.getTheory());
    }

    private Bound factorMessage(PartitionTree partitionTree, Set<VariableNode> set) {
        if (!partitionTree.node.isFactor()) {
            Util.println("error in S-BP!!!");
            return null;
        }
        PairOf<Set<VariableNode>> computeSeparatorOnThisLevelAndSeparatorOnLevelsBelow = computeSeparatorOnThisLevelAndSeparatorOnLevelsBelow(partitionTree, set);
        Set set2 = (Set) computeSeparatorOnThisLevelAndSeparatorOnLevelsBelow.first;
        Set<VariableNode> set3 = (Set) computeSeparatorOnThisLevelAndSeparatorOnLevelsBelow.second;
        Bound[] boundArr = new Bound[partitionTree.children.size()];
        HashSet hashSet = new HashSet();
        Iterator it = set2.iterator();
        while (it.hasNext()) {
            hashSet.add(((VariableNode) it.next()).getValue());
        }
        int i = 0;
        for (PartitionTree partitionTree2 : partitionTree.children) {
            boundArr[i] = variableMessage(partitionTree2, set3);
            hashSet.add(partitionTree2.node.getValue());
            i++;
        }
        Iterator<VariableNode> it2 = set.iterator();
        while (it2.hasNext()) {
            hashSet.remove(it2.next().getValue());
        }
        Bound boundProduct = Bounds.boundProduct(this.model.getTheory(), this.model.getContext(), this.model.isExtensional(), boundArr);
        ArrayList arrayList = new ArrayList();
        arrayList.addAll(hashSet);
        return boundProduct.summingPhiTimesBound(new DefaultExtensionalMultiSet(arrayList), partitionTree.node.getValue(), this.model.getContext(), this.model.getTheory());
    }

    private Set<VariableNode> computeSeparator(PartitionTree partitionTree) {
        ArrayList arrayList = new ArrayList();
        for (PartitionTree partitionTree2 : partitionTree.children) {
            HashSet hashSet = new HashSet();
            Iterator<FactorNode> it = partitionTree2.setOfFactors.iterator();
            while (it.hasNext()) {
                hashSet.addAll(this.model.getExploredGraph().getAsOfB(it.next()));
            }
            arrayList.add(hashSet);
        }
        HashSet hashSet2 = new HashSet();
        for (int i = 0; i < arrayList.size(); i++) {
            for (int i2 = i + 1; i2 < arrayList.size(); i2++) {
                HashSet hashSet3 = new HashSet();
                hashSet3.addAll((Collection) arrayList.get(i));
                hashSet3.retainAll((Collection) arrayList.get(i2));
                hashSet2.addAll(hashSet3);
            }
        }
        return hashSet2;
    }

    private PairOf<Set<VariableNode>> computeSeparatorOnThisLevelAndSeparatorOnLevelsBelow(PartitionTree partitionTree, Set<VariableNode> set) {
        Set<VariableNode> computeSeparator = computeSeparator(partitionTree);
        if (partitionTree.node.isVariable()) {
            computeSeparator.remove((VariableNode) partitionTree.node);
        }
        computeSeparator.removeAll(set);
        HashSet hashSet = new HashSet();
        hashSet.addAll(computeSeparator);
        hashSet.addAll(set);
        return new PairOf<>(computeSeparator, hashSet);
    }

    public Model getModel() {
        return this.model;
    }
}
