package org.jpmml.evaluator.general_regression;

import com.google.common.base.Function;
import com.google.common.base.Predicate;
import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.BiMap;
import com.google.common.collect.HashBiMap;
import com.google.common.collect.ImmutableBiMap;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
import com.google.common.collect.ListMultimap;
import com.google.common.collect.Lists;
import com.google.common.collect.Ordering;
import com.google.common.collect.Sets;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import org.dmg.pmml.DataType;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.HasValue;
import org.dmg.pmml.MathContext;
import org.dmg.pmml.Matrix;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.OpType;
import org.dmg.pmml.PMML;
import org.dmg.pmml.general_regression.BaseCumHazardTables;
import org.dmg.pmml.general_regression.BaselineCell;
import org.dmg.pmml.general_regression.BaselineStratum;
import org.dmg.pmml.general_regression.Categories;
import org.dmg.pmml.general_regression.Category;
import org.dmg.pmml.general_regression.GeneralRegressionModel;
import org.dmg.pmml.general_regression.PCell;
import org.dmg.pmml.general_regression.PPCell;
import org.dmg.pmml.general_regression.Parameter;
import org.dmg.pmml.general_regression.ParameterCell;
import org.dmg.pmml.general_regression.ParameterList;
import org.dmg.pmml.general_regression.Predictor;
import org.dmg.pmml.general_regression.PredictorList;
import org.jpmml.evaluator.CacheUtil;
import org.jpmml.evaluator.Classification;
import org.jpmml.evaluator.EvaluationContext;
import org.jpmml.evaluator.FieldValue;
import org.jpmml.evaluator.FieldValueUtil;
import org.jpmml.evaluator.HasParsedValueMapping;
import org.jpmml.evaluator.InvalidAttributeException;
import org.jpmml.evaluator.InvalidElementException;
import org.jpmml.evaluator.MatrixUtil;
import org.jpmml.evaluator.MisplacedAttributeException;
import org.jpmml.evaluator.MissingAttributeException;
import org.jpmml.evaluator.MissingElementException;
import org.jpmml.evaluator.MissingValueException;
import org.jpmml.evaluator.ModelEvaluationContext;
import org.jpmml.evaluator.ModelEvaluator;
import org.jpmml.evaluator.OutputUtil;
import org.jpmml.evaluator.PMMLAttributes;
import org.jpmml.evaluator.PMMLElements;
import org.jpmml.evaluator.ProbabilityDistribution;
import org.jpmml.evaluator.TargetField;
import org.jpmml.evaluator.TargetUtil;
import org.jpmml.evaluator.UnsupportedAttributeException;
import org.jpmml.evaluator.UnsupportedElementException;
import org.jpmml.evaluator.Value;
import org.jpmml.evaluator.ValueFactory;
import org.jpmml.evaluator.ValueMap;
import org.jpmml.evaluator.ValueUtil;

/* loaded from: classes7.dex */
public class GeneralRegressionModelEvaluator extends ModelEvaluator<GeneralRegressionModel> {
    private transient Map<String, List<PCell>> paramMatrixMap;
    private transient BiMap<String, Parameter> parameterRegistry;
    private transient Map<String, Map<String, Row>> ppMatrixMap;
    private transient List<String> targetCategories;
    private static final LoadingCache<GeneralRegressionModel, BiMap<String, Parameter>> parameterCache = CacheUtil.buildLoadingCache(new CacheLoader<GeneralRegressionModel, BiMap<String, Parameter>>() { // from class: org.jpmml.evaluator.general_regression.GeneralRegressionModelEvaluator.3
        @Override // com.google.common.cache.CacheLoader
        public BiMap<String, Parameter> load(GeneralRegressionModel generalRegressionModel) {
            return ImmutableBiMap.copyOf((Map) GeneralRegressionModelEvaluator.parseParameterRegistry(generalRegressionModel.getParameterList()));
        }
    });
    public static final LoadingCache<GeneralRegressionModel, BiMap<FieldName, Predictor>> factorCache = CacheUtil.buildLoadingCache(new CacheLoader<GeneralRegressionModel, BiMap<FieldName, Predictor>>() { // from class: org.jpmml.evaluator.general_regression.GeneralRegressionModelEvaluator.8
        @Override // com.google.common.cache.CacheLoader
        public BiMap<FieldName, Predictor> load(GeneralRegressionModel generalRegressionModel) {
            return ImmutableBiMap.copyOf((Map) GeneralRegressionModelEvaluator.parsePredictorRegistry(generalRegressionModel.getFactorList()));
        }
    });
    public static final LoadingCache<GeneralRegressionModel, BiMap<FieldName, Predictor>> covariateCache = CacheUtil.buildLoadingCache(new CacheLoader<GeneralRegressionModel, BiMap<FieldName, Predictor>>() { // from class: org.jpmml.evaluator.general_regression.GeneralRegressionModelEvaluator.9
        @Override // com.google.common.cache.CacheLoader
        public BiMap<FieldName, Predictor> load(GeneralRegressionModel generalRegressionModel) {
            return ImmutableBiMap.copyOf((Map) GeneralRegressionModelEvaluator.parsePredictorRegistry(generalRegressionModel.getCovariateList()));
        }
    });
    private static final LoadingCache<GeneralRegressionModel, Map<String, Map<String, Row>>> ppMatrixCache = CacheUtil.buildLoadingCache(new CacheLoader<GeneralRegressionModel, Map<String, Map<String, Row>>>() { // from class: org.jpmml.evaluator.general_regression.GeneralRegressionModelEvaluator.4
        @Override // com.google.common.cache.CacheLoader
        public Map<String, Map<String, Row>> load(GeneralRegressionModel generalRegressionModel) {
            return Collections.unmodifiableMap(GeneralRegressionModelEvaluator.parsePPMatrix(generalRegressionModel));
        }
    });
    private static final LoadingCache<GeneralRegressionModel, Map<String, List<PCell>>> paramMatrixCache = CacheUtil.buildLoadingCache(new CacheLoader<GeneralRegressionModel, Map<String, List<PCell>>>() { // from class: org.jpmml.evaluator.general_regression.GeneralRegressionModelEvaluator.5
        @Override // com.google.common.cache.CacheLoader
        public Map<String, List<PCell>> load(GeneralRegressionModel generalRegressionModel) {
            return Collections.unmodifiableMap(GeneralRegressionModelEvaluator.parseParamMatrix(generalRegressionModel));
        }
    });

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: classes7.dex */
    public static class Row {
        private List<CovariateHandler> covariateHandlers;
        private List<FactorHandler> factorHandlers;

        /* JADX INFO: Access modifiers changed from: private */
        /* loaded from: classes7.dex */
        public class ContrastMatrixHandler extends FactorHandler {
            private List<String> categories;
            private Matrix matrix;
            private List<FieldValue> parsedCategories;

            private ContrastMatrixHandler(PPCell pPCell, Matrix matrix, List<String> list) {
                super(pPCell);
                setMatrix(matrix);
                setCategories(list);
            }

            private List<FieldValue> parseCategories(final DataType dataType, final OpType opType) {
                return Lists.transform(getCategories(), new Function<String, FieldValue>() { // from class: org.jpmml.evaluator.general_regression.GeneralRegressionModelEvaluator.Row.ContrastMatrixHandler.1
                    @Override // com.google.common.base.Function
                    public FieldValue apply(String str) {
                        return FieldValueUtil.create(dataType, opType, str);
                    }
                });
            }

            private void setCategories(List<String> list) {
                this.categories = list;
            }

            private void setMatrix(Matrix matrix) {
                this.matrix = matrix;
            }

            public List<String> getCategories() {
                return this.categories;
            }

            public int getIndex(String str) {
                return getCategories().indexOf(str);
            }

            public int getIndex(FieldValue fieldValue) {
                if (this.parsedCategories == null) {
                    this.parsedCategories = ImmutableList.copyOf((Collection) parseCategories(fieldValue.getDataType(), fieldValue.getOpType()));
                }
                return this.parsedCategories.indexOf(fieldValue);
            }

            public Matrix getMatrix() {
                return this.matrix;
            }

            @Override // org.jpmml.evaluator.general_regression.GeneralRegressionModelEvaluator.Row.FactorHandler, org.jpmml.evaluator.general_regression.GeneralRegressionModelEvaluator.Row.PredictorHandler
            public <V extends Number> Value<V> updateProduct(Value<V> value, FieldValue fieldValue) {
                Matrix matrix = getMatrix();
                int index = getIndex(fieldValue);
                int index2 = getIndex(getCategory());
                if (index < 0 || index2 < 0) {
                    throw new InvalidElementException(getPPCell());
                }
                Number elementAt = MatrixUtil.getElementAt(matrix, index + 1, index2 + 1);
                if (elementAt == null) {
                    throw new InvalidElementException(matrix);
                }
                return value.multiply2(elementAt.doubleValue());
            }
        }

        /* JADX INFO: Access modifiers changed from: private */
        /* loaded from: classes7.dex */
        public class CovariateHandler extends PredictorHandler {
            private double exponent;

            private CovariateHandler(PPCell pPCell) {
                super(pPCell);
                this.exponent = 1.0d;
                String value = pPCell.getValue();
                if (value == null) {
                    throw new MissingAttributeException(pPCell, PMMLAttributes.PPCELL_VALUE);
                }
                setExponent(Double.parseDouble(value));
            }

            private void setExponent(double d) {
                this.exponent = d;
            }

            public double getExponent() {
                return this.exponent;
            }

            @Override // org.jpmml.evaluator.general_regression.GeneralRegressionModelEvaluator.Row.PredictorHandler
            public <V extends Number> Value<V> updateProduct(Value<V> value, FieldValue fieldValue) {
                double exponent = getExponent();
                return exponent != 1.0d ? value.multiply2(fieldValue.asNumber(), exponent) : value.multiply2(fieldValue.asNumber().doubleValue());
            }
        }

        /* JADX INFO: Access modifiers changed from: private */
        /* loaded from: classes7.dex */
        public class FactorHandler extends PredictorHandler {
            private String category;

            private FactorHandler(PPCell pPCell) {
                super(pPCell);
                String value = pPCell.getValue();
                if (value == null) {
                    throw new MissingAttributeException(pPCell, PMMLAttributes.PPCELL_VALUE);
                }
                setCategory(value);
            }

            private void setCategory(String str) {
                this.category = str;
            }

            public String getCategory() {
                return this.category;
            }

            @Override // org.jpmml.evaluator.general_regression.GeneralRegressionModelEvaluator.Row.PredictorHandler
            public <V extends Number> Value<V> updateProduct(Value<V> value, FieldValue fieldValue) {
                return fieldValue.equals((HasValue<?>) getPPCell()) ? value : value.multiply2(0.0d);
            }
        }

        /* JADX INFO: Access modifiers changed from: private */
        /* loaded from: classes7.dex */
        public abstract class PredictorHandler {
            private PPCell ppCell;

            private PredictorHandler(PPCell pPCell) {
                setPPCell(pPCell);
                if (pPCell.getPredictorName() == null) {
                    throw new MissingAttributeException(pPCell, PMMLAttributes.PPCELL_PREDICTORNAME);
                }
            }

            private void setPPCell(PPCell pPCell) {
                this.ppCell = pPCell;
            }

            public PPCell getPPCell() {
                return this.ppCell;
            }

            public FieldName getPredictorName() {
                return getPPCell().getPredictorName();
            }

            public abstract <V extends Number> Value<V> updateProduct(Value<V> value, FieldValue fieldValue);
        }

        private Row() {
            this.factorHandlers = new ArrayList();
            this.covariateHandlers = new ArrayList();
        }

        public void addCovariate(PPCell pPCell) {
            getCovariateHandlers().add(new CovariateHandler(pPCell));
        }

        public void addFactor(PPCell pPCell, Predictor predictor) {
            List<FactorHandler> factorHandlers = getFactorHandlers();
            Matrix matrix = predictor.getMatrix();
            if (matrix == null) {
                factorHandlers.add(new FactorHandler(pPCell));
                return;
            }
            Categories categories = predictor.getCategories();
            if (categories == null) {
                throw new UnsupportedElementException(predictor);
            }
            factorHandlers.add(new ContrastMatrixHandler(pPCell, matrix, Lists.transform(categories.getCategories(), new Function<Category, String>() { // from class: org.jpmml.evaluator.general_regression.GeneralRegressionModelEvaluator.Row.1
                @Override // com.google.common.base.Function
                public String apply(Category category) {
                    String value = category.getValue();
                    if (value == null) {
                        throw new MissingAttributeException(category, PMMLAttributes.CATEGORY_VALUE);
                    }
                    return value;
                }
            })));
        }

        public <V extends Number> Value<V> evaluate(ValueFactory<V> valueFactory, EvaluationContext evaluationContext) {
            Value<V> newValue = valueFactory.newValue(1.0d);
            List<FactorHandler> factorHandlers = getFactorHandlers();
            int size = factorHandlers.size();
            for (int i = 0; i < size; i++) {
                FactorHandler factorHandler = factorHandlers.get(i);
                FieldValue evaluate = evaluationContext.evaluate(factorHandler.getPredictorName());
                if (evaluate == null) {
                    return null;
                }
                factorHandler.updateProduct(newValue, evaluate);
            }
            if (newValue.equals(0.0d)) {
                return newValue;
            }
            List<CovariateHandler> covariateHandlers = getCovariateHandlers();
            int size2 = covariateHandlers.size();
            for (int i2 = 0; i2 < size2; i2++) {
                CovariateHandler covariateHandler = covariateHandlers.get(i2);
                FieldValue evaluate2 = evaluationContext.evaluate(covariateHandler.getPredictorName());
                if (evaluate2 == null) {
                    return null;
                }
                covariateHandler.updateProduct(newValue, evaluate2);
            }
            return newValue;
        }

        public List<CovariateHandler> getCovariateHandlers() {
            return this.covariateHandlers;
        }

        public List<FactorHandler> getFactorHandlers() {
            return this.factorHandlers;
        }
    }

    public GeneralRegressionModelEvaluator(PMML pmml) {
        this(pmml, (GeneralRegressionModel) selectModel(pmml, GeneralRegressionModel.class));
    }

    public GeneralRegressionModelEvaluator(PMML pmml, GeneralRegressionModel generalRegressionModel) {
        super(pmml, generalRegressionModel);
        if (generalRegressionModel.getModelType() == null) {
            throw new MissingAttributeException(generalRegressionModel, PMMLAttributes.GENERALREGRESSIONMODEL_MODELTYPE);
        }
        if (generalRegressionModel.getParameterList() == null) {
            throw new MissingElementException(generalRegressionModel, PMMLElements.GENERALREGRESSIONMODEL_PARAMETERLIST);
        }
        if (generalRegressionModel.getPPMatrix() == null) {
            throw new MissingElementException(generalRegressionModel, PMMLElements.GENERALREGRESSIONMODEL_PPMATRIX);
        }
        if (generalRegressionModel.getParamMatrix() == null) {
            throw new MissingElementException(generalRegressionModel, PMMLElements.GENERALREGRESSIONMODEL_PARAMMATRIX);
        }
    }

    private static <C extends ParameterCell> Map<String, List<C>> asMap(ListMultimap<String, C> listMultimap) {
        return listMultimap.asMap();
    }

    private <V extends Number> Value<V> computeCumulativeLink(Value<V> value, EvaluationContext evaluationContext) {
        GeneralRegressionModel model = getModel();
        GeneralRegressionModel.CumulativeLinkFunction cumulativeLinkFunction = model.getCumulativeLinkFunction();
        if (cumulativeLinkFunction == null) {
            throw new MissingAttributeException(model, PMMLAttributes.GENERALREGRESSIONMODEL_CUMULATIVELINKFUNCTION);
        }
        Double offset = getOffset(model, evaluationContext);
        if (offset != null && offset.doubleValue() != 0.0d) {
            value.add2(offset.doubleValue());
        }
        switch (cumulativeLinkFunction) {
            case LOGIT:
            case PROBIT:
            case CLOGLOG:
            case LOGLOG:
            case CAUCHIT:
                GeneralRegressionModelUtil.computeCumulativeLink(value, cumulativeLinkFunction);
                return value;
            default:
                throw new UnsupportedAttributeException(model, cumulativeLinkFunction);
        }
    }

