Add fp4 quantize before all-gather for Flashinfer cutlass MoE DP (max throughput) (#7667)

This commit is contained in:
Trevor Morris
2025-08-15 22:08:11 -07:00
committed by GitHub
parent 87dab54824
commit eff4eb3fdd
16 changed files with 360 additions and 52 deletions

View File

@@ -84,6 +84,7 @@ class _StageExecutor:
forward_batch: ForwardBatch = inputs["forward_batch"]
self._global_dp_buffer_len = forward_batch.global_dp_buffer_len
self._local_dp_buffer_len = forward_batch.input_ids.shape[0]
self._global_num_tokens = forward_batch.global_num_tokens_cpu
def next(self):
assert not self.done
@@ -91,7 +92,11 @@ class _StageExecutor:
stage = self._stages[self._index]
if self._global_dp_buffer_len is not None:
set_dp_buffer_len(self._global_dp_buffer_len, self._local_dp_buffer_len)
set_dp_buffer_len(
self._global_dp_buffer_len,
self._local_dp_buffer_len,
self._global_num_tokens,
)
with _annotate_region(debug_name=f"{self._debug_name}{self._index}"):
for op in stage: