Regresión lineal en Scala

En el siguiente post se muestran los pasos a seguir para recrear un ejemplo de regresión lineal en Scala.scala_logo

Definir el conjunto de datos

Se define el conjunto de datos sobre el que aplicar el modelo

import org.apache.spark.ml.linalg.Vectors
val df = spark.createDataFrame(Seq(
    (0, 60),
    (0, 56),
    (0, 54),
    (0, 62),
    (0, 61),
    (0, 53),
    (0, 55),
    (0, 62),
    (0, 64),
    (1, 73),
    (1, 78),
    (1, 67),
    (1, 68),
    (1, 78)
)).toDF("defecto" , "temperatura")

 

Definir el modelo mediante tuberias

import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.ml.regression.{LinearRegression, LinearRegressionModel}
// Definir características
val features = new VectorAssembler()
  .setInputCols(Array("temperatura"))
  .setOutputCol("features")
// Definir modelo a utilizar
val lr = new LinearRegression().setLabelCol("defecto")
// Crear una tuberia que asocie el modelo con la secuencia de tratamiento de datos
val pipeline = new Pipeline().setStages(Array(features, lr))
//Ejecutar el modelo
val model = pipeline.fit(df)

 

Mostrar resultados del modelo

val linRegModel = model.stages(1).asInstanceOf[LinearRegressionModel]
println(s"RMSE:  ${linRegModel.summary.rootMeanSquaredError}")
println(s"r2:    ${linRegModel.summary.r2}")
println(s"Model: Y = ${linRegModel.coefficients(0)} * X + ${linRegModel.intercept}")
linRegModel.summary.residuals.show()
RMSE:  0.24965353110553395
r2:    0.7285317871929219
Model: Y = 0.05114497726003437 * X + -2.8978696241921877
+--------------------+
|           residuals|
+--------------------+
|-0.17082901140987428|
|0.033750897630262955|
| 0.13604085215033157|
|-0.27311896592994334|
| -0.2219739886699088|
|  0.1871858294103661|
| 0.08489587489029748|
|-0.27311896592994334|
|-0.37540892045001195|
|  0.1642862842096786|
|-0.09143860209049315|
|  0.4711561477698849|
|  0.4200111705098504|
|-0.09143860209049315|
+--------------------+

Mostrar prediciones

val result = model.transform(data).select("temperatura", "defecto", "prediction")
result.show()

+-----------+-------+--------------------+
|temperatura|defecto|          prediction|
+-----------+-------+--------------------+
|         60|      0| 0.17082901140987428|
|         56|      0|-0.03375089763026...|
|         54|      0|-0.13604085215033157|
|         62|      0| 0.27311896592994334|
|         61|      0|  0.2219739886699088|
|         53|      0| -0.1871858294103661|
|         55|      0|-0.08489587489029748|
|         62|      0| 0.27311896592994334|
|         64|      0| 0.37540892045001195|
|         73|      1|  0.8357137157903214|
|         78|      1|  1.0914386020904931|
|         67|      1|  0.5288438522301151|
|         68|      1|  0.5799888294901496|
|         78|      1|  1.0914386020904931|
+-----------+-------+--------------------+