backends/mlx: runtime MoE expert-sort for decode (issue #20554)#20685
backends/mlx: runtime MoE expert-sort for decode (issue #20554)#20685AxelNoun wants to merge 1 commit into
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20685
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 86ee91c with merge base 0f3303f ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
Hi @AxelNoun! Thank you for your pull request and welcome to our community. Action RequiredIn order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you. ProcessIn order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with If you have received this in error or have any questions, please contact us at cla@meta.com. Thanks! |
|
This PR needs a
|
Replace the compile-time sort_experts: bool flag in SwitchMLP with a runtime decision made inside two new custom ops (moe_gather_inputs, moe_scatter_outputs). A single exported .pte now handles both prefill (sorted, coalesced gather_mm) and decode (unsorted, no argsort overhead) without requiring separate exports. Changes: - schema.fbs: sorted_indices: bool -> IntOrVid (required) on GatherMmNode/GatherQmmNode; required fields moved before optionals - MLXInterpreter.h: resolve_int(n.sorted_indices, st) != 0 (cf. kth) - custom_ops.py: moe_gather_inputs, moe_scatter_outputs + register_fake; gather_mm/gather_qmm sorted_indices: bool -> Optional[Tensor] - op_helpers.py: emit_floordiv helper (alongside emit_ceil_div) - ops.py: _moe_gather_inputs_handler, _moe_scatter_outputs_handler; updated _gather_mm/_gather_qmm handlers for IntOrVid - switch.py: SwitchMLP gains sort_cutoff; forward replaces if/else block with the two new ops; SwitchLinear sorted_indices: bool -> Optional[Tensor] - mlx_source_transformations.py + export.py: sort_experts -> sort_cutoff - test_ops.py: MoeGatherInputsTest, MoeScatterOutputsTest with expected_node_counts; GatherMmTest/GatherQmmTest extended for sorted_indices=Tensor configs Test plan: - Windows: python backends/mlx/test/validate_moe_20554.py (all passed) - CI: test-mlx job on macos-14-xlarge (run_all_tests) Fixes pytorch#20554 PR authored with Claude. Co-authored-by: Cursor <cursoragent@cursor.com>
d709ef4 to
86ee91c
Compare
| b: "IntOrVid", | ||
| ) -> "IntOrVid": | ||
| """Emit ``a // b`` (floor division), folding when both operands are | ||
| static (issue #20554). Used for ``cond = (M - 1) // sort_cutoff``: |
There was a problem hiding this comment.
Clean up comments, removing things like "issue 20554". Comments should be clear and concise, with code self-documenting.
| x: torch.Tensor, | ||
| indices: torch.Tensor, | ||
| sorted_indices: bool = False, | ||
| sorted_indices: Optional[torch.Tensor] = None, |
| w: Tid (required); // Quantized weight matrix [E, out, in_packed] | ||
| scales: Tid (required); // Quantization scales [E, out, in//gs] | ||
| x: Tid (required); | ||
| w: Tid (required); |
There was a problem hiding this comment.
Put back comments like " // Quantized weight matrix [E, out, in_packed]". We should only modify sorted_indices
To preserve BC, could "sorted_indices: IntOrVid (required)" be "sorted_indices_flag: IntOrVid // optional" and added last? If provided, sorted_indices is ignored.
| return (x, y) | ||
|
|
||
|
|
||
| class ReinplaceChainModel(nn.Module): |
There was a problem hiding this comment.
I think you need to rebase on main? This PR is removing code that was checked in?
| @@ -0,0 +1,282 @@ | |||
| #!/usr/bin/env python3 | |||
There was a problem hiding this comment.
Is this file needed? What is it testing beyond our test_ops.py file?
| """ | ||
| Gumbel-max sampling from softmax(logits / temperature), with top-k and | ||
| top-p (nucleus) filtering. | ||
| Gumbel-max sampling from softmax(logits / temperature), with top-p (nucleus). |
There was a problem hiding this comment.
These changes look related to old base?
|
|
||
| one_const = emit_lifted_constant(P, 1, torch.int32) | ||
|
|
||
| P.emit(IdCopyNode(x=P.slot_to_tid(x_input_slot), out=P.slot_to_tid(out_slots[0]))) |
There was a problem hiding this comment.
Could we directly write to P.slot_to_tid(out_slots[i]) above and avoid these IdCopyNode?
|
@claude review this code |
|
Claude finished @metascroy's task in 3m 4s —— View job Code ReviewI reviewed the diff against 🔴 Blocking: PR is built on a stale base (confirms @metascroy's suspicion)The diff deletes code that already exists on
I verified these still exist on The PR's Action: rebase onto current 🟠 Schema backward-compatibility (re: @metascroy's schema.fbs comment)Two concerns on
🟡 Comment style (repo CLAUDE.md: "minimal comments; code self-documenting")Many new comments embed
🟡
|
|
Thanks @AxelNoun! This is a good start, but it looks like it's based on an old base, so it was a little hard to review. Left some comments :) |
Summary
Replace the compile-time
sort_experts: boolflag inSwitchMLPwith a runtime decision inside two new custom ops (moe_gather_inputs,moe_scatter_outputs). A single exported.ptenow handles both prefill (sorted, coalescedgather_mm) and decode (unsorted, no argsort overhead) without separate exports.Key changes:
schema.fbs:sorted_indices: bool→IntOrVid(required) onGatherMmNode/GatherQmmNode; required fields before optionalsMLXInterpreter.h:resolve_int(n.sorted_indices, st) != 0(cf.kth)custom_ops.py:moe_gather_inputs,moe_scatter_outputs;gather_mm/gather_qmmsorted_indices: Optional[Tensor]ops.py: new MoE handlers + updated gather handlers forIntOrVidswitch.py:sort_cutoffreplaces compile-time sort branchtest_ops.py: MoE + GatherMm/GatherQmm tests withsorted_indices=TensorconfigsMLXLoader.{h,cpp}and FlatBuffer bindings are regenerated automatically bygenerate.py+flatcduring the CMake build on Mac CI — they are not included in this commit, per repo convention.Test plan
python backends/mlx/test/validate_moe_20554.py(all passed)test-mlxjob onmacos-14-xlarge(run_all_tests— coversgather_mm,gather_qmm,moe_gather_inputs,moe_scatter_outputs)Fixes #20554
PR authored with Claude.