    private <V extends Number> Value<V> computeDotProduct(ValueFactory<V> valueFactory, Iterable<PCell> iterable, Map<String, Row> map, EvaluationContext evaluationContext) {
        Value<V> value = null;
        for (PCell pCell : iterable) {
            Row row = map.get(pCell.getParameterName());
            if (value == null) {
                value = valueFactory.newValue();
            }
            if (row != null) {
                Value<V> evaluate = row.evaluate(valueFactory, evaluationContext);
                if (evaluate == null) {
                    return null;
                }
                value.add2(pCell.getBeta(), evaluate.getValue());
            } else {
                value.add2(pCell.getBeta());
            }
        }
        return value;
    }

    private <V extends Number> Value<V> computeDotProduct(ValueFactory<V> valueFactory, EvaluationContext evaluationContext) {
        Map<String, Row> map;
        GeneralRegressionModel model = getModel();
        Map<String, Map<String, Row>> pPMatrixMap = getPPMatrixMap();
        if (pPMatrixMap.isEmpty()) {
            map = Collections.emptyMap();
        } else {
            map = pPMatrixMap.get(null);
            if (map == null) {
                throw new InvalidElementException(model.getPPMatrix());
            }
        }
        Map<String, List<PCell>> paramMatrixMap = getParamMatrixMap();
        List<PCell> list = paramMatrixMap.get(null);
        if (paramMatrixMap.size() != 1 || list == null) {
            throw new InvalidElementException(model.getParamMatrix());
        }
        return computeDotProduct(valueFactory, list, map, evaluationContext);
    }

    private <V extends Number> Value<V> computeLink(Value<V> value, EvaluationContext evaluationContext) {
        GeneralRegressionModel model = getModel();
        GeneralRegressionModel.LinkFunction linkFunction = model.getLinkFunction();
        if (linkFunction == null) {
            throw new MissingAttributeException(model, PMMLAttributes.GENERALREGRESSIONMODEL_LINKFUNCTION);
        }
        Double distParameter = model.getDistParameter();
        Double linkParameter = model.getLinkParameter();
        switch (linkFunction) {
            case CLOGLOG:
            case IDENTITY:
            case LOG:
            case LOGC:
            case LOGIT:
            case LOGLOG:
            case PROBIT:
                if (distParameter != null) {
                    throw new MisplacedAttributeException(model, PMMLAttributes.GENERALREGRESSIONMODEL_DISTPARAMETER, distParameter);
                }
                if (linkParameter != null) {
                    throw new MisplacedAttributeException(model, PMMLAttributes.GENERALREGRESSIONMODEL_LINKPARAMETER, linkParameter);
                }
                break;
            case NEGBIN:
                if (distParameter == null) {
                    throw new MissingAttributeException(model, PMMLAttributes.GENERALREGRESSIONMODEL_DISTPARAMETER);
                }
                if (linkParameter != null) {
                    throw new MisplacedAttributeException(model, PMMLAttributes.GENERALREGRESSIONMODEL_LINKPARAMETER, linkParameter);
                }
                break;
            case ODDSPOWER:
            case POWER:
                if (distParameter != null) {
                    throw new MisplacedAttributeException(model, PMMLAttributes.GENERALREGRESSIONMODEL_DISTPARAMETER, distParameter);
                }
                if (linkParameter == null) {
                    throw new MissingAttributeException(model, PMMLAttributes.GENERALREGRESSIONMODEL_LINKPARAMETER);
                }
                break;
            default:
                throw new UnsupportedAttributeException(model, linkFunction);
        }
        Double offset = getOffset(model, evaluationContext);
        if (offset != null && offset.doubleValue() != 0.0d) {
            value.add2(offset.doubleValue());
        }
        switch (linkFunction) {
            case CLOGLOG:
            case IDENTITY:
            case LOG:
            case LOGC:
            case LOGIT:
            case LOGLOG:
            case PROBIT:
            case NEGBIN:
            case ODDSPOWER:
            case POWER:
                GeneralRegressionModelUtil.computeLink(value, distParameter, linkParameter, linkFunction);
                Integer trials = getTrials(model, evaluationContext);
                if (trials != null && trials.intValue() != 1) {
                    value.multiply2(trials.intValue());
                }
                return value;
            default:
                throw new UnsupportedAttributeException(model, linkFunction);
        }
    }

    private <V extends Number> Value<V> computeReferencePoint(ValueFactory<V> valueFactory) {
        GeneralRegressionModel model = getModel();
        BiMap<String, Parameter> parameterRegistry = getParameterRegistry();
        Map<String, List<PCell>> paramMatrixMap = getParamMatrixMap();
        List<PCell> list = paramMatrixMap.get(null);
        if (paramMatrixMap.size() != 1 || list == null) {
            throw new InvalidElementException(model.getParamMatrix());
        }
        Value<V> value = null;
        for (PCell pCell : list) {
            Parameter parameter = parameterRegistry.get(pCell.getParameterName());
            if (value == null) {
                value = valueFactory.newValue();
            }
            if (parameter == null) {
                return null;
            }
            value.add2(pCell.getBeta(), parameter.getReferencePoint());
        }
        return value;
    }

    private <V extends Number> Map<FieldName, ? extends Classification<V>> evaluateClassification(ValueFactory<V> valueFactory, EvaluationContext evaluationContext) {
        Map<String, Map<String, Row>> map;
        Value<V> newValue;
        Map<String, Row> map2;
        List<PCell> list;
        GeneralRegressionModel model = getModel();
        TargetField targetField = getTargetField();
        List<String> targetCategories = getTargetCategories();
        GeneralRegressionModel.ModelType modelType = model.getModelType();
        Map<String, Map<String, Row>> pPMatrixMap = getPPMatrixMap();
        Map<String, List<PCell>> paramMatrixMap = getParamMatrixMap();
        ValueMap valueMap = new ValueMap(targetCategories.size() * 2);
        Object obj = null;
        int i = 0;
        Value<? extends Number> value = null;
        while (i < targetCategories.size()) {
            String str = targetCategories.get(i);
            if (i < targetCategories.size() - 1) {
                if (pPMatrixMap.isEmpty()) {
                    map2 = Collections.emptyMap();
                } else {
                    map2 = pPMatrixMap.get(str);
                    if (map2 == null) {
                        map2 = pPMatrixMap.get(obj);
                    }
                    if (map2 == null) {
                        throw new InvalidElementException(model.getPPMatrix());
                    }
                }
                switch (modelType) {
                    case COX_REGRESSION:
                    case REGRESSION:
                    case GENERAL_LINEAR:
                        throw new InvalidAttributeException(model, modelType);
                    case GENERALIZED_LINEAR:
                    case MULTINOMIAL_LOGISTIC:
                        map = pPMatrixMap;
                        list = paramMatrixMap.get(str);
                        if (list == null && targetCategories.size() == 2) {
                            list = paramMatrixMap.get(null);
                        }
                        if (list == null) {
                            throw new InvalidElementException(model.getParamMatrix());
                        }
                        break;
                    case ORDINAL_MULTINOMIAL:
                        List<PCell> list2 = paramMatrixMap.get(str);
                        if (list2 != null) {
                            map = pPMatrixMap;
                            if (list2.size() == 1) {
                                List<PCell> list3 = paramMatrixMap.get(null);
                                if (list3 != null) {
                                    list = Iterables.concat(list2, list3);
                                    break;
                                } else {
                                    throw new InvalidElementException(model.getParamMatrix());
                                }
                            }
                        }
                        throw new InvalidElementException(model.getParamMatrix());
                    default:
                        throw new UnsupportedAttributeException(model, modelType);
                }
                newValue = computeDotProduct(valueFactory, list, map2, evaluationContext);
                if (newValue == null) {
                    return TargetUtil.evaluateClassificationDefault(valueFactory, targetField);
                }
                switch (modelType) {
                    case COX_REGRESSION:
                    case REGRESSION:
                    case GENERAL_LINEAR:
                        throw new InvalidAttributeException(model, modelType);
                    case GENERALIZED_LINEAR:
                        newValue = computeLink(newValue, evaluationContext);
                        break;
                    case MULTINOMIAL_LOGISTIC:
                        newValue.exp2();
                        break;
                    case ORDINAL_MULTINOMIAL:
                        newValue = computeCumulativeLink(newValue, evaluationContext);
                        break;
                    default:
                        throw new UnsupportedAttributeException(model, modelType);
                }
            } else {
                map = pPMatrixMap;
                switch (modelType) {
                    case COX_REGRESSION:
                    case REGRESSION:
                    case GENERAL_LINEAR:
                        throw new InvalidAttributeException(model, modelType);
                    case GENERALIZED_LINEAR:
                        newValue = valueFactory.newValue(1.0d);
                        if (value != null) {
                            newValue.subtract(value);
                            break;
                        }
                        break;
                    case MULTINOMIAL_LOGISTIC:
                        newValue = valueFactory.newValue(0.0d);
                        newValue.exp2();
                        break;
                    case ORDINAL_MULTINOMIAL:
                        newValue = valueFactory.newValue(1.0d);
                        break;
                    default:
                        throw new UnsupportedAttributeException(model, modelType);
                }
            }
            switch (modelType) {
                case COX_REGRESSION:
                case REGRESSION:
                case GENERAL_LINEAR:
                    throw new InvalidAttributeException(model, modelType);
                case GENERALIZED_LINEAR:
                case MULTINOMIAL_LOGISTIC:
                    break;
                case ORDINAL_MULTINOMIAL:
                    if (value == null) {
                        break;
                    } else {
                        newValue.subtract(value);
                        break;
                    }
                default:
                    throw new UnsupportedAttributeException(model, modelType);
            }
            valueMap.put(str, newValue);
            i++;
            value = newValue;
            pPMatrixMap = map;
            obj = null;
        }
        switch (modelType) {
            case COX_REGRESSION:
            case REGRESSION:
            case GENERAL_LINEAR:
                throw new InvalidAttributeException(model, modelType);
            case GENERALIZED_LINEAR:
            case ORDINAL_MULTINOMIAL:
                break;
            case MULTINOMIAL_LOGISTIC:
                ValueUtil.normalizeSimpleMax(valueMap);
                break;
            default:
                throw new UnsupportedAttributeException(model, modelType);
        }
        return TargetUtil.evaluateClassification(targetField, new ProbabilityDistribution(valueMap));
    }

    private <V extends Number> Map<FieldName, ?> evaluateCoxRegression(ValueFactory<V> valueFactory, EvaluationContext evaluationContext) {
        List<BaselineCell> baselineCells;
        Double maxTime;
        GeneralRegressionModel model = getModel();
        model.getStartTimeVariable();
        FieldName endTimeVariable = model.getEndTimeVariable();
        if (endTimeVariable == null) {
            throw new MissingAttributeException(model, PMMLAttributes.GENERALREGRESSIONMODEL_ENDTIMEVARIABLE);
        }
        BaseCumHazardTables baseCumHazardTables = model.getBaseCumHazardTables();
        if (baseCumHazardTables == null) {
            throw new MissingElementException(model, PMMLElements.GENERALREGRESSIONMODEL_BASECUMHAZARDTABLES);
        }
        FieldName baselineStrataVariable = model.getBaselineStrataVariable();
        if (baselineStrataVariable != null) {
            BaselineStratum baselineStratum = getBaselineStratum(baseCumHazardTables, getVariable(baselineStrataVariable, evaluationContext));
            if (baselineStratum == null) {
                return null;
            }
            List<BaselineCell> baselineCells2 = baselineStratum.getBaselineCells();
            if (baselineCells2.size() < 1) {
                throw new MissingElementException(baselineStratum, PMMLElements.BASELINESTRATUM_BASELINECELLS);
            }
            maxTime = Double.valueOf(baselineStratum.getMaxTime());
            baselineCells = baselineCells2;
        } else {
            baselineCells = baseCumHazardTables.getBaselineCells();
            if (baselineCells.size() < 1) {
                throw new MissingElementException(baseCumHazardTables, PMMLElements.BASECUMHAZARDTABLES_BASELINECELLS);
            }
            maxTime = baseCumHazardTables.getMaxTime();
            if (maxTime == null) {
                throw new MissingAttributeException(baseCumHazardTables, PMMLAttributes.BASECUMHAZARDTABLES_MAXTIME);
            }
        }
        Ordering from = Ordering.from(new Comparator<BaselineCell>() { // from class: org.jpmml.evaluator.general_regression.GeneralRegressionModelEvaluator.1
            @Override // java.util.Comparator
            public int compare(BaselineCell baselineCell, BaselineCell baselineCell2) {
                return Double.compare(baselineCell.getTime(), baselineCell2.getTime());
            }
        });
        Double valueOf = Double.valueOf(((BaselineCell) from.min(baselineCells)).getTime());
        final FieldValue variable = getVariable(endTimeVariable, evaluationContext);
        if (variable.compareToValue(FieldValueUtil.create(DataType.DOUBLE, OpType.CONTINUOUS, valueOf)) < 0) {
            return TargetUtil.evaluateRegression(getTargetField(), valueFactory.newValue(0.0d));
        }
        if (variable.compareToValue(FieldValueUtil.create(DataType.DOUBLE, OpType.CONTINUOUS, maxTime)) > 0) {
            return null;
        }
        BaselineCell baselineCell = (BaselineCell) from.max(Iterables.filter(baselineCells, new Predicate<BaselineCell>() { // from class: org.jpmml.evaluator.general_regression.GeneralRegressionModelEvaluator.2
            private double time;

            {
                this.time = variable.asNumber().doubleValue();
            }

            @Override // com.google.common.base.Predicate
            public boolean apply(BaselineCell baselineCell2) {
                return baselineCell2.getTime() <= this.time;
            }
        }));
        Value<V> computeDotProduct = computeDotProduct(valueFactory, evaluationContext);
        Value<? extends Number> computeReferencePoint = computeReferencePoint(valueFactory);
        if (computeDotProduct == null || computeReferencePoint == null) {
            return null;
        }
        return TargetUtil.evaluateRegression(getTargetField(), computeDotProduct.subtract(computeReferencePoint).exp2().multiply2(baselineCell.getCumHazard()));
    }

    private <V extends Number> Map<FieldName, ?> evaluateGeneralRegression(ValueFactory<V> valueFactory, EvaluationContext evaluationContext) {
        GeneralRegressionModel model = getModel();
        TargetField targetField = getTargetField();
        Value<V> computeDotProduct = computeDotProduct(valueFactory, evaluationContext);
        if (computeDotProduct == null) {
            return TargetUtil.evaluateRegressionDefault(valueFactory, targetField);
        }
        GeneralRegressionModel.ModelType modelType = model.getModelType();
        switch (modelType) {
            case COX_REGRESSION:
            case MULTINOMIAL_LOGISTIC:
            case ORDINAL_MULTINOMIAL:
                throw new InvalidAttributeException(model, modelType);
            case REGRESSION:
            case GENERAL_LINEAR:
                break;
            case GENERALIZED_LINEAR:
                computeDotProduct = computeLink(computeDotProduct, evaluationContext);
                break;
            default:
                throw new UnsupportedAttributeException(model, modelType);
        }
        return TargetUtil.evaluateRegression(targetField, computeDotProduct);
    }

    private <V extends Number> Map<FieldName, ?> evaluateRegression(ValueFactory<V> valueFactory, EvaluationContext evaluationContext) {
        return AnonymousClass12.$SwitchMap$org$dmg$pmml$general_regression$GeneralRegressionModel$ModelType[getModel().getModelType().ordinal()] != 1 ? evaluateGeneralRegression(valueFactory, evaluationContext) : evaluateCoxRegression(valueFactory, evaluationContext);
    }

