/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.ml.tuning;

import org.apache.spark.annotation.Experimental;
import org.apache.spark.ml.Estimator;
import org.apache.spark.ml.Model;
import org.apache.spark.ml.PipelineStage;
import org.apache.spark.ml.Transformer;
import org.apache.spark.ml.evaluation.Evaluator;
import org.apache.spark.ml.param.DoubleParam;
import org.apache.spark.ml.param.Param;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.ml.param.Params$class;
import org.apache.spark.ml.tuning.TrainValidationSplitModel;
import org.apache.spark.ml.tuning.TrainValidationSplitParams;
import org.apache.spark.ml.tuning.TrainValidationSplitParams$class;
import org.apache.spark.ml.tuning.ValidatorParams$class;
import org.apache.spark.ml.util.Identifiable$;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.types.StructType;
import scala.Array$;
import scala.Function0;
import scala.Function1;
import scala.MatchError;
import scala.Option;
import scala.Predef$;
import scala.Serializable;
import scala.StringContext;
import scala.Tuple2;
import scala.collection.Seq;
import scala.collection.SeqLike;
import scala.collection.immutable.Nil$;
import scala.math.Ordering;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.IntRef;

@Experimental
@ScalaSignature(bytes="\u0006\u0001\u0005ea\u0001B\u0001\u0003\u00015\u0011A\u0003\u0016:bS:4\u0016\r\\5eCRLwN\\*qY&$(BA\u0002\u0005\u0003\u0019!XO\\5oO*\u0011QAB\u0001\u0003[2T!a\u0002\u0005\u0002\u000bM\u0004\u0018M]6\u000b\u0005%Q\u0011AB1qC\u000eDWMC\u0001\f\u0003\ry'oZ\u0002\u0001'\u0011\u0001aBF\r\u0011\u0007=\u0001\"#D\u0001\u0005\u0013\t\tBAA\u0005FgRLW.\u0019;peB\u00111\u0003F\u0007\u0002\u0005%\u0011QC\u0001\u0002\u001a)J\f\u0017N\u001c,bY&$\u0017\r^5p]N\u0003H.\u001b;N_\u0012,G\u000e\u0005\u0002\u0014/%\u0011\u0001D\u0001\u0002\u001b)J\f\u0017N\u001c,bY&$\u0017\r^5p]N\u0003H.\u001b;QCJ\fWn\u001d\t\u00035mi\u0011AB\u0005\u00039\u0019\u0011q\u0001T8hO&tw\r\u0003\u0005\u001f\u0001\t\u0015\r\u0011\"\u0011 \u0003\r)\u0018\u000eZ\u000b\u0002AA\u0011\u0011e\n\b\u0003E\u0015j\u0011a\t\u0006\u0002I\u0005)1oY1mC&\u0011aeI\u0001\u0007!J,G-\u001a4\n\u0005!J#AB*ue&twM\u0003\u0002'G!A1\u0006\u0001B\u0001B\u0003%\u0001%\u0001\u0003vS\u0012\u0004\u0003\"B\u0017\u0001\t\u0003q\u0013A\u0002\u001fj]&$h\b\u0006\u00020aA\u00111\u0003\u0001\u0005\u0006=1\u0002\r\u0001\t\u0005\u0006[\u0001!\tA\r\u000b\u0002_!)A\u0007\u0001C\u0001k\u0005a1/\u001a;FgRLW.\u0019;peR\u0011agN\u0007\u0002\u0001!)\u0001h\ra\u0001s\u0005)a/\u00197vKB\u0012!(\u0010\t\u0004\u001fAY\u0004C\u0001\u001f>\u0019\u0001!\u0011BP\u001c\u0002\u0002\u0003\u0005)\u0011A \u0003\u0007}#\u0013'\u0005\u0002A\u0007B\u0011!%Q\u0005\u0003\u0005\u000e\u0012qAT8uQ&tw\r\u0005\u0002#\t&\u0011Qi\t\u0002\u0004\u0003:L\b\"B$\u0001\t\u0003A\u0015!F:fi\u0016\u001bH/[7bi>\u0014\b+\u0019:b[6\u000b\u0007o\u001d\u000b\u0003m%CQ\u0001\u000f$A\u0002)\u00032AI&N\u0013\ta5EA\u0003BeJ\f\u0017\u0010\u0005\u0002O#6\tqJ\u0003\u0002Q\t\u0005)\u0001/\u0019:b[&\u0011!k\u0014\u0002\t!\u0006\u0014\u0018-\\'ba\")A\u000b\u0001C\u0001+\u0006a1/\u001a;Fm\u0006dW/\u0019;peR\u0011aG\u0016\u0005\u0006qM\u0003\ra\u0016\t\u00031nk\u0011!\u0017\u0006\u00035\u0012\t!\"\u001a<bYV\fG/[8o\u0013\ta\u0016LA\u0005Fm\u0006dW/\u0019;pe\")a\f\u0001C\u0001?\u0006i1/\u001a;Ue\u0006LgNU1uS>$\"A\u000e1\t\u000baj\u0006\u0019A1\u0011\u0005\t\u0012\u0017BA2$\u0005\u0019!u.\u001e2mK\")Q\r\u0001C!M\u0006\u0019a-\u001b;\u0015\u0005I9\u0007\"\u00025e\u0001\u0004I\u0017a\u00023bi\u0006\u001cX\r\u001e\t\u0003U6l\u0011a\u001b\u0006\u0003Y\u001a\t1a]9m\u0013\tq7NA\u0005ECR\fgI]1nK\")\u0001\u000f\u0001C!c\u0006yAO]1og\u001a|'/\\*dQ\u0016l\u0017\r\u0006\u0002sqB\u00111O^\u0007\u0002i*\u0011Qo[\u0001\u0006if\u0004Xm]\u0005\u0003oR\u0014!b\u0015;sk\u000e$H+\u001f9f\u0011\u0015Ix\u000e1\u0001s\u0003\u0019\u00198\r[3nC\")1\u0010\u0001C!y\u0006qa/\u00197jI\u0006$X\rU1sC6\u001cH#A?\u0011\u0005\tr\u0018BA@$\u0005\u0011)f.\u001b;\t\u000f\u0005\r\u0001\u0001\"\u0011\u0002\u0006\u0005!1m\u001c9z)\ry\u0013q\u0001\u0005\b\u0003\u0013\t\t\u00011\u0001N\u0003\u0015)\u0007\u0010\u001e:bQ\r\u0001\u0011Q\u0002\t\u0005\u0003\u001f\t)\"\u0004\u0002\u0002\u0012)\u0019\u00111\u0003\u0004\u0002\u0015\u0005tgn\u001c;bi&|g.\u0003\u0003\u0002\u0018\u0005E!\u0001D#ya\u0016\u0014\u0018.\\3oi\u0006d\u0007")
public class TrainValidationSplit
extends Estimator<TrainValidationSplitModel>
implements TrainValidationSplitParams {
    private final String uid;
    private final DoubleParam trainRatio;
    private final Param<Estimator<?>> estimator;
    private final Param<ParamMap[]> estimatorParamMaps;
    private final Param<Evaluator> evaluator;

    @Override
    public DoubleParam trainRatio() {
        return this.trainRatio;
    }

    @Override
    public void org$apache$spark$ml$tuning$TrainValidationSplitParams$_setter_$trainRatio_$eq(DoubleParam x$1) {
        this.trainRatio = x$1;
    }

    @Override
    public double getTrainRatio() {
        return TrainValidationSplitParams$class.getTrainRatio(this);
    }

    @Override
    public Param<Estimator<?>> estimator() {
        return this.estimator;
    }

    @Override
    public Param<ParamMap[]> estimatorParamMaps() {
        return this.estimatorParamMaps;
    }

    @Override
    public Param<Evaluator> evaluator() {
        return this.evaluator;
    }

    @Override
    public void org$apache$spark$ml$tuning$ValidatorParams$_setter_$estimator_$eq(Param x$1) {
        this.estimator = x$1;
    }

    @Override
    public void org$apache$spark$ml$tuning$ValidatorParams$_setter_$estimatorParamMaps_$eq(Param x$1) {
        this.estimatorParamMaps = x$1;
    }

    @Override
    public void org$apache$spark$ml$tuning$ValidatorParams$_setter_$evaluator_$eq(Param x$1) {
        this.evaluator = x$1;
    }

    @Override
    public Estimator<?> getEstimator() {
        return ValidatorParams$class.getEstimator(this);
    }

    @Override
    public ParamMap[] getEstimatorParamMaps() {
        return ValidatorParams$class.getEstimatorParamMaps(this);
    }

    @Override
    public Evaluator getEvaluator() {
        return ValidatorParams$class.getEvaluator(this);
    }

    @Override
    public String uid() {
        return this.uid;
    }

    public TrainValidationSplit setEstimator(Estimator<?> value) {
        return (TrainValidationSplit)this.set(this.estimator(), value);
    }

    public TrainValidationSplit setEstimatorParamMaps(ParamMap[] value) {
        return (TrainValidationSplit)this.set(this.estimatorParamMaps(), value);
    }

    public TrainValidationSplit setEvaluator(Evaluator value) {
        return (TrainValidationSplit)this.set(this.evaluator(), value);
    }

    public TrainValidationSplit setTrainRatio(double value) {
        return (TrainValidationSplit)this.set(this.trainRatio(), BoxesRunTime.boxToDouble((double)value));
    }

    @Override
    public TrainValidationSplitModel fit(DataFrame dataset) {
        StructType schema = dataset.schema();
        this.transformSchema(schema, true);
        SQLContext sqlCtx = dataset.sqlContext();
        Estimator<?> est = this.$(this.estimator());
        Evaluator eval2 = this.$(this.evaluator());
        ParamMap[] epm = this.$(this.estimatorParamMaps());
        int numModels = epm.length;
        double[] metrics = new double[epm.length];
        RDD[] rDDArray = dataset.rdd().randomSplit(new double[]{BoxesRunTime.unboxToDouble((Object)this.$(this.trainRatio())), 1.0 - BoxesRunTime.unboxToDouble((Object)this.$(this.trainRatio()))}, dataset.rdd().randomSplit$default$2());
        Option option = Array$.MODULE$.unapplySeq((Object)rDDArray);
        if (!option.isEmpty() && option.get() != null && ((SeqLike)option.get()).lengthCompare(2) == 0) {
            Tuple2 tuple2;
            Tuple2 tuple22;
            RDD training = (RDD)((SeqLike)option.get()).apply(0);
            RDD validation = (RDD)((SeqLike)option.get()).apply(1);
            Tuple2 tuple23 = tuple22 = new Tuple2((Object)training, (Object)validation);
            RDD training2 = (RDD)tuple23._1();
            RDD validation2 = (RDD)tuple23._2();
            DataFrame trainingDataset = sqlCtx.createDataFrame(training2, schema).cache();
            DataFrame validationDataset = sqlCtx.createDataFrame(validation2, schema).cache();
            this.logDebug((Function0<String>)new Serializable(this){
                public static final long serialVersionUID = 0L;

                public final String apply() {
                    return new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"Train split with multiple sets of parameters."})).s((Seq)Nil$.MODULE$);
                }
            });
            Seq<?> models = est.fit(trainingDataset, epm);
            trainingDataset.unpersist();
            IntRef i = IntRef.create((int)0);
            while (i.elem < numModels) {
                double metric = eval2.evaluate(((Transformer)models.apply(i.elem)).transform(validationDataset, epm[i.elem]));
                this.logDebug((Function0<String>)new Serializable(this, epm, i, metric){
                    public static final long serialVersionUID = 0L;
                    private final ParamMap[] epm$1;
                    private final IntRef i$1;
                    private final double metric$1;

                    public final String apply() {
                        return new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"Got metric ", " for model trained with ", "."})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{BoxesRunTime.boxToDouble((double)this.metric$1), this.epm$1[this.i$1.elem]}));
                    }
                    {
                        this.epm$1 = epm$1;
                        this.i$1 = i$1;
                        this.metric$1 = metric$1;
                    }
                });
                int n = i.elem++;
                metrics[n] = metrics[n] + metric;
            }
            validationDataset.unpersist();
            this.logInfo((Function0<String>)new Serializable(this, metrics){
                public static final long serialVersionUID = 0L;
                private final double[] metrics$1;

                public final String apply() {
                    return new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"Train validation split metrics: ", ""})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{Predef$.MODULE$.doubleArrayOps(this.metrics$1).toSeq()}));
                }
                {
                    this.metrics$1 = metrics$1;
                }
            });
            Tuple2 tuple24 = tuple2 = eval2.isLargerBetter() ? (Tuple2)Predef$.MODULE$.refArrayOps((Object[])Predef$.MODULE$.doubleArrayOps(metrics).zipWithIndex(Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Tuple2.class)))).maxBy((Function1)new Serializable(this){
                public static final long serialVersionUID = 0L;

                public final double apply(Tuple2<Object, Object> x$2) {
                    return x$2._1$mcD$sp();
                }
            }, (Ordering)Ordering.Double$.MODULE$) : (Tuple2)Predef$.MODULE$.refArrayOps((Object[])Predef$.MODULE$.doubleArrayOps(metrics).zipWithIndex(Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Tuple2.class)))).minBy((Function1)new Serializable(this){
                public static final long serialVersionUID = 0L;

                public final double apply(Tuple2<Object, Object> x$3) {
                    return x$3._1$mcD$sp();
                }
            }, (Ordering)Ordering.Double$.MODULE$);
            if (tuple2 != null) {
                Tuple2.mcDI.sp sp2;
                double bestMetric = tuple2._1$mcD$sp();
                int bestIndex = tuple2._2$mcI$sp();
                Tuple2.mcDI.sp sp3 = sp2 = new Tuple2.mcDI.sp(bestMetric, bestIndex);
                double bestMetric2 = sp3._1$mcD$sp();
                int bestIndex2 = sp3._2$mcI$sp();
                this.logInfo((Function0<String>)new Serializable(this, epm, bestIndex2){
                    public static final long serialVersionUID = 0L;
                    private final ParamMap[] epm$1;
                    private final int bestIndex$1;

                    public final String apply() {
                        return new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"Best set of parameters:\\n", ""})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{this.epm$1[this.bestIndex$1]}));
                    }
                    {
                        this.epm$1 = epm$1;
                        this.bestIndex$1 = bestIndex$1;
                    }
                });
                this.logInfo((Function0<String>)new Serializable(this, bestMetric2){
                    public static final long serialVersionUID = 0L;
                    private final double bestMetric$1;

                    public final String apply() {
                        return new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"Best train validation split metric: ", "."})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{BoxesRunTime.boxToDouble((double)this.bestMetric$1)}));
                    }
                    {
                        this.bestMetric$1 = bestMetric$1;
                    }
                });
                Object bestModel = est.fit(dataset, epm[bestIndex2]);
                return this.copyValues(new TrainValidationSplitModel(this.uid(), (Model<?>)bestModel, metrics).setParent(this), this.copyValues$default$2());
            }
            throw new MatchError((Object)tuple2);
        }
        throw new MatchError((Object)rDDArray);
    }

    @Override
    public StructType transformSchema(StructType schema) {
        return ((PipelineStage)this.$(this.estimator())).transformSchema(schema);
    }

    @Override
    public void validateParams() {
        Params$class.validateParams(this);
        Estimator<?> est = this.$(this.estimator());
        Predef$.MODULE$.refArrayOps((Object[])this.$(this.estimatorParamMaps())).foreach((Function1)new Serializable(this, est){
            public static final long serialVersionUID = 0L;
            private final Estimator est$1;

            public final void apply(ParamMap paramMap) {
                this.est$1.copy(paramMap).validateParams();
            }
            {
                this.est$1 = est$1;
            }
        });
    }

    /*
     * WARNING - void declaration
     */
    @Override
    public TrainValidationSplit copy(ParamMap extra) {
        void var2_2;
        TrainValidationSplit copied = (TrainValidationSplit)this.defaultCopy(extra);
        Object object = copied.isDefined(this.estimator()) ? copied.setEstimator(copied.getEstimator().copy(extra)) : BoxedUnit.UNIT;
        Object object2 = copied.isDefined(this.evaluator()) ? copied.setEvaluator(copied.getEvaluator().copy(extra)) : BoxedUnit.UNIT;
        return var2_2;
    }

    public TrainValidationSplit(String uid) {
        this.uid = uid;
        ValidatorParams$class.$init$(this);
        TrainValidationSplitParams$class.$init$(this);
    }

    public TrainValidationSplit() {
        this(Identifiable$.MODULE$.randomUID("tvs"));
    }
}

