Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

Наивный Байес

Наивный Байес использует теорему Байеса: P(yx)=P(xy)P(y)P(x)P(y|x) = \frac{P(x|y)P(y)}{P(x)}. Предполагается условная независимость признаков, поэтому P(xy)=j=1pP(xjy)P(x|y)=\prod_{j=1}^{p} P(x_j|y). В случае GaussianNB каждый признак моделируется нормальным распределением N(μy,j,σy,j2)\mathcal{N}(\mu_{y,j}, \sigma_{y,j}^2). На практике важен гиперпараметр var_smoothing: он добавляет маленькую стабилизирующую константу к оценкам дисперсий и может заметно улучшить качество на более сложных данных.

Используемые библиотеки

Используем numpy, pandas, seaborn, matplotlib. Из sklearn — датасет load_digits, train_test_split, StratifiedKFold, GridSearchCV, Pipeline, GaussianNB и метрики классификации.

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split, StratifiedKFold, GridSearchCV
from sklearn.pipeline import Pipeline
from sklearn.naive_bayes import GaussianNB
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
sns.set_theme(style="whitegrid", palette="deep")
plt.rcParams["figure.dpi"] = 120
PRIMARY_COLOR = "#1f77b4"
HEATMAP_CMAP = "coolwarm"
CONFUSION_CMAP = "Blues"

Датасет: описание и частичная распечатка

Датасет Digits из sklearn заметно сложнее, чем Iris: 1797 изображений рукописных цифр от 0 до 9, каждое изображение представлено матрицей 8×8 пикселей (64 числовых признака). Это уже 10-классовая задача классификации с более высокой размерностью и похожими между собой классами, поэтому она лучше подходит для демонстрации ограничений и сильных сторон наивного Байеса.

digits = load_digits(as_frame=True)
data = digits.frame
print(f"Размерность: {data.shape}")
data.head()
Размерность: (1797, 65)
Loading...
fig, axes = plt.subplots(2, 5, figsize=(8, 4))
for ax, image, label in zip(axes.ravel(), digits.images[:10], digits.target[:10]):
    ax.imshow(image, cmap="gray_r")
    ax.set_title(f"Цифра {label}")
    ax.axis("off")
plt.suptitle("Примеры изображений из Digits")
plt.tight_layout()
plt.show()
<Figure size 960x480 with 10 Axes>

Предварительная обработка

Отделяем признаки от целевой переменной и разбиваем выборку на train/test (80/20) со стратификацией. Масштабирование для GaussianNB не обязательно: модель отдельно оценивает средние и дисперсии признаков внутри каждого класса. Подбор гиперпараметров будем выполнять только на обучающей части, чтобы не допускать утечки данных.

features = data.drop(columns=["target"])
target = data["target"]

X_train, X_test, y_train, y_test = train_test_split(
    features, target, test_size=0.2, random_state=42, stratify=target
)

Тепловая карта корреляций

Полная корреляционная матрица для 64 пикселей была бы слишком громоздкой, поэтому возьмём первые 16 признаков. Уже здесь видно, что пиксели не являются независимыми — это нарушает базовое предположение наивного Байеса, но модель всё равно часто работает удивительно хорошо на практике.

plt.figure(figsize=(7, 5))
subset_cols = features.columns[:16]
correlation = data[list(subset_cols) + ["target"]].corr()
sns.heatmap(correlation, annot=False, cmap=HEATMAP_CMAP, linewidths=0.3)
plt.title("Корреляции (подмножество пикселей)")
plt.tight_layout()
plt.show()
<Figure size 840x600 with 2 Axes>

Подбор гиперпараметров

У GaussianNB мало настраиваемых параметров, но var_smoothing часто влияет на устойчивость модели. Проверим логарифмическую сетку значений от 10-12 до 100 на кросс-валидации StratifiedKFold. Оптимизируем метрику f1_macro, чтобы качество оценивалось равномерно по всем 10 классам.

cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)

nb_pipeline = Pipeline(
    steps=[
        ("nb", GaussianNB()),
    ]
)

param_grid = {
    "nb__var_smoothing": np.logspace(-12, 0, 13),
}

