Attention механизм для Data Scientist
Карьерник — Duolingo для аналитиков: 10 минут в день тренируй SQL, Python, A/B, статистику, метрики и ещё 3 темы собеса. 1500+ вопросов в Telegram-боте. Бесплатно.
Содержание:
Зачем attention
Attention — главная инновация, которая позволила Transformer заменить RNN. На собесе DS любого middle+ уровня попросят расписать формулу attention и объяснить интуицию.
Главная боль без понимания attention — кандидат на собесе говорит «modelo focuses на нужные части», но не может сказать, как именно. Recruiter переходит к следующей теме, оффер не выдаётся.
Эта статья — про базовый механизм без воды: формула, размерности, где применяется.
Базовая интуиция
Allow each токен «опросить» все остальные токены и собрать взвешенную информацию. Веса — не статичны, они зависят от текущего токена и выучиваются.
Аналогия: в библиотеке вы ищете книгу. У вас есть вопрос (Query). У каждой книги есть карточка с ключевыми словами (Key). Если ваш запрос совпадает с карточкой — берёте содержимое книги (Value). Финальный ответ — взвешенное сочетание содержимых, где веса — степень совпадения запроса с карточкой каждой книги.
В нейросети — то же:
- Query (Q): что мы спрашиваем (вектор из текущего токена)
- Key (K): что описывает каждый источник
- Value (V): реальное содержимое каждого источника
- Result = взвешенная сумма V с весами от схожести Q-K
Scaled dot-product attention
Стандартная формула из «Attention Is All You Need»:
Attention(Q, K, V) = softmax(Q · K^T / √d_k) · VРазмерности:
- Q: (n_query, d_k) — n_query вопросов размерности d_k
- K: (n_key, d_k) — n_key ключей размерности d_k
- V: (n_key, d_v) — n_key значений размерности d_v
- Output: (n_query, d_v)
Обычно n_query == n_key (self-attention), d_k == d_v == d_model / h (где h — число голов).
Шаги:
Скоры:
S = Q · K^T— каждая пара (Q_i, K_j) даёт scalar = dot product. Размер S: (n_query, n_key).Scaling: делим на
√d_k. Без scaling при больших d_k значения dot product растут, softmax выходит на плато (gradient → 0). Деление на √d_k стабилизирует variance.Softmax по оси key: для каждой query — распределение вероятностей по keys.
Взвешенная сумма V: результирующий вектор для каждой query.
Self-attention: Q, K, V получаются из одной последовательности через linear:
Q = X @ W_Q
K = X @ W_K
V = X @ W_VW_Q, W_K, W_V — обучаемые матрицы.
Masked attention
В decoder при генерации токен на позиции t не должен видеть будущие токены — иначе модель «cheats». Реализуется маской:
mask = torch.tril(torch.ones(n, n)) # нижний треугольник единиц
scores = (Q @ K.T) / sqrt(d_k)
scores = scores.masked_fill(mask == 0, float('-inf'))
attention = softmax(scores)После softmax -inf → 0, верхний треугольник весов = 0.
Padding mask: в батче последовательности разной длины. PAD-токены не должны влиять на attention. Маска зануляет PAD-позиции.
В практике используют комбинированную маску (causal + padding).
Cross-attention
В encoder-decoder архитектурах decoder использует cross-attention к выходу encoder.
- Q — из decoder (текущий генерируемый токен)
- K, V — из encoder (исходное предложение)
Это позволяет генерации «смотреть на всё исходное предложение» при выборе следующего токена.
Translation: "Hello world" → "Привет мир"
При генерации токена "Привет" decoder cross-attends ко всем токенам "Hello world".В decoder-only моделях (GPT) cross-attention отсутствует — есть только self-attention с causal mask.
Сложность и FlashAttention
Стандартный attention имеет сложность O(n²) по длине последовательности:
- Q · K^T создаёт матрицу (n × n)
- Softmax — n операций на каждой строке
- Умножение на V — O(n² · d)
На длинном контексте (32k токенов) attention — основное узкое место по compute и memory.
FlashAttention (Tri Dao, 2022):
- Тот же результат, но через блочное вычисление с tile-by-tile подходом
- Использует SRAM (быстрый кэш GPU) вместо HBM
- В 2–4 раза быстрее, в разы меньше memory
FlashAttention 2 (2023) и 3 (2024) — дальнейшие оптимизации. Стандарт в современном тренинге LLM.
Approximations:
- Linear attention (Performer, Linformer) — O(n) сложность через приближения. Качество обычно хуже full attention.
- Sparse attention (Longformer, BigBird) — каждый токен видит только часть других. Локальные + global tokens.
- MQA / GQA (Multi-Query / Grouped-Query Attention) — общие K/V между головами. Снижает inference cost (KV-cache меньше) при том же quality.
Частые ошибки
Считать softmax(QK)V «средним пуллингом». Это взвешенное взвешивание, веса учатся, разные для каждой query. Не среднее.
Игнорировать √d_k. Без scaling на больших размерностях обучение нестабильно. В формуле обязательно.
Не маскировать PAD. Без padding mask attention учитывает PAD как обычный токен. На длинных батчах = шум.
Применять causal mask в encoder. Encoder должен видеть всё. Causal mask — только в decoder.
Высокая размерность ≠ много голов. d_model = 768, h = 12 → каждая голова работает в d_k = 64. Многие ошибочно думают, что добавление голов добавляет compute.
Считать что attention «interpretable». Веса показывают, на что модель «смотрит», но это не объяснение. Многие исследования показали, что веса не всегда коррелируют с feature importance.
Использовать full attention на 1М токенов. Сложность O(n²) делает это невозможным без оптимизаций. Sparse / sliding window / streaming attention.
Связанные темы
- Transformer на собесе DS
- Embeddings на собесе DS
- Подготовка к собесу Data Scientist
- Loss-функции на собесе DS
- Gradient descent и оптимизаторы
FAQ
Чем attention отличается от FFN?
Attention — взаимодействие между токенами (cross-token mixing). FFN — независимая обработка каждого токена (per-token transformation). Transformer-блок чередует их.
Что такое multi-head attention?
Параллельные attention головы со своими Q/K/V проекциями. Каждая может «специализироваться» на разных паттернах. Compute не растёт, потому что d_head = d_model / h.
Self-attention или cross-attention для классификации?
Классификация — encoder-only, self-attention. Cross-attention нужен в encoder-decoder при генерации.
Что такое sparse attention?
Каждый токен видит только подмножество других (локальное окно + глобальные токены). Снижает O(n²) до O(n log n) или O(n). Используется в Longformer, BigBird для длинного контекста.
MQA vs MHA — что выбрать?
MHA (Multi-Head): каждая голова свой K/V, лучшее качество. MQA: одна K/V на все головы, быстрее inference (меньше KV-cache). GQA — компромисс. В современных LLM (Llama 3) — GQA.
Это официальная информация?
Нет. Статья основана на «Attention Is All You Need» (Vaswani 2017), FlashAttention paper (Dao 2022), документации PyTorch / Hugging Face.
Тренируйте Data Science — откройте тренажёр с 1500+ вопросами для собесов.