Add fp4 quantize before all-gather for Flashinfer cutlass MoE DP (max throughput) (#7667)
This commit is contained in:
@@ -84,6 +84,7 @@ class _StageExecutor:
|
||||
forward_batch: ForwardBatch = inputs["forward_batch"]
|
||||
self._global_dp_buffer_len = forward_batch.global_dp_buffer_len
|
||||
self._local_dp_buffer_len = forward_batch.input_ids.shape[0]
|
||||
self._global_num_tokens = forward_batch.global_num_tokens_cpu
|
||||
|
||||
def next(self):
|
||||
assert not self.done
|
||||
@@ -91,7 +92,11 @@ class _StageExecutor:
|
||||
stage = self._stages[self._index]
|
||||
|
||||
if self._global_dp_buffer_len is not None:
|
||||
set_dp_buffer_len(self._global_dp_buffer_len, self._local_dp_buffer_len)
|
||||
set_dp_buffer_len(
|
||||
self._global_dp_buffer_len,
|
||||
self._local_dp_buffer_len,
|
||||
self._global_num_tokens,
|
||||
)
|
||||
|
||||
with _annotate_region(debug_name=f"{self._debug_name}{self._index}"):
|
||||
for op in stage:
|
||||
|
||||
Reference in New Issue
Block a user