π¬ Hybrid Attention Mechanisms: Kimi Linear & DeepSeek Sparse
Executive Summary
This note explores the integration of Kimi Delta Attention (KDA) and DeepSeek Sparse Attention (DSA) to create efficient, high-performance Transformer architectures.
- KDA (from Kimi Linear) offers efficiency with fine-grained forgetting.
- DSA (from DeepSeek-V3.2-Exp) enables precise long-context retrieval via sparse access.
1. Kimi Delta Attention (KDA)
Source: Kimi Linear: An Expressive, Efficient Attention Architecture (arXiv:2510.26692)
At a Glance
- Complexity: Linear time & memory.
- Core Innovation: Channel-wise gated DeltaNet with DPLR-style decay.
- Best For: Efficiently compressing global context into a fixed-size state.
Mechanism
- Feature Projection: Input is projected to Query (), Key (), Value (), and a special Forget Gate ().
- Data-Dependent Decay: Unlike standard Linear Attention, KDA computes a channel-wise decay rate based on the input token. This allows the model to selectively βforgetβ irrelevant information per feature dimension.
- Chunkwise Computation: Tokens are processed in chunks (e.g., 64) to leverage GPU Tensor Cores, balancing the sequential nature of RNNs with the parallel efficiency of Transformers.
Mathematical Formulation
Note on
The decay rate is broadcasted across the state matrix. This channel-wise gating is what differentiates KDA from simple DeltaNet or RWKV, allowing for more expressive memory management.
PyTorch Implementation
import torch
import torch.nn as nn
import torch.nn.functional as F
class KimiDeltaAttention(nn.Module):
"""
KDA: Gated DeltaNet with Data-Dependent Decay.
Based on Kimi Linear (arXiv:2510.26692).
"""
def __init__(self, dim: int, num_heads: int = 8):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
# Projections: Q, K, V, and Beta (Decay/Gate)
self.qkv_gate = nn.Linear(dim, 3 * dim + dim)
self.out_proj = nn.Linear(dim, dim)
def forward(self, x: torch.Tensor, state: torch.Tensor = None):
B, N, D = x.shape
H, HD = self.num_heads, self.head_dim
# 1. Project inputs
proj = self.qkv_gate(x)
q, k, v, beta = torch.split(proj, [D, D, D, D], dim=-1)
# 2. Reshape & Activation
q, k, v = [t.view(B, N, H, HD) for t in (q, k, v)]
g = torch.sigmoid(beta.view(B, N, H, HD)) # Decay rate in (0, 1)
# 3. Initialize State (B, H, D, D)
if state is None:
state = torch.zeros(B, H, HD, HD, device=x.device)
outputs = []
# 4. Recurrent Update (Simplified; real impl uses Chunkwise DPLR)
for t in range(N):
q_t = q[:, t]
k_t = k[:, t]
v_t = v[:, t]
g_t = g[:, t].unsqueeze(-1) # Broadcast over last dim
# S_t = S_{t-1} * g_t + K^T * V
state = state * g_t + torch.einsum('bhd,bhm->bhdm', k_t, v_t)
# O_t = Q_t * S_t
out_t = torch.einsum('bhd,bhdm->bhm', q_t, state)
outputs.append(out_t)
output = torch.stack(outputs, dim=1).reshape(B, N, D)
return self.out_proj(output), state2. DeepSeek Sparse Attention (DSA)
Source: DeepSeek-V3.2-Exp Technical Report
At a Glance
- Complexity: Sparse access (Top-K).
- Core Innovation: Query-Dependent βLightning Indexerβ + FlashMLA kernels.
- Best For: Retrieving specific βneedle-in-a-haystackβ details from massive contexts.
Mechanism
- Lightning Indexer: A lightweight, compressed attention branch (often FP8) is used to quickly estimate token relevance.
- Query-Dependent Scoring: Unlike static sparse methods, DSA computes relevance scores dynamically between the current query and all past compressed keys.
- Top-K Selection: The indices of the top- most relevant tokens are selected.
- Sparse Gather: Full-precision KV pairs are fetched only for these selected indices.
Implementation Detail
Real-world DSA relies heavily on RoPE (Rotary Positional Embeddings) in the indexer to maintain relative position awareness, and specialized FlashMLA CUDA kernels to avoid materializing full attention matrices.
Mathematical Formulation
PyTorch Implementation (Conceptual)
class DeepSeekSparseAttention(nn.Module):
"""
DSA: Query-Dependent Lightning Indexer + Sparse Gather.
Ref: DeepSeek-V3.2-Exp
"""
def __init__(self, dim: int, num_heads: int, k_sparse: int = 64):
super().__init__()
self.k_sparse = k_sparse
self.num_heads = num_heads
# Full Precision Projections
self.q_proj = nn.Linear(dim, dim)
self.k_proj = nn.Linear(dim, dim)
self.v_proj = nn.Linear(dim, dim)
# Lightning Indexer (Compressed)
self.idx_dim = dim // 4
self.idx_q = nn.Linear(dim, self.idx_dim)
self.idx_k = nn.Linear(dim, self.idx_dim)
def forward(self, x: torch.Tensor):
B, N, D = x.shape
# 1. Lightning Indexer (Score Computation)
q_idx = self.idx_q(x)
k_idx = self.idx_k(x)
# (Apply RoPE here in real impl)
scores = torch.bmm(q_idx, k_idx.transpose(1, 2))
# 2. Top-K Selection
k = min(self.k_sparse, N)
_, top_indices = torch.topk(scores, k, dim=-1)
# 3. Sparse Attention (Simulated)
# In practice: call FlashMLA.sparse_attention(q, k, v, top_indices)
return x # Placeholder for full sparse op3. Hybrid Integration Strategies
Combining KDAβs global compression with DSAβs precise retrieval offers a βbest of both worldsβ architecture.
Option A: Sequential Stack (Compress β Refine)
Logic: Use KDA layers to maintain a running summary of the context, followed by DSA layers to βzoom inβ on specific details when needed.
class SequentialHybrid(nn.Module):
def __init__(self, dim):
super().__init__()
self.kda = KimiDeltaAttention(dim)
self.dsa = DeepSeekSparseAttention(dim, k_sparse=64)
self.norm = nn.LayerNorm(dim)
def forward(self, x, state=None):
res = x
# Global Compression
x, state = self.kda(x, state=state)
x = self.norm(x + res)
res = x
# Local Refinement
x = self.dsa(x)
x = self.norm(x + res)
return x, stateOption B: Parallel Gated Merge
Logic: Run both branches in parallel and let the model learn when to rely on memory (KDA) vs. retrieval (DSA).
class ParallelHybrid(nn.Module):
def __init__(self, dim):
super().__init__()
self.kda = KimiDeltaAttention(dim)
self.dsa = DeepSeekSparseAttention(dim)
self.gate = nn.Sequential(nn.Linear(dim, 1), nn.Sigmoid())
def forward(self, x, state=None):
out_kda, state = self.kda(x, state=state)
out_dsa = self.dsa(x)
alpha = self.gate(x) # Dynamic weighting
return alpha * out_kda + (1 - alpha) * out_dsa, stateOption C: Layer-wise Interleaving (Recommended)
Logic: Interleave layers in a fixed ratio (e.g., 3 KDA : 1 DSA). This is the approach suggested by the Kimi Linear paper for optimal performance/efficiency balance.
class LayerWiseHybrid(nn.Module):
def __init__(self, dim, depth=12):
super().__init__()
self.layers = nn.ModuleList()
for i in range(depth):
if (i + 1) % 4 == 0:
self.layers.append(DeepSeekSparseAttention(dim, k_sparse=64))
else:
self.layers.append(KimiDeltaAttention(dim))
def forward(self, x, state=None):
for layer in self.layers:
if isinstance(layer, KimiDeltaAttention):
x, state = layer(x, state)
else:
x = layer(x)
return x, stateπ Project Structure
Recommended folder structure for a hybrid implementation:
hybrid_attention/
βββ attention/
β βββ __init__.py
β βββ kda.py # KimiDeltaAttention (DPLR)
β βββ dsa.py # DeepSeekSparseAttention (FlashMLA)
β βββ hybrid.py # Fusion modules
β βββ kernels.py # Triton/CUDA kernels
βββ requirements.txt
βββ README.md