Функция дерево решений python

Использование дерева принятия решений с помощью Scikit-Learn в Python

Дерево решений – один из наиболее часто и широко используемых алгоритмов контролируемого машинного обучения, который может выполнять задачи, как регрессии, так и классификации. Интуиция, лежащая в основе алгоритма decision tree, проста, но при этом очень эффективна.

Для каждого атрибута в наборе данных алгоритм дерева решений формирует узел, в котором наиболее важный атрибут помещается в корневой узел. Для оценки мы начинаем с корневого узла и продвигаемся вниз по дереву, следуя за соответствующим узлом, который соответствует нашему условию или «решению». Этот процесс продолжается до тех пор, пока не будет достигнут конечный узел, содержащий прогноз или результат дерева решений.

  1. Этот человек – близкий друг или просто знакомый? Если человек просто знакомый, то отклоните предложение; если человек друг, переходите к следующему шагу.
  2. Человек впервые просит машину? Если да, одолжите ему машину, в противном случае переходите к следующему шагу.
  3. Была ли машина повреждена при последнем возврате машины? Если да, то откажите; если нет, одолжите машину.

Дерево решений для вышеупомянутого сценария выглядит так:

Дерево решений

Преимущества decision tree

  1. decision tree могут использоваться для прогнозирования как непрерывных, так и дискретных значений, т.е. они хорошо работают, как для задач регрессии, так и для классификации.
  2. Для их обучения требуется относительно меньше усилий.
  3. Их можно использовать для классификации нелинейно разделимых данных.
  4. Они очень быстрые и эффективные по сравнению с KNN и другими алгоритмами классификации.

Реализация

В этом разделе мы реализуем алгоритм дерева решений с использованием библиотеки Scikit-Learn в Python. В следующих примерах мы решим как классификационные, так и регрессионные задачи с помощью дерева решений.

Примечание. Задачи классификации и регрессии выполнялись в Jupyter iPython Notebook.

Читайте также:  Next js typescript install

1. Схема принятия решений для классификации

В этом разделе мы предскажем, является ли банкнота подлинной или поддельной, в зависимости от четырех различных атрибутов изображения банкноты. Атрибуты – это дисперсия изображения, преобразованного вейвлетом, кратность изображения, энтропия и асимметрия изображения.

Набор данных

Остальные шаги по реализации этого алгоритма в Scikit-Learn идентичны любой типичной задаче машинного обучения: мы импортируем библиотеки и наборы данных, проведем некоторый анализ данных, разделим данные на наборы для обучения и тестирования, обучим алгоритм, сделаем прогнозы, и, наконец, мы оценим производительность алгоритма на нашем наборе данных.

Импорт библиотек

Следующий скрипт импортирует необходимые библиотеки:

import pandas as pd import numpy as np import matplotlib.pyplot as plt %matplotlib inline
Импорт набора данных

Поскольку наш файл находится в формате CSV, мы будем использовать метод panda read_csv для чтения нашего файла данных CSV. Для этого выполните следующий сценарий:

dataset = pd.read_csv("D:/Datasets/bill_authentication.csv")

В этом случае файл bill_authentication.csv находится в папке «Datasets» на диске «D». Вы должны изменить этот путь в соответствии с настройками вашей собственной системы.

Анализ данных

Выполните следующую команду, чтобы увидеть количество строк и столбцов в нашем наборе данных:

На выходе будет «(1372,5)», что означает, что в нашем наборе данных 1372 записи и 5 атрибутов.

Выполните следующую команду, чтобы проверить первые пять записей набора данных:

Результат будет выглядеть так:

Дисперсия Асимметрия Curtosis Entropy Класс
0 3,62160 8,6661 -2.8073 -0,44699 0
1 4,54590 8,1674 -2,4586 -1,46210 0
2 3,86600 -2,6383 1,9242 0,10645 0
3 3,45660 9,5228 -4,0112 -3,59440 0
4 0,32924 -4,4552 4,5718 -0,98880 0
Подготовка данных

В этом разделе мы разделим наши данные на атрибуты и метки, а затем разделим полученные данные на обучающие и тестовые наборы. Таким образом мы можем обучить наш алгоритм на одном наборе данных, а затем протестировать его на совершенно другом наборе данных, который алгоритм еще не видел. Это дает вам более точное представление о том, как на самом деле будет работать ваш обученный алгоритм.

