support Llama4 with non uniformed intermediate size across layers for… (#10047)
This commit is contained in:
@@ -104,12 +104,18 @@ class LoRAMemoryPool:
|
|||||||
return all(_can_support(x) for x in config)
|
return all(_can_support(x) for x in config)
|
||||||
|
|
||||||
def get_lora_A_shape(
|
def get_lora_A_shape(
|
||||||
self, module_name: str, base_model: torch.nn.Module, max_lora_dim: int
|
self,
|
||||||
|
module_name: str,
|
||||||
|
base_model: torch.nn.Module,
|
||||||
|
max_lora_dim: int,
|
||||||
|
layer_idx: int,
|
||||||
) -> Tuple[int]:
|
) -> Tuple[int]:
|
||||||
"""
|
"""
|
||||||
Given a module_name (might be a stacked name), return the hidden dims of modules' input and output.
|
Given a module_name (might be a stacked name), return the hidden dims of modules' input and output.
|
||||||
"""
|
"""
|
||||||
input_dim, _ = get_hidden_dim(module_name, self.base_hf_config, base_model)
|
input_dim, _ = get_hidden_dim(
|
||||||
|
module_name, self.base_hf_config, base_model, layer_idx
|
||||||
|
)
|
||||||
c = get_stacked_multiply(module_name)
|
c = get_stacked_multiply(module_name)
|
||||||
if self.tp_size > 1 and module_name in ROW_PARALLELISM_LINEAR_LORA_NAMES:
|
if self.tp_size > 1 and module_name in ROW_PARALLELISM_LINEAR_LORA_NAMES:
|
||||||
input_dim = divide(input_dim, self.tp_size)
|
input_dim = divide(input_dim, self.tp_size)
|
||||||
@@ -120,12 +126,18 @@ class LoRAMemoryPool:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def get_lora_B_shape(
|
def get_lora_B_shape(
|
||||||
self, module_name: str, base_model: torch.nn.Module, max_lora_dim: int
|
self,
|
||||||
|
module_name: str,
|
||||||
|
base_model: torch.nn.Module,
|
||||||
|
max_lora_dim: int,
|
||||||
|
layer_idx: int,
|
||||||
) -> Tuple[int]:
|
) -> Tuple[int]:
|
||||||
"""
|
"""
|
||||||
Given a module_name (might be a stacked name), return the hidden dims of modules' input and output.
|
Given a module_name (might be a stacked name), return the hidden dims of modules' input and output.
|
||||||
"""
|
"""
|
||||||
_, output_dim = get_hidden_dim(module_name, self.base_hf_config, base_model)
|
_, output_dim = get_hidden_dim(
|
||||||
|
module_name, self.base_hf_config, base_model, layer_idx
|
||||||
|
)
|
||||||
if self.tp_size > 1 and module_name not in ROW_PARALLELISM_LINEAR_LORA_NAMES:
|
if self.tp_size > 1 and module_name not in ROW_PARALLELISM_LINEAR_LORA_NAMES:
|
||||||
output_dim = divide(output_dim, self.tp_size)
|
output_dim = divide(output_dim, self.tp_size)
|
||||||
return (
|
return (
|
||||||
@@ -140,19 +152,21 @@ class LoRAMemoryPool:
|
|||||||
def init_buffer(
|
def init_buffer(
|
||||||
buffer: Dict[str, List[torch.Tensor]],
|
buffer: Dict[str, List[torch.Tensor]],
|
||||||
target_modules: Set[str],
|
target_modules: Set[str],
|
||||||
get_lora_shape_fn: Callable[[str, torch.nn.Module, int], Tuple[int]],
|
get_lora_shape_fn: Callable[[str, torch.nn.Module, int, int], Tuple[int]],
|
||||||
):
|
):
|
||||||
for module_name in target_modules:
|
for module_name in target_modules:
|
||||||
lora_shape = get_lora_shape_fn(
|
|
||||||
module_name, base_model, self.max_lora_rank
|
|
||||||
)
|
|
||||||
buffer[module_name] = [
|
buffer[module_name] = [
|
||||||
torch.empty(
|
torch.empty(
|
||||||
lora_shape,
|
get_lora_shape_fn(
|
||||||
|
module_name,
|
||||||
|
base_model,
|
||||||
|
self.max_lora_rank,
|
||||||
|
idx,
|
||||||
|
),
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
for _ in range(self.num_layer)
|
for idx in range(self.num_layer)
|
||||||
]
|
]
|
||||||
|
|
||||||
init_buffer(
|
init_buffer(
|
||||||
|
|||||||
@@ -48,14 +48,14 @@ def get_layer_id(name: str) -> int:
|
|||||||
|
|
||||||
|
|
||||||
def get_hidden_dim(
|
def get_hidden_dim(
|
||||||
module_name: str, config: AutoConfig, base_model: torch.nn.Module
|
module_name: str, config: AutoConfig, base_model: torch.nn.Module, layer_idx: int
|
||||||
) -> Tuple[int]:
|
) -> Tuple[int]:
|
||||||
"""
|
"""
|
||||||
Given a module_name (might be a stacked name), return the hidden dims of modules' input and output.
|
Given a module_name (might be a stacked name), return the hidden dims of modules' input and output.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if hasattr(base_model, "get_hidden_dim"):
|
if hasattr(base_model, "get_hidden_dim"):
|
||||||
return base_model.get_hidden_dim(module_name)
|
return base_model.get_hidden_dim(module_name, layer_idx)
|
||||||
else:
|
else:
|
||||||
"""
|
"""
|
||||||
WARNING: get_hidden_dim() is not defined,
|
WARNING: get_hidden_dim() is not defined,
|
||||||
|
|||||||
@@ -499,7 +499,7 @@ class Gemma3nForConditionalGeneration(PreTrainedModel):
|
|||||||
def should_apply_lora(self, module_name: str) -> bool:
|
def should_apply_lora(self, module_name: str) -> bool:
|
||||||
return bool(self.lora_pattern.match(module_name))
|
return bool(self.lora_pattern.match(module_name))
|
||||||
|
|
||||||
def get_hidden_dim(self, module_name):
|
def get_hidden_dim(self, module_name, layer_idx):
|
||||||
# return input_dim, output_dim
|
# return input_dim, output_dim
|
||||||
if module_name == "qkv_proj":
|
if module_name == "qkv_proj":
|
||||||
return (
|
return (
|
||||||
|
|||||||
@@ -423,6 +423,12 @@ class Llama4DecoderLayer(nn.Module):
|
|||||||
return self.config.num_local_experts > 0
|
return self.config.num_local_experts > 0
|
||||||
return (layer_id + 1) % self.config.interleave_moe_layer_step == 0
|
return (layer_id + 1) % self.config.interleave_moe_layer_step == 0
|
||||||
|
|
||||||
|
def get_intermediate_size(self) -> int:
|
||||||
|
if isinstance(self.feed_forward, Llama4MoE):
|
||||||
|
return self.config.intermediate_size
|
||||||
|
else:
|
||||||
|
return self.config.intermediate_size_mlp
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
@@ -540,6 +546,9 @@ class Llama4ForCausalLM(LlamaForCausalLM):
|
|||||||
def get_input_embeddings(self):
|
def get_input_embeddings(self):
|
||||||
return self.model.embed_tokens
|
return self.model.embed_tokens
|
||||||
|
|
||||||
|
def get_layers(self):
|
||||||
|
return self.model.layers
|
||||||
|
|
||||||
def _init_model(
|
def _init_model(
|
||||||
self,
|
self,
|
||||||
config: Llama4TextConfig,
|
config: Llama4TextConfig,
|
||||||
|
|||||||
@@ -961,5 +961,30 @@ class Llama4ForConditionalGeneration(nn.Module):
|
|||||||
def set_embed(self, embed):
|
def set_embed(self, embed):
|
||||||
return self.language_model.set_embed(embed)
|
return self.language_model.set_embed(embed)
|
||||||
|
|
||||||
|
def get_hidden_dim(self, module_name, layer_idx):
|
||||||
|
# return input_dim, output_dim
|
||||||
|
if module_name == "qkv_proj":
|
||||||
|
return (
|
||||||
|
self.config.hidden_size,
|
||||||
|
self.config.head_dim
|
||||||
|
* (
|
||||||
|
self.config.num_attention_heads
|
||||||
|
+ self.config.num_key_value_heads * 2
|
||||||
|
),
|
||||||
|
)
|
||||||
|
elif module_name == "o_proj":
|
||||||
|
return (
|
||||||
|
self.config.head_dim * self.config.num_attention_heads,
|
||||||
|
self.config.hidden_size,
|
||||||
|
)
|
||||||
|
elif module_name == "gate_up_proj":
|
||||||
|
return self.config.hidden_size, self.config.intermediate_size * 2
|
||||||
|
elif module_name == "down_proj":
|
||||||
|
decoder_layer = self.language_model.get_layers()[layer_idx]
|
||||||
|
intermediate_size = decoder_layer.get_intermediate_size()
|
||||||
|
return intermediate_size, self.config.hidden_size
|
||||||
|
else:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
EntryClass = Llama4ForConditionalGeneration
|
EntryClass = Llama4ForConditionalGeneration
|
||||||
|
|||||||
61
test/srt/lora/test_lora_llama4.py
Normal file
61
test/srt/lora/test_lora_llama4.py
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
import unittest
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
from sglang.srt.utils import kill_process_tree
|
||||||
|
from sglang.test.test_utils import (
|
||||||
|
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
DEFAULT_URL_FOR_TEST,
|
||||||
|
CustomTestCase,
|
||||||
|
popen_launch_server,
|
||||||
|
)
|
||||||
|
|
||||||
|
MODELS = [
|
||||||
|
SimpleNamespace(
|
||||||
|
model="meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
|
||||||
|
tp_size=8,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class TestLlama4LoRA(CustomTestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||||
|
|
||||||
|
def test_bringup(self):
|
||||||
|
for model in MODELS:
|
||||||
|
try:
|
||||||
|
process = popen_launch_server(
|
||||||
|
model.model,
|
||||||
|
self.base_url,
|
||||||
|
timeout=3 * DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
other_args=[
|
||||||
|
"--enable-lora",
|
||||||
|
"--max-lora-rank",
|
||||||
|
"64",
|
||||||
|
"--lora-target-modules",
|
||||||
|
"all",
|
||||||
|
"--tp-size",
|
||||||
|
str(model.tp_size),
|
||||||
|
"--context-length",
|
||||||
|
"1048576",
|
||||||
|
"--attention-backend",
|
||||||
|
"fa3",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error testing {model.model}: {e}")
|
||||||
|
self.fail(f"Test failed for {model.model}: {e}")
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Ensure process cleanup happens regardless of success/failure
|
||||||
|
if process is not None and process.poll() is None:
|
||||||
|
print(f"Cleaning up process {process.pid}")
|
||||||
|
try:
|
||||||
|
kill_process_tree(process.pid)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error killing process: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
@@ -136,6 +136,7 @@ suites = {
|
|||||||
"per-commit-8-gpu": [
|
"per-commit-8-gpu": [
|
||||||
# Disabled because it hangs on the CI.
|
# Disabled because it hangs on the CI.
|
||||||
# TestFile("ep/test_moe_ep.py", 181),
|
# TestFile("ep/test_moe_ep.py", 181),
|
||||||
|
TestFile("lora/test_lora_llama4.py", 600),
|
||||||
TestFile("test_disaggregation.py", 499),
|
TestFile("test_disaggregation.py", 499),
|
||||||
TestFile("test_disaggregation_different_tp.py", 155),
|
TestFile("test_disaggregation_different_tp.py", 155),
|
||||||
TestFile("test_full_deepseek_v3.py", 333),
|
TestFile("test_full_deepseek_v3.py", 333),
|
||||||
|
|||||||
Reference in New Issue
Block a user