Skip to content

backends/mlx: runtime MoE expert-sort for decode (issue #20554)#20685

Open
AxelNoun wants to merge 1 commit into
pytorch:mainfrom
AxelNoun:moe-runtime-sort-20554
Open

backends/mlx: runtime MoE expert-sort for decode (issue #20554)#20685
AxelNoun wants to merge 1 commit into
pytorch:mainfrom
AxelNoun:moe-runtime-sort-20554

Conversation

@AxelNoun

@AxelNoun AxelNoun commented Jul 2, 2026

Copy link
Copy Markdown

Summary

Replace the compile-time sort_experts: bool flag in SwitchMLP with a runtime decision inside two new custom ops (moe_gather_inputs, moe_scatter_outputs). A single exported .pte now handles both prefill (sorted, coalesced gather_mm) and decode (unsorted, no argsort overhead) without separate exports.

Key changes:

  • schema.fbs: sorted_indices: boolIntOrVid (required) on GatherMmNode/GatherQmmNode; required fields before optionals
  • MLXInterpreter.h: resolve_int(n.sorted_indices, st) != 0 (cf. kth)
  • custom_ops.py: moe_gather_inputs, moe_scatter_outputs; gather_mm/gather_qmm sorted_indices: Optional[Tensor]
  • ops.py: new MoE handlers + updated gather handlers for IntOrVid
  • switch.py: sort_cutoff replaces compile-time sort branch
  • test_ops.py: MoE + GatherMm/GatherQmm tests with sorted_indices=Tensor configs

MLXLoader.{h,cpp} and FlatBuffer bindings are regenerated automatically by generate.py + flatc during the CMake build on Mac CI — they are not included in this commit, per repo convention.

Test plan

  • Windows: python backends/mlx/test/validate_moe_20554.py (all passed)
  • CI: test-mlx job on macos-14-xlarge (run_all_tests — covers gather_mm, gather_qmm, moe_gather_inputs, moe_scatter_outputs)

Fixes #20554

PR authored with Claude.

@pytorch-bot

pytorch-bot Bot commented Jul 2, 2026

Copy link
Copy Markdown

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20685

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 86ee91c with merge base 0f3303f (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla

meta-cla Bot commented Jul 2, 2026

Copy link
Copy Markdown

Hi @AxelNoun!

Thank you for your pull request and welcome to our community.

Action Required

In order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at cla@meta.com. Thanks!

@linux-foundation-easycla

linux-foundation-easycla Bot commented Jul 2, 2026

Copy link
Copy Markdown

CLA Not Signed

@github-actions

github-actions Bot commented Jul 2, 2026

Copy link
Copy Markdown

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Replace the compile-time sort_experts: bool flag in SwitchMLP with a
runtime decision made inside two new custom ops (moe_gather_inputs,
moe_scatter_outputs). A single exported .pte now handles both prefill
(sorted, coalesced gather_mm) and decode (unsorted, no argsort overhead)
without requiring separate exports.

Changes:
- schema.fbs: sorted_indices: bool -> IntOrVid (required) on
  GatherMmNode/GatherQmmNode; required fields moved before optionals
- MLXInterpreter.h: resolve_int(n.sorted_indices, st) != 0 (cf. kth)
- custom_ops.py: moe_gather_inputs, moe_scatter_outputs + register_fake;
  gather_mm/gather_qmm sorted_indices: bool -> Optional[Tensor]
- op_helpers.py: emit_floordiv helper (alongside emit_ceil_div)
- ops.py: _moe_gather_inputs_handler, _moe_scatter_outputs_handler;
  updated _gather_mm/_gather_qmm handlers for IntOrVid
- switch.py: SwitchMLP gains sort_cutoff; forward replaces if/else block
  with the two new ops; SwitchLinear sorted_indices: bool -> Optional[Tensor]
- mlx_source_transformations.py + export.py: sort_experts -> sort_cutoff
- test_ops.py: MoeGatherInputsTest, MoeScatterOutputsTest with
  expected_node_counts; GatherMmTest/GatherQmmTest extended for
  sorted_indices=Tensor configs

Test plan:
- Windows: python backends/mlx/test/validate_moe_20554.py (all passed)
- CI: test-mlx job on macos-14-xlarge (run_all_tests)

Fixes pytorch#20554

PR authored with Claude.

Co-authored-by: Cursor <cursoragent@cursor.com>
@AxelNoun AxelNoun force-pushed the moe-runtime-sort-20554 branch from d709ef4 to 86ee91c Compare July 2, 2026 03:44
@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jul 2, 2026
@rascani rascani requested a review from metascroy July 2, 2026 15:38
b: "IntOrVid",
) -> "IntOrVid":
"""Emit ``a // b`` (floor division), folding when both operands are
static (issue #20554). Used for ``cond = (M - 1) // sort_cutoff``:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Clean up comments, removing things like "issue 20554". Comments should be clear and concise, with code self-documenting.

x: torch.Tensor,
indices: torch.Tensor,
sorted_indices: bool = False,
sorted_indices: Optional[torch.Tensor] = None,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does None mean?

w: Tid (required); // Quantized weight matrix [E, out, in_packed]
scales: Tid (required); // Quantization scales [E, out, in//gs]
x: Tid (required);
w: Tid (required);

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Put back comments like " // Quantized weight matrix [E, out, in_packed]". We should only modify sorted_indices

To preserve BC, could "sorted_indices: IntOrVid (required)" be "sorted_indices_flag: IntOrVid // optional" and added last? If provided, sorted_indices is ignored.

return (x, y)


class ReinplaceChainModel(nn.Module):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you need to rebase on main? This PR is removing code that was checked in?

@@ -0,0 +1,282 @@
#!/usr/bin/env python3

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this file needed? What is it testing beyond our test_ops.py file?

"""
Gumbel-max sampling from softmax(logits / temperature), with top-k and
top-p (nucleus) filtering.
Gumbel-max sampling from softmax(logits / temperature), with top-p (nucleus).

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These changes look related to old base?

Comment thread backends/mlx/ops.py

one_const = emit_lifted_constant(P, 1, torch.int32)

P.emit(IdCopyNode(x=P.slot_to_tid(x_input_slot), out=P.slot_to_tid(out_slots[0])))

@metascroy metascroy Jul 2, 2026

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we directly write to P.slot_to_tid(out_slots[i]) above and avoid these IdCopyNode?

@metascroy

Copy link
Copy Markdown
Contributor

@claude review this code

@claude

claude Bot commented Jul 2, 2026

Copy link
Copy Markdown

Claude finished @metascroy's task in 3m 4s —— View job


Code Review

I reviewed the diff against origin/main. The runtime-sort design is sound — moving the sort/no-sort choice into moe_gather_inputs and mirroring it branch-for-branch in the lowering handler is a clean way to get one .pte that serves both prefill and decode. But there is one blocking issue and several smaller ones. I've grouped them below.


🔴 Blocking: PR is built on a stale base (confirms @metascroy's suspicion)

The diff deletes code that already exists on origin/main from two other PRs. This is not intentional to this feature and will regress main if merged as-is:

  • Reinplacement handlers (from [MLX] enable reinplacement for unary/binary ops #20557, merged as 0f3303f): ops.py removes _make_inplace_unary_handler, _make_inplace_binary_handler, _make_inplace_addsub_handler, _make_inplace_passthrough_handler, REINPLACEABLE_UNARY_BASE_NAMES, REINPLACEABLE_BINARY_BASE_OVERLOADS, REINPLACEABLE_EXTRA_BASE_OVERLOADS, and the _derive_edge_inplace_overload import (~330 lines). test_ops.py deletes ReinplaceChainModel / ReinplaceChainTest.
  • Sample top-k (from Add top-k support to MLX sample #20564, merged as 035b45a): custom_ops.py and ops.py remove the top_k argument and the top-k filtering chain from mlx::sample.

I verified these still exist on origin/main:

$ git show origin/main:backends/mlx/ops.py | grep -c _make_inplace_unary_handler   # -> 2
$ git show origin/main:backends/mlx/custom_ops.py | grep -n top_k                   # -> present

The PR's sample signature drops top_k — that would silently revert #20564's feature and break its callers/tests.

Action: rebase onto current origin/main so the diff contains only the MoE runtime-sort changes. After rebasing, re-verify the ops.py diff has no _make_inplace_* or sample/top_k deletions. Rebase & re-verify →


🟠 Schema backward-compatibility (re: @metascroy's schema.fbs comment)

Two concerns on schema.fbs:

  1. BC break. Changing sorted_indices: bool = falsesorted_indices: IntOrVid (required) and reordering it before the optional Tid fields is a wire-incompatible change to GatherMmNode/GatherQmmNode — existing .pte files will not deserialize. As @metascroy suggested, a BC-preserving option is to keep the existing fields untouched and add a new trailing optional field, e.g. sorted_indices_flag: IntOrVid; (added last, ignored when absent → falls back to the old sorted_indices bool). FlatBuffers only guarantees BC when new fields are appended and old ones keep their IDs.
  2. Lost comments. The field comments (// Input activations, // Weight matrix [E, out, in], // Quantized weight matrix [E, out, in_packed], // "affine", "fp", etc.) were removed. Per @metascroy, restore them — only the sorted_indices line should change.

🟡 Comment style (repo CLAUDE.md: "minimal comments; code self-documenting")

Many new comments embed issue #20554 and narrate the diff/design history rather than the code. These should be trimmed:

🟡 switch.py:135 — "What does None mean?" (@metascroy)

sorted_indices: Optional[torch.Tensor] = None — the contract (None/0 = unsorted, non-zero 0-d tensor = sorted) is only documented in custom_ops.py. Add a one-line note at the SwitchLinear.forward_raw signature stating what None means so callers don't have to trace it.

🟡 ops.pyIdCopyNode fan-out (@metascroy's op.py:1719 comment)

Both emit_sorted/emit_unsorted compute into fresh temps and then IdCopyNode into the 4 fixed out_slots. @metascroy asked whether you can write the final op directly into P.slot_to_tid(out_slots[i]) and skip the copies. The reason the copies exist is that emit_if_else requires both branches to write the same fixed output slots, so the pre-allocated out_slots must be written inside each branch. That said, the last producing op in each branch (e.g. ExpandDimsNodeout_slots[0], the flatten/take → out_slots[1]) can usually target out_slots[i] directly, eliminating most of the 5 IdCopyNodes. Worth trying — it reduces node count on the hot MoE path. At minimum, the const-passthroughs (one_const/zero_constout_slots[2]) are hard to avoid, but the tensor ones look collapsible.

🟢 test_ops.py — stale-base deletions + node-count coupling

  • The ReinplaceChainTest deletion (line ~117) is part of the stale-base problem above — it'll come back on rebase.
  • get_expected_node_counts() hard-codes IdCopyNode: 5. If you collapse the IdCopyNodes per the above, update these. Good that both prefill/decode paths and IfNode: 0 (fold) are asserted.

🟢 validate_moe_20554.py (@metascroy: "Is this file needed?")

This 282-line standalone script duplicates coverage now in test_ops.py (MoeGatherInputsTest/MoeScatterOutputsTest) and isn't wired into pytest/run_all_tests. Recommend deleting it and folding any unique assertions (e.g. specific node-count checks) into test_ops.py, so there's a single source of truth that CI actually runs.


Correctness notes (look good, for the record)

  • The moe_gather_inputs eager op branches on N while its register_fake does not (always emits the sorted-path shape) — correct, since data-dependent control flow on a symbolic dim is illegal under export, and the fake must be shape-stable.
  • cond = (M - 1) // sort_cutoff correctly maps M <= sort_cutoff → 0 (else/unsorted) and M > sort_cutoff → ≥1 (then/sorted), and folds to a single branch when M is static. 👍
  • The inv_order empty-tensor sentinel is only read on the sorted path, so the static 0-element constant on the unsorted path is safe.
  • MLXInterpreter.h resolve_int(n.sorted_indices, st) != 0 mirrors the kth pattern — consistent.

Summary: The feature design is good and the correctness reasoning holds up. The must-fix is the rebase — the diff currently reverts #20557 and #20564. After that, address the schema BC/comments, trim the changelog-style comments, answer the None contract in switch.py, try collapsing the IdCopyNodes, and remove validate_moe_20554.py.

@metascroy

Copy link
Copy Markdown
Contributor

Thanks @AxelNoun! This is a good start, but it looks like it's based on an old base, so it was a little hard to review.

Left some comments :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Good First Issue: Runtime MoE expert-sort for decode (MLX backend, Qwen 3.5 MoE)

3 participants