    /* JADX WARN: Multi-variable type inference failed */
    private static BaselineStratum getBaselineStratum(BaseCumHazardTables baseCumHazardTables, FieldValue fieldValue) {
        if (baseCumHazardTables instanceof HasParsedValueMapping) {
            return (BaselineStratum) fieldValue.getMapping((HasParsedValueMapping) baseCumHazardTables);
        }
        for (BaselineStratum baselineStratum : baseCumHazardTables.getBaselineStrata()) {
            String value = baselineStratum.getValue();
            if (value == null) {
                throw new MissingAttributeException(baselineStratum, PMMLAttributes.BASELINESTRATUM_VALUE);
            }
            if (fieldValue.equalsString(value)) {
                return baselineStratum;
            }
        }
        return null;
    }

    private static Double getOffset(GeneralRegressionModel generalRegressionModel, EvaluationContext evaluationContext) {
        FieldName offsetVariable = generalRegressionModel.getOffsetVariable();
        return offsetVariable != null ? getVariable(offsetVariable, evaluationContext).asDouble() : generalRegressionModel.getOffsetValue();
    }

    private Map<String, Map<String, Row>> getPPMatrixMap() {
        if (this.ppMatrixMap == null) {
            this.ppMatrixMap = (Map) getValue(ppMatrixCache);
        }
        return this.ppMatrixMap;
    }

    private List<String> getTargetCategories() {
        if (this.targetCategories == null) {
            this.targetCategories = ImmutableList.copyOf((Collection) parseTargetCategories());
        }
        return this.targetCategories;
    }

    private static Integer getTrials(GeneralRegressionModel generalRegressionModel, EvaluationContext evaluationContext) {
        FieldName trialsVariable = generalRegressionModel.getTrialsVariable();
        return trialsVariable != null ? getVariable(trialsVariable, evaluationContext).asInteger() : generalRegressionModel.getTrialsValue();
    }

    private static FieldValue getVariable(FieldName fieldName, EvaluationContext evaluationContext) {
        FieldValue evaluate = evaluationContext.evaluate(fieldName);
        if (evaluate == null) {
            throw new MissingValueException(fieldName);
        }
        return evaluate;
    }

    private static <C extends ParameterCell> ListMultimap<String, C> groupByParameterName(List<C> list) {
        return groupCells(list, new Function<C, String>() { // from class: org.jpmml.evaluator.general_regression.GeneralRegressionModelEvaluator.10
            /* JADX WARN: Incorrect types in method signature: (TC;)Ljava/lang/String; */
            @Override // com.google.common.base.Function
            public String apply(ParameterCell parameterCell) {
                return parameterCell.getParameterName();
            }
        });
    }

    private static <C extends ParameterCell> ListMultimap<String, C> groupByTargetCategory(List<C> list) {
        return groupCells(list, new Function<C, String>() { // from class: org.jpmml.evaluator.general_regression.GeneralRegressionModelEvaluator.11
            /* JADX WARN: Incorrect types in method signature: (TC;)Ljava/lang/String; */
            @Override // com.google.common.base.Function
            public String apply(ParameterCell parameterCell) {
                return parameterCell.getTargetCategory();
            }
        });
    }

