Add target module validation for init adapters (#9429)
This commit is contained in:
@@ -420,20 +420,37 @@ class LoRAManager:
|
||||
):
|
||||
"""Infer LoRA target modules and max_lora_rank from loaded adapters if not provided."""
|
||||
|
||||
if target_modules is not None:
|
||||
self.target_modules = set(target_modules)
|
||||
else:
|
||||
self.target_modules = set()
|
||||
for config in self.configs.values():
|
||||
if not isinstance(config.target_modules, list):
|
||||
self.target_modules = (
|
||||
get_normalized_target_modules(target_modules) if target_modules else set()
|
||||
)
|
||||
|
||||
for lora_id, config in self.configs.items():
|
||||
if not isinstance(config.target_modules, list):
|
||||
raise ValueError(
|
||||
f"SGLang currently only supports inferring LoRA target modules when a list of "
|
||||
"suffixes is provided in `target_modules` field of PEFT config. Please explicitly "
|
||||
"specify `--lora-target-modules` during server startup. You can specify `all` to "
|
||||
"enable all support modules types. "
|
||||
)
|
||||
|
||||
adapter_target_modules = get_normalized_target_modules(
|
||||
config.target_modules
|
||||
)
|
||||
|
||||
if target_modules is not None:
|
||||
# When `--lora-target-modules` is provided, validate adapter target modules is a subset of the specified target modules.
|
||||
if not adapter_target_modules.issubset(self.target_modules):
|
||||
unsupported_modules = adapter_target_modules - self.target_modules
|
||||
lora_name = self.lora_refs[lora_id].lora_name
|
||||
raise ValueError(
|
||||
f"SGLang currently only supports inferring LoRA target modules when a list of "
|
||||
"suffixes is provided in `target_modules` field of PEFT config. Please explicitly "
|
||||
"specify `--lora-target-modules` during server startup. You can specify `all` to "
|
||||
"enable all support modules types. "
|
||||
f"LoRA adapter '{lora_name}' contains target modules {sorted(unsupported_modules)} "
|
||||
f"that are not included in the specified --lora-target-modules {sorted(self.target_modules)}. "
|
||||
f"Please update --lora-target-modules to include all required modules: "
|
||||
f"{sorted(self.target_modules | adapter_target_modules)}, or use 'all' to enable all supported modules."
|
||||
)
|
||||
self.target_modules.update(config.target_modules)
|
||||
self.target_modules = get_normalized_target_modules(self.target_modules)
|
||||
else:
|
||||
# Otherwise, infer target_modules from adapter configs.
|
||||
self.target_modules.update(adapter_target_modules)
|
||||
|
||||
if max_lora_rank is not None:
|
||||
self.max_lora_rank = max_lora_rank
|
||||
|
||||
Reference in New Issue
Block a user