package anytimeExactBeliefPropagation;

import anytimeExactBeliefPropagation.Model.Model;
import anytimeExactBeliefPropagation.Model.Node.FactorNode;
import anytimeExactBeliefPropagation.Model.Node.VariableNode;
import com.sri.ai.expresso.core.DefaultExtensionalMultiSet;
import com.sri.ai.grinder.sgdpllt.library.bounds.Bound;
import com.sri.ai.grinder.sgdpllt.library.bounds.Bounds;
import com.sri.ai.util.Util;
import com.sri.ai.util.base.PairOf;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Set;

/* loaded from: input_file:anytimeExactBeliefPropagation/IncrementalBeliefPropagationWithConditioning.class */
public class IncrementalBeliefPropagationWithConditioning {
    private Model model;
    private boolean AllExplored = false;
    public PartitionTree partitionTree;

    public IncrementalBeliefPropagationWithConditioning(Model model) {
        this.model = model;
    }

    public Bound expandAndComputeInference(Iterator<FactorNode> it) {
        if (!it.hasNext()) {
            return null;
        }
        this.model.ExpandModel(it);
        return inference();
    }

    public Bound InferenceOverEntireModel() {
        this.model.SetExploredGraphToEntireGraph();
        return inference();
    }

    public Bound inference() {
        this.partitionTree = new PartitionTree(this.model.getQuery(), this.model);
        this.AllExplored = this.model.AllExplored();
        return variableMessage(this.partitionTree, new HashSet());
    }

    public Bound variableMessage(PartitionTree partitionTree, Set<VariableNode> set) {
        if (!partitionTree.node.isVariable()) {
            Util.println("error in S-BP!!!");
            return null;
        }
        PairOf<Set<VariableNode>> ComputeSeparatorOnThisLevelAndSeparatorOnLevelsBelow = ComputeSeparatorOnThisLevelAndSeparatorOnLevelsBelow(partitionTree, set);
        Set set2 = (Set) ComputeSeparatorOnThisLevelAndSeparatorOnLevelsBelow.first;
        Set<VariableNode> set3 = (Set) ComputeSeparatorOnThisLevelAndSeparatorOnLevelsBelow.second;
        Bound[] boundArr = new Bound[partitionTree.children.size()];
        HashSet hashSet = new HashSet();
        Iterator it = set2.iterator();
        while (it.hasNext()) {
            hashSet.add(((VariableNode) it.next()).getValue());
        }
        if (!this.AllExplored && !this.model.isExhausted((VariableNode) partitionTree.node)) {
            return Bounds.simplex(Util.arrayList(partitionTree.node.getValue()), this.model.getTheory(), this.model.getContext(), this.model.isExtensional());
        }
        int i = 0;
        Iterator<PartitionTree> it2 = partitionTree.children.iterator();
        while (it2.hasNext()) {
            boundArr[i] = factorMessage(it2.next(), set3);
            i++;
        }
        Bound boundProduct = Bounds.boundProduct(this.model.getTheory(), this.model.getContext(), this.model.isExtensional(), boundArr);
        ArrayList arrayList = new ArrayList();
        arrayList.addAll(hashSet);
        return boundProduct.summingBound(new DefaultExtensionalMultiSet(arrayList), this.model.getContext(), this.model.getTheory());
    }

    public Bound factorMessage(PartitionTree partitionTree, Set<VariableNode> set) {
        if (!partitionTree.node.isFactor()) {
            Util.println("error in S-BP!!!");
            return null;
        }
        PairOf<Set<VariableNode>> ComputeSeparatorOnThisLevelAndSeparatorOnLevelsBelow = ComputeSeparatorOnThisLevelAndSeparatorOnLevelsBelow(partitionTree, set);
        Set set2 = (Set) ComputeSeparatorOnThisLevelAndSeparatorOnLevelsBelow.first;
        Set<VariableNode> set3 = (Set) ComputeSeparatorOnThisLevelAndSeparatorOnLevelsBelow.second;
        Bound[] boundArr = new Bound[partitionTree.children.size()];
        HashSet hashSet = new HashSet();
        Iterator it = set2.iterator();
        while (it.hasNext()) {
            hashSet.add(((VariableNode) it.next()).getValue());
        }
        int i = 0;
        for (PartitionTree partitionTree2 : partitionTree.children) {
            boundArr[i] = variableMessage(partitionTree2, set3);
            hashSet.add(partitionTree2.node.getValue());
            i++;
        }
        Iterator<VariableNode> it2 = set.iterator();
        while (it2.hasNext()) {
            hashSet.remove(it2.next().getValue());
        }
        Bound boundProduct = Bounds.boundProduct(this.model.getTheory(), this.model.getContext(), this.model.isExtensional(), boundArr);
        ArrayList arrayList = new ArrayList();
        arrayList.addAll(hashSet);
        return boundProduct.summingPhiTimesBound(new DefaultExtensionalMultiSet(arrayList), partitionTree.node.getValue(), this.model.getContext(), this.model.getTheory());
    }

    private Set<VariableNode> ComputeSeparator(PartitionTree partitionTree) {
        ArrayList arrayList = new ArrayList();
        for (PartitionTree partitionTree2 : partitionTree.children) {
            HashSet hashSet = new HashSet();
            Iterator<FactorNode> it = partitionTree2.setOfFactorsInsidePartition.iterator();
            while (it.hasNext()) {
                hashSet.addAll(this.model.getExploredGraph().getAsOfB(it.next()));
            }
            arrayList.add(hashSet);
        }
        HashSet hashSet2 = new HashSet();
        for (int i = 0; i < arrayList.size(); i++) {
            for (int i2 = i + 1; i2 < arrayList.size(); i2++) {
                HashSet hashSet3 = new HashSet();
                hashSet3.addAll((Collection) arrayList.get(i));
                hashSet3.retainAll((Collection) arrayList.get(i2));
                hashSet2.addAll(hashSet3);
            }
        }
        return hashSet2;
    }

    private PairOf<Set<VariableNode>> ComputeSeparatorOnThisLevelAndSeparatorOnLevelsBelow(PartitionTree partitionTree, Set<VariableNode> set) {
        Set<VariableNode> ComputeSeparator = ComputeSeparator(partitionTree);
        if (partitionTree.node.isVariable()) {
            ComputeSeparator.remove((VariableNode) partitionTree.node);
        }
        ComputeSeparator.removeAll(set);
        HashSet hashSet = new HashSet();
        hashSet.addAll(ComputeSeparator);
        hashSet.addAll(set);
        return new PairOf<>(ComputeSeparator, hashSet);
    }
}
