|
|
|
|
@@ -264,15 +264,23 @@ class _SinglePassGatherer(ABC):
|
|
|
|
|
return _DetailSinglePassGatherer(
|
|
|
|
|
server_args, expert_location_metadata, rank
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if server_args.expert_distribution_recorder_mode == "stat_approx":
|
|
|
|
|
if server_args.enable_deepep_moe and (server_args.deepep_mode == "normal"):
|
|
|
|
|
return _DeepepNormalSinglePassGatherer(expert_location_metadata, rank)
|
|
|
|
|
else:
|
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
if server_args.enable_deepep_moe:
|
|
|
|
|
if server_args.deepep_mode == "normal":
|
|
|
|
|
return _DeepepNormalSinglePassGatherer(expert_location_metadata, rank)
|
|
|
|
|
return _SelectExpertsSinglePassGatherer(expert_location_metadata, rank)
|
|
|
|
|
elif server_args.deepep_mode == "low_latency":
|
|
|
|
|
return _DeepepLowLatencySinglePassGatherer(
|
|
|
|
|
expert_location_metadata, rank
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
return _SelectExpertsSinglePassGatherer(expert_location_metadata, rank)
|
|
|
|
|
|
|
|
|
|
def __init__(self, expert_location_metadata: "ExpertLocationMetadata", rank: int):
|
|
|
|
|
@@ -347,7 +355,9 @@ class _DetailSinglePassGatherer(_SinglePassGatherer):
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor):
|
|
|
|
|
self._topk_ids_of_layer[layer_idx, : topk_ids.shape[0], :] = topk_ids
|
|
|
|
|
self._topk_ids_of_layer[layer_idx, : topk_ids.shape[0], : topk_ids.shape[1]] = (
|
|
|
|
|
topk_ids
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def on_deepep_dispatch_normal(
|
|
|
|
|
self,
|
|
|
|
|
@@ -380,7 +390,7 @@ class _DetailSinglePassGatherer(_SinglePassGatherer):
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _LayerBasedSinglePassGatherer(_SinglePassGatherer):
|
|
|
|
|
class _LayerBasedCpuSinglePassGatherer(_SinglePassGatherer):
|
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
|
|
super().__init__(*args, **kwargs)
|
|
|
|
|
self._objects_of_layer = {}
|
|
|
|
|
@@ -409,29 +419,63 @@ def _list_sum(a: List, b: List) -> List:
|
|
|
|
|
return [x + y for x, y in zip(a, b, strict=True)]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _SelectExpertsSinglePassGatherer(_LayerBasedSinglePassGatherer):
|
|
|
|
|
# pretty slow, but we will use the DeepEP Gatherer in production
|
|
|
|
|
def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor):
|
|
|
|
|
topk_ids_list = topk_ids.to("cpu", non_blocking=True).numpy().tolist()
|
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
|
class _LayerBasedGpuSinglePassGatherer(_SinglePassGatherer):
|
|
|
|
|
def __init__(self, *args, enable_global_physical_experts: bool, **kwargs):
|
|
|
|
|
super().__init__(*args, **kwargs)
|
|
|
|
|
self._enable_global_physical_experts = enable_global_physical_experts
|
|
|
|
|
self._data = torch.zeros(
|
|
|
|
|
(
|
|
|
|
|
self._expert_location_metadata.num_layers,
|
|
|
|
|
(
|
|
|
|
|
self._expert_location_metadata.num_physical_experts
|
|
|
|
|
if enable_global_physical_experts
|
|
|
|
|
else self._expert_location_metadata.num_local_physical_experts
|
|
|
|
|
),
|
|
|
|
|
),
|
|
|
|
|
dtype=torch.int,
|
|
|
|
|
device="cuda",
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
global_physical_count = [
|
|
|
|
|
0
|
|
|
|
|
] * self._expert_location_metadata.num_physical_experts
|
|
|
|
|
for token_record in topk_ids_list:
|
|
|
|
|
for global_physical_expert_idx in token_record:
|
|
|
|
|
global_physical_count[global_physical_expert_idx] += 1
|
|
|
|
|
|
|
|
|
|
self._on_layer_data(layer_idx, global_physical_count)
|
|
|
|
|
def reset(self):
|
|
|
|
|
self._data[...] = 0
|
|
|
|
|
|
|
|
|
|
def collect(self) -> Dict:
|
|
|
|
|
global_physical_count = super()._collect_objects(
|
|
|
|
|
pad_len=self._expert_location_metadata.num_physical_experts
|
|
|
|
|
if self._enable_global_physical_experts:
|
|
|
|
|
global_physical_count = self._data
|
|
|
|
|
else:
|
|
|
|
|
# Can optimize if bottleneck
|
|
|
|
|
global_physical_count = _convert_local_to_global_physical_count(
|
|
|
|
|
self._data,
|
|
|
|
|
rank=self._rank,
|
|
|
|
|
num_local_physical_experts=self._expert_location_metadata.num_local_physical_experts,
|
|
|
|
|
num_physical_experts=self._expert_location_metadata.num_physical_experts,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return dict(global_physical_count=global_physical_count)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _DeepepNormalSinglePassGatherer(_LayerBasedSinglePassGatherer):
|
|
|
|
|
class _SelectExpertsSinglePassGatherer(_LayerBasedGpuSinglePassGatherer):
|
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
|
|
super().__init__(*args, **kwargs, enable_global_physical_experts=True)
|
|
|
|
|
|
|
|
|
|
# can optimize (e.g. fuse / compile)
|
|
|
|
|
def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor):
|
|
|
|
|
topk_ids = topk_ids.flatten()
|
|
|
|
|
mask = topk_ids != -1
|
|
|
|
|
self._data[layer_idx, :].scatter_add_(
|
|
|
|
|
dim=0, index=topk_ids.masked_fill(~mask, 0).long(), src=mask.int()
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _DeepepNormalSinglePassGatherer(_LayerBasedCpuSinglePassGatherer):
|
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
|
|
super().__init__(*args, **kwargs)
|
|
|
|
|
if torch.distributed.get_rank() == 0:
|
|
|
|
|
logger.info(
|
|
|
|
|
"DeepepNormalSinglePassGatherer gathers approximate statistics. "
|
|
|
|
|
"If used with small batch size, consider using expert_distribution_recorder_mode=stat."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def on_deepep_dispatch_normal(
|
|
|
|
|
self,
|
|
|
|
|
layer_idx: int,
|
|
|
|
|
@@ -456,17 +500,9 @@ class _DeepepNormalSinglePassGatherer(_LayerBasedSinglePassGatherer):
|
|
|
|
|
return dict(global_physical_count=global_physical_count)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _DeepepLowLatencySinglePassGatherer(_SinglePassGatherer):
|
|
|
|
|
class _DeepepLowLatencySinglePassGatherer(_LayerBasedGpuSinglePassGatherer):
|
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
|
|
super().__init__(*args, **kwargs)
|
|
|
|
|
self._data = torch.zeros(
|
|
|
|
|
(
|
|
|
|
|
self._expert_location_metadata.num_layers,
|
|
|
|
|
self._expert_location_metadata.num_local_physical_experts,
|
|
|
|
|
),
|
|
|
|
|
dtype=torch.int,
|
|
|
|
|
device="cuda",
|
|
|
|
|
)
|
|
|
|
|
super().__init__(*args, **kwargs, enable_global_physical_experts=False)
|
|
|
|
|
|
|
|
|
|
def on_deepep_dispatch_low_latency(
|
|
|
|
|
self, layer_idx: int, local_physical_count_of_layer: torch.Tensor
|
|
|
|
|
@@ -474,19 +510,6 @@ class _DeepepLowLatencySinglePassGatherer(_SinglePassGatherer):
|
|
|
|
|
# Most naive implementation, can optimize later
|
|
|
|
|
self._data[layer_idx, :] += local_physical_count_of_layer
|
|
|
|
|
|
|
|
|
|
def reset(self):
|
|
|
|
|
self._data[...] = 0
|
|
|
|
|
|
|
|
|
|
def collect(self) -> Dict:
|
|
|
|
|
# Can optimize if bottleneck
|
|
|
|
|
global_physical_count = _convert_local_to_global_physical_count(
|
|
|
|
|
self._data,
|
|
|
|
|
rank=self._rank,
|
|
|
|
|
num_local_physical_experts=self._expert_location_metadata.num_local_physical_experts,
|
|
|
|
|
num_physical_experts=self._expert_location_metadata.num_physical_experts,
|
|
|
|
|
)
|
|
|
|
|
return dict(global_physical_count=global_physical_count)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _convert_local_to_global_physical_count(
|
|
|
|
|
local_physical_count: torch.Tensor,
|
|
|
|
|
@@ -525,6 +548,7 @@ class _Accumulator(ABC):
|
|
|
|
|
def get_class(server_args: ServerArgs) -> Type["_Accumulator"]:
|
|
|
|
|
return {
|
|
|
|
|
"stat": _StatAccumulator,
|
|
|
|
|
"stat_approx": _StatAccumulator,
|
|
|
|
|
"per_pass": _DetailAccumulator,
|
|
|
|
|
"per_token": _DetailAccumulator,
|
|
|
|
|
}[server_args.expert_distribution_recorder_mode]
|
|
|
|
|
|