package com.sri.ai.expresso.type;

import com.google.common.annotations.Beta;
import com.sri.ai.expresso.api.Expression;
import com.sri.ai.expresso.api.Symbol;
import com.sri.ai.expresso.api.Type;
import com.sri.ai.expresso.helper.Expressions;
import com.sri.ai.grinder.api.Registry;
import com.sri.ai.grinder.helper.AssignmentsIterator;
import com.sri.ai.grinder.sgdpllt.core.DefaultRegistry;
import com.sri.ai.grinder.sgdpllt.library.FunctorConstants;
import com.sri.ai.util.Util;
import com.sri.ai.util.base.NullaryFunction;
import com.sri.ai.util.collect.FunctionIterator;
import com.sri.ai.util.math.Rational;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Random;
import java.util.Set;
import java.util.StringJoiner;
import java.util.concurrent.atomic.AtomicInteger;

@Beta
/* loaded from: input_file:com/sri/ai/expresso/type/FunctionType.class */
public class FunctionType extends AbstractType {
    private static final long serialVersionUID = 1;
    private Type codomain;
    private List<Type> argumentTypes;
    private String cachedString;
    private Registry cachedIterateRegistry;
    private List<Expression> codomainVariables;
    private Expression genericLambda;

    public FunctionType(Type type, Type... typeArr) {
        this.codomain = type;
        this.argumentTypes = Collections.unmodifiableList(Arrays.asList(typeArr));
    }

    public Type getCodomain() {
        return this.codomain;
    }

    public int getArity() {
        return getArgumentTypes().size();
    }

    public List<Type> getArgumentTypes() {
        return this.argumentTypes;
    }

    @Override // com.sri.ai.expresso.api.Type
    public String getName() {
        return toString();
    }

    @Override // com.sri.ai.expresso.api.Type
    public Iterator<Expression> iterator() {
        if (!getCodomain().isDiscrete() || !getArgumentTypes().stream().allMatch((v0) -> {
            return v0.isFinite();
        })) {
            throw new Error("Only function types with left-finite argument types and a discrete codomain can be enumerated.");
        }
        if (this.cachedIterateRegistry == null) {
            this.cachedIterateRegistry = new DefaultRegistry();
            int intValue = ((Rational) this.argumentTypes.stream().map((v0) -> {
                return v0.cardinality();
            }).map((v0) -> {
                return v0.rationalValue();
            }).reduce(Rational.ONE, (v0, v1) -> {
                return v0.multiply(v1);
            })).intValue();
            this.cachedIterateRegistry = this.cachedIterateRegistry.makeCloneWithAddedType(getCodomain());
            Expression parse = Expressions.parse(getCodomain().getName());
            this.codomainVariables = new ArrayList(intValue);
            LinkedHashMap linkedHashMap = new LinkedHashMap();
            for (int i = 0; i < intValue; i++) {
                Symbol makeSymbol = Expressions.makeSymbol("C" + (i + 1));
                this.codomainVariables.add(makeSymbol);
                linkedHashMap.put(makeSymbol, parse);
            }
            ArrayList<Expression> arrayList = new ArrayList();
            for (int i2 = 0; i2 < getArgumentTypes().size(); i2++) {
                this.cachedIterateRegistry = this.cachedIterateRegistry.makeCloneWithAddedType(getArgumentTypes().get(i2));
                arrayList.add(Expressions.makeSymbol("A" + (i2 + 1)));
                linkedHashMap.put((Expression) arrayList.get(i2), Expressions.parse(getArgumentTypes().get(i2).getName()));
            }
            this.cachedIterateRegistry = this.cachedIterateRegistry.setSymbolsAndTypes(linkedHashMap);
            StringJoiner stringJoiner = new StringJoiner(", ", "(lambda ", " : ");
            for (Expression expression : arrayList) {
                stringJoiner.add(expression + " in " + linkedHashMap.get(expression));
            }
            AssignmentsIterator assignmentsIterator = new AssignmentsIterator(arrayList, this.cachedIterateRegistry);
            StringJoiner stringJoiner2 = new StringJoiner(" else ", "", ")");
            AtomicInteger atomicInteger = new AtomicInteger(0);
            assignmentsIterator.forEachRemaining(map -> {
                if (atomicInteger.incrementAndGet() == intValue) {
                    stringJoiner2.add("C" + intValue);
                    return;
                }
                StringJoiner stringJoiner3 = new StringJoiner(" and ", "if ", " then C" + atomicInteger);
                for (int i3 = 0; i3 < arrayList.size(); i3++) {
                    Expression expression2 = (Expression) arrayList.get(i3);
                    stringJoiner3.add(expression2 + " = " + map.get(expression2));
                }
                stringJoiner2.add(stringJoiner3.toString());
            });
            this.genericLambda = Expressions.parse(String.valueOf(stringJoiner.toString()) + stringJoiner2.toString());
        }
        return FunctionIterator.functionIterator(new AssignmentsIterator(this.codomainVariables, this.cachedIterateRegistry), map2 -> {
            Expression expression2 = this.genericLambda;
            for (int i3 = 0; i3 < this.codomainVariables.size(); i3++) {
                Expression expression3 = this.codomainVariables.get(i3);
                expression2 = expression2.replaceFirstOccurrence(expression3, (Expression) map2.get(expression3), this.cachedIterateRegistry);
            }
            return expression2;
        });
    }