Читайте также:  Bootstrap css class for links

Чтобы разделить данные на атрибуты и метки, выполните следующий код:

X = dataset.drop('Class', axis=1) y = dataset['Class']

Здесь переменная X содержит все столбцы из набора данных, кроме столбца «Класс», который является меткой. Переменная y содержит значения из столбца «Класс». Переменная X – это наш набор атрибутов, а переменная y содержит соответствующие метки.

Последний шаг предварительной обработки – разделить наши данные на обучающие и тестовые наборы. Библиотека model_selection Scikit-Learn содержит метод train_test_split, который мы будем использовать для случайного разделения данных на наборы для обучения и тестирования. Для этого выполните следующий код:

from sklearn.model_selection import train_test_split X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.20)

В приведенном выше коде параметр test_size указывает соотношение набора тестов, которое мы используем для разделения 20% данных на набор тестов и 80% для обучения.

Обучение и прогнозирование

После того, как данные были разделены на наборы для обучения и тестирования, последний шаг – обучить алгоритм дерева решений на этих данных и сделать прогнозы. Scikit-Learn содержит древовидную библиотеку, которая имеет встроенные классы и методы для различных алгоритмов. Поскольку здесь мы собираемся выполнить задачу классификации, мы будем использовать класс DecisionTreeClassifier для этого примера. Метод подгонки этого класса вызывается для обучения алгоритма на обучающих данных, которые передаются в качестве параметра методу подгонки. Выполните следующий скрипт для обучения алгоритма:

from sklearn.tree import DecisionTreeClassifier classifier = DecisionTreeClassifier() classifier.fit(X_train, y_train)

Теперь, когда наш классификатор обучен, давайте сделаем прогнозы на основе тестовых данных. Для прогнозирования используется метод прогнозирования класса DecisionTreeClassifier. Взгляните на следующий код для использования:

y_pred = classifier.predict(X_test)
Оценка алгоритма

На этом этапе мы обучили наш алгоритм и сделали некоторые прогнозы. Теперь посмотрим, насколько точен наш алгоритм. Для задач классификации часто используются такие показатели, как матрица неточностей, точность, отзыв и оценка F1. К счастью для нас, библиотека метрик Scikit = -Learn содержит методы классификации_report и confusion_matrix, которые можно использовать для расчета этих показателей для нас:

from sklearn.metrics import classification_report, confusion_matrix print(confusion_matrix(y_test, y_pred)) print(classification_report(y_test, y_pred))

Это даст следующую оценку:

[[142 2] 2 129]] precision recall f1-score support 0 0.99 0.99 0.99 144 1 0.98 0.98 0.98 131 avg / total 0.99 0.99 0.99 275

Из матрицы неточностей видно, что из 275 тестовых примеров наш алгоритм неправильно классифицировал только 4. Это точность 98,5%. Не плохо!

Читайте также:  Версии php которые поддерживаются

2. Дерево решений для регрессии

Процесс решения проблемы регрессии с помощью дерева решений с использованием Scikit Learn очень похож на процесс классификации. Однако для регрессии мы используем класс DecisionTreeRegressor древовидной библиотеки. Матрицы оценки регрессии также отличаются от матриц классификации. В остальном процесс почти такой же.

Набор данных

Мы будем использовать этот набор данных, чтобы попытаться спрогнозировать потребление газа (в миллионах галлонов) в 48 штатах США на основе налога на газ (в центах), дохода на душу населения (в долларах), асфальтированных дорог (в милях) и доли населения с водительское удостоверение.

Первые два столбца в приведенном выше наборе данных не предоставляют никакой полезной информации, поэтому они были удалены из файла набора данных.

Теперь давайте применим наш алгоритм дерева решений к этим данным, чтобы попытаться предсказать потребление газа на основе этих данных.

Импорт библиотек
import pandas as pd import numpy as np import matplotlib.pyplot as plt %matplotlib inline
Импорт набора данных
dataset = pd.read_csv('D:\Datasets\petrol_consumption.csv')
Анализ данных

Мы снова будем использовать функцию head dataframe, чтобы увидеть, как на самом деле выглядят наши данные:

Источник

Оцените статью