support qwen3 dense model dp attention (#7681)
This commit is contained in:
@@ -43,6 +43,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
||||
from sglang.srt.model_loader.weight_utils import (
|
||||
default_weight_loader,
|
||||
@@ -264,6 +265,7 @@ class Qwen2Model(nn.Module):
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
enable_tp=not global_server_args_dict["enable_dp_attention"],
|
||||
prefix=add_prefix("embed_tokens", prefix),
|
||||
)
|
||||
else:
|
||||
@@ -332,7 +334,11 @@ class Qwen2Model(nn.Module):
|
||||
}
|
||||
)
|
||||
else:
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
if hidden_states.shape[0] != 0:
|
||||
if residual is None:
|
||||
hidden_states = self.norm(hidden_states)
|
||||
else:
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
# If this function is called, it should always initialize KV cache scale
|
||||
|
||||
Reference in New Issue
Block a user