線形アテンション最前線を網羅!PyTorch+Tritonで動くFlash Linear Attention完全ガイド
線形アテンション最前線を網羅!PyTorch+Tritonで動くFlash Linear Attention完全ガイド
ひとことでいうと
Flash Linear Attention(通称 fla)は、大規模言語モデルの「注意機構(アテンション)」を高速・省メモリで置き換えるための Python ライブラリです。GLA・DeltaNet・RWKV・Mamba2 など、2026 年 4 月時点で 30 を超えるモデルアーキテクチャが一つのパッケージにまとまっています。実装はすべて PyTorch と Triton(GPU カーネルを書くための専用言語)だけで書かれており、NVIDIA・AMD・Intel の GPU で動作します。長文処理や学習コストの削減を目指す研究者・エンジニアに向けたライブラリです。
こんな人におすすめ
1. 長い文章や長い系列データを扱いたい研究者・エンジニア
通常の Transformer が使うソフトマックスアテンションは、文章が長くなるほどメモリが「2 乗」で膨らみます。数万トークン(単語を細かく分割した単位)を超える処理が必要なとき、線形アテンションは強力な代替手段になります。fla には最先端のカーネル実装がそろっているため、すぐに試せます。
2. LLM(大規模言語モデル)の学習コストを下げたいエンジニア
fla には FusedCrossEntropy・FusedRMSNorm・FusedSwiGLU など、ピークメモリと学習スループットを改善する融合モジュールが含まれています。付属の学習フレームワーク flame(torchtitan ベース)と組み合わせれば、最小限の設定で分散学習パイプラインを構築できます。
3. 標準アテンションと線形アテンションを混ぜたハイブリッドモデルを作りたい人
設定ファイルに数行書くだけで、標準アテンション層と線形アテンション層を同じモデル内に混在させられます。Samba や Mamba など SSM(状態空間モデル)との組み合わせ実験も手軽に行えます。
インストール・使い方
動作要件
fla を動かすには以下の環境が必要です。GPU が必須で、CPU のみの環境では動作しません。
- Python 3.x(動作確認: 3.12.13)
- PyTorch 2.7.0 以上
- Triton 3.3 以上(または nightly 版)
- einops
- transformers 4.45.0 以上
- datasets 3.3.0 以上
- GPU(NVIDIA / AMD / Intel のいずれか)
Step 1: PyPI からインストールする
pip install flash-linear-attention
pip は Python のパッケージ管理ツールです。ターミナル(文字を入力してコンピュータに命令を送る画面)に上のコマンドをコピー&ペーストするだけでインストールできます。
Step 2: 最新版をソースからインストールする(推奨)
pip uninstall fla-core flash-linear-attention -y
pip install -U git+https://github.com/fla-org/flash-linear-attention
開発が活発なライブラリなので、最新の修正やモデルを使いたい場合はこちらが推奨です。既存バージョンを一度削除してから、GitHub(ソースコードの置き場)から直接インストールしています。
Step 3: 自分のプロジェクトにサブモジュールとして組み込む場合
git submodule add https://github.com/fla-org/flash-linear-attention.git 3rdparty/flash-linear-attention
ln -s 3rdparty/flash-linear-attention/fla fla
Git(バージョン管理ツール)のサブモジュール機能を使い、自分のリポジトリ(ソースコードの置き場)の一部として管理する方法です。チーム開発や再現性が求められる研究環境で役立ちます。
Step 4: AMD GPU・Intel GPU を使う場合の追加設定
AMD GPU を使う場合は Triton ROCm バックエンドを、Intel GPU の場合は Triton XPU バックエンドを別途インストールする必要があります。詳しい手順はリポジトリ内の FAQs.md に記載されています。
パッケージの構成について
v0.3.2 からパッケージは 2 つに分かれました。用途に合わせて使い分けることで、インストールサイズを最適化できます。
fla-core: カーネル本体だけを含む軽量パッケージ。PyTorch・Triton・einops のみに依存するため、transformers が不要な場面で使えます。flash-linear-attention:fla/layersやfla/modelsを含む拡張パッケージ。transformers に依存し、モデル定義や HuggingFace との連携も含まれます。
動かしてみた
公式ベンチマーク(NVIDIA GB200・CUDA 12.9・PyTorch 2.9.0 環境)によると、chunk_gla の forward+backward 実行時間は FlashAttention2 と比べて最大約 5 倍高速という結果が示されています。バッチサイズ 1・シーケンス長 8192・ヘッド数 96・次元 128 の条件で、chunk_gla が 7.67ms に対し flash_attn は 15.37ms でした。
公式 CI(継続的インテグレーション=自動テストの仕組み)は NVIDIA 4090・A100・H100 および Intel B580 上で継続的に実行されており、各プラットフォームでの動作が定期的に確認されています。
GPU 環境が用意できれば、次のコードで GLA モデルを使ったテキスト生成をすぐに試せます。
import fla
from transformers import AutoModelForCausalLM, AutoTokenizer
name = 'fla-hub/gla-1.3B-100B'
tokenizer = AutoTokenizer.from_pretrained(name)
model = AutoModelForCausalLM.from_pretrained(name).cuda()
input_ids = tokenizer("Hello, FLA!", return_tensors="pt").input_ids.cuda()
outputs = model.generate(input_ids, max_length=64)
print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
fla-hub/gla-1.3B-100B は HuggingFace Hub(学習済みモデルの公開サービス)で公開されている 13 億パラメータの GLA モデルです。.cuda() はモデルや入力データを GPU に転送する命令です。
試す前に知っておくとよいこと
- GPU が必須です。CPU のみの環境では動作しません。
- Triton カーネルは初回実行時に JIT コンパイル(その場でコードを機械語に変換する処理)が走るため、最初の 1 回だけ時間がかかります。
- PyTorch 2.7.0 以上・Triton 3.3 以上のバージョンを確認してからインストールしてください。
- 学習済みモデルのダウンロードにはネットワーク接続と十分なディスク空き容量が必要です。
デモについて
本ライブラリは Triton カーネルの JIT コンパイルと GPU ハードウェアが前提のため、CPU のみの環境やブラウザ上でのインタラクティブデモには適していません。現時点では Gradio などを使ったブラウザ上のデモは提供されていません。
ただし、HuggingFace Hub の fla-hub には複数の学習済みモデルが公開されています。GPU 環境があれば、上記の最小コードをそのままコピーして実行できます。試したいモデルが他にあれば、コード中の name の部分を fla-hub のモデル一覧から選んで書き換えるだけで切り替えられます。
実践:はじめの一歩
fla を初めて使うときに押さえておきたいポイントをまとめました。
- まず
MultiScaleRetentionを試してみる:fla.layersに含まれるシンプルな層の一つです。既存の Transformer の注意機構をそのまま差し替えられる「ドロップイン実装」として設計されています。
import torch
from fla.layers import MultiScaleRetention
batch_size, num_heads, seq_len, hidden_size = 32, 4, 2048, 1024
device, dtype = 'cuda:0', torch.bfloat16
retnet = MultiScaleRetention(
hidden_size=hidden_size,
num_heads=num_heads
).to(device=device, dtype=dtype)
x = torch.randn(batch_size, seq_len, hidden_size).to(device=device, dtype=dtype)
y, *_ = retnet(x)
print(y.shape) # torch.Size([32, 2048, 1024])
bfloat16(16 ビット浮動小数点型)を指定することで、精度をほぼ維持しながらメモリ消費を半減できます。
-
ハイブリッドモデルを構成するには: モデルコンフィグの
attnパラメータに、挿入したい層番号・ヘッド数・ウィンドウサイズを辞書(キーと値のペア)として渡すだけです。標準アテンションと線形アテンションが混在したモデルが自動的に構成されます。 -
HuggingFace の生成 API をそのまま使える: 推論時に特別な設定変更は不要です。
model.generate()やpipeline()など、使い慣れた HuggingFace の API をそのまま利用できます。 -
flameでスムーズに学習を始める:fla付属の学習フレームワークflameは torchtitan をベースにしており、分散学習やコンテキスト並列(長い文脈を複数 GPU で分割処理する手法)に対応しています。設定ファイルを最小限に抑えながら本格的な学習が行えます。 -
AMD・Intel GPU でも動く: NVIDIA に限らず AMD ROCm・Intel XPU バックエンドにも対応しているため、手元にある GPU の種類に合わせて環境を選べます。
活用アイデア
- 超長文脈テキスト処理: 法的文書・学術論文・コードベース全体など数万〜数十万トークンに及ぶ入力に線形アテンションを適用し、メモリを抑えながら処理できます。
- エッジデバイス向け軽量 LLM 開発: HGRN2 や Rodimus* など状態圧縮の強いモデルを選ぶことで、同等品質でパラメータ数と推論メモリを削減した軽量モデルを構築できます。
- カスタム分散学習パイプラインの構築:
flameとflaを組み合わせ、コンテキスト並列対応の学習パイプラインを最小コードで立ち上げられます。大規模クラスター環境での研究開発に適しています。 - 新しい注意機構のプロトタイピング:
flaの Triton カーネル実装をテンプレートとして参照することで、独自の注意機構を効率的に GPU 実装できます。アイデアを素早く形にしたい研究者に向いています。 - 既存 Transformer モデルのアーキテクチャ置換実験:
fla.layersのドロップイン実装を使い、既存モデルの注意層を線形アテンション層に置き換えて性能比較する実験が手軽に行えます。 - 教育・学習目的での実装参照: PyTorch と Triton だけで書かれたクリーンな実装は、線形アテンションや SSM の仕組みを学ぶ教材としても活用できます。
用語とポイント解説
線形アテンション(Linear Attention) 通常のアテンション(注意機構)は計算量がシーケンス長の 2 乗で増えますが、線形アテンションはカーネル関数(特定の数式による近似)を使って線形の計算量に抑える手法です。文章が 2 倍になっても計算量が 4 倍にならない点が大きな特徴です。かんたんに言うと、長い文章でも計算が爆発しないよう工夫した注意機構の亜種です。
SSM(状態空間モデル / State Space Model) 入力の情報を「状態」として圧縮しながら順番に処理する系列モデルの一種です。Mamba や RWKV が代表例で、過去の長い情報を固定サイズの状態変数に要約して保持します。かんたんに言うと、「要約しながら進む、記憶容量に制限のある処理モデル」です。
Triton NVIDIA が開発した、GPU カーネル(GPU で直接動く計算処理の単位)を記述するための専用言語(DSL)です。CUDA より高い抽象度で書けるため、Python に慣れた開発者でも GPU カーネルを実装しやすくなっています。かんたんに言うと、「GPU の速さを引き出すための高レベルなプログラミング言語」です。
チャンクアルゴリズム(Chunk Algorithm)
長い系列を小さな「チャンク(塊)」に分割し、並列処理することで線形アテンションを高速化する手法です。chunk_gla などの関数名に反映されています。かんたんに言うと、「長い文章を小分けにして同時進行で処理するテクニック」です。
ハイブリッドモデル(Hybrid Model) 標準のアテンション層と線形アテンション層・SSM 層を同じモデルの中に混在させたアーキテクチャです。それぞれの長所を組み合わせることで、精度と効率のバランスを取れます。かんたんに言うと、「普通の注意機構と省エネ注意機構を交互に組み合わせた設計のモデル」です。
fla-core
flash-linear-attention パッケージからカーネル本体だけを切り出した軽量パッケージです。PyTorch・Triton・einops のみに依存しており、transformers を必要としません。かんたんに言うと、「重い依存関係なしでカーネルだけ使いたい人向けの最小構成版」です。
JIT コンパイル(Just-In-Time Compilation) プログラムの実行時に必要な部分だけをその場で機械語(CPU や GPU が直接理解できる命令)に変換する仕組みです。Triton カーネルは初回実行時にこの処理が走るため、最初の実行だけ時間がかかります。かんたんに言うと、「最初の起動時に速さのための準備をする仕組み」です。
コンテキスト並列(Context Parallel)
非常に長い文脈(コンテキスト)を複数の GPU に分割して同時処理する並列化手法です。flame が対応しており、単一 GPU では扱いきれないほど長い系列の学習を可能にします。かんたんに言うと、「長すぎる文章を複数の GPU で手分けして処理する方法」です。
GLA(Gated Linear Attention)
ゲート機構(情報の通過量を調整する仕組み)を組み合わせた線形アテンションの一種です。fla-hub/gla-1.3B-100B として HuggingFace Hub 上に学習済みモデルが公開されており、すぐに推論を試せます。かんたんに言うと、「フィルター付きの省エネ注意機構」です。
Flash Linear Attention は、長文処理・学習コスト削減・新アーキテクチャの実験など、現代の LLM 開発における多くの課題を一つのライブラリで解決できる強力なツールです。ぜひ超長文脈テキストの処理や軽量 LLM の開発、そして新しい注意機構のプロトタイピングなどに活用してみてはいかがでしょうか。