    private static <C extends ParameterCell> ListMultimap<String, C> groupCells(List<C> list, Function<C, String> function) {
        ArrayListMultimap create = ArrayListMultimap.create();
        for (C c : list) {
            create.put(function.apply(c), c);
        }
        return create;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static Map<String, Map<String, Row>> parsePPMatrix(final GeneralRegressionModel generalRegressionModel) {
        Function<List<PPCell>, Row> function = new Function<List<PPCell>, Row>() { // from class: org.jpmml.evaluator.general_regression.GeneralRegressionModelEvaluator.7
            private BiMap<FieldName, Predictor> covariates;
            private BiMap<FieldName, Predictor> factors;

            {
                this.factors = (BiMap) CacheUtil.getValue(GeneralRegressionModel.this, GeneralRegressionModelEvaluator.factorCache);
                this.covariates = (BiMap) CacheUtil.getValue(GeneralRegressionModel.this, GeneralRegressionModelEvaluator.covariateCache);
            }

            @Override // com.google.common.base.Function
            public Row apply(List<PPCell> list) {
                Row row = new Row();
                for (PPCell pPCell : list) {
                    FieldName predictorName = pPCell.getPredictorName();
                    if (predictorName == null) {
                        throw new MissingAttributeException(pPCell, PMMLAttributes.PPCELL_PREDICTORNAME);
                    }
                    Predictor predictor = this.factors.get(predictorName);
                    if (predictor != null) {
                        row.addFactor(pPCell, predictor);
                    } else {
                        if (this.covariates.get(predictorName) == null) {
                            throw new InvalidAttributeException(pPCell, PMMLAttributes.PPCELL_PREDICTORNAME, predictorName);
                        }
                        row.addCovariate(pPCell);
                    }
                }
                return row;
            }
        };
        ListMultimap groupByTargetCategory = groupByTargetCategory(generalRegressionModel.getPPMatrix().getPPCells());
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        for (Map.Entry entry : asMap(groupByTargetCategory).entrySet()) {
            LinkedHashMap linkedHashMap2 = new LinkedHashMap();
            for (Map.Entry entry2 : asMap(groupByParameterName((List) entry.getValue())).entrySet()) {
                linkedHashMap2.put(entry2.getKey(), function.apply(entry2.getValue()));
            }
            linkedHashMap.put(entry.getKey(), linkedHashMap2);
        }
        return linkedHashMap;
    }

    public static Map<String, List<PCell>> parseParamMatrix(GeneralRegressionModel generalRegressionModel) {
        return asMap(groupByTargetCategory(generalRegressionModel.getParamMatrix().getPCells()));
    }

    public static BiMap<String, Parameter> parseParameterRegistry(ParameterList parameterList) {
        HashBiMap create = HashBiMap.create();
        if (!parameterList.hasParameters()) {
            return create;
        }
        for (Parameter parameter : parameterList.getParameters()) {
            create.put(parameter.getName(), parameter);
        }
        return create;
    }

    public static BiMap<FieldName, Predictor> parsePredictorRegistry(PredictorList predictorList) {
        HashBiMap create = HashBiMap.create();
        if (predictorList == null || !predictorList.hasPredictors()) {
            return create;
        }
        for (Predictor predictor : predictorList.getPredictors()) {
            create.put(predictor.getName(), predictor);
        }
        return create;
    }

    private List<String> parseTargetCategories() {
        GeneralRegressionModel model = getModel();
        TargetField targetField = getTargetField();
        switch (targetField.getOpType()) {
            case CATEGORICAL:
            case ORDINAL:
                List<String> targetCategories = FieldValueUtil.getTargetCategories(targetField);
                if (targetCategories.size() > 0 && targetCategories.size() < 2) {
                    throw new InvalidElementException(model);
                }
                String targetReferenceCategory = model.getTargetReferenceCategory();
                GeneralRegressionModel.ModelType modelType = model.getModelType();
                switch (modelType) {
                    case COX_REGRESSION:
                    case REGRESSION:
                    case GENERAL_LINEAR:
                        throw new InvalidAttributeException(model, modelType);
                    case GENERALIZED_LINEAR:
                    case MULTINOMIAL_LOGISTIC:
                        if (targetReferenceCategory == null) {
                            LinkedHashSet newLinkedHashSet = Sets.newLinkedHashSet(Iterables.filter(targetCategories, new Predicate<String>() { // from class: org.jpmml.evaluator.general_regression.GeneralRegressionModelEvaluator.6
                                private Map<String, List<PCell>> paramMatrixMap;

                                {
                                    this.paramMatrixMap = GeneralRegressionModelEvaluator.this.getParamMatrixMap();
                                }

                                @Override // com.google.common.base.Predicate
                                public boolean apply(String str) {
                                    return !this.paramMatrixMap.containsKey(str);
                                }
                            }));
                            if (newLinkedHashSet.size() == 1) {
                                targetReferenceCategory = (String) Iterables.getOnlyElement(newLinkedHashSet);
                                break;
                            } else {
                                throw new InvalidElementException(model.getParamMatrix());
                            }
                        }
                        break;
                    case ORDINAL_MULTINOMIAL:
                        break;
                    default:
                        throw new UnsupportedAttributeException(model, modelType);
                }
                if (targetReferenceCategory == null) {
                    return targetCategories;
                }
                ArrayList arrayList = new ArrayList(targetCategories);
                if (!arrayList.remove(targetReferenceCategory)) {
                    return arrayList;
                }
                arrayList.add(targetReferenceCategory);
                return arrayList;
            default:
                throw new InvalidElementException(model);
        }
    }

    @Override // org.jpmml.evaluator.ModelEvaluator
    public Map<FieldName, ?> evaluate(ModelEvaluationContext modelEvaluationContext) {
        Map<FieldName, ?> evaluateRegression;
        GeneralRegressionModel ensureScorableModel = ensureScorableModel();
        MathContext mathContext = ensureScorableModel.getMathContext();
        switch (mathContext) {
            case FLOAT:
            case DOUBLE:
                ValueFactory<?> valueFactory = getValueFactory();
                MiningFunction miningFunction = ensureScorableModel.getMiningFunction();
                switch (miningFunction) {
                    case REGRESSION:
                        evaluateRegression = evaluateRegression(valueFactory, modelEvaluationContext);
                        break;
                    case CLASSIFICATION:
                        evaluateRegression = evaluateClassification(valueFactory, modelEvaluationContext);
                        break;
                    case ASSOCIATION_RULES:
                    case SEQUENCES:
                    case CLUSTERING:
                    case TIME_SERIES:
                    case MIXED:
                        throw new InvalidAttributeException(ensureScorableModel, miningFunction);
                    default:
                        throw new UnsupportedAttributeException(ensureScorableModel, miningFunction);
                }
                return OutputUtil.evaluate(evaluateRegression, modelEvaluationContext);
            default:
                throw new UnsupportedAttributeException(ensureScorableModel, mathContext);
        }
    }

    public Map<String, List<PCell>> getParamMatrixMap() {
        if (this.paramMatrixMap == null) {
            this.paramMatrixMap = (Map) getValue(paramMatrixCache);
        }
        return this.paramMatrixMap;
    }

    public BiMap<String, Parameter> getParameterRegistry() {
        if (this.parameterRegistry == null) {
            this.parameterRegistry = (BiMap) getValue(parameterCache);
        }
        return this.parameterRegistry;
    }

    @Override // org.jpmml.evaluator.Evaluator
    public String getSummary() {
        return AnonymousClass12.$SwitchMap$org$dmg$pmml$general_regression$GeneralRegressionModel$ModelType[getModel().getModelType().ordinal()] != 1 ? "General regression" : "Cox regression";
    }
}
