EPLB support for MTP (#7510)
This commit is contained in:
@@ -61,6 +61,10 @@ class ExpertDistributionRecorder(ABC):
|
|||||||
def with_debug_name(self, debug_name):
|
def with_debug_name(self, debug_name):
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def disable_this_region(self):
|
||||||
|
yield
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def with_forward_pass(self, forward_pass_id: int, forward_batch: ForwardBatch):
|
def with_forward_pass(self, forward_pass_id: int, forward_batch: ForwardBatch):
|
||||||
yield
|
yield
|
||||||
@@ -116,6 +120,7 @@ class _ExpertDistributionRecorderReal(ExpertDistributionRecorder):
|
|||||||
self._expert_location_metadata = expert_location_metadata
|
self._expert_location_metadata = expert_location_metadata
|
||||||
|
|
||||||
self._recording = False
|
self._recording = False
|
||||||
|
self._disable_all = False
|
||||||
self._current_forward_pass_id = Withable()
|
self._current_forward_pass_id = Withable()
|
||||||
self._current_layer_idx = Withable()
|
self._current_layer_idx = Withable()
|
||||||
self._current_debug_name = Withable()
|
self._current_debug_name = Withable()
|
||||||
@@ -148,6 +153,16 @@ class _ExpertDistributionRecorderReal(ExpertDistributionRecorder):
|
|||||||
finally:
|
finally:
|
||||||
self._on_forward_pass_end(forward_pass_id)
|
self._on_forward_pass_end(forward_pass_id)
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def disable_this_region(self):
|
||||||
|
"""Context manager to temporarily disable recording."""
|
||||||
|
previous_disable_all = self._disable_all
|
||||||
|
self._disable_all = True
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
self._disable_all = previous_disable_all
|
||||||
|
|
||||||
def _on_forward_pass_start(self, forward_batch: ForwardBatch):
|
def _on_forward_pass_start(self, forward_batch: ForwardBatch):
|
||||||
if not self._recording:
|
if not self._recording:
|
||||||
return
|
return
|
||||||
@@ -189,6 +204,8 @@ class _ExpertDistributionRecorderReal(ExpertDistributionRecorder):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _on_hook(self, hook_name: str, **kwargs):
|
def _on_hook(self, hook_name: str, **kwargs):
|
||||||
|
if self._disable_all:
|
||||||
|
return
|
||||||
if not (self._recording or torch.cuda.is_current_stream_capturing()):
|
if not (self._recording or torch.cuda.is_current_stream_capturing()):
|
||||||
return
|
return
|
||||||
gatherer = self._single_pass_gatherers[
|
gatherer = self._single_pass_gatherers[
|
||||||
@@ -462,6 +479,10 @@ class _SelectExpertsSinglePassGatherer(_LayerBasedGpuSinglePassGatherer):
|
|||||||
def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor):
|
def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor):
|
||||||
topk_ids = topk_ids.flatten()
|
topk_ids = topk_ids.flatten()
|
||||||
mask = topk_ids != -1
|
mask = topk_ids != -1
|
||||||
|
assert self._data[layer_idx, :].shape == topk_ids.shape, (
|
||||||
|
"Shape mismatch between data and topk_ids."
|
||||||
|
"Selecting expert is not supported for multiple token prediction at the moment."
|
||||||
|
)
|
||||||
self._data[layer_idx, :].scatter_add_(
|
self._data[layer_idx, :].scatter_add_(
|
||||||
dim=0, index=topk_ids.masked_fill(~mask, 0).long(), src=mask.int()
|
dim=0, index=topk_ids.masked_fill(~mask, 0).long(), src=mask.int()
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -28,6 +28,9 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|||||||
ParallelLMHead,
|
ParallelLMHead,
|
||||||
VocabParallelEmbedding,
|
VocabParallelEmbedding,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.managers.expert_distribution import (
|
||||||
|
get_global_expert_distribution_recorder,
|
||||||
|
)
|
||||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
from sglang.srt.models.deepseek_v2 import DeepseekV2DecoderLayer, DeepseekV3ForCausalLM
|
from sglang.srt.models.deepseek_v2 import DeepseekV2DecoderLayer, DeepseekV3ForCausalLM
|
||||||
@@ -82,7 +85,6 @@ class DeepseekModelNextN(nn.Module):
|
|||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
input_embeds: torch.Tensor = None,
|
input_embeds: torch.Tensor = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
zero_allocator = BumpAllocator(
|
zero_allocator = BumpAllocator(
|
||||||
buffer_size=2,
|
buffer_size=2,
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
@@ -108,9 +110,10 @@ class DeepseekModelNextN(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
residual = None
|
residual = None
|
||||||
hidden_states, residual = self.decoder(
|
with get_global_expert_distribution_recorder().disable_this_region():
|
||||||
positions, hidden_states, forward_batch, residual, zero_allocator
|
hidden_states, residual = self.decoder(
|
||||||
)
|
positions, hidden_states, forward_batch, residual, zero_allocator
|
||||||
|
)
|
||||||
|
|
||||||
if not forward_batch.forward_mode.is_idle():
|
if not forward_batch.forward_mode.is_idle():
|
||||||
if residual is not None:
|
if residual is not None:
|
||||||
|
|||||||
Reference in New Issue
Block a user