Support XiaomiMiMo inference with mtp (#6059)
This commit is contained in:
@@ -782,12 +782,15 @@ class ModelRunner:
|
||||
distributed=get_world_group().world_size > 1,
|
||||
cpu_group=get_world_group().cpu_group,
|
||||
)
|
||||
if self.use_mla_backend:
|
||||
num_layers = (
|
||||
self.model_config.num_hidden_layers
|
||||
if not self.is_draft_worker
|
||||
else self.model_config.hf_config.num_nextn_predict_layers
|
||||
if self.is_draft_worker:
|
||||
num_layers = getattr(
|
||||
self.model_config.hf_config,
|
||||
"num_nextn_predict_layers",
|
||||
self.num_effective_layers,
|
||||
)
|
||||
else:
|
||||
num_layers = self.num_effective_layers
|
||||
if self.use_mla_backend:
|
||||
# FIXME: pipeline parallelism is not compatible with mla backend
|
||||
assert self.pp_size == 1
|
||||
cell_size = (
|
||||
@@ -799,7 +802,7 @@ class ModelRunner:
|
||||
cell_size = (
|
||||
self.model_config.get_num_kv_heads(get_attention_tp_size())
|
||||
* self.model_config.head_dim
|
||||
* self.num_effective_layers
|
||||
* num_layers
|
||||
* 2
|
||||
* torch._utils._element_size(self.kv_cache_dtype)
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user