    @Override // com.sri.ai.expresso.api.Type
    public boolean contains(Expression expression) {
        return false;
    }

    @Override // com.sri.ai.expresso.api.Type
    public boolean isSampleUniquelyNamedConstantSupported() {
        return false;
    }

    @Override // com.sri.ai.expresso.api.Type
    public Expression sampleUniquelyNamedConstant(Random random) {
        throw new Error("Cannot sample uniquely named constant from function type that is infinite and/or defined by variables: " + getName());
    }

    @Override // com.sri.ai.expresso.api.Type
    public Expression cardinality() {
        return isFinite() ? Expressions.makeSymbol(this.codomain.cardinality().rationalValue().pow(((Rational) this.argumentTypes.stream().map((v0) -> {
            return v0.cardinality();
        }).map((v0) -> {
            return v0.rationalValue();
        }).reduce(Rational.ONE, (v0, v1) -> {
            return v0.multiply(v1);
        })).intValue())) : Expressions.INFINITY;
    }

    @Override // com.sri.ai.expresso.api.Type
    public boolean isDiscrete() {
        return this.codomain.isDiscrete() && this.argumentTypes.stream().allMatch((v0) -> {
            return v0.isDiscrete();
        });
    }

    @Override // com.sri.ai.expresso.api.Type
    public boolean isFinite() {
        return this.codomain.isFinite() && this.argumentTypes.stream().allMatch((v0) -> {
            return v0.isFinite();
        });
    }

    public String toString() {
        if (this.cachedString == null) {
            if (getArgumentTypes().size() == 0) {
                this.cachedString = Expressions.apply(FunctorConstants.FUNCTION_TYPE, getCodomain()).toString();
            } else if (getArgumentTypes().size() == 1) {
                this.cachedString = Expressions.apply(FunctorConstants.FUNCTION_TYPE, getArgumentTypes().get(0), getCodomain()).toString();
            } else {
                this.cachedString = Expressions.apply(FunctorConstants.FUNCTION_TYPE, Expressions.apply(FunctorConstants.TUPLE_TYPE, getArgumentTypes()), getCodomain()).toString();
            }
        }
        return this.cachedString;
    }

    @Override // com.sri.ai.expresso.api.Type
    public Set<Type> getEmbeddedTypes() {
        LinkedHashSet linkedHashSet = new LinkedHashSet();
        linkedHashSet.add(this.codomain);
        linkedHashSet.addAll(this.codomain.getEmbeddedTypes());
        this.argumentTypes.forEach(type -> {
            linkedHashSet.add(type);
            linkedHashSet.addAll(type.getEmbeddedTypes());
        });
        return linkedHashSet;
    }

    public static Expression make(Expression expression, Expression... expressionArr) {
        return make(expression, (List<Expression>) Arrays.asList(expressionArr));
    }

    public static Expression make(Expression expression, List<Expression> list) {
        return list.size() == 0 ? Expressions.apply(FunctorConstants.FUNCTION_TYPE, expression) : list.size() == 1 ? Expressions.apply(FunctorConstants.FUNCTION_TYPE, list.get(0), expression) : Expressions.apply(FunctorConstants.FUNCTION_TYPE, Expressions.apply(FunctorConstants.TUPLE_TYPE, list), expression);
    }

    public static Expression getCodomain(Expression expression) {
        assertFunctionType(expression);
        return expression.numberOfArguments() == 1 ? expression.get(0) : expression.get(1);
    }

    public static List<Expression> getArgumentList(Expression expression) {
        assertFunctionType(expression);
        ArrayList arrayList = new ArrayList();
        if (expression.numberOfArguments() == 2) {
            if (expression.get(0).hasFunctor(FunctorConstants.TUPLE_TYPE)) {
                arrayList.addAll(expression.get(0).getArguments());
            } else {
                arrayList.add(expression.get(0));
            }
        }
        return arrayList;
    }

    public static boolean isFunctionType(Expression expression) {
        boolean z = false;
        if (expression.hasFunctor(FunctorConstants.FUNCTION_TYPE)) {
            if (expression.numberOfArguments() == 1) {
                z = true;
            } else if (expression.numberOfArguments() == 2) {
                z = true;
            }
        }
        return z;
    }

    public static void assertFunctionType(Expression expression) {
        Util.myAssert(expression.hasFunctor(FunctorConstants.FUNCTION_TYPE), (NullaryFunction<String>) () -> {
            return "Functor in expression " + expression + " should be a functional type (that is, have functor '->')";
        });
        Util.myAssert(expression.numberOfArguments() == 1 || expression.numberOfArguments() == 2, (NullaryFunction<String>) () -> {
            return "Function type has illegal number of arguments (should be 1 or 2), has " + expression.numberOfArguments() + " for " + expression;
        });
    }
}
