Arbol de Decisión en Apache Spark con Python

Cargar datos

# Cargar un dataframe
df = sqlContext.read.format("com.databricks.spark.csv").options(delimiter='\t',header='true',inferschema='true').load("/databricks-datasets/power-plant/data")
display(df)

AT V AP RH PE
14.96 41.76 1024.07 73.17 463.26
25.18 62.96 1020.04 59.08 444.37
5.11 39.4 1012.16 92.14 488.56
20.86 57.32 1010.24 76.64 446.48

 

Generar conjunto de entrenamiento y test

#Definir una semilla
seed = 1800009193L
# Generar un grupo de entrenamiento y otro de prueba con una proporción 80-20
(split20DF, split80DF) = df.randomSplit([.2, .8], seed=seed)
# Cachear los conjuntos de datos
testSetDF = split20DF.cache()
trainingSetDF = split80DF.cache()
display(trainingSetDF)

 

Generar el modelo

from pyspark.ml import Pipeline
from pyspark.ml.regression import DecisionTreeRegressor
from pyspark.ml.feature import VectorIndexer
from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.ml.feature import VectorAssembler

# Definir un vector de ensamblado para que las variables de entrada se queden en una sola "features"
vectorizer = VectorAssembler()
vectorizer.setInputCols(["AT", "V", "AP", "RH"])
vectorizer.setOutputCol("features")
# Definir molelo de arbol de regresión
dt = DecisionTreeRegressor()
# Definir los parametros del modelo:
# - Predicted_PE: columna que almacenará las predicciones estimadas
# - features: columna que almacena el vector de variables predictoras
# - PE: columna que almacena la predicción real 
# - 8 niveles de profundidad
dt.setPredictionCol("Predicted_PE").setMaxBins(100).setFeaturesCol("features").setLabelCol("PE").setMaxDepth(8)
# Crear una 'pipeline' en la cual hay 2 elementos, un 'Vector Assembler' y un modelo 'Decision Tree', accesibles mediante el atributo 'stages'.
pipeline = Pipeline(stages=[vectorizer, dt])
# Ajustar el modelo (Ejecutar)
model = pipeline.fit(trainingSetDF)
# Visualizar los resultados
vectAssembler = model.stages[0]
dtModel = model.stages[1]
print("Nodos: " + str(dtModel.numNodes))
print("Profundidad: "+ str(dtModel.depth)) # summary only
print(dtModel.toDebugString)
Nodos: 503 
Profundidad: 8 
DecisionTreeRegressionModel (uid=DecisionTreeRegressor_4f21b2e2b1f92f4c08f3) of depth 8 with 503 nodes 
If (feature 0 <= 17.75) 
  If (feature 0 <= 11.86) 
    If (feature 0 <= 8.81) 
      If (feature 0 <= 6.92) 
        If (feature 1 <= 40.55) 
          If (feature 1 <= 39.99) 
            If (feature 0 <= 5.23) 
              If (feature 2 <= 1003.91) 
                Predict: 478.4114285714286 
              Else (feature 2 > 1003.91) 
                Predict: 488.2392936802974 
            Else (feature 0 > 5.23) 
              If (feature 1 <= 39.81) 
                Predict: 485.9372307692308 
              Else (feature 1 > 39.81)
              ...

 

Predicción del modelo

predictions = model.transform(testSetDF)
display(predictions)
AT V AP RH PE features Predicted_PE
1.81 39.42 1026.92 76.97 490.55 [1,4,[],[1.81,39.42,1026.92,76.97]] 488.2392936802974
3.2 41.31 997.67 98.84 489.86 [1,4,[],[3.2,41.31,997.67,98.84]] 488.81500000000005
3.38 41.31 998.79 97.76 489.11 [1,4,[],[3.38,41.31,998.79,97.76]] 488.81500000000005
3.4 39.64 1011.1 83.43 459.86 [1,4,[],[3.4,39.64,1011.1,83.43]] 488.2392936802974
3.51 35.47 1017.53 86.56 489.07 [1,4,[],[3.51,35.47,1017.53,86.56]] 488.2392936802974
3.63 38.44 1016.16 87.38 487.87 [1,4,[],[3.63,38.44,1016.16,87.38]] 488.2392936802974

 

Evaluar el modelo

# Cargar libreria de evaluación
from pyspark.ml.evaluation import RegressionEvaluator
# Evaluación mediante el metodo de regresion
regEval = RegressionEvaluator(predictionCol="Predicted_PE", labelCol="PE", metricName="rmse")
# Evaluación 1 - RMSE:  Error cuadrático medio
rmse = regEval.evaluate(predictions)
print(" Error cuadrático medio (RMSE): %.2f" % rmse)
## Error cuadrático medio (RMSE): 3.60
# Evaluación 2 - r2: Coeficiente de determinación
r2 = regEval.evaluate(predictions, {regEval.metricName: "r2"})
print("coeficiente de determinación (r2): {0:.2f}".format(r2))
## Coeficiente de determinación (r2): 0.96
 # Almacenar los datos en una tabla para poder mostrar estadisticas
sqlContext.sql("DROP TABLE IF EXISTS Power_Plant_RMSE_Evaluation")
dbutils.fs.rm("dbfs:/user/hive/warehouse/Power_Plant_RMSE_Evaluation", True)
predictions.selectExpr("PE", "Predicted_PE", "PE - Predicted_PE Residual_Error", "(PE - Predicted_PE) / {} Within_RSME".format(rmse)).registerTempTable("Power_Plant_RMSE_Evaluation")

Otros artículos que pueden ser de interés:

Autor: Diego Calvo