[Fix] Fix several issues preventing gemma3n LoRA support. (#8776)
This commit is contained in:
@@ -386,6 +386,13 @@ class LoRAManager:
|
||||
else:
|
||||
self.target_modules = set()
|
||||
for config in self.configs.values():
|
||||
if not isinstance(config.target_modules, list):
|
||||
raise ValueError(
|
||||
f"SGLang currently only supports inferring LoRA target modules when a list of "
|
||||
"suffixes is provided in `target_modules` field of PEFT config. Please explicitly "
|
||||
"specify `--lora-target-modules` during server startup. You can specify `all` to "
|
||||
"enable all support modules types. "
|
||||
)
|
||||
self.target_modules.update(config.target_modules)
|
||||
|
||||
if max_lora_rank is not None:
|
||||
|
||||
@@ -492,5 +492,44 @@ class Gemma3nForConditionalGeneration(PreTrainedModel):
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
lora_pattern = re.compile(
|
||||
r"^language_model\.layers\.(\d+)\.(?:self_attn|mlp)\.(?:qkv_proj|o_proj|down_proj|gate_up_proj)"
|
||||
)
|
||||
|
||||
def should_apply_lora(self, module_name: str) -> bool:
|
||||
return bool(self.lora_pattern.match(module_name))
|
||||
|
||||
def get_hidden_dim(self, module_name):
|
||||
# return input_dim, output_dim
|
||||
if module_name in ["q_proj", "qkv_proj"]:
|
||||
return (
|
||||
self.config.hidden_size,
|
||||
self.config.head_dim * self.config.num_attention_heads,
|
||||
)
|
||||
elif module_name in ["o_proj"]:
|
||||
return (
|
||||
self.config.head_dim * self.config.num_attention_heads,
|
||||
self.config.hidden_size,
|
||||
)
|
||||
elif module_name in ["kv_proj"]:
|
||||
return (
|
||||
self.config.hidden_size,
|
||||
self.config.head_dim * self.config.num_key_value_heads,
|
||||
)
|
||||
elif module_name == "gate_up_proj":
|
||||
assert len(set(self.config.intermediate_size)) == 1, (
|
||||
"Currently SGLang requires uniform intermediate size for all layers. "
|
||||
"Please file an issue if you need support for non-uniform intermediate sizes."
|
||||
)
|
||||
return self.config.hidden_size, self.config.intermediate_size[0]
|
||||
elif module_name == "down_proj":
|
||||
assert len(set(self.config.intermediate_size)) == 1, (
|
||||
"Currently SGLang requires uniform intermediate size for all layers. "
|
||||
"Please file an issue if you need support for non-uniform intermediate sizes."
|
||||
)
|
||||
return self.config.intermediate_size[0], self.config.hidden_size
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
EntryClass = Gemma3nForConditionalGeneration
|
||||
|
||||
@@ -1943,10 +1943,16 @@ class ServerArgs:
|
||||
if "Llama4" in model_arch:
|
||||
assert self.attention_backend == "fa3", "fa3 is required for Llama4 model"
|
||||
|
||||
if "Gemma2ForCausalLM" in model_arch:
|
||||
if model_arch in [
|
||||
"Gemma2ForCausalLM",
|
||||
"Gemma3nForCausalLM",
|
||||
"Gemma3nForConditionalGeneration",
|
||||
]:
|
||||
# FIXME: https://github.com/sgl-project/sglang/pull/7367 is not compatible with gemma2 model.
|
||||
# It failed at this test: https://github.com/sgl-project/sglang/actions/runs/16255155597/job/45890331952#step:4:736
|
||||
logger.warning("Disable hybrid SWA memory for Gemma2ForCausalLM.")
|
||||
logger.warning(
|
||||
f"Disable hybrid SWA memory for {model_arch} as it is not yet supported."
|
||||
)
|
||||
self.disable_hybrid_swa_memory = True
|
||||
|
||||
# Check LoRA
|
||||
|
||||
Reference in New Issue
Block a user