diff --git a/python/sglang/srt/lora/utils.py b/python/sglang/srt/lora/utils.py index 486e9b918..83c8f1e89 100644 --- a/python/sglang/srt/lora/utils.py +++ b/python/sglang/srt/lora/utils.py @@ -98,6 +98,7 @@ def get_normalized_target_modules( ) -> set[str]: """ Mapping a list of target module name to names of the normalized LoRA weights. + Handles both base module names (e.g., "gate_proj") and prefixed module names (e.g., "feed_forward.gate_proj"). """ params_mapping = { "q_proj": "qkv_proj", @@ -109,7 +110,8 @@ def get_normalized_target_modules( result = set() for name in target_modules: - normalized_name = params_mapping.get(name, name) + base_name = name.split(".")[-1] + normalized_name = params_mapping.get(base_name, base_name) result.add(normalized_name) return result