混合注意力
简要概述:对 PyTorch 和 Triton 的内部进行了分支修改。将注意力机制改为线性第一层、中间的二次层和最后的线性层。推理速度显著提升,测试中的困惑度影响较小。
完整注意力 O(n²):17.96秒 / 5.6 个标记/秒
混合注意力 O(n·W + n·D):0.35秒 / 286.6 个标记/秒
我正在用 PyTorch 从零开始构建一个小型 Rust 语言模型。这不是微调,而是从随机初始化开始的字节级训练,训练数据来自一个以 Rust 为主的语料库,链接在此:https://codeberg.org/JohannaJuntos/Sisyphus
模型和训练设置
该模型有 2560 万个参数,上下文长度为 512。它使用 256 的字节级词汇,包含 8 层、8 个头和 512 维的嵌入。位置嵌入是学习得来的,嵌入和语言模型头的权重是绑定在一起的。
训练在一个 173.5M 字节的 Rust 语料库上进行了 30,000 步,使用单个 RTX 4060 Ti 8GB 显卡。最终指标为训练损失 0.5834,验证损失 0.8217,困惑度 2.15。最佳验证损失出现在大约第 18,500 步,这表明可能存在一些晚期过拟合或停滞现象。
架构
该模型是一个 GPT 风格的解码器,但在每一层中用混合注意力块替代了标准的完整注意力。这种设计结合了局部窗口因果注意力和类似 GRU 的递归状态路径,以及一个学习到的门控机制,用于混合这两者。
局部路径处理短程语法,而递归路径则携带压缩的长程状态。门控偏置在训练早期被初始化为偏向局部注意力。
推理使用 Triton 内核和自定义的 torch.library 操作。
语料库
最大的提升来自于语料库的扩展。
运行开始时约有 31MB 的数据,来自 Rust 官方源和主要项目,如 rustc、cargo、rust analyzer、tokio、serde、ripgrep、clap 和 axum。通过克隆前 500 个 crate,语料库扩展到 173.5M 字节,成功克隆 461 个。
这种扩展的影响超过了任何架构上的变化。
推理性能
完整注意力的运行速度约为每秒 5.6 个标记,而使用 KV 缓存的混合注意力达到了每秒 286.6 个标记。这是约 51 倍的加速,且没有明显的质量损失。
KV 缓存使用 64 个标记的热窗口存储在显存中,而旧的标记则被压缩为 8 位的幅度和角度,并可以选择性地提升回全精度。这使得该设置的有效复杂度从二次降低到接近线性。
质量
表面的 Rust 语法看起来不错,导入和函数签名通常是合理的。语义仍然较弱,重复和递归模式较为常见。它看起来像 Rust,但尚未具备良好的推理能力。
有趣之处
该项目结合了从零开始的字节级 Rust 预训练、混合局部注意力和递归架构、跨 Rust 生态系统的大规模语料库扩展,以及一种实用的 KV 缓存分页策略,在消费级 GPU 上实现了显著的加速。
下一步
我计划进行消融实验,比较混合注意力与仅局部和仅递归变体,评估大约 18,500 步的检查点与最终模型,并增加语法级验证,如解析和编译生成的代码。我还想探索将上下文长度从 256 扩展到 2048,并测试在语料库增大后,从字节级切换到 BPE 是否变得有意义。
问题
对于小型代码模型,除了困惑度之外,哪些评估最有用?
有没有人看到混合局部加递归注意力在代码生成中表现良好?
考虑到这个设置,你会优先考虑更多的标记、更长的上下文,还是首先进行干净的消融实验?
查看原文
TLDR: Forked pytorch and triton internals . Changed attention so its linear first layer , middle quadratic layer, last linear layer
Inference got much faster with a low perplexity hit in tests .<p>Full attention O(n²): 17.96s / 5.6 tok/s<p>HybridAttention O(n·W + n·D): 0.35s / 286.6 tok/s<p>I have been building a small Rust focused language model from scratch in PyTorch. This is not a finetune. It is byte level, trained from random initialization on a Rust heavy corpus assembled here: https://codeberg.org/JohannaJuntos/Sisyphus<p>Model and training setup<p>The model has 25.6M parameters with a 512 context length. It uses a byte level vocabulary of 256, with 8 layers, 8 heads, and 512 dimensional embeddings. Positional embeddings are learned and the embedding and LM head weights are tied.<p>Training ran for 30k steps on a 173.5M byte Rust corpus using a single RTX 4060 Ti 8GB.<p>Final metrics were a train loss of 0.5834, validation loss of 0.8217, and perplexity of 2.15. The best validation loss occurred around step 18.5k, which suggests some late overfitting or plateau.<p>Architecture<p>The model is a GPT style decoder, but replaces standard full attention with a HybridAttention block in each layer. This combines local windowed causal attention with a GRU like recurrent state path, along with a learned gate that mixes the two.<p>The local path handles short range syntax, while the recurrent path carries compressed long range state. The gate bias is initialized to favor local attention early in training.<p>Inference uses Triton kernels and custom torch.library ops.<p>Corpus<p>The biggest gain came from corpus expansion.<p>The run started with about 31MB from Rust official sources and major projects such as rustc, cargo, rust analyzer, tokio, serde, ripgrep, clap, and axum. The corpus was expanded to 173.5M bytes by cloning the top 500 crates, with 461 successful clones.<p>This expansion had more impact than any architectural change.<p>Inference performance<p>Full attention runs at about 5.6 tokens per second, while HybridAttention with KV cache reaches 286.6 tokens per second. This is about a 51x speedup with no visible quality loss.<p>The KV cache uses a hot window of 64 tokens in VRAM, while older tokens are compressed to 8 bit magnitude and angle and can be selectively promoted back to full precision. This changes the effective complexity from quadratic to near linear for this setup.<p>Quality<p>Surface Rust syntax looks decent, and imports and function signatures are often plausible. Semantics are still weak, and repetition and recursive patterns are common. It looks like Rust, but does not reason well yet.<p>What seems interesting<p>This project combines byte level Rust only pretraining from scratch, a hybrid local attention and recurrent architecture, large scale corpus expansion across the Rust ecosystem, and a practical KV cache paging strategy that delivers large speedups on consumer GPUs.<p>Next steps<p>I plan to run ablations comparing hybrid attention against local only and recurrent only variants, evaluate checkpoints around 18.5k versus the final model, and add syntax level validation such as parsing and compiling generated code. I also want to explore scaling context length from 256 up to 2048 and test whether switching from byte level to BPE becomes worthwhile now that the corpus is larger.<p>Questions<p>For small code models, which evaluations have been most useful beyond perplexity?<p>Has anyone seen hybrid local plus recurrent attention work well for code generation?<p>Given this setup, would you prioritize more tokens, longer context, or clean ablations first?