/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.ml.optim;

import java.io.Serializable;
import org.apache.spark.ml.feature.Instance;
import org.apache.spark.ml.feature.OffsetInstance;
import org.apache.spark.ml.linalg.BLAS$;
import org.apache.spark.ml.linalg.DenseVector;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.optim.IterativelyReweightedLeastSquaresModel;
import org.apache.spark.ml.optim.WeightedLeastSquares;
import org.apache.spark.ml.optim.WeightedLeastSquares$;
import org.apache.spark.ml.optim.WeightedLeastSquaresModel;
import org.apache.spark.ml.util.OptionalInstrumentation;
import org.apache.spark.ml.util.OptionalInstrumentation$;
import org.apache.spark.rdd.RDD;
import scala.Function0;
import scala.Function1;
import scala.Function2;
import scala.MatchError;
import scala.Predef$;
import scala.Tuple2;
import scala.collection.mutable.ArrayOps;
import scala.math.package$;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxesRunTime;
import scala.runtime.IntRef;
import scala.runtime.ObjectRef;
import scala.runtime.java8.JFunction2;

@ScalaSignature(bytes="\u0006\u0001i4Q\u0001E\t\u0001'mA\u0001\"\n\u0001\u0003\u0006\u0004%\ta\n\u0005\tY\u0001\u0011\t\u0011)A\u0005Q!AQ\u0006\u0001BC\u0002\u0013\u0005a\u0006\u0003\u0005?\u0001\t\u0005\t\u0015!\u00030\u0011!y\u0004A!b\u0001\n\u0003\u0001\u0005\u0002\u0003#\u0001\u0005\u0003\u0005\u000b\u0011B!\t\u0011\u0015\u0003!Q1A\u0005\u0002\u0019C\u0001b\u0012\u0001\u0003\u0002\u0003\u0006Ia\u000f\u0005\t\u0011\u0002\u0011)\u0019!C\u0001\u0013\"AQ\n\u0001B\u0001B\u0003%!\n\u0003\u0005O\u0001\t\u0015\r\u0011\"\u0001G\u0011!y\u0005A!A!\u0002\u0013Y\u0004\"\u0002)\u0001\t\u0003\t\u0006\"B-\u0001\t\u0003Q\u0006b\u00028\u0001#\u0003%\ta\u001c\u0002\"\u0013R,'/\u0019;jm\u0016d\u0017PU3xK&<\u0007\u000e^3e\u0019\u0016\f7\u000f^*rk\u0006\u0014Xm\u001d\u0006\u0003%M\tQa\u001c9uS6T!\u0001F\u000b\u0002\u00055d'B\u0001\f\u0018\u0003\u0015\u0019\b/\u0019:l\u0015\tA\u0012$\u0001\u0004ba\u0006\u001c\u0007.\u001a\u0006\u00025\u0005\u0019qN]4\u0014\u0007\u0001a\"\u0005\u0005\u0002\u001eA5\taDC\u0001 \u0003\u0015\u00198-\u00197b\u0013\t\tcD\u0001\u0004B]f\u0014VM\u001a\t\u0003;\rJ!\u0001\n\u0010\u0003\u0019M+'/[1mSj\f'\r\\3\u0002\u0019%t\u0017\u000e^5bY6{G-\u001a7\u0004\u0001U\t\u0001\u0006\u0005\u0002*U5\t\u0011#\u0003\u0002,#\tIr+Z5hQR,G\rT3bgR\u001c\u0016/^1sKNlu\u000eZ3m\u00035Ig.\u001b;jC2lu\u000eZ3mA\u0005a!/Z<fS\u001eDGOR;oGV\tq\u0006E\u0003\u001eaIB\u0003(\u0003\u00022=\tIa)\u001e8di&|gN\r\t\u0003gYj\u0011\u0001\u000e\u0006\u0003kM\tqAZ3biV\u0014X-\u0003\u00028i\tqqJ\u001a4tKRLen\u001d;b]\u000e,\u0007\u0003B\u000f:wmJ!A\u000f\u0010\u0003\rQ+\b\u000f\\33!\tiB(\u0003\u0002>=\t1Ai\\;cY\u0016\fQB]3xK&<\u0007\u000e\u001e$v]\u000e\u0004\u0013\u0001\u00044ji&sG/\u001a:dKB$X#A!\u0011\u0005u\u0011\u0015BA\"\u001f\u0005\u001d\u0011un\u001c7fC:\fQBZ5u\u0013:$XM]2faR\u0004\u0013\u0001\u0003:fOB\u000b'/Y7\u0016\u0003m\n\u0011B]3h!\u0006\u0014\u0018-\u001c\u0011\u0002\u000f5\f\u00070\u0013;feV\t!\n\u0005\u0002\u001e\u0017&\u0011AJ\b\u0002\u0004\u0013:$\u0018\u0001C7bq&#XM\u001d\u0011\u0002\u0007Q|G.\u0001\u0003u_2\u0004\u0013A\u0002\u001fj]&$h\bF\u0004S'R+fk\u0016-\u0011\u0005%\u0002\u0001\"B\u0013\u000e\u0001\u0004A\u0003\"B\u0017\u000e\u0001\u0004y\u0003\"B \u000e\u0001\u0004\t\u0005\"B#\u000e\u0001\u0004Y\u0004\"\u0002%\u000e\u0001\u0004Q\u0005\"\u0002(\u000e\u0001\u0004Y\u0014a\u00014jiR\u00191L\u00184\u0011\u0005%b\u0016BA/\u0012\u0005\u0019JE/\u001a:bi&4X\r\\=SK^,\u0017n\u001a5uK\u0012dU-Y:u'F,\u0018M]3t\u001b>$W\r\u001c\u0005\u0006?:\u0001\r\u0001Y\u0001\nS:\u001cH/\u00198dKN\u00042!\u001933\u001b\u0005\u0011'BA2\u0016\u0003\r\u0011H\rZ\u0005\u0003K\n\u00141A\u0015#E\u0011\u001d9g\u0002%AA\u0002!\fQ!\u001b8tiJ\u0004\"!\u001b7\u000e\u0003)T!a[\n\u0002\tU$\u0018\u000e\\\u0005\u0003[*\u0014qc\u00149uS>t\u0017\r\\%ogR\u0014X/\\3oi\u0006$\u0018n\u001c8\u0002\u001b\u0019LG\u000f\n3fM\u0006,H\u000e\u001e\u00133+\u0005\u0001(F\u00015rW\u0005\u0011\bCA:y\u001b\u0005!(BA;w\u0003%)hn\u00195fG.,GM\u0003\u0002x=\u0005Q\u0011M\u001c8pi\u0006$\u0018n\u001c8\n\u0005e$(!E;oG\",7m[3e-\u0006\u0014\u0018.\u00198dK\u0002")
public class IterativelyReweightedLeastSquares
implements scala.Serializable {
    private final WeightedLeastSquaresModel initialModel;
    private final Function2<OffsetInstance, WeightedLeastSquaresModel, Tuple2<Object, Object>> reweightFunc;
    private final boolean fitIntercept;
    private final double regParam;
    private final int maxIter;
    private final double tol;

    public WeightedLeastSquaresModel initialModel() {
        return this.initialModel;
    }

    public Function2<OffsetInstance, WeightedLeastSquaresModel, Tuple2<Object, Object>> reweightFunc() {
        return this.reweightFunc;
    }

    public boolean fitIntercept() {
        return this.fitIntercept;
    }

    public double regParam() {
        return this.regParam;
    }

    public int maxIter() {
        return this.maxIter;
    }

    public double tol() {
        return this.tol;
    }

    public IterativelyReweightedLeastSquaresModel fit(RDD<OffsetInstance> instances, OptionalInstrumentation instr) {
        boolean converged = false;
        IntRef iter = IntRef.create((int)0);
        WeightedLeastSquaresModel model = this.initialModel();
        ObjectRef oldModel = ObjectRef.create(null);
        while (iter.elem < this.maxIter() && !converged) {
            oldModel.elem = model;
            RDD newInstances = instances.map((Function1 & Serializable & scala.Serializable)instance -> {
                Tuple2 tuple2 = (Tuple2)this.reweightFunc().apply(instance, (Object)((WeightedLeastSquaresModel)oldModel$1.elem));
                if (tuple2 == null) {
                    throw new MatchError((Object)tuple2);
                }
                double newLabel = tuple2._1$mcD$sp();
                double newWeight = tuple2._2$mcD$sp();
                Tuple2.mcDD.sp sp2 = new Tuple2.mcDD.sp(newLabel, newWeight);
                Tuple2.mcDD.sp sp3 = sp2;
                double newLabel2 = sp3._1$mcD$sp();
                double newWeight2 = sp3._2$mcD$sp();
                return new Instance(newLabel2, newWeight2, instance.features());
            }, ClassTag$.MODULE$.apply(Instance.class));
            model = new WeightedLeastSquares(this.fitIntercept(), this.regParam(), 0.0, false, false, WeightedLeastSquares$.MODULE$.$lessinit$greater$default$6(), WeightedLeastSquares$.MODULE$.$lessinit$greater$default$7(), WeightedLeastSquares$.MODULE$.$lessinit$greater$default$8()).fit((RDD<Instance>)newInstances, instr);
            DenseVector oldCoefficients = ((WeightedLeastSquaresModel)oldModel.elem).coefficients();
            DenseVector coefficients = model.coefficients();
            BLAS$.MODULE$.axpy(-1.0, (Vector)coefficients, (Vector)oldCoefficients);
            double maxTolOfCoefficients = BoxesRunTime.unboxToDouble((Object)new ArrayOps.ofDouble(Predef$.MODULE$.doubleArrayOps(oldCoefficients.toArray())).foldLeft((Object)BoxesRunTime.boxToDouble((double)0.0), (Function2)(JFunction2.mcDDD.sp & Serializable & scala.Serializable)(x, y) -> package$.MODULE$.max(package$.MODULE$.abs(x), package$.MODULE$.abs(y))));
            double maxTol = package$.MODULE$.max(maxTolOfCoefficients, package$.MODULE$.abs(((WeightedLeastSquaresModel)oldModel.elem).intercept() - model.intercept()));
            if (maxTol < this.tol()) {
                converged = true;
                instr.logInfo((Function0<String>)(Function0 & Serializable & scala.Serializable)() -> new StringBuilder(30).append("IRLS converged in ").append(iter$1.elem).append(" iterations.").toString());
            }
            instr.logInfo((Function0<String>)(Function0 & Serializable & scala.Serializable)() -> new StringBuilder(34).append("Iteration ").append(iter$1.elem).append(" : relative tolerance = ").append(maxTol).toString());
            ++iter.elem;
            if (iter.elem != this.maxIter()) continue;
            instr.logInfo((Function0<String>)(Function0 & Serializable & scala.Serializable)() -> new StringBuilder(44).append("IRLS reached the max number of iterations: ").append(this.maxIter()).append(".").toString());
        }
        return new IterativelyReweightedLeastSquaresModel(model.coefficients(), model.intercept(), model.diagInvAtWA(), iter.elem);
    }

    public OptionalInstrumentation fit$default$2() {
        return OptionalInstrumentation$.MODULE$.create(IterativelyReweightedLeastSquares.class);
    }

    public IterativelyReweightedLeastSquares(WeightedLeastSquaresModel initialModel, Function2<OffsetInstance, WeightedLeastSquaresModel, Tuple2<Object, Object>> reweightFunc, boolean fitIntercept, double regParam, int maxIter, double tol) {
        this.initialModel = initialModel;
        this.reweightFunc = reweightFunc;
        this.fitIntercept = fitIntercept;
        this.regParam = regParam;
        this.maxIter = maxIter;
        this.tol = tol;
    }
}

