Support 2x8xH100 for Llama 4 (#5159)
This commit is contained in:
@@ -27,6 +27,13 @@ from sglang.srt.distributed import (
|
|||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
tensor_model_parallel_all_reduce,
|
tensor_model_parallel_all_reduce,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.layers.dp_attention import (
|
||||||
|
dp_gather_partial,
|
||||||
|
dp_scatter,
|
||||||
|
get_attention_dp_size,
|
||||||
|
get_attention_tp_rank,
|
||||||
|
get_attention_tp_size,
|
||||||
|
)
|
||||||
from sglang.srt.layers.layernorm import RMSNorm
|
from sglang.srt.layers.layernorm import RMSNorm
|
||||||
from sglang.srt.layers.linear import (
|
from sglang.srt.layers.linear import (
|
||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
@@ -38,6 +45,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
|||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.layers.rotary_embedding import get_rope
|
from sglang.srt.layers.rotary_embedding import get_rope
|
||||||
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||||
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
from sglang.srt.models.llama import LlamaForCausalLM, LlamaMLP
|
from sglang.srt.models.llama import LlamaForCausalLM, LlamaMLP
|
||||||
from sglang.srt.utils import add_prefix, get_compiler_backend, make_layers
|
from sglang.srt.utils import add_prefix, get_compiler_backend, make_layers
|
||||||
@@ -143,20 +151,24 @@ class Llama4Attention(nn.Module):
|
|||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
self.use_rope = int((layer_id + 1) % 4 != 0)
|
self.use_rope = int((layer_id + 1) % 4 != 0)
|
||||||
self.use_qk_norm = config.use_qk_norm and self.use_rope
|
self.use_qk_norm = config.use_qk_norm and self.use_rope
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
|
||||||
|
self.dp_size = get_attention_dp_size()
|
||||||
|
attn_tp_rank = get_attention_tp_rank()
|
||||||
|
attn_tp_size = get_attention_tp_size()
|
||||||
|
|
||||||
self.total_num_heads = num_heads
|
self.total_num_heads = num_heads
|
||||||
assert self.total_num_heads % tp_size == 0
|
assert self.total_num_heads % attn_tp_size == 0
|
||||||
self.num_heads = self.total_num_heads // tp_size
|
self.num_heads = self.total_num_heads // attn_tp_size
|
||||||
self.total_num_kv_heads = num_kv_heads
|
self.total_num_kv_heads = num_kv_heads
|
||||||
if self.total_num_kv_heads >= tp_size:
|
if self.total_num_kv_heads >= attn_tp_size:
|
||||||
# Number of KV heads is greater than TP size, so we partition
|
# Number of KV heads is greater than TP size, so we partition
|
||||||
# the KV heads across multiple tensor parallel GPUs.
|
# the KV heads across multiple tensor parallel GPUs.
|
||||||
assert self.total_num_kv_heads % tp_size == 0
|
assert self.total_num_kv_heads % attn_tp_size == 0
|
||||||
else:
|
else:
|
||||||
# Number of KV heads is less than TP size, so we replicate
|
# Number of KV heads is less than TP size, so we replicate
|
||||||
# the KV heads across multiple tensor parallel GPUs.
|
# the KV heads across multiple tensor parallel GPUs.
|
||||||
assert tp_size % self.total_num_kv_heads == 0
|
assert attn_tp_size % self.total_num_kv_heads == 0
|
||||||
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
self.num_kv_heads = max(1, self.total_num_kv_heads // attn_tp_size)
|
||||||
self.head_dim = config.head_dim
|
self.head_dim = config.head_dim
|
||||||
self.q_size = self.num_heads * self.head_dim
|
self.q_size = self.num_heads * self.head_dim
|
||||||
self.kv_size = self.num_kv_heads * self.head_dim
|
self.kv_size = self.num_kv_heads * self.head_dim
|
||||||
@@ -183,6 +195,8 @@ class Llama4Attention(nn.Module):
|
|||||||
bias=bias,
|
bias=bias,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=add_prefix("qkv_proj", prefix),
|
prefix=add_prefix("qkv_proj", prefix),
|
||||||
|
tp_rank=attn_tp_rank,
|
||||||
|
tp_size=attn_tp_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.o_proj = RowParallelLinear(
|
self.o_proj = RowParallelLinear(
|
||||||
@@ -191,6 +205,9 @@ class Llama4Attention(nn.Module):
|
|||||||
bias=bias_o_proj,
|
bias=bias_o_proj,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=add_prefix("o_proj", prefix),
|
prefix=add_prefix("o_proj", prefix),
|
||||||
|
tp_rank=attn_tp_rank,
|
||||||
|
tp_size=attn_tp_size,
|
||||||
|
reduce_results=False,
|
||||||
)
|
)
|
||||||
is_neox_style = True
|
is_neox_style = True
|
||||||
is_gguf = quant_config and quant_config.get_name() == "gguf"
|
is_gguf = quant_config and quant_config.get_name() == "gguf"
|
||||||
@@ -274,6 +291,9 @@ class Llama4DecoderLayer(nn.Module):
|
|||||||
rope_theta = config.rope_theta
|
rope_theta = config.rope_theta
|
||||||
rope_scaling = config.rope_scaling
|
rope_scaling = config.rope_scaling
|
||||||
max_position_embeddings = config.max_position_embeddings
|
max_position_embeddings = config.max_position_embeddings
|
||||||
|
self.dp_size = get_attention_dp_size()
|
||||||
|
self.attn_tp_size = get_attention_tp_size()
|
||||||
|
self.attn_tp_rank = get_attention_tp_rank()
|
||||||
|
|
||||||
self.self_attn = Llama4Attention(
|
self.self_attn = Llama4Attention(
|
||||||
config=config,
|
config=config,
|
||||||
@@ -316,21 +336,58 @@ class Llama4DecoderLayer(nn.Module):
|
|||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
residual: Optional[torch.Tensor],
|
residual: Optional[torch.Tensor],
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
# Self Attention
|
if hidden_states.shape[0] == 0:
|
||||||
if residual is None:
|
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
hidden_states = self.input_layernorm(hidden_states)
|
|
||||||
else:
|
else:
|
||||||
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
# Self Attention
|
||||||
hidden_states = self.self_attn(
|
if residual is None:
|
||||||
positions=positions,
|
residual = hidden_states
|
||||||
hidden_states=hidden_states,
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
forward_batch=forward_batch,
|
else:
|
||||||
)
|
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||||
|
hidden_states = self.self_attn(
|
||||||
|
positions=positions,
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
forward_batch=forward_batch,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Gather
|
||||||
|
if get_tensor_model_parallel_world_size() > 1:
|
||||||
|
# all gather and all reduce
|
||||||
|
if self.dp_size != 1:
|
||||||
|
if self.attn_tp_rank == 0:
|
||||||
|
hidden_states += residual
|
||||||
|
hidden_states, local_hidden_states = (
|
||||||
|
forward_batch.gathered_buffer,
|
||||||
|
hidden_states,
|
||||||
|
)
|
||||||
|
dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
|
||||||
|
dp_scatter(residual, hidden_states, forward_batch)
|
||||||
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||||
|
else:
|
||||||
|
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
|
||||||
|
hidden_states, residual = self.post_attention_layernorm(
|
||||||
|
hidden_states, residual
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
hidden_states, residual = self.post_attention_layernorm(
|
||||||
|
hidden_states, residual
|
||||||
|
)
|
||||||
|
|
||||||
# Fully Connected
|
# Fully Connected
|
||||||
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
|
||||||
hidden_states = self.feed_forward(hidden_states)
|
hidden_states = self.feed_forward(hidden_states)
|
||||||
|
|
||||||
|
# TODO(ch-wan): ues reduce-scatter in MLP to avoid this scatter
|
||||||
|
# Scatter
|
||||||
|
if self.dp_size != 1:
|
||||||
|
# important: forward batch.gathered_buffer is used both after scatter and after gather.
|
||||||
|
# be careful about this!
|
||||||
|
hidden_states, global_hidden_states = (
|
||||||
|
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
||||||
|
hidden_states,
|
||||||
|
)
|
||||||
|
dp_scatter(hidden_states, global_hidden_states, forward_batch)
|
||||||
|
|
||||||
return hidden_states, residual
|
return hidden_states, residual
|
||||||
|
|
||||||
|
|
||||||
@@ -350,6 +407,7 @@ class Llama4Model(nn.Module):
|
|||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=add_prefix("embed_tokens", prefix),
|
prefix=add_prefix("embed_tokens", prefix),
|
||||||
|
enable_tp=not global_server_args_dict["enable_dp_attention"],
|
||||||
)
|
)
|
||||||
self.layers = make_layers(
|
self.layers = make_layers(
|
||||||
config.num_hidden_layers,
|
config.num_hidden_layers,
|
||||||
@@ -385,7 +443,8 @@ class Llama4Model(nn.Module):
|
|||||||
forward_batch,
|
forward_batch,
|
||||||
residual,
|
residual,
|
||||||
)
|
)
|
||||||
hidden_states, _ = self.norm(hidden_states, residual)
|
if not forward_batch.forward_mode.is_idle():
|
||||||
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
|
|
||||||
if len(aux_hidden_states) == 0:
|
if len(aux_hidden_states) == 0:
|
||||||
return hidden_states
|
return hidden_states
|
||||||
@@ -394,7 +453,6 @@ class Llama4Model(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class Llama4ForCausalLM(LlamaForCausalLM):
|
class Llama4ForCausalLM(LlamaForCausalLM):
|
||||||
|
|
||||||
packed_modules_mapping = {
|
packed_modules_mapping = {
|
||||||
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
||||||
"gate_up_proj": ["gate_proj", "up_proj"],
|
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||||
|
|||||||
Reference in New Issue
Block a user