forked from EngineX-Cambricon/enginex-mlu370-vllm
fix: pass lm_head to LogitsProcessor instead of calling forward()
In vLLM v0.6.2, ParallelLMHead.forward() raises RuntimeError since its weights should be used through LogitsProcessor.linear_method.apply(). Pass lm_head as first arg to LogitsProcessor which handles the hidden_states -> logits projection internally.
This commit is contained in:
@@ -114,15 +114,12 @@ class CausalMixin:
|
||||
# Non-last PP rank
|
||||
return None
|
||||
|
||||
# Apply lm_head (at wrapper level, not self.model.lm_head)
|
||||
output = self.lm_head(hidden_states)
|
||||
# Handle tuple output from vLLM Linear layers (output, bias)
|
||||
if isinstance(output, tuple):
|
||||
logits = output[0]
|
||||
else:
|
||||
logits = output
|
||||
# In v0.6.2, LogitsProcessor handles the lm_head projection internally
|
||||
# via lm_head.linear_method.apply(). Pass lm_head as the first arg.
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
|
||||
return self.logits_processor(None, logits, sampling_metadata)
|
||||
return logits
|
||||
|
||||
def sample(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user