Add get_hidden_dim to qwen3.py for correct lora (#7312)
This commit is contained in:
@@ -330,6 +330,30 @@ class Qwen3ForCausalLM(nn.Module):
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
|
||||
def get_hidden_dim(self, module_name: str) -> Tuple[int]:
|
||||
# return input_dim, output_dim
|
||||
if module_name in ["q_proj", "qkv_proj"]:
|
||||
return (
|
||||
self.config.hidden_size,
|
||||
self.config.head_dim * self.config.num_attention_heads,
|
||||
)
|
||||
elif module_name in ["o_proj"]:
|
||||
return (
|
||||
self.config.head_dim * self.config.num_attention_heads,
|
||||
self.config.hidden_size,
|
||||
)
|
||||
elif module_name in ["kv_proj"]:
|
||||
return (
|
||||
self.config.hidden_size,
|
||||
self.config.head_dim * self.config.num_key_value_heads,
|
||||
)
|
||||
elif module_name == "gate_up_proj":
|
||||
return self.config.hidden_size, self.config.intermediate_size
|
||||
elif module_name == "down_proj":
|
||||
return self.config.intermediate_size, self.config.hidden_size
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@@ -134,10 +134,12 @@ class HFRunner:
|
||||
model_type: str = "generation",
|
||||
output_str_only: bool = False,
|
||||
trust_remote_code: bool = False,
|
||||
patch_model_do_sample_false: bool = False,
|
||||
):
|
||||
self.model_type = model_type
|
||||
self.output_str_only = output_str_only
|
||||
self.trust_remote_code = trust_remote_code
|
||||
self.patch_model_do_sample_false = patch_model_do_sample_false
|
||||
|
||||
self.in_queue = mp.Queue()
|
||||
self.out_queue = mp.Queue()
|
||||
@@ -292,6 +294,7 @@ class HFRunner:
|
||||
torch_dtype=torch_dtype,
|
||||
output_str_only=self.output_str_only,
|
||||
token_ids_logprob=token_ids_logprob,
|
||||
patch_model_do_sample_false=self.patch_model_do_sample_false,
|
||||
)
|
||||
)
|
||||
elif self.model_type == "embedding":
|
||||
@@ -380,6 +383,7 @@ class HFRunner:
|
||||
lora_paths: Optional[List[str]] = None,
|
||||
output_str_only: bool = False,
|
||||
token_ids_logprob: Optional[int] = None,
|
||||
patch_model_do_sample_false: Optional[bool] = False,
|
||||
) -> ModelOutput:
|
||||
output_strs = []
|
||||
top_input_logprobs = []
|
||||
@@ -407,7 +411,8 @@ class HFRunner:
|
||||
)
|
||||
else:
|
||||
model = base_model
|
||||
|
||||
if patch_model_do_sample_false:
|
||||
model.generation_config.do_sample = False
|
||||
outputs = model.generate(
|
||||
input_ids=input_ids,
|
||||
generation_config=GenerationConfig(
|
||||
|
||||
Reference in New Issue
Block a user