Skip to content

Support bf16/fp16 activations in CPU SDPA (#20611)#20611

Open
GregoryComer wants to merge 1 commit into
pytorch:mainfrom
GregoryComer:sdpa-half
Open

Support bf16/fp16 activations in CPU SDPA (#20611)#20611
GregoryComer wants to merge 1 commit into
pytorch:mainfrom
GregoryComer:sdpa-half

Conversation

@GregoryComer

@GregoryComer GregoryComer commented Jun 29, 2026

Copy link
Copy Markdown
Member

Summary:
Add support for reduced-precision (bf16 + f16) activations in the LLM SDPA op. Bf16 support is the main goal, but I wired up f16 at the same time, as it's straightforward.

This initial implementation just exposes the actual operator-level support and does not yet wire e2e through ET-LLM or anything else. It also initially only supports our simple CPU BLAS implementation. I'll wire up MKL as a follow up. I'm probably going to leverage Kleidi for the ARM GEMM kernels. Apple's Accelerate doesn't support BF16 gemms as far as I can tell. Kleidi will be fast on SME2 (M4/A18+) but won't hit AMX on M1-M3. I think that's fine, though. The i8mm/dot kernels should be sufficient. We can use BNNS if we need to close the gap.

Perplexity

Measured on wikitext, M4 Max, ~40k token sanity check, 1000 token windows

f32 bf16 delta
26.453 26.671 +0.22 (+0.82%)

Performance

It is slow because it uses the fallback GEMM (as described above). Will wire Kleidi + MKL as a follow-up. Nothing uses bf16 SDPA yet, so this is fine.

(SDPA op timing - qwen3-1.7b)

Device Context Phase f32 (ms) bf16 (ms) bf16/f32
M4 Max 128 Prefill 1.855 10.180 5.5×
M4 Max 128 Decode 0.112 0.074 0.66×
M4 Max 1024 Prefill 76.930 1254.800 16.3×
M4 Max 1024 Decode 0.800 1.201 1.5×
S25 128 Prefill 3.924 43.760 11.2×
S25 128 Decode 0.245 0.329 1.3×
S25 1024 Prefill 168.340 2790.700 16.6×
S25 1024 Decode 1.614 2.698 1.7×

Differential Revision: D110246161

Pulled By: GregoryComer

@pytorch-bot

pytorch-bot Bot commented Jun 29, 2026

Copy link
Copy Markdown

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20611

Note: Links to docs will display an error until the docs builds have been completed.

❌ 2 New Failures

As of commit 6f86ec1 with merge base 8965e51 (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 29, 2026
@GregoryComer GregoryComer added the release notes: ops & kernels Changes to the opset and any new / changed kernel implementations label Jun 29, 2026
@GregoryComer GregoryComer marked this pull request as ready for review June 29, 2026 23:45
@GregoryComer GregoryComer marked this pull request as draft June 29, 2026 23:51
@GregoryComer GregoryComer marked this pull request as ready for review June 30, 2026 02:00
@meta-codesync

meta-codesync Bot commented Jun 30, 2026

Copy link
Copy Markdown
Contributor

@GregoryComer has imported this pull request. If you are a Meta employee, you can view this in D110246161.

@psiddh

psiddh commented Jul 2, 2026

Copy link
Copy Markdown
Contributor

Add support for reduced-precision (bf16 + f16) activations in the LLM SDPA op. Bf16 support is the main goal, but I wired up f16 at the same time, as it's straightforward.

This initial implementation just exposes the actual operator-level support and does not yet wire e2e through ET-LLM or anything else. It also initially only supports our simple CPU BLAS implementation. I'll wire up MKL as a follow up. I'm probably going to leverage Kleidi for the ARM GEMM kernels. Apple's Accelerate doesn't support BF16 gemms as far as I can tell. Kleidi will be fast on SME2 (M4/A18+) but won't hit AMX on M1-M3. I think that's fine, though. The i8mm/dot kernels should be sufficient. We can use BNNS if we need to close the gap.

Perplexity

Measured on wikitext, M4 Max, ~40k token sanity check, 1000 token windows

f32 bf16 delta
26.453 26.671 +0.22 (+0.82%)

Performance

It is slow because it uses the fallback GEMM (as described above). Will wire Kleidi + MKL as a follow-up. Nothing uses bf16 SDPA yet, so this is fine.

(SDPA op timing - qwen3-1.7b)

Device Context Phase f32 (ms) bf16 (ms) bf16/f32
M4 Max 128 Prefill 1.855 10.180 5.5×
M4 Max 128 Decode 0.112 0.074 0.66×
M4 Max 1024 Prefill 76.930 1254.800 16.3×
M4 Max 1024 Decode 0.800 1.201 1.5×
S25 128 Prefill 3.924 43.760 11.2×
S25 128 Decode 0.245 0.329 1.3×
S25 1024 Prefill 168.340 2790.700 16.6×
S25 1024 Decode 1.614 2.698 1.7×

nothing uses bf16 yet, so the slow path harms / regresses anything as no one is using it , right ?

Comment thread extension/llm/custom_ops/op_sdpa_impl.h
Comment thread extension/llm/custom_ops/op_sdpa_impl.h
Comment thread extension/llm/custom_ops/op_sdpa_impl.h
Comment thread extension/llm/custom_ops/op_sdpa_impl.h Outdated
@psiddh

psiddh commented Jul 2, 2026

Copy link
Copy Markdown
Contributor

Also maybe we should add more tests in follow-ups, e.g. long sequences, fp16 large-value/overflow cases.

@GregoryComer

GregoryComer commented Jul 2, 2026

Copy link
Copy Markdown
Member Author

nothing uses bf16 yet, so the slow path harms / regresses anything as no one is using it , right ?

That's correct. This isn't wired up e2e yet and is just the groundwork for the SDPA. I'll have MKL + Kleidi (or some other sane ARM GEMM) wired up before enabling it e2e.

@meta-codesync meta-codesync Bot changed the title Support bf16/fp16 activations in CPU SDPA Support bf16/fp16 activations in CPU SDPA (#20611) Jul 2, 2026
GregoryComer added a commit to GregoryComer/executorch that referenced this pull request Jul 2, 2026
Summary:
Add support for reduced-precision (bf16 + f16) activations in the LLM SDPA op. Bf16 support is the main goal, but I wired up f16 at the same time, as it's straightforward.

This initial implementation just exposes the actual operator-level support and does not yet wire e2e through ET-LLM or anything else. It also initially only supports our simple CPU BLAS implementation. I'll wire up MKL as a follow up. I'm probably going to leverage Kleidi for the ARM GEMM kernels. Apple's Accelerate doesn't support BF16 gemms as far as I can tell. Kleidi will be fast on SME2 (M4/A18+) but won't hit AMX on M1-M3. I think that's fine, though. The i8mm/dot kernels should be sufficient. We can use BNNS if we need to close the gap.

## Perplexity
Measured on wikitext, M4 Max, ~40k token sanity check, 1000 token windows
| f32 | bf16 | delta |
| -- | -- | -- |
| 26.453 | 26.671 | +0.22 (+0.82%) |

## Performance
It is slow because it uses the fallback GEMM (as described above). Will wire Kleidi + MKL as a follow-up. Nothing uses bf16 SDPA yet, so this is fine.

(SDPA op timing - qwen3-1.7b)
  | Device | Context | Phase   | f32 (ms) | bf16 (ms) | bf16/f32 |
  |--------|--------:|---------|---------:|----------:|---------:|
  | M4 Max | 128     | Prefill |    1.855 |    10.180 |    5.5×  |
  | M4 Max | 128     | Decode  |    0.112 |     0.074 |    0.66× |
  | M4 Max | 1024    | Prefill |   76.930 |  1254.800 |   16.3×  |
  | M4 Max | 1024    | Decode  |    0.800 |     1.201 |    1.5×  |
  | S25    | 128     | Prefill |    3.924 |    43.760 |   11.2×  |
  | S25    | 128     | Decode  |    0.245 |     0.329 |    1.3×  |
  | S25    | 1024    | Prefill |  168.340 |  2790.700 |   16.6×  |
  | S25    | 1024    | Decode  |    1.614 |     2.698 |    1.7×  |


Differential Revision: D110246161

Pulled By: GregoryComer
Summary:
Add support for reduced-precision (bf16 + f16) activations in the LLM SDPA op. Bf16 support is the main goal, but I wired up f16 at the same time, as it's straightforward.

This initial implementation just exposes the actual operator-level support and does not yet wire e2e through ET-LLM or anything else. It also initially only supports our simple CPU BLAS implementation. I'll wire up MKL as a follow up. I'm probably going to leverage Kleidi for the ARM GEMM kernels. Apple's Accelerate doesn't support BF16 gemms as far as I can tell. Kleidi will be fast on SME2 (M4/A18+) but won't hit AMX on M1-M3. I think that's fine, though. The i8mm/dot kernels should be sufficient. We can use BNNS if we need to close the gap.

## Perplexity
Measured on wikitext, M4 Max, ~40k token sanity check, 1000 token windows
| f32 | bf16 | delta |
| -- | -- | -- |
| 26.453 | 26.671 | +0.22 (+0.82%) |

## Performance
It is slow because it uses the fallback GEMM (as described above). Will wire Kleidi + MKL as a follow-up. Nothing uses bf16 SDPA yet, so this is fine.

(SDPA op timing - qwen3-1.7b)
  | Device | Context | Phase   | f32 (ms) | bf16 (ms) | bf16/f32 |
  |--------|--------:|---------|---------:|----------:|---------:|
  | M4 Max | 128     | Prefill |    1.855 |    10.180 |    5.5×  |
  | M4 Max | 128     | Decode  |    0.112 |     0.074 |    0.66× |
  | M4 Max | 1024    | Prefill |   76.930 |  1254.800 |   16.3×  |
  | M4 Max | 1024    | Decode  |    0.800 |     1.201 |    1.5×  |
  | S25    | 128     | Prefill |    3.924 |    43.760 |   11.2×  |
  | S25    | 128     | Decode  |    0.245 |     0.329 |    1.3×  |
  | S25    | 1024    | Prefill |  168.340 |  2790.700 |   16.6×  |
  | S25    | 1024    | Decode  |    1.614 |     2.698 |    1.7×  |

Pull Request resolved: pytorch#20611

Differential Revision: D110246161

Pulled By: GregoryComer
@meta-codesync

meta-codesync Bot commented Jul 2, 2026

Copy link
Copy Markdown
Contributor

@GregoryComer has exported this pull request. If you are a Meta employee, you can view the originating Diff in D110246161.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. meta-exported release notes: ops & kernels Changes to the opset and any new / changed kernel implementations

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants