Как посчитать confusion matrix в SQL
Содержание:
Зачем confusion matrix
Confusion matrix — фундамент classification metrics. Без неё непонятно, что значит «accuracy 95%»: модель угадывает positive хорошо или просто всегда говорит «negative»? Из 4 чисел (TP/FP/TN/FN) получаются precision, recall, F1, accuracy. В SQL — один запрос с CASE WHEN.
Бинарная матрица
Для бинарного классификатора порог 0.5:
WITH predictions AS (
SELECT
user_id,
actual_label,
CASE WHEN predicted_proba >= 0.5 THEN 1 ELSE 0 END AS predicted_label
FROM model_predictions
WHERE prediction_date >= CURRENT_DATE - INTERVAL '30 days'
)
SELECT
SUM(CASE WHEN actual_label = 1 AND predicted_label = 1 THEN 1 ELSE 0 END) AS tp,
SUM(CASE WHEN actual_label = 0 AND predicted_label = 1 THEN 1 ELSE 0 END) AS fp,
SUM(CASE WHEN actual_label = 0 AND predicted_label = 0 THEN 1 ELSE 0 END) AS tn,
SUM(CASE WHEN actual_label = 1 AND predicted_label = 0 THEN 1 ELSE 0 END) AS fn
FROM predictions;Метрики из матрицы
WITH cm AS (
SELECT
SUM(CASE WHEN actual = 1 AND predicted = 1 THEN 1 ELSE 0 END) AS tp,
SUM(CASE WHEN actual = 0 AND predicted = 1 THEN 1 ELSE 0 END) AS fp,
SUM(CASE WHEN actual = 0 AND predicted = 0 THEN 1 ELSE 0 END) AS tn,
SUM(CASE WHEN actual = 1 AND predicted = 0 THEN 1 ELSE 0 END) AS fn
FROM predictions
)
SELECT
tp, fp, tn, fn,
(tp + tn)::NUMERIC / NULLIF(tp + fp + tn + fn, 0) AS accuracy,
tp::NUMERIC / NULLIF(tp + fp, 0) AS precision,
tp::NUMERIC / NULLIF(tp + fn, 0) AS recall,
2.0 * tp / NULLIF(2.0 * tp + fp + fn, 0) AS f1_score,
tn::NUMERIC / NULLIF(tn + fp, 0) AS specificity
FROM cm;По сегментам
Отдельные матрицы для разных групп (например, по платформе):
SELECT
platform,
SUM(CASE WHEN actual = 1 AND predicted = 1 THEN 1 ELSE 0 END) AS tp,
SUM(CASE WHEN actual = 0 AND predicted = 1 THEN 1 ELSE 0 END) AS fp,
SUM(CASE WHEN actual = 0 AND predicted = 0 THEN 1 ELSE 0 END) AS tn,
SUM(CASE WHEN actual = 1 AND predicted = 0 THEN 1 ELSE 0 END) AS fn,
SUM(CASE WHEN actual = 1 AND predicted = 1 THEN 1 ELSE 0 END)::NUMERIC
/ NULLIF(SUM(CASE WHEN predicted = 1 THEN 1 ELSE 0 END), 0) AS precision_by_segment
FROM predictions
GROUP BY platform
ORDER BY precision_by_segment DESC;Если у iOS precision = 0.92, а у Android = 0.45 — модель смещена в одну сторону.
Multi-class
Pairwise (actual_class × predicted_class):
SELECT
actual_label AS actual,
predicted_label AS predicted,
COUNT(*) AS n
FROM predictions
GROUP BY actual_label, predicted_label
ORDER BY actual_label, predicted_label;Это сразу диагональная таблица. Для per-class precision:
WITH per_class AS (
SELECT
predicted_label AS class,
SUM(CASE WHEN actual_label = predicted_label THEN 1 ELSE 0 END) AS tp,
SUM(CASE WHEN actual_label <> predicted_label THEN 1 ELSE 0 END) AS fp
FROM predictions
GROUP BY predicted_label
)
SELECT
class,
tp,
fp,
tp::NUMERIC / NULLIF(tp + fp, 0) AS precision
FROM per_class;Частые ошибки
Ошибка 1. actual = predicted без оговорки про NULL.
Если в данных есть NULL labels — он не равен ничему. Фильтруйте actual IS NOT NULL.
Ошибка 2. Считать без NULLIF. TP/FP/TN/FN могут быть 0 в edge cases. Деление сломается.
Ошибка 3. Делать confusion matrix для probabilities без threshold. Threshold обязателен. Стандартный 0.5, но можно подбирать по precision-recall trade-off.
Ошибка 4. Игнорировать class imbalance. Если в данных 99% negative, accuracy 99% — это «всегда говори negative». Используйте precision/recall.
Ошибка 5. Сравнивать модели по разным CM. Confusion matrix on training ≠ on test. Сообщайте, какая выборка.
Связанные темы
- Как посчитать precision-recall в SQL
- Как посчитать AUC-ROC в SQL
- Как посчитать F1-score в SQL
- Как посчитать log loss в SQL
FAQ
Какой threshold?
Default 0.5. Подбирают по precision-recall curve в зависимости от cost FP/FN.
Multi-class с N классами — N×N матрица?
Да. Confusion matrix размера N×N.
Precision per class vs macro precision?
Per-class — для каждого класса свой. Macro — среднее по классам. На imbalanced — отличаются.
F1 = 2 × P × R / (P + R) — почему так?
Гармоническое среднее P и R. Штрафует низкий показатель.
Какой порог TP/FP в продукте?
Зависит от cost. Спам-фильтр: FP — плохо (рабочий email в спам). Fraud — FN плохо (мошенник пропущен).