Support bf16/fp16 activations in CPU SDPA (#20611)#20611
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20611
Note: Links to docs will display an error until the docs builds have been completed. ❌ 2 New FailuresAs of commit 6f86ec1 with merge base 8965e51 ( NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@GregoryComer has imported this pull request. If you are a Meta employee, you can view this in D110246161. |
nothing uses bf16 yet, so the slow path harms / regresses anything as no one is using it , right ? |
|
Also maybe we should add more tests in follow-ups, e.g. long sequences, fp16 large-value/overflow cases. |
That's correct. This isn't wired up e2e yet and is just the groundwork for the SDPA. I'll have MKL + Kleidi (or some other sane ARM GEMM) wired up before enabling it e2e. |
Summary: Add support for reduced-precision (bf16 + f16) activations in the LLM SDPA op. Bf16 support is the main goal, but I wired up f16 at the same time, as it's straightforward. This initial implementation just exposes the actual operator-level support and does not yet wire e2e through ET-LLM or anything else. It also initially only supports our simple CPU BLAS implementation. I'll wire up MKL as a follow up. I'm probably going to leverage Kleidi for the ARM GEMM kernels. Apple's Accelerate doesn't support BF16 gemms as far as I can tell. Kleidi will be fast on SME2 (M4/A18+) but won't hit AMX on M1-M3. I think that's fine, though. The i8mm/dot kernels should be sufficient. We can use BNNS if we need to close the gap. ## Perplexity Measured on wikitext, M4 Max, ~40k token sanity check, 1000 token windows | f32 | bf16 | delta | | -- | -- | -- | | 26.453 | 26.671 | +0.22 (+0.82%) | ## Performance It is slow because it uses the fallback GEMM (as described above). Will wire Kleidi + MKL as a follow-up. Nothing uses bf16 SDPA yet, so this is fine. (SDPA op timing - qwen3-1.7b) | Device | Context | Phase | f32 (ms) | bf16 (ms) | bf16/f32 | |--------|--------:|---------|---------:|----------:|---------:| | M4 Max | 128 | Prefill | 1.855 | 10.180 | 5.5× | | M4 Max | 128 | Decode | 0.112 | 0.074 | 0.66× | | M4 Max | 1024 | Prefill | 76.930 | 1254.800 | 16.3× | | M4 Max | 1024 | Decode | 0.800 | 1.201 | 1.5× | | S25 | 128 | Prefill | 3.924 | 43.760 | 11.2× | | S25 | 128 | Decode | 0.245 | 0.329 | 1.3× | | S25 | 1024 | Prefill | 168.340 | 2790.700 | 16.6× | | S25 | 1024 | Decode | 1.614 | 2.698 | 1.7× | Differential Revision: D110246161 Pulled By: GregoryComer
Summary: Add support for reduced-precision (bf16 + f16) activations in the LLM SDPA op. Bf16 support is the main goal, but I wired up f16 at the same time, as it's straightforward. This initial implementation just exposes the actual operator-level support and does not yet wire e2e through ET-LLM or anything else. It also initially only supports our simple CPU BLAS implementation. I'll wire up MKL as a follow up. I'm probably going to leverage Kleidi for the ARM GEMM kernels. Apple's Accelerate doesn't support BF16 gemms as far as I can tell. Kleidi will be fast on SME2 (M4/A18+) but won't hit AMX on M1-M3. I think that's fine, though. The i8mm/dot kernels should be sufficient. We can use BNNS if we need to close the gap. ## Perplexity Measured on wikitext, M4 Max, ~40k token sanity check, 1000 token windows | f32 | bf16 | delta | | -- | -- | -- | | 26.453 | 26.671 | +0.22 (+0.82%) | ## Performance It is slow because it uses the fallback GEMM (as described above). Will wire Kleidi + MKL as a follow-up. Nothing uses bf16 SDPA yet, so this is fine. (SDPA op timing - qwen3-1.7b) | Device | Context | Phase | f32 (ms) | bf16 (ms) | bf16/f32 | |--------|--------:|---------|---------:|----------:|---------:| | M4 Max | 128 | Prefill | 1.855 | 10.180 | 5.5× | | M4 Max | 128 | Decode | 0.112 | 0.074 | 0.66× | | M4 Max | 1024 | Prefill | 76.930 | 1254.800 | 16.3× | | M4 Max | 1024 | Decode | 0.800 | 1.201 | 1.5× | | S25 | 128 | Prefill | 3.924 | 43.760 | 11.2× | | S25 | 128 | Decode | 0.245 | 0.329 | 1.3× | | S25 | 1024 | Prefill | 168.340 | 2790.700 | 16.6× | | S25 | 1024 | Decode | 1.614 | 2.698 | 1.7× | Pull Request resolved: pytorch#20611 Differential Revision: D110246161 Pulled By: GregoryComer
|
@GregoryComer has exported this pull request. If you are a Meta employee, you can view the originating Diff in D110246161. |
Summary:
Add support for reduced-precision (bf16 + f16) activations in the LLM SDPA op. Bf16 support is the main goal, but I wired up f16 at the same time, as it's straightforward.
This initial implementation just exposes the actual operator-level support and does not yet wire e2e through ET-LLM or anything else. It also initially only supports our simple CPU BLAS implementation. I'll wire up MKL as a follow up. I'm probably going to leverage Kleidi for the ARM GEMM kernels. Apple's Accelerate doesn't support BF16 gemms as far as I can tell. Kleidi will be fast on SME2 (M4/A18+) but won't hit AMX on M1-M3. I think that's fine, though. The i8mm/dot kernels should be sufficient. We can use BNNS if we need to close the gap.
Perplexity
Measured on wikitext, M4 Max, ~40k token sanity check, 1000 token windows
Performance
It is slow because it uses the fallback GEMM (as described above). Will wire Kleidi + MKL as a follow-up. Nothing uses bf16 SDPA yet, so this is fine.
(SDPA op timing - qwen3-1.7b)
Differential Revision: D110246161
Pulled By: GregoryComer