support qwen3 dense model dp attention (#7681)

This commit is contained in:
Yi Zhang
2025-07-04 00:58:20 +08:00
committed by GitHub
parent 1dce6c480f
commit 646cef2e2e
2 changed files with 49 additions and 17 deletions

View File

@@ -43,6 +43,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import (
default_weight_loader,
@@ -264,6 +265,7 @@ class Qwen2Model(nn.Module):
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
enable_tp=not global_server_args_dict["enable_dp_attention"],
prefix=add_prefix("embed_tokens", prefix),
)
else:
@@ -332,7 +334,11 @@ class Qwen2Model(nn.Module):
}
)
else:
hidden_states, _ = self.norm(hidden_states, residual)
if hidden_states.shape[0] != 0:
if residual is None:
hidden_states = self.norm(hidden_states)
else:
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
# If this function is called, it should always initialize KV cache scale