Support both approximate and exact expert distribution collection (#6964)
This commit is contained in:
@@ -264,15 +264,23 @@ class _SinglePassGatherer(ABC):
|
||||
return _DetailSinglePassGatherer(
|
||||
server_args, expert_location_metadata, rank
|
||||
)
|
||||
|
||||
if server_args.expert_distribution_recorder_mode == "stat_approx":
|
||||
if server_args.enable_deepep_moe and (server_args.deepep_mode == "normal"):
|
||||
return _DeepepNormalSinglePassGatherer(expert_location_metadata, rank)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
if server_args.enable_deepep_moe:
|
||||
if server_args.deepep_mode == "normal":
|
||||
return _DeepepNormalSinglePassGatherer(expert_location_metadata, rank)
|
||||
return _SelectExpertsSinglePassGatherer(expert_location_metadata, rank)
|
||||
elif server_args.deepep_mode == "low_latency":
|
||||
return _DeepepLowLatencySinglePassGatherer(
|
||||
expert_location_metadata, rank
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return _SelectExpertsSinglePassGatherer(expert_location_metadata, rank)
|
||||
|
||||
def __init__(self, expert_location_metadata: "ExpertLocationMetadata", rank: int):
|
||||
@@ -347,7 +355,9 @@ class _DetailSinglePassGatherer(_SinglePassGatherer):
|
||||
)
|
||||
|
||||
def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor):
|
||||
self._topk_ids_of_layer[layer_idx, : topk_ids.shape[0], :] = topk_ids
|
||||
self._topk_ids_of_layer[layer_idx, : topk_ids.shape[0], : topk_ids.shape[1]] = (
|
||||
topk_ids
|
||||
)
|
||||
|
||||
def on_deepep_dispatch_normal(
|
||||
self,
|
||||
@@ -380,7 +390,7 @@ class _DetailSinglePassGatherer(_SinglePassGatherer):
|
||||
)
|
||||
|
||||
|
||||
class _LayerBasedSinglePassGatherer(_SinglePassGatherer):
|
||||
class _LayerBasedCpuSinglePassGatherer(_SinglePassGatherer):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._objects_of_layer = {}
|
||||
@@ -409,29 +419,63 @@ def _list_sum(a: List, b: List) -> List:
|
||||
return [x + y for x, y in zip(a, b, strict=True)]
|
||||
|
||||
|
||||
class _SelectExpertsSinglePassGatherer(_LayerBasedSinglePassGatherer):
|
||||
# pretty slow, but we will use the DeepEP Gatherer in production
|
||||
def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor):
|
||||
topk_ids_list = topk_ids.to("cpu", non_blocking=True).numpy().tolist()
|
||||
torch.cuda.synchronize()
|
||||
class _LayerBasedGpuSinglePassGatherer(_SinglePassGatherer):
|
||||
def __init__(self, *args, enable_global_physical_experts: bool, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._enable_global_physical_experts = enable_global_physical_experts
|
||||
self._data = torch.zeros(
|
||||
(
|
||||
self._expert_location_metadata.num_layers,
|
||||
(
|
||||
self._expert_location_metadata.num_physical_experts
|
||||
if enable_global_physical_experts
|
||||
else self._expert_location_metadata.num_local_physical_experts
|
||||
),
|
||||
),
|
||||
dtype=torch.int,
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
global_physical_count = [
|
||||
0
|
||||
] * self._expert_location_metadata.num_physical_experts
|
||||
for token_record in topk_ids_list:
|
||||
for global_physical_expert_idx in token_record:
|
||||
global_physical_count[global_physical_expert_idx] += 1
|
||||
|
||||
self._on_layer_data(layer_idx, global_physical_count)
|
||||
def reset(self):
|
||||
self._data[...] = 0
|
||||
|
||||
def collect(self) -> Dict:
|
||||
global_physical_count = super()._collect_objects(
|
||||
pad_len=self._expert_location_metadata.num_physical_experts
|
||||
)
|
||||
if self._enable_global_physical_experts:
|
||||
global_physical_count = self._data
|
||||
else:
|
||||
# Can optimize if bottleneck
|
||||
global_physical_count = _convert_local_to_global_physical_count(
|
||||
self._data,
|
||||
rank=self._rank,
|
||||
num_local_physical_experts=self._expert_location_metadata.num_local_physical_experts,
|
||||
num_physical_experts=self._expert_location_metadata.num_physical_experts,
|
||||
)
|
||||
|
||||
return dict(global_physical_count=global_physical_count)
|
||||
|
||||
|
||||
class _DeepepNormalSinglePassGatherer(_LayerBasedSinglePassGatherer):
|
||||
class _SelectExpertsSinglePassGatherer(_LayerBasedGpuSinglePassGatherer):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs, enable_global_physical_experts=True)
|
||||
|
||||
# can optimize (e.g. fuse / compile)
|
||||
def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor):
|
||||
topk_ids = topk_ids.flatten()
|
||||
mask = topk_ids != -1
|
||||
self._data[layer_idx, :].scatter_add_(
|
||||
dim=0, index=topk_ids.masked_fill(~mask, 0).long(), src=mask.int()
|
||||
)
|
||||
|
||||
|
||||
class _DeepepNormalSinglePassGatherer(_LayerBasedCpuSinglePassGatherer):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
if torch.distributed.get_rank() == 0:
|
||||
logger.info(
|
||||
"DeepepNormalSinglePassGatherer gathers approximate statistics. "
|
||||
"If used with small batch size, consider using expert_distribution_recorder_mode=stat."
|
||||
)
|
||||
|
||||
def on_deepep_dispatch_normal(
|
||||
self,
|
||||
layer_idx: int,
|
||||
@@ -456,17 +500,9 @@ class _DeepepNormalSinglePassGatherer(_LayerBasedSinglePassGatherer):
|
||||
return dict(global_physical_count=global_physical_count)
|
||||
|
||||
|
||||
class _DeepepLowLatencySinglePassGatherer(_SinglePassGatherer):
|
||||
class _DeepepLowLatencySinglePassGatherer(_LayerBasedGpuSinglePassGatherer):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._data = torch.zeros(
|
||||
(
|
||||
self._expert_location_metadata.num_layers,
|
||||
self._expert_location_metadata.num_local_physical_experts,
|
||||
),
|
||||
dtype=torch.int,
|
||||
device="cuda",
|
||||
)
|
||||
super().__init__(*args, **kwargs, enable_global_physical_experts=False)
|
||||
|
||||
def on_deepep_dispatch_low_latency(
|
||||
self, layer_idx: int, local_physical_count_of_layer: torch.Tensor
|
||||
@@ -474,19 +510,6 @@ class _DeepepLowLatencySinglePassGatherer(_SinglePassGatherer):
|
||||
# Most naive implementation, can optimize later
|
||||
self._data[layer_idx, :] += local_physical_count_of_layer
|
||||
|
||||
def reset(self):
|
||||
self._data[...] = 0
|
||||
|
||||
def collect(self) -> Dict:
|
||||
# Can optimize if bottleneck
|
||||
global_physical_count = _convert_local_to_global_physical_count(
|
||||
self._data,
|
||||
rank=self._rank,
|
||||
num_local_physical_experts=self._expert_location_metadata.num_local_physical_experts,
|
||||
num_physical_experts=self._expert_location_metadata.num_physical_experts,
|
||||
)
|
||||
return dict(global_physical_count=global_physical_count)
|
||||
|
||||
|
||||
def _convert_local_to_global_physical_count(
|
||||
local_physical_count: torch.Tensor,
|
||||
@@ -525,6 +548,7 @@ class _Accumulator(ABC):
|
||||
def get_class(server_args: ServerArgs) -> Type["_Accumulator"]:
|
||||
return {
|
||||
"stat": _StatAccumulator,
|
||||
"stat_approx": _StatAccumulator,
|
||||
"per_pass": _DetailAccumulator,
|
||||
"per_token": _DetailAccumulator,
|
||||
}[server_args.expert_distribution_recorder_mode]
|
||||
|
||||
@@ -460,22 +460,25 @@ class DeepseekV2MoE(nn.Module):
|
||||
hidden_states = state.hidden_states_mlp_input
|
||||
|
||||
if router_logits is not None:
|
||||
state.topk_weights_local, state.topk_idx_local = select_experts(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
top_k=self.top_k,
|
||||
use_grouped_topk=True,
|
||||
renormalize=self.renormalize,
|
||||
topk_group=self.topk_group,
|
||||
num_expert_group=self.num_expert_group,
|
||||
num_fused_shared_experts=self.num_fused_shared_experts,
|
||||
correction_bias=self.correction_bias,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
num_token_non_padded=state.forward_batch.num_token_non_padded,
|
||||
expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
|
||||
layer_id=self.layer_id,
|
||||
),
|
||||
)
|
||||
with get_global_expert_distribution_recorder().with_current_layer(
|
||||
self.layer_id
|
||||
):
|
||||
state.topk_weights_local, state.topk_idx_local = select_experts(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
top_k=self.top_k,
|
||||
use_grouped_topk=True,
|
||||
renormalize=self.renormalize,
|
||||
topk_group=self.topk_group,
|
||||
num_expert_group=self.num_expert_group,
|
||||
num_fused_shared_experts=self.num_fused_shared_experts,
|
||||
correction_bias=self.correction_bias,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
num_token_non_padded=state.forward_batch.num_token_non_padded,
|
||||
expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
|
||||
layer_id=self.layer_id,
|
||||
),
|
||||
)
|
||||
else:
|
||||
state.topk_idx_local = torch.full(
|
||||
(0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
|
||||
|
||||
@@ -255,17 +255,20 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
||||
router_logits = state.pop("router_logits")
|
||||
hidden_states = state.hidden_states_mlp_input
|
||||
if router_logits is not None:
|
||||
state.topk_weights_local, state.topk_idx_local = select_experts(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
top_k=self.top_k,
|
||||
use_grouped_topk=False,
|
||||
renormalize=self.renormalize,
|
||||
num_token_non_padded=state.forward_batch.num_token_non_padded,
|
||||
expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
|
||||
layer_id=self.layer_id,
|
||||
),
|
||||
)
|
||||
with get_global_expert_distribution_recorder().with_current_layer(
|
||||
self.layer_id
|
||||
):
|
||||
state.topk_weights_local, state.topk_idx_local = select_experts(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
top_k=self.top_k,
|
||||
use_grouped_topk=False,
|
||||
renormalize=self.renormalize,
|
||||
num_token_non_padded=state.forward_batch.num_token_non_padded,
|
||||
expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
|
||||
layer_id=self.layer_id,
|
||||
),
|
||||
)
|
||||
else:
|
||||
state.topk_idx_local = torch.full(
|
||||
(0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
|
||||
|
||||
@@ -182,7 +182,7 @@ class ServerArgs:
|
||||
eplb_rebalance_num_iterations: int = 1000
|
||||
eplb_rebalance_layers_per_chunk: Optional[int] = None
|
||||
expert_distribution_recorder_mode: Optional[
|
||||
Literal["stat", "per_pass", "per_token"]
|
||||
Literal["stat", "stat_approx", "per_pass", "per_token"]
|
||||
] = None
|
||||
expert_distribution_recorder_buffer_size: Optional[int] = None
|
||||
enable_expert_distribution_metrics: bool = False
|
||||
|
||||
Reference in New Issue
Block a user