[Fix] Support qwen3-next MTP+DP (#10392)
This commit is contained in:
@@ -170,6 +170,7 @@ class ModelConfig:
|
|||||||
|
|
||||||
if is_draft_model and self.hf_config.architectures[0] == "Qwen3NextForCausalLM":
|
if is_draft_model and self.hf_config.architectures[0] == "Qwen3NextForCausalLM":
|
||||||
self.hf_config.architectures[0] = "Qwen3NextForCausalLMMTP"
|
self.hf_config.architectures[0] = "Qwen3NextForCausalLMMTP"
|
||||||
|
self.hf_config.num_nextn_predict_layers = 1
|
||||||
|
|
||||||
# Check model type
|
# Check model type
|
||||||
self.is_generation = is_generation_model(
|
self.is_generation = is_generation_model(
|
||||||
|
|||||||
@@ -185,10 +185,9 @@ class LogitsMetadata:
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
dp_local_start_pos = cumtokens[dp_rank - 1]
|
dp_local_start_pos = cumtokens[dp_rank - 1]
|
||||||
dp_local_num_tokens = self.global_num_tokens_for_logprob_gpu[dp_rank]
|
|
||||||
|
|
||||||
self.dp_local_start_pos = dp_local_start_pos
|
self.dp_local_start_pos = dp_local_start_pos
|
||||||
self.dp_local_num_tokens = dp_local_num_tokens
|
self.dp_local_num_tokens = self.global_num_tokens_for_logprob_gpu[dp_rank]
|
||||||
|
|
||||||
hidden_size = get_dp_hidden_size()
|
hidden_size = get_dp_hidden_size()
|
||||||
dtype = get_dp_dtype()
|
dtype = get_dp_dtype()
|
||||||
|
|||||||
@@ -52,6 +52,10 @@ if _is_npu:
|
|||||||
import torch_npu
|
import torch_npu
|
||||||
|
|
||||||
|
|
||||||
|
def get_tensor_size_bytes(t: torch.Tensor):
|
||||||
|
return np.prod(t.shape) * t.dtype.itemsize
|
||||||
|
|
||||||
|
|
||||||
class ReqToTokenPool:
|
class ReqToTokenPool:
|
||||||
"""A memory pool that maps a request to its token locations."""
|
"""A memory pool that maps a request to its token locations."""
|
||||||
|
|
||||||
@@ -158,16 +162,23 @@ class MambaPool:
|
|||||||
intermediate_ssm_state_cache,
|
intermediate_ssm_state_cache,
|
||||||
intermediate_conv_window_cache,
|
intermediate_conv_window_cache,
|
||||||
)
|
)
|
||||||
|
logger.info(
|
||||||
|
f"Mamba Cache is allocated. "
|
||||||
|
f"conv_state size: {get_tensor_size_bytes(conv_state) / GB:.2f}GB, "
|
||||||
|
f"ssm_state size: {get_tensor_size_bytes(temporal_state) / GB:.2f}GB "
|
||||||
|
f"intermediate_ssm_state_cache size: {get_tensor_size_bytes(intermediate_ssm_state_cache) / GB:.2f}GB "
|
||||||
|
f"intermediate_conv_window_cache size: {get_tensor_size_bytes(intermediate_conv_window_cache) / GB:.2f}GB "
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.mamba_cache = (conv_state, temporal_state)
|
self.mamba_cache = (conv_state, temporal_state)
|
||||||
|
logger.info(
|
||||||
|
f"Mamba Cache is allocated. "
|
||||||
|
f"conv_state size: {get_tensor_size_bytes(conv_state) / GB:.2f}GB, "
|
||||||
|
f"ssm_state size: {get_tensor_size_bytes(temporal_state) / GB:.2f}GB "
|
||||||
|
)
|
||||||
self.size = size
|
self.size = size
|
||||||
self.free_slots = list(range(size))
|
self.free_slots = list(range(size))
|
||||||
self.mem_usage = self.get_mamba_size() / GB
|
self.mem_usage = self.get_mamba_size() / GB
|
||||||
logger.info(
|
|
||||||
f"Mamba Cache is allocated. "
|
|
||||||
f"conv_state size: {conv_state.numel() * conv_state.itemsize / GB:.2f}GB, "
|
|
||||||
f"ssm_state size: {temporal_state.numel() * temporal_state.itemsize / GB:.2f}GB "
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_mamba_params_all_layers(self):
|
def get_mamba_params_all_layers(self):
|
||||||
return [self.mamba_cache[i] for i in range(len(self.mamba_cache))]
|
return [self.mamba_cache[i] for i in range(len(self.mamba_cache))]
|
||||||
@@ -176,10 +187,7 @@ class MambaPool:
|
|||||||
return [self.mamba_cache[i][layer_id] for i in range(len(self.mamba_cache))]
|
return [self.mamba_cache[i][layer_id] for i in range(len(self.mamba_cache))]
|
||||||
|
|
||||||
def get_mamba_size(self):
|
def get_mamba_size(self):
|
||||||
return (
|
return sum(get_tensor_size_bytes(t) for t in self.mamba_cache)
|
||||||
np.prod(self.mamba_cache[0].shape) * self.mamba_cache[0].dtype.itemsize
|
|
||||||
+ np.prod(self.mamba_cache[1].shape) * self.mamba_cache[1].dtype.itemsize
|
|
||||||
)
|
|
||||||
|
|
||||||
def available_size(self):
|
def available_size(self):
|
||||||
return len(self.free_slots)
|
return len(self.free_slots)
|
||||||
@@ -492,10 +500,10 @@ class MHATokenToKVPool(KVCache):
|
|||||||
assert hasattr(self, "v_buffer")
|
assert hasattr(self, "v_buffer")
|
||||||
k_size_bytes = 0
|
k_size_bytes = 0
|
||||||
for k_cache in self.k_buffer:
|
for k_cache in self.k_buffer:
|
||||||
k_size_bytes += np.prod(k_cache.shape) * k_cache.dtype.itemsize
|
k_size_bytes += get_tensor_size_bytes(k_cache)
|
||||||
v_size_bytes = 0
|
v_size_bytes = 0
|
||||||
for v_cache in self.v_buffer:
|
for v_cache in self.v_buffer:
|
||||||
v_size_bytes += np.prod(v_cache.shape) * v_cache.dtype.itemsize
|
v_size_bytes += get_tensor_size_bytes(v_cache)
|
||||||
return k_size_bytes, v_size_bytes
|
return k_size_bytes, v_size_bytes
|
||||||
|
|
||||||
# for disagg
|
# for disagg
|
||||||
@@ -1077,7 +1085,7 @@ class MLATokenToKVPool(KVCache):
|
|||||||
assert hasattr(self, "kv_buffer")
|
assert hasattr(self, "kv_buffer")
|
||||||
kv_size_bytes = 0
|
kv_size_bytes = 0
|
||||||
for kv_cache in self.kv_buffer:
|
for kv_cache in self.kv_buffer:
|
||||||
kv_size_bytes += np.prod(kv_cache.shape) * kv_cache.dtype.itemsize
|
kv_size_bytes += get_tensor_size_bytes(kv_cache)
|
||||||
return kv_size_bytes
|
return kv_size_bytes
|
||||||
|
|
||||||
# for disagg
|
# for disagg
|
||||||
@@ -1240,9 +1248,9 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
|
|||||||
assert hasattr(self, "v_buffer")
|
assert hasattr(self, "v_buffer")
|
||||||
kv_size_bytes = 0
|
kv_size_bytes = 0
|
||||||
for k_cache in self.k_buffer:
|
for k_cache in self.k_buffer:
|
||||||
kv_size_bytes += np.prod(k_cache.shape) * k_cache.dtype.itemsize
|
kv_size_bytes += get_tensor_size_bytes(k_cache)
|
||||||
for v_cache in self.v_buffer:
|
for v_cache in self.v_buffer:
|
||||||
kv_size_bytes += np.prod(v_cache.shape) * v_cache.dtype.itemsize
|
kv_size_bytes += get_tensor_size_bytes(v_cache)
|
||||||
return kv_size_bytes
|
return kv_size_bytes
|
||||||
|
|
||||||
def get_kv_buffer(self, layer_id: int):
|
def get_kv_buffer(self, layer_id: int):
|
||||||
|
|||||||
@@ -85,8 +85,11 @@ class Qwen3NextForCausalLMMTP(Qwen3NextForCausalLM):
|
|||||||
if input_embeds is None:
|
if input_embeds is None:
|
||||||
input_embeds = self.model.embed_tokens(input_ids)
|
input_embeds = self.model.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
hidden_states = forward_batch.spec_info.hidden_states
|
||||||
|
# Some idle batch has 0 batch size. GemmaRMSNorm.forward would fail due to bs=0.
|
||||||
|
if not forward_batch.forward_mode.is_idle():
|
||||||
input_embeds = self.pre_fc_norm_embedding(input_embeds)
|
input_embeds = self.pre_fc_norm_embedding(input_embeds)
|
||||||
hidden_states = self.pre_fc_norm_hidden(forward_batch.spec_info.hidden_states)
|
hidden_states = self.pre_fc_norm_hidden(hidden_states)
|
||||||
hidden_states = self.fc(torch.cat((input_embeds, hidden_states), dim=-1))
|
hidden_states = self.fc(torch.cat((input_embeds, hidden_states), dim=-1))
|
||||||
|
|
||||||
hidden_states = self.model(
|
hidden_states = self.model(
|
||||||
|
|||||||
Reference in New Issue
Block a user