/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.regression.example;

import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.config.Configurable;
import com.oracle.labs.mlrg.olcut.config.PropertyException;
import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.Provenance;
import com.oracle.labs.mlrg.olcut.provenance.impl.SkeletalConfiguredObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import org.tribuo.ConfigurableDataSource;
import org.tribuo.DataSource;
import org.tribuo.Example;
import org.tribuo.MutableDataset;
import org.tribuo.Output;
import org.tribuo.OutputFactory;
import org.tribuo.impl.ArrayExample;
import org.tribuo.provenance.ConfiguredDataSourceProvenance;
import org.tribuo.provenance.DataSourceProvenance;
import org.tribuo.regression.RegressionFactory;
import org.tribuo.regression.Regressor;

public class GaussianDataSource
implements ConfigurableDataSource<Regressor> {
    @Config(mandatory=true, description="The number of samples to draw.")
    private int numSamples;
    @Config(description="The slope of the line.")
    private float slope = 0.0f;
    @Config(description="The y-intercept of the line.")
    private float intercept = 0.0f;
    @Config(description="The variance of the gaussian.")
    private float variance = 1.0f;
    @Config(mandatory=true, description="The minimum feature value.")
    private float xMin;
    @Config(mandatory=true, description="The maximum feature value.")
    private float xMax;
    @Config(description="The RNG seed.")
    private long seed = 12345L;
    private List<Example<Regressor>> examples;
    private final RegressionFactory factory = new RegressionFactory();

    private GaussianDataSource() {
    }

    public GaussianDataSource(int numSamples, float slope, float intercept, float variance, float xMin, float xMax, long seed) {
        this.numSamples = numSamples;
        this.slope = slope;
        this.intercept = intercept;
        this.variance = variance;
        this.xMin = xMin;
        this.xMax = xMax;
        this.seed = seed;
        this.postConfig();
    }

    public void postConfig() {
        Random rng = new Random(this.seed);
        ArrayList<ArrayExample> examples = new ArrayList<ArrayExample>(this.numSamples);
        if (this.xMax <= this.xMin) {
            throw new PropertyException("", "xMax", "xMax must be greater than xMin, found xMax = " + this.xMax + ", xMin = " + this.xMin);
        }
        if ((double)this.variance <= 0.0) {
            throw new PropertyException("", "variance", "Variance must be positive, found variance = " + this.variance);
        }
        double range = this.xMax - this.xMin;
        for (int i = 0; i < this.numSamples; ++i) {
            double input = rng.nextDouble() * range + (double)this.xMin;
            Regressor output = new Regressor("Y", rng.nextGaussian() * (double)this.variance + ((double)this.slope * input + (double)this.intercept));
            ArrayExample e = new ArrayExample((Output)output, new String[]{"X"}, new double[]{input});
            examples.add(e);
        }
        this.examples = Collections.unmodifiableList(examples);
    }

    public OutputFactory<Regressor> getOutputFactory() {
        return this.factory;
    }

    public DataSourceProvenance getProvenance() {
        return new GaussianDataSourceProvenance(this);
    }

    public Iterator<Example<Regressor>> iterator() {
        return this.examples.iterator();
    }

    public static MutableDataset<Regressor> generateDataset(int numSamples, float slope, float intercept, float variance, float xMin, float xMax, long seed) {
        GaussianDataSource source = new GaussianDataSource(numSamples, slope, intercept, variance, xMin, xMax, seed);
        return new MutableDataset((DataSource)source);
    }

    public static class GaussianDataSourceProvenance
    extends SkeletalConfiguredObjectProvenance
    implements ConfiguredDataSourceProvenance {
        private static final long serialVersionUID = 1L;

        GaussianDataSourceProvenance(GaussianDataSource host) {
            super((Configurable)host, "DataSource");
        }

        public GaussianDataSourceProvenance(Map<String, Provenance> map) {
            this(GaussianDataSourceProvenance.extractProvenanceInfo(map));
        }

        private GaussianDataSourceProvenance(SkeletalConfiguredObjectProvenance.ExtractedInfo info) {
            super(info);
        }

        protected static SkeletalConfiguredObjectProvenance.ExtractedInfo extractProvenanceInfo(Map<String, Provenance> map) {
            HashMap<String, Provenance> configuredParameters = new HashMap<String, Provenance>(map);
            String className = ((StringProvenance)ObjectProvenance.checkAndExtractProvenance(configuredParameters, (String)"class-name", StringProvenance.class, (String)GaussianDataSourceProvenance.class.getSimpleName())).getValue();
            String hostTypeStringName = ((StringProvenance)ObjectProvenance.checkAndExtractProvenance(configuredParameters, (String)"host-short-name", StringProvenance.class, (String)GaussianDataSourceProvenance.class.getSimpleName())).getValue();
            return new SkeletalConfiguredObjectProvenance.ExtractedInfo(className, hostTypeStringName, configuredParameters, Collections.emptyMap());
        }
    }
}