grid = GridSearchCV(
    nb_pipeline,
    param_grid=param_grid,
    cv=cv,
    scoring="f1_macro",
    n_jobs=-1,
    refit=True,
)

grid.fit(X_train, y_train)

cv_results = (
    pd.DataFrame(grid.cv_results_)[["param_nb__var_smoothing", "mean_test_score", "std_test_score"]]
    .rename(columns={"param_nb__var_smoothing": "var_smoothing"})
)
cv_results["var_smoothing"] = cv_results["var_smoothing"].astype(float)
cv_results = cv_results.sort_values("var_smoothing")

print("Лучшие параметры по CV")
print(grid.best_params_)
print(f"Best CV f1_macro: {grid.best_score_:.3f}")
cv_results
Лучшие параметры по CV
{'nb__var_smoothing': np.float64(0.1)}
Best CV f1_macro: 0.923
Loading...
plt.figure(figsize=(7, 4))
plt.semilogx(cv_results["var_smoothing"], cv_results["mean_test_score"], marker="o", color=PRIMARY_COLOR)
plt.fill_between(
    cv_results["var_smoothing"],
    cv_results["mean_test_score"] - cv_results["std_test_score"],
    cv_results["mean_test_score"] + cv_results["std_test_score"],
    color=PRIMARY_COLOR,
    alpha=0.15,
)
plt.xlabel("var_smoothing")
plt.ylabel("Средний f1_macro (CV=5)")
plt.title("Подбор var_smoothing для GaussianNB")
plt.tight_layout()
plt.show()
<Figure size 840x480 with 1 Axes>

Обучение модели

Используем лучшую конфигурацию, найденную на кросс-валидации, и обучаем её на всей обучающей выборке.

best_model = grid.best_estimator_
best_model.fit(X_train, y_train)
model = best_model

Прогнозы модели

На тестовой выборке посмотрим несколько метрик:

  • Accuracy — общая доля верных предсказаний;

  • Precision/Recall/F1 weighted — учитывают размеры классов;

  • F1 macro — особенно полезна здесь, потому что мы подбирали параметры по равномерному качеству на всех цифрах.

y_pred = model.predict(X_test)

print("Gaussian Naive Bayes Metrics")
print(f"Accuracy:    {accuracy_score(y_test, y_pred):.3f}")
print(f"Precision:   {precision_score(y_test, y_pred, average='weighted', zero_division=0):.3f}")
print(f"Recall:      {recall_score(y_test, y_pred, average='weighted', zero_division=0):.3f}")
print(f"F1 weighted: {f1_score(y_test, y_pred, average='weighted', zero_division=0):.3f}")
print(f"F1 macro:    {f1_score(y_test, y_pred, average='macro', zero_division=0):.3f}")
Gaussian Naive Bayes Metrics
Accuracy:    0.922
Precision:   0.923
Recall:      0.922
F1 weighted: 0.922
F1 macro:    0.921

Графики выходных результатов

График 1. Матрица ошибок 10×10. На диагонали — верные предсказания. Вне диагонали будут пары цифр, которые модель чаще путает из-за визуального сходства.

График 2. Recall по классам. Это доля объектов каждого истинного класса, распознанных правильно. График помогает быстро увидеть, какие цифры особенно сложны для GaussianNB.

cm = confusion_matrix(y_test, y_pred)
class_names = [str(name) for name in digits.target_names]

plt.figure(figsize=(7, 5))
sns.heatmap(cm, annot=False, cmap=CONFUSION_CMAP, xticklabels=class_names, yticklabels=class_names)
plt.title("Матрица ошибок")
plt.xlabel("Прогноз")
plt.ylabel("Факт")
plt.tight_layout()
plt.show()
<Figure size 840x600 with 2 Axes>
per_class_recall = []
for cls in np.unique(y_test):
    mask = y_test == cls
    per_class_recall.append(np.mean(y_pred[mask] == y_test[mask]))

plt.figure(figsize=(8, 4))
plt.bar(class_names, per_class_recall, color=PRIMARY_COLOR)
plt.ylim(0, 1)
plt.ylabel("Recall по классу")
plt.xlabel("Класс")
plt.title("Recall по классам")
plt.tight_layout()
plt.show()
<Figure size 960x480 with 1 Axes>