Add get_hidden_dim to qwen3.py for correct lora (#7312)

This commit is contained in:
Pavel Logachev
2025-07-20 05:31:16 +03:00
committed by GitHub
parent cbdfb77123
commit 877e35d775
5 changed files with 240 additions and 2 deletions

View File

@@ -330,6 +330,30 @@ class Qwen3ForCausalLM(nn.Module):
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def get_hidden_dim(self, module_name: str) -> Tuple[int]:
# return input_dim, output_dim
if module_name in ["q_proj", "qkv_proj"]:
return (
self.config.hidden_size,
self.config.head_dim * self.config.num_attention_heads,
)
elif module_name in ["o_proj"]:
return (
self.config.head_dim * self.config.num_attention_heads,
self.config.hidden_size,
)
elif module_name in ["kv_proj"]:
return (
self.config.hidden_size,
self.config.head_dim * self.config.num_key_value_heads,
)
elif module_name == "gate_up_proj":
return self.config.hidden_size, self.config.intermediate_size
elif module_name == "down_proj":
return self.config.intermediate_size, self.config.hidden_size
else:
raise NotImplementedError()
@torch.no_grad()
def forward(
self,

View File

@@ -134,10 +134,12 @@ class HFRunner:
model_type: str = "generation",
output_str_only: bool = False,
trust_remote_code: bool = False,
patch_model_do_sample_false: bool = False,
):
self.model_type = model_type
self.output_str_only = output_str_only
self.trust_remote_code = trust_remote_code
self.patch_model_do_sample_false = patch_model_do_sample_false
self.in_queue = mp.Queue()
self.out_queue = mp.Queue()
@@ -292,6 +294,7 @@ class HFRunner:
torch_dtype=torch_dtype,
output_str_only=self.output_str_only,
token_ids_logprob=token_ids_logprob,
patch_model_do_sample_false=self.patch_model_do_sample_false,
)
)
elif self.model_type == "embedding":
@@ -380,6 +383,7 @@ class HFRunner:
lora_paths: Optional[List[str]] = None,
output_str_only: bool = False,
token_ids_logprob: Optional[int] = None,
patch_model_do_sample_false: Optional[bool] = False,
) -> ModelOutput:
output_strs = []
top_input_logprobs = []
@@ -407,7 +411,8 @@ class HFRunner:
)
else:
model = base_model
if patch_model_do_sample_false:
model.generation_config.do_sample = False
outputs = model.generate(
input_ids=input_ids,
generation_config=GenerationConfig(