
Configuración de sesión de Spark
import findspark
findspark.init()
from pyspark.sql import SparkSession
spark = SparkSession.builder.getOrCreate()
LibrerÃas a usar
#Pyspark libraries/modules
from pyspark.sql import *
from pyspark.sql import Window
from pyspark.sql import functions as F
from pyspark.sql.functions import row_number, monotonically_increasing_id
from pyspark.sql.types import StringType, DoubleType, IntegerType, ArrayType, DateType
#Python libraries/modules
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
import seaborn as sns
import warnings
warnings.filterwarnings("ignore")
from scipy import stats
from scipy.stats import kstest
#Pyspark Machine Learning libraries/modules
from pyspark.mllib.stat import Statistics
from pyspark.ml import Pipeline
from pyspark.ml.feature import OneHotEncoder, StringIndexer, VectorAssembler, MinMaxScaler
from pyspark.ml.classification import RandomForestClassifier, GBTClassifier, LogisticRegression
from pyspark.ml.evaluation import BinaryClassificationEvaluator, MulticlassClassificationEvaluator
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
from pyspark.mllib.linalg import Vectors
from pyspark.mllib.stat import Statistics
import re
En esta oportunidad vamos a retomar el caso publicado anteriormente en el que realice un análisis exploratorio de una base de posibles clientes bancarios.
Si no lo viste te invito a pasar por el notebook anterior.
Ahora que la base se encuentra limpia procedemos a aplicar un modelo de clasificación Random
pd_df = pd.read_excel('bank_full_final.xlsx')
df = spark.createDataFrame(pd_df)
df = df.drop('Target_no','Target_yes')
print('filas',df.count(),len(df.columns),"columnas")
df.printSchema()
Codificación de variables Targets
df = df.withColumn('Target',F.when(F.col('Target') == 'no',0)
.otherwise(1))
Separación Train/Test (Hold-out)
train, test = df.randomSplit([0.80,0.20])
#test = test.drop('Target')
print('Filas training set {}'.format(train.count()))
print('Filas validation set {}'.format(test.count()))
print('Filas training set {}'.format(len(train.columns)))
print('Filas validation set {}'.format(len(test.columns)))
Selección de Features
feature = [c for c,t in df.dtypes if t != 'string']
feature.remove('Target')
feature
from pyspark.ml.feature import VectorAssembler
assembler = VectorAssembler(inputCols=feature, outputCol='features')
train = assembler.transform(train)
test = assembler.transform(test)
train.select('features').take(1)
test.show(1,vertical = True)
train.select('features').show(2)
from pyspark.ml.stat import Correlation
matrix = Correlation.corr(train.select('features'), 'features')
matrix_np = matrix.collect()[0]["pearson({})".format('features')].values
import seaborn as sns
matrix_np = matrix_np.reshape(len(feature),len(feature))
fig, ax = plt.subplots(figsize=(20,20))
ax = sns.heatmap(matrix_np, cmap="YlGnBu")
ax.xaxis.set_ticklabels(feature, rotation=270)
ax.yaxis.set_ticklabels(feature, rotation=0)
ax.set_title("Correlation Matrix")
plt.tight_layout()
plt.show()
from pyspark.ml.classification import RandomForestClassifier
rf = RandomForestClassifier(labelCol='Target', featuresCol='features', numTrees=50)
rf = rf.fit(train)
pred = rf.transform(test)
pred.select('Target', 'prediction', 'probability').show(5)
pred_pd = pred.select(['Target', 'prediction', 'probability']).toPandas()
pred_pd.head()
import numpy as np
pred_pd['probability'] = pred_pd['probability'].map(lambda x: list(x))
pred_pd['encoded_Target'] = pred_pd['Target'].map(lambda x: np.eye(2)[int(x)])
y_pred = np.array(pred_pd['probability'].tolist())
y_true = np.array(pred_pd['encoded_Target'].tolist())
from sklearn.metrics import auc, roc_curve
fpr, tpr, threshold = roc_curve(y_score=y_pred[:,0], y_true=y_true[:,0])
auc = auc(fpr, tpr)
print('AUC: {:.3f}'.format(auc))
plt.figure()
plt.subplots(figsize=(15,15))
plt.plot([0,1], [0,1], 'k--', color='orange')
plt.plot(fpr, tpr, label='auc = {:.3f}'.format(auc))
plt.xlabel('False positive rate')
plt.ylabel('True positive rate')
plt.title('Curva ROC')
plt.legend(loc='lower right')
plt.grid()
plt.show()