Support XiaomiMiMo inference with mtp (#6059)

This commit is contained in:
ryang
2025-05-23 05:14:49 +08:00
committed by GitHub
parent 0b07c4a99f
commit a6ae3af15e
6 changed files with 344 additions and 6 deletions

View File

@@ -782,12 +782,15 @@ class ModelRunner:
distributed=get_world_group().world_size > 1,
cpu_group=get_world_group().cpu_group,
)
if self.use_mla_backend:
num_layers = (
self.model_config.num_hidden_layers
if not self.is_draft_worker
else self.model_config.hf_config.num_nextn_predict_layers
if self.is_draft_worker:
num_layers = getattr(
self.model_config.hf_config,
"num_nextn_predict_layers",
self.num_effective_layers,
)
else:
num_layers = self.num_effective_layers
if self.use_mla_backend:
# FIXME: pipeline parallelism is not compatible with mla backend
assert self.pp_size == 1
cell_size = (
@@ -799,7 +802,7 @@ class ModelRunner:
cell_size = (
self.model_config.get_num_kv_heads(get_attention_tp_size())
* self.model_config.head_dim
* self.num_effective_layers
* num_layers
* 2
* torch._utils._element_size(self.kv_cache_dtype)
)