/*
 * Decompiled with CFR 0.152.
 */
package com.o19s.es.ltr.ranker.parser;

import com.o19s.es.ltr.feature.FeatureSet;
import com.o19s.es.ltr.ranker.dectree.NaiveAdditiveDecisionTree;
import com.o19s.es.ltr.ranker.normalizer.Normalizer;
import com.o19s.es.ltr.ranker.normalizer.Normalizers;
import com.o19s.es.ltr.ranker.parser.LtrRankerParser;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.ListIterator;
import java.util.Map;
import java.util.Optional;
import org.opensearch.common.xcontent.LoggingDeprecationHandler;
import org.opensearch.common.xcontent.json.JsonXContent;
import org.opensearch.core.ParseField;
import org.opensearch.core.common.ParsingException;
import org.opensearch.core.xcontent.DeprecationHandler;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.ObjectParser;
import org.opensearch.core.xcontent.XContentParseException;
import org.opensearch.core.xcontent.XContentParser;

public class XGBoostRawJsonParser
implements LtrRankerParser {
    public static final String TYPE = "model/xgboost+json+raw";

    @Override
    public NaiveAdditiveDecisionTree parse(FeatureSet set, String model) {
        XGBoostDefinition modelDefinition;
        try (XContentParser parser = JsonXContent.jsonXContent.createParser(NamedXContentRegistry.EMPTY, (DeprecationHandler)LoggingDeprecationHandler.INSTANCE, model);){
            modelDefinition = XGBoostDefinition.parse(parser, set);
        }
        catch (IOException e) {
            throw new IllegalArgumentException("Cannot parse model", e);
        }
        NaiveAdditiveDecisionTree.Node[] trees = modelDefinition.getLearner().getTrees(set);
        List<String> modelFeatures = modelDefinition.learner.featureNames;
        HashMap<Integer, Integer> modelFeaturesReordering = new HashMap<Integer, Integer>();
        for (int i = 0; i < modelFeatures.size(); ++i) {
            modelFeaturesReordering.put(i, set.featureOrdinal(modelFeatures.get(i)));
        }
        NaiveAdditiveDecisionTree.Node[] adjustedTrees = new NaiveAdditiveDecisionTree.Node[trees.length];
        for (int i = 0; i < trees.length; ++i) {
            adjustedTrees[i] = this.reorderTreeFeatures(trees[i], modelFeaturesReordering);
        }
        float[] weights = new float[trees.length];
        Arrays.fill(weights, 1.0f);
        return new NaiveAdditiveDecisionTree(adjustedTrees, weights, set.size(), modelDefinition.getLearner().getObjective().getNormalizer());
    }

    private NaiveAdditiveDecisionTree.Node reorderTreeFeatures(NaiveAdditiveDecisionTree.Node node, Map<Integer, Integer> modelFeaturesReordering) {
        if (node instanceof NaiveAdditiveDecisionTree.Split) {
            NaiveAdditiveDecisionTree.Split splitNode = (NaiveAdditiveDecisionTree.Split)node;
            return new NaiveAdditiveDecisionTree.Split(this.reorderTreeFeatures(splitNode.getLeft(), modelFeaturesReordering), this.reorderTreeFeatures(splitNode.getRight(), modelFeaturesReordering), modelFeaturesReordering.get(splitNode.getFeature()), splitNode.getThreshold());
        }
        return node;
    }

    private static class XGBoostDefinition {
        private static final ObjectParser<XGBoostDefinition, FeatureSet> PARSER = new ObjectParser("xgboost_definition", true, XGBoostDefinition::new);
        private XGBoostLearner learner;
        private List<Integer> version;

        private XGBoostDefinition() {
        }

        public static XGBoostDefinition parse(XContentParser parser, FeatureSet set) throws IOException {
            XGBoostDefinition definition;
            XContentParser.Token startToken = parser.nextToken();
            if (startToken == XContentParser.Token.START_OBJECT) {
                try {
                    definition = (XGBoostDefinition)PARSER.apply(parser, (Object)set);
                }
                catch (XContentParseException e) {
                    throw new ParsingException(parser.getTokenLocation(), "Unable to parse XGBoost object", (Throwable)e, new Object[0]);
                }
                if (definition.learner == null) {
                    throw new ParsingException(parser.getTokenLocation(), "XGBoost model missing required field [learner]", new Object[0]);
                }
                ArrayList<String> unknownFeatures = new ArrayList<String>();
                for (String modelFeatureName : definition.learner.featureNames) {
                    if (set.hasFeature(modelFeatureName)) continue;
                    unknownFeatures.add(modelFeatureName);
                }
                if (!unknownFeatures.isEmpty()) {
                    throw new ParsingException(parser.getTokenLocation(), "Unknown features in model: [" + String.join((CharSequence)", ", unknownFeatures) + "]", new Object[0]);
                }
                if (definition.learner.featureNames.size() != definition.learner.featureTypes.size()) {
                    throw new ParsingException(parser.getTokenLocation(), "Feature names list and feature types list must have the same length", new Object[0]);
                }
                Optional<String> firstUnsupportedType = definition.learner.featureTypes.stream().filter(typeStr -> !typeStr.equals("float")).findFirst();
                if (firstUnsupportedType.isPresent()) {
                    throw new ParsingException(parser.getTokenLocation(), "The LTR plugin only supports float feature types because Elasticsearch scores are always float32. Found feature type [" + firstUnsupportedType.get() + "] in model", new Object[0]);
                }
            } else {
                throw new ParsingException(parser.getTokenLocation(), "Expected [START_OBJECT] but got [" + String.valueOf(startToken) + "]", new Object[0]);
            }
            return definition;
        }

        public XGBoostLearner getLearner() {
            return this.learner;
        }

        public void setLearner(XGBoostLearner learner) {
            this.learner = learner;
        }

        public List<Integer> getVersion() {
            return this.version;
        }

        public void setVersion(List<Integer> version) {
            this.version = version;
        }

        static {
            PARSER.declareObject(XGBoostDefinition::setLearner, XGBoostLearner::parse, new ParseField("learner", new String[0]));
            PARSER.declareIntArray(XGBoostDefinition::setVersion, new ParseField("version", new String[0]));
        }
    }

    static class XGBoostLearner {
        private List<String> featureNames;
        private List<String> featureTypes;
        private XGBoostGradientBooster gradientBooster;
        private XGBoostObjective objective;
        private static final ObjectParser<XGBoostLearner, FeatureSet> PARSER = new ObjectParser("xgboost_learner", true, XGBoostLearner::new);

        private void setFeatureTypes(List<String> featureTypes) {
            this.featureTypes = featureTypes;
        }

        private void setFeatureNames(List<String> featureNames) {
            this.featureNames = featureNames;
        }

        public static XGBoostLearner parse(XContentParser parser, FeatureSet set) throws IOException {
            return (XGBoostLearner)PARSER.apply(parser, (Object)set);
        }

        XGBoostLearner() {
        }

        NaiveAdditiveDecisionTree.Node[] getTrees(FeatureSet set) {
            return this.getGradientBooster().getModel().getTrees();
        }

        public XGBoostObjective getObjective() {
            return this.objective;
        }

        public void setObjective(XGBoostObjective objective) {
            this.objective = objective;
        }

        public XGBoostGradientBooster getGradientBooster() {
            return this.gradientBooster;
        }

        public void setGradientBooster(XGBoostGradientBooster gradientBooster) {
            this.gradientBooster = gradientBooster;
        }

        static {
            PARSER.declareObject(XGBoostLearner::setObjective, XGBoostObjective::parse, new ParseField("objective", new String[0]));
            PARSER.declareObject(XGBoostLearner::setGradientBooster, XGBoostGradientBooster::parse, new ParseField("gradient_booster", new String[0]));
            PARSER.declareStringArray(XGBoostLearner::setFeatureNames, new ParseField("feature_names", new String[0]));
            PARSER.declareStringArray(XGBoostLearner::setFeatureTypes, new ParseField("feature_types", new String[0]));
        }
    }

    static class XGBoostObjective {
        private Normalizer normalizer;
        private static final ObjectParser<XGBoostObjective, FeatureSet> PARSER = new ObjectParser("xgboost_objective", true, XGBoostObjective::new);

        public static XGBoostObjective parse(XContentParser parser, FeatureSet set) throws IOException {
            return (XGBoostObjective)PARSER.apply(parser, (Object)set);
        }

        XGBoostObjective() {
        }

        public void setName(String name) {
            switch (name) {
                case "binary:logitraw": 
                case "rank:ndcg": 
                case "rank:map": 
                case "rank:pairwise": 
                case "reg:linear": {
                    this.normalizer = Normalizers.get("noop");
                    break;
                }
                case "binary:logistic": 
                case "reg:logistic": {
                    this.normalizer = Normalizers.get("sigmoid");
                    break;
                }
                default: {
                    throw new IllegalArgumentException("Objective [" + name + "] is not a valid XGBoost objective");
                }
            }
        }

        Normalizer getNormalizer() {
            return this.normalizer;
        }

        static {
            PARSER.declareString(XGBoostObjective::setName, new ParseField("name", new String[0]));
        }
    }

    static class XGBoostTree {
        private Integer treeId;
        private List<Integer> leftChildren;
        private List<Integer> rightChildren;
        private List<Integer> parents;
        private List<Float> splitConditions;
        private List<Integer> splitIndices;
        private List<Integer> defaultLeft;
        private List<Integer> splitTypes;
        private List<Float> baseWeights;
        private NaiveAdditiveDecisionTree.Node rootNode;
        private static final ObjectParser<XGBoostTree, FeatureSet> PARSER = new ObjectParser("xgboost_tree", true, XGBoostTree::new);

        XGBoostTree() {
        }

        public static XGBoostTree parse(XContentParser parser, FeatureSet set) throws IOException {
            XGBoostTree tree = (XGBoostTree)PARSER.apply(parser, (Object)set);
            tree.rootNode = tree.asLibTree(0);
            return tree;
        }

        public Integer getTreeId() {
            return this.treeId;
        }

        public void setTreeId(Integer treeId) {
            this.treeId = treeId;
        }

        public List<Integer> getLeftChildren() {
            return this.leftChildren;
        }

        public void setLeftChildren(List<Integer> leftChildren) {
            this.leftChildren = leftChildren;
        }

        public List<Integer> getRightChildren() {
            return this.rightChildren;
        }

        public void setRightChildren(List<Integer> rightChildren) {
            this.rightChildren = rightChildren;
        }

        public List<Integer> getParents() {
            return this.parents;
        }

        public void setParents(List<Integer> parents) {
            this.parents = parents;
        }

        public List<Float> getSplitConditions() {
            return this.splitConditions;
        }

        public void setSplitConditions(List<Float> splitConditions) {
            this.splitConditions = splitConditions;
        }

        public List<Integer> getSplitIndices() {
            return this.splitIndices;
        }

        public void setSplitIndices(List<Integer> splitIndices) {
            this.splitIndices = splitIndices;
        }

        public List<Integer> getDefaultLeft() {
            return this.defaultLeft;
        }

        public void setDefaultLeft(List<Integer> defaultLeft) {
            this.defaultLeft = defaultLeft;
        }

        public List<Integer> getSplitTypes() {
            return this.splitTypes;
        }

        public void setSplitTypes(List<Integer> splitTypes) {
            this.splitTypes = splitTypes;
        }

        private boolean isSplit(Integer nodeId) {
            return this.leftChildren.get(nodeId) != -1 && this.rightChildren.get(nodeId) != -1;
        }

        private NaiveAdditiveDecisionTree.Node asLibTree(Integer nodeId) {
            if (nodeId >= this.leftChildren.size()) {
                throw new IllegalArgumentException("Child node reference ID [" + nodeId + "] is invalid");
            }
            if (nodeId >= this.rightChildren.size()) {
                throw new IllegalArgumentException("Child node reference ID [" + nodeId + "] is invalid");
            }
            if (this.isSplit(nodeId)) {
                return new NaiveAdditiveDecisionTree.Split(this.asLibTree(this.leftChildren.get(nodeId)), this.asLibTree(this.rightChildren.get(nodeId)), this.splitIndices.get(nodeId), this.splitConditions.get(nodeId).floatValue());
            }
            return new NaiveAdditiveDecisionTree.Leaf(this.baseWeights.get(nodeId).floatValue());
        }

        public List<Float> getBaseWeights() {
            return this.baseWeights;
        }

        public void setBaseWeights(List<Float> baseWeights) {
            this.baseWeights = baseWeights;
        }

        public NaiveAdditiveDecisionTree.Node getRootNode() {
            return this.rootNode;
        }

        static {
            PARSER.declareInt(XGBoostTree::setTreeId, new ParseField("id", new String[0]));
            PARSER.declareIntArray(XGBoostTree::setLeftChildren, new ParseField("left_children", new String[0]));
            PARSER.declareIntArray(XGBoostTree::setRightChildren, new ParseField("right_children", new String[0]));
            PARSER.declareIntArray(XGBoostTree::setParents, new ParseField("parents", new String[0]));
            PARSER.declareFloatArray(XGBoostTree::setSplitConditions, new ParseField("split_conditions", new String[0]));
            PARSER.declareIntArray(XGBoostTree::setSplitIndices, new ParseField("split_indices", new String[0]));
            PARSER.declareIntArray(XGBoostTree::setDefaultLeft, new ParseField("default_left", new String[0]));
            PARSER.declareIntArray(XGBoostTree::setSplitTypes, new ParseField("split_type", new String[0]));
            PARSER.declareFloatArray(XGBoostTree::setBaseWeights, new ParseField("base_weights", new String[0]));
        }
    }

    static class XGBoostModel {
        private NaiveAdditiveDecisionTree.Node[] trees;
        private List<Integer> treeInfo;
        private static final ObjectParser<XGBoostModel, FeatureSet> PARSER = new ObjectParser("xgboost_model", true, XGBoostModel::new);

        public List<Integer> getTreeInfo() {
            return this.treeInfo;
        }

        public void setTreeInfo(List<Integer> treeInfo) {
            this.treeInfo = treeInfo;
        }

        public static XGBoostModel parse(XContentParser parser, FeatureSet set) throws IOException {
            try {
                return (XGBoostModel)PARSER.apply(parser, (Object)set);
            }
            catch (IllegalArgumentException e) {
                throw new ParsingException(parser.getTokenLocation(), e.getMessage(), (Throwable)e, new Object[0]);
            }
        }

        XGBoostModel() {
        }

        public NaiveAdditiveDecisionTree.Node[] getTrees() {
            return this.trees;
        }

        public void setTrees(List<XGBoostTree> parsedTrees) {
            NaiveAdditiveDecisionTree.Node[] trees = new NaiveAdditiveDecisionTree.Node[parsedTrees.size()];
            ListIterator<XGBoostTree> it = parsedTrees.listIterator();
            while (it.hasNext()) {
                trees[it.nextIndex()] = it.next().getRootNode();
            }
            this.trees = trees;
        }

        static {
            PARSER.declareObjectArray(XGBoostModel::setTrees, XGBoostTree::parse, new ParseField("trees", new String[0]));
            PARSER.declareIntArray(XGBoostModel::setTreeInfo, new ParseField("tree_info", new String[0]));
        }
    }

    static class XGBoostGradientBooster {
        private XGBoostModel model;
        private static final ObjectParser<XGBoostGradientBooster, FeatureSet> PARSER = new ObjectParser("xgboost_gradient_booster", true, XGBoostGradientBooster::new);

        static XGBoostGradientBooster parse(XContentParser parser, FeatureSet set) throws IOException {
            return (XGBoostGradientBooster)PARSER.apply(parser, (Object)set);
        }

        XGBoostGradientBooster() {
        }

        public XGBoostModel getModel() {
            return this.model;
        }

        public void setModel(XGBoostModel model) {
            this.model = model;
        }

        static {
            PARSER.declareObject(XGBoostGradientBooster::setModel, XGBoostModel::parse, new ParseField("model", new String[0]));
        }
    }
}

