fix: pass lm_head to LogitsProcessor instead of calling forward()

In vLLM v0.6.2, ParallelLMHead.forward() raises RuntimeError since
its weights should be used through LogitsProcessor.linear_method.apply().
Pass lm_head as first arg to LogitsProcessor which handles the
hidden_states -> logits projection internally.
This commit is contained in:
Chranos
2026-02-06 14:21:14 +08:00
parent b702adf015
commit ebdc6fed03

View File

@@ -114,15 +114,12 @@ class CausalMixin:
# Non-last PP rank
return None
# Apply lm_head (at wrapper level, not self.model.lm_head)
output = self.lm_head(hidden_states)
# Handle tuple output from vLLM Linear layers (output, bias)
if isinstance(output, tuple):
logits = output[0]
else:
logits = output
# In v0.6.2, LogitsProcessor handles the lm_head projection internally
# via lm_head.linear_method.apply(). Pass lm_head as the first arg.
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return self.logits_processor(None, logits, sampling_metadata)
return logits
def sample(
self,