¿Qué son los árboles de clanes?
Árboles del clan es un algoritmo de aprendizaje automático versátil capaz de realizar tareas de clasificación y regresión. Son algoritmos muy poderosos, capaces de adaptarse a un conjunto de datos complejo. Además, los árboles de decisión son componentes básicos de los bosques aleatorios, que se encuentran entre los algoritmos de aprendizaje automático más potentes disponibles en la actualidad.
Entrenamiento y visualización del árbol de decisiones
Para construir su primer árbol de decisión en el ejemplo R, continuaremos en este tutorial de árbol de decisión de la siguiente manera:
Paso 1) Importar los datos
Si tiene curiosidad sobre el destino del titanio, puede ver este video YouTube. El objetivo de este conjunto de datos es predecir quién tiene más probabilidades de sobrevivir a la colisión con el iceberg. El conjunto de datos contiene 13 variables y 1309 observaciones. La variable X controla el conjunto de datos.
set.seed(678) path <- 'https://raw.githubusercontent.com/guru99-edu/R-Programming/master/titanic_data.csv' titanic <-read.csv(path) head(titanic)
Producción:
## X pclass survived name sex ## 1 1 1 1 Allen, Miss. Elisabeth Walton female ## 2 2 1 1 Allison, Master. Hudson Trevor male ## 3 3 1 0 Allison, Miss. Helen Loraine female ## 4 4 1 0 Allison, Mr. Hudson Joshua Creighton male ## 5 5 1 0 Allison, Mrs. Hudson J C (Bessie Waldo Daniels) female ## 6 6 1 1 Anderson, Mr. Harry male ## age sibsp parch ticket fare cabin embarked ## 1 29.0000 0 0 24160 211.3375 B5 S ## 2 0.9167 1 2 113781 151.5500 C22 C26 S ## 3 2.0000 1 2 113781 151.5500 C22 C26 S ## 4 30.0000 1 2 113781 151.5500 C22 C26 S ## 5 25.0000 1 2 113781 151.5500 C22 C26 S ## 6 48.0000 0 0 19952 26.5500 E12 S ## home.dest ## 1 St Louis, MO ## 2 Montreal, PQ / Chesterville, ON ## 3 Montreal, PQ / Chesterville, ON ## 4 Montreal, PQ / Chesterville, ON ## 5 Montreal, PQ / Chesterville, ON ## 6 New York, NY
tail(titanic)
Producción:
## X pclass survived name sex age sibsp ## 1304 1304 3 0 Yousseff, Mr. Gerious male NA 0 ## 1305 1305 3 0 Zabour, Miss. Hileni female 14.5 1 ## 1306 1306 3 0 Zabour, Miss. Thamine female NA 1 ## 1307 1307 3 0 Zakarian, Mr. Mapriededer male 26.5 0 ## 1308 1308 3 0 Zakarian, Mr. Ortin male 27.0 0 ## 1309 1309 3 0 Zimmerman, Mr. Leo male 29.0 0 ## parch ticket fare cabin embarked home.dest ## 1304 0 2627 14.4583 C ## 1305 0 2665 14.4542 C ## 1306 0 2665 14.4542 C ## 1307 0 2656 7.2250 C ## 1308 0 2670 7.2250 C ## 1309 0 315082 7.8750 S
Desde la salida de cabeza y cola, puede notar que los datos no se mueven. ¡Esta es una gran pregunta! Cuando divide sus datos entre un conjunto de trenes y un conjunto de prueba, selecciona uno el pasajero de la clase 1 y 2 (no hay ningún pasajero de la clase 3 en el 80 por ciento superior de las observaciones), lo que significa que el algoritmo nunca verá las características de los pasajeros de la clase 3. El resultado será una mala predicción.
Para solucionar este problema, puede utilizar el ejemplo funcional ().
shuffle_index <- sample(1:nrow(titanic)) head(shuffle_index)
Código de árbol R Decisión Explicación
- ejemplo (1: nrow (titanic)): Cree una lista aleatoria de índices de 1 a 1309 (es decir, el número máximo de filas).
Producción:
## [1] 288 874 1078 633 887 992
Utilizará este índice para cambiar el conjunto de datos titánico.
titanic <- titanic[shuffle_index, ] head(titanic)
Producción:
## X pclass survived ## 288 288 1 0 ## 874 874 3 0 ## 1078 1078 3 1 ## 633 633 3 0 ## 887 887 3 1 ## 992 992 3 1 ## name sex age ## 288 Sutton, Mr. Frederick male 61 ## 874 Humblen, Mr. Adolf Mathias Nicolai Olsen male 42 ## 1078 O'Driscoll, Miss. Bridget female NA ## 633 Andersson, Mrs. Anders Johan (Alfrida Konstantia Brogren) female 39 ## 887 Jermyn, Miss. Annie female NA ## 992 Mamee, Mr. Hanna male NA ## sibsp parch ticket fare cabin embarked home.dest## 288 0 0 36963 32.3208 D50 S Haddenfield, NJ ## 874 0 0 348121 7.6500 F G63 S ## 1078 0 0 14311 7.7500 Q ## 633 1 5 347082 31.2750 S Sweden Winnipeg, MN ## 887 0 0 14313 7.7500 Q ## 992 0 0 2677 7.2292 C
Paso 2) Conjunto de datos claro
La estructura de los datos muestra que NA tiene varias variables. La limpieza de datos se debe realizar de la siguiente manera
- Presione home.dest variables, cabina, nombre, X y boleto
- Creación de variables factoriales para pclass y supervivencia
- Golpea la ONU
library(dplyr) # Drop variables clean_titanic <- titanic % > % select(-c(home.dest, cabin, name, X, ticket)) % > % #Convert to factor level mutate(pclass = factor(pclass, levels = c(1, 2, 3), labels = c('Upper', 'Middle', 'Lower')), survived = factor(survived, levels = c(0, 1), labels = c('No', 'Yes'))) % > % na.omit() glimpse(clean_titanic)
Explicación del Código
- select (-c (home.dest, cabin, name, X, ticket)): selecciona variables innecesarias
- pclass = factor (pclass, niveles = c (1,2,3), labels = c (‘Superior’, ‘Medio’, ‘Inferior’)): agregue una etiqueta a la variable pclass. 1 se convierte en Superior, 2 en Medio y 3 en Inferior
- factor (superviviente, niveles = c (0,1), etiquetas = c (‘No’, ‘Sí’)): Etiqueta la variable superviviente. 1 Sí No y 2 Sí
- na.omit (): Elimina las observaciones de NA
Producción:
## Observations: 1,045 ## Variables: 8 ## $ pclass <fctr> Upper, Lower, Lower, Upper, Middle, Upper, Middle, U... ## $ survived <fctr> No, No, No, Yes, No, Yes, Yes, No, No, No, No, No, Y... ## $ sex <fctr> male, male, female, female, male, male, female, male... ## $ age <dbl> 61.0, 42.0, 39.0, 49.0, 29.0, 37.0, 20.0, 54.0, 2.0, ... ## $ sibsp <int> 0, 0, 1, 0, 0, 1, 0, 0, 4, 0, 0, 1, 1, 0, 0, 0, 1, 1,... ## $ parch <int> 0, 0, 5, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 2, 0, 4, 0,... ## $ fare <dbl> 32.3208, 7.6500, 31.2750, 25.9292, 10.5000, 52.5542, ... ## $ embarked <fctr> S, S, S, S, S, S, S, S, S, C, S, S, S, Q, C, S, S, C...
Paso 3) Crea un tren / conjunto de prueba
Antes de entrenar su modelo, debe seguir dos pasos:
- Crear un tren y un conjunto de prueba: entrena el modelo en el conjunto de trenes y prueba la predicción en el conjunto de prueba (es decir, datos sin precedentes)
- Instale rpart.plot desde la consola
La práctica común es compartir los datos en un 80/20, el 80 por ciento de los datos sirve para entrenar el modelo y el 20 por ciento para hacer predicciones. Necesita crear dos marcos de datos separados. No desea tocar el conjunto de prueba hasta que haya terminado de construir su modelo. Puede crear un nombre de función create_train_test () que tenga tres argumentos.
create_train_test(df, size = 0.8, train = TRUE) arguments: -df: Dataset used to train the model. -size: Size of the split. By default, 0.8. Numerical value -train: If set to `TRUE`, the function creates the train set, otherwise the test set. Default value sets to `TRUE`. Boolean value.You need to add a Boolean parameter because R does not allow to return two data frames simultaneously.
create_train_test <- function(data, size = 0.8, train = TRUE) { n_row = nrow(data) total_row = size * n_row train_sample < - 1: total_row if (train == TRUE) { return (data[train_sample, ]) } else { return (data[-train_sample, ]) } }
Explicación del Código
- función (datos, tamaño = 0.8, tren = VERDADERO): agregue los argumentos a la función
- n_row = nrow (datos): Número de filas en el conjunto de datos
- total_row = size * n_row: Devuelve la novena fila para construir el tren
- train_sample <- 1: total_row: Seleccione la primera fila a la novena fila
- if (train == TRUE) {} else {}: Si la condición es verdadera, devuelve el conjunto de trenes y el otro conjunto de prueba.
Puede probar su función y comprobar la característica.
data_train <- create_train_test(clean_titanic, 0.8, train = TRUE) data_test <- create_train_test(clean_titanic, 0.8, train = FALSE) dim(data_train)
Producción:
## [1] 836 8
dim(data_test)
Producción:
## [1] 209 8
El conjunto de datos del tren contiene 1046 capas y el conjunto de datos de prueba contiene 262 capas.
Utilice la función prop.table () junto con una tabla () para verificar que el proceso de aleatorización sea correcto.
prop.table(table(data_train$survived))
Producción:
## ## No Yes ## 0.5944976 0.4055024
prop.table(table(data_test$survived))
Producción:
## ## No Yes ## 0.5789474 0.4210526
En ambos conjuntos de datos, el número de supervivientes es el mismo, alrededor del 40 por ciento.
Instalar rpart.plot
rpart.plot no está disponible en las bibliotecas de conda. Puedes instalarlo desde la consola:
install.packages("rpart.plot")
Paso 4) Construye el modelo
Estás listo para construir el modelo. La sintaxis de la función del árbol de decisiones de Rpart es:
rpart(formula, data=, method='') arguments: - formula: The function to predict - data: Specifies the data frame- method: - "class" for a classification tree - "anova" for a regression tree
Utiliza el método de clase porque predice una clase.
library(rpart) library(rpart.plot) fit <- rpart(survived~., data = data_train, method = 'class') rpart.plot(fit, extra = 106
Explicación del Código
- rpart (): Función para adaptar el modelo. Los argumentos son:
- sobrevivir ~: Fórmula de árboles de decisión
- data = data_train: conjunto de datos
- método = ‘clase’: Ajuste de un modelo binario
- rpart.plot (fit, extra = 106): traza el árbol. Las funciones adicionales se establecen en 101 para mostrar la probabilidad de segunda clase (útil para respuestas binarias). Puede consultar el viñeta para obtener más información sobre las alternativas.
Producción:
Comienza en el nodo raíz (profundidad 0 sobre 3, parte superior del gráfico):
- En la parte superior, es la probabilidad general de supervivencia. Muestra el porcentaje del pasajero que sobrevivió al accidente. El 41 por ciento del pasajero sobrevivió.
- Este nodo pregunta si el género del pasajero es masculino. Si es así, entonces baja al nodo secundario izquierdo de la raíz (profundidad 2). El 63 por ciento son hombres con una probabilidad de supervivencia del 21 por ciento.
- En el segundo nodo, se pregunta si el pasajero masculino tiene más de 3,5 años. En caso afirmativo, la probabilidad de supervivencia es del 19 por ciento.
- De modo que sigue comprendiendo los factores que influyen en la probabilidad de su supervivencia.
Tenga en cuenta que una de las muchas cualidades de los árboles de decisión es que requieren poca preparación de datos. En particular, no requieren centralización o escalado de funciones.
Por defecto, la función rpart () usa la Guinea medida de impureza para dividir el billete. Cuanto mayor sea el coeficiente de Gini, más escenarios diferentes en el nodo.
Paso 5) Haz una predicción
Puede predecir su conjunto de datos de prueba. Para hacer una predicción, puede utilizar la función de predicción (). La sintaxis básica predicha para el árbol de decisión R es:
predict(fitted_model, df, type="class") arguments: - fitted_model: This is the object stored after model estimation. - df: Data frame used to make the prediction - type: Type of prediction - 'class': for classification - 'prob': to compute the probability of each class - 'vector': Predict the mean response at the node level
Desea predecir qué pasajeros tienen más probabilidades de sobrevivir a la colisión desde el equipo de prueba. Es decir, sabrá entre esos 209 pasajeros, si sobrevivirá o no.
predict_unseen <-predict(fit, data_test, type="class")
Explicación del Código
- predicción (ajuste, prueba_datos, tipo = «clase»): predice la clase (0/1) del conjunto de prueba
Pruebe el pasajero que no lo hizo y los que sí lo hicieron.
table_mat <- table(data_test$survived, predict_unseen) table_mat
Explicación del Código
- table (data_test $ surviv, predic_unseen): cree una tabla para contar cuántos pasajeros se clasifican como sobrevivientes y murieron en comparación con la clasificación correcta de árboles de decisión en R
Producción:
## predict_unseen ## No Yes ## No 106 15 ## Yes 30 58
El modelo predijo correctamente 106 pasajeros muertos, pero 15 supervivientes fueron clasificados como muertos. Por analogía, el modelo clasificó erróneamente a 30 pasajeros como supervivientes y muertos.
Paso 6) Medición de desempeño
Puede calcular una medida de precisión para la tarea de clasificación con el matriz de confusión:
El es matriz de confusión es una mejor opción para evaluar el desempeño de la clasificación. La idea general es el número de veces que los casos verdaderos se clasifican como falsos.
Cada fila de una matriz representa una confusión de objetivo real y cada columna representa un objetivo previsto. El primer conjunto de esta matriz considera pasajeros muertos (Falsa clase): 106 fueron correctamente clasificados como muertos (Absolutamente negativo), aunque el otro fue clasificado incorrectamente como superviviente (Falso positivo). La segunda fila considera a los sobrevivientes, la clase positiva fue 58 (Muy positivo), y el Absolutamente negativo tenía 30 años.
Usted puede prueba de precisión de la matriz de confusión:
La proporción de verdaderamente positivo y verdaderamente negativo está más allá de la suma de la matriz. Con R, puede codificar de la siguiente manera:
accuracy_Test <- sum(diag(table_mat)) / sum(table_mat)
Explicación del Código
- sum (diagonal (table_mat)): La suma de la diagonal
- sum (table_mat): la suma de la matriz.
Puede imprimir la precisión del conjunto de prueba:
print(paste('Accuracy for test', accuracy_Test))
Producción:
## [1] "Accuracy for test 0.784688995215311"
Tiene una puntuación del 78 por ciento para el conjunto de pruebas. Puede replicar el mismo ejercicio con el conjunto de datos de entrenamiento.
Paso 7) Ajuste de los hiperparámetros
El árbol de decisiones en R tiene varios parámetros que controlan aspectos del ajuste. En la biblioteca del árbol de decisiones de rpart, puede controlar los parámetros utilizando la función rpart.control (). En el siguiente código, ingresa los parámetros que necesitará. Puede consultar el viñeta para otros parámetros.
rpart.control(minsplit = 20, minbucket = round(minsplit/3), maxdepth = 30) Arguments: -minsplit: Set the minimum number of observations in the node before the algorithm perform a split -minbucket: Set the minimum number of observations in the final note i.e. the leaf -maxdepth: Set the maximum depth of any node of the final tree. The root node is treated a depth 0
Continuaremos de la siguiente manera:
- Toma una función para restaurar la precisión
- Sintonice a la profundidad máxima
- Ajuste la cantidad mínima de muestras que debe tener un nodo antes de que pueda dividirse
- Ajuste el número mínimo de muestras que debe tener un nodo hoja
Puede escribir una función para mostrar la precisión. Simplemente devuelva el código que utilizó anteriormente:
- predecir: predic_unseen <- predecir (ajuste, prueba_datos, tipo = "clase")
- Tabla de producción: table_mat <- table (existente data_test $, predic_unseen)
- Precisión calculada: precision_Test <- sum (diag (table_mat)) / sum (table_mat)
accuracy_tune <- function(fit) { predict_unseen <- predict(fit, data_test, type="class") table_mat <- table(data_test$survived, predict_unseen) accuracy_Test <- sum(diag(table_mat)) / sum(table_mat) accuracy_Test }
Puede intentar ajustar los parámetros y ver si puede mejorar el modelo más allá del valor predeterminado. Como recordatorio, debe obtener una precisión superior a 0,78
control <- rpart.control(minsplit = 4, minbucket = round(5 / 3), maxdepth = 3, cp = 0) tune_fit <- rpart(survived~., data = data_train, method = 'class', control = control) accuracy_tune(tune_fit)
Producción:
## [1] 0.7990431
Con el siguiente parámetro:
minsplit = 4 minbucket= round(5/3) maxdepth = 3cp=0
Obtienes un mayor rendimiento que el modelo anterior. ¡Felicidades!
Resumen
Podemos resumir las funciones para entrenar el algoritmo del árbol de decisión en R.
Biblioteca | Propósito | función | sonó | parámetros | detalles |
---|---|---|---|---|---|
rpart | Árbol de clasificación de trenes en R. | rpart () | sonó | fórmula, df, método | |
rpart | Mástil de regresión de tren | rpart () | anova | fórmula, df, método | |
rpart | Trucha los arboles | rpart.plot () | modelo ajustado | ||
bonn | profecía | predicción () | sonó | modelo ajustado, tipo | |
bonn | profecía | predicción () | problema | modelo ajustado, tipo | |
bonn | profecía | predicción () | vector | modelo ajustado, tipo | |
rpart | Parámetros de control | rpart.control () | minplit | Establecer el número mínimo de observaciones en el nodo antes de que el algoritmo se divida | |
minbucket | Establezca el número mínimo de observaciones en la última nota, es decir, la hoja | ||||
máxima profundidad | Establezca la profundidad máxima de un nodo del árbol final. La profundidad 0 se trata en el nodo raíz | ||||
rpart | Modelo de tren con parámetro de control | rpart () | fórmula, df, método, control |
Nota: Entrene el modelo con datos de entrenamiento y pruebe el rendimiento en un conjunto de datos sin precedentes, es decir, un conjunto de prueba.