diff --git a/python/sglang/srt/lora/mem_pool.py b/python/sglang/srt/lora/mem_pool.py index 94955f414..cdf1707e8 100644 --- a/python/sglang/srt/lora/mem_pool.py +++ b/python/sglang/srt/lora/mem_pool.py @@ -104,12 +104,18 @@ class LoRAMemoryPool: return all(_can_support(x) for x in config) def get_lora_A_shape( - self, module_name: str, base_model: torch.nn.Module, max_lora_dim: int + self, + module_name: str, + base_model: torch.nn.Module, + max_lora_dim: int, + layer_idx: int, ) -> Tuple[int]: """ Given a module_name (might be a stacked name), return the hidden dims of modules' input and output. """ - input_dim, _ = get_hidden_dim(module_name, self.base_hf_config, base_model) + input_dim, _ = get_hidden_dim( + module_name, self.base_hf_config, base_model, layer_idx + ) c = get_stacked_multiply(module_name) if self.tp_size > 1 and module_name in ROW_PARALLELISM_LINEAR_LORA_NAMES: input_dim = divide(input_dim, self.tp_size) @@ -120,12 +126,18 @@ class LoRAMemoryPool: ) def get_lora_B_shape( - self, module_name: str, base_model: torch.nn.Module, max_lora_dim: int + self, + module_name: str, + base_model: torch.nn.Module, + max_lora_dim: int, + layer_idx: int, ) -> Tuple[int]: """ Given a module_name (might be a stacked name), return the hidden dims of modules' input and output. """ - _, output_dim = get_hidden_dim(module_name, self.base_hf_config, base_model) + _, output_dim = get_hidden_dim( + module_name, self.base_hf_config, base_model, layer_idx + ) if self.tp_size > 1 and module_name not in ROW_PARALLELISM_LINEAR_LORA_NAMES: output_dim = divide(output_dim, self.tp_size) return ( @@ -140,19 +152,21 @@ class LoRAMemoryPool: def init_buffer( buffer: Dict[str, List[torch.Tensor]], target_modules: Set[str], - get_lora_shape_fn: Callable[[str, torch.nn.Module, int], Tuple[int]], + get_lora_shape_fn: Callable[[str, torch.nn.Module, int, int], Tuple[int]], ): for module_name in target_modules: - lora_shape = get_lora_shape_fn( - module_name, base_model, self.max_lora_rank - ) buffer[module_name] = [ torch.empty( - lora_shape, + get_lora_shape_fn( + module_name, + base_model, + self.max_lora_rank, + idx, + ), dtype=self.dtype, device=device, ) - for _ in range(self.num_layer) + for idx in range(self.num_layer) ] init_buffer( diff --git a/python/sglang/srt/lora/utils.py b/python/sglang/srt/lora/utils.py index 1067b40b0..6528e2691 100644 --- a/python/sglang/srt/lora/utils.py +++ b/python/sglang/srt/lora/utils.py @@ -48,14 +48,14 @@ def get_layer_id(name: str) -> int: def get_hidden_dim( - module_name: str, config: AutoConfig, base_model: torch.nn.Module + module_name: str, config: AutoConfig, base_model: torch.nn.Module, layer_idx: int ) -> Tuple[int]: """ Given a module_name (might be a stacked name), return the hidden dims of modules' input and output. """ if hasattr(base_model, "get_hidden_dim"): - return base_model.get_hidden_dim(module_name) + return base_model.get_hidden_dim(module_name, layer_idx) else: """ WARNING: get_hidden_dim() is not defined, diff --git a/python/sglang/srt/models/gemma3n_mm.py b/python/sglang/srt/models/gemma3n_mm.py index fa9a10c85..995db2602 100644 --- a/python/sglang/srt/models/gemma3n_mm.py +++ b/python/sglang/srt/models/gemma3n_mm.py @@ -499,7 +499,7 @@ class Gemma3nForConditionalGeneration(PreTrainedModel): def should_apply_lora(self, module_name: str) -> bool: return bool(self.lora_pattern.match(module_name)) - def get_hidden_dim(self, module_name): + def get_hidden_dim(self, module_name, layer_idx): # return input_dim, output_dim if module_name == "qkv_proj": return ( diff --git a/python/sglang/srt/models/llama4.py b/python/sglang/srt/models/llama4.py index e05d96527..2d2a60730 100644 --- a/python/sglang/srt/models/llama4.py +++ b/python/sglang/srt/models/llama4.py @@ -423,6 +423,12 @@ class Llama4DecoderLayer(nn.Module): return self.config.num_local_experts > 0 return (layer_id + 1) % self.config.interleave_moe_layer_step == 0 + def get_intermediate_size(self) -> int: + if isinstance(self.feed_forward, Llama4MoE): + return self.config.intermediate_size + else: + return self.config.intermediate_size_mlp + def forward( self, positions: torch.Tensor, @@ -540,6 +546,9 @@ class Llama4ForCausalLM(LlamaForCausalLM): def get_input_embeddings(self): return self.model.embed_tokens + def get_layers(self): + return self.model.layers + def _init_model( self, config: Llama4TextConfig, diff --git a/python/sglang/srt/models/mllama4.py b/python/sglang/srt/models/mllama4.py index b57d637f0..f0184390c 100644 --- a/python/sglang/srt/models/mllama4.py +++ b/python/sglang/srt/models/mllama4.py @@ -961,5 +961,30 @@ class Llama4ForConditionalGeneration(nn.Module): def set_embed(self, embed): return self.language_model.set_embed(embed) + def get_hidden_dim(self, module_name, layer_idx): + # return input_dim, output_dim + if module_name == "qkv_proj": + return ( + self.config.hidden_size, + self.config.head_dim + * ( + self.config.num_attention_heads + + self.config.num_key_value_heads * 2 + ), + ) + elif module_name == "o_proj": + return ( + self.config.head_dim * self.config.num_attention_heads, + self.config.hidden_size, + ) + elif module_name == "gate_up_proj": + return self.config.hidden_size, self.config.intermediate_size * 2 + elif module_name == "down_proj": + decoder_layer = self.language_model.get_layers()[layer_idx] + intermediate_size = decoder_layer.get_intermediate_size() + return intermediate_size, self.config.hidden_size + else: + raise NotImplementedError() + EntryClass = Llama4ForConditionalGeneration diff --git a/test/srt/lora/test_lora_llama4.py b/test/srt/lora/test_lora_llama4.py new file mode 100644 index 000000000..65a4b766f --- /dev/null +++ b/test/srt/lora/test_lora_llama4.py @@ -0,0 +1,61 @@ +import unittest +from types import SimpleNamespace + +from sglang.srt.utils import kill_process_tree +from sglang.test.test_utils import ( + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + popen_launch_server, +) + +MODELS = [ + SimpleNamespace( + model="meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8", + tp_size=8, + ), +] + + +class TestLlama4LoRA(CustomTestCase): + @classmethod + def setUpClass(cls): + cls.base_url = DEFAULT_URL_FOR_TEST + + def test_bringup(self): + for model in MODELS: + try: + process = popen_launch_server( + model.model, + self.base_url, + timeout=3 * DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--enable-lora", + "--max-lora-rank", + "64", + "--lora-target-modules", + "all", + "--tp-size", + str(model.tp_size), + "--context-length", + "1048576", + "--attention-backend", + "fa3", + ], + ) + except Exception as e: + print(f"Error testing {model.model}: {e}") + self.fail(f"Test failed for {model.model}: {e}") + + finally: + # Ensure process cleanup happens regardless of success/failure + if process is not None and process.poll() is None: + print(f"Cleaning up process {process.pid}") + try: + kill_process_tree(process.pid) + except Exception as e: + print(f"Error killing process: {e}") + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index f417af1bc..bfe867f17 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -136,6 +136,7 @@ suites = { "per-commit-8-gpu": [ # Disabled because it hangs on the CI. # TestFile("ep/test_moe_ep.py", 181), + TestFile("lora/test_lora_llama4.py", 600), TestFile("test_disaggregation.py", 499), TestFile("test_disaggregation_different_tp.py", 155), TestFile("test_full_deepseek_v3.py", 333),