[Bugfix] Fix long context seq accuracy problem for GLM4.5 (#2601)
### What this PR does / why we need it?
Fix long context seq accuracy problem for `GLM4.5`.
When `max_tokens=1000`, there is cyclic output problem like:
```bash
00 00 00 00 00 00 00 00 00 00 00 00 00 00
```
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
```python
import os
os.environ["VLLM_USE_MODELSCOPE"] = "True"
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
from vllm import LLM, SamplingParams
def main():
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(max_tokens=1000, temperature=0.0)
# Create an LLM.
llm = LLM(model="/root/.cache/modelscope/hub/models/ZhipuAI/GLM-4___5",
tensor_parallel_size=8,
enforce_eager=True,
trust_remote_code=True,
max_model_len=1024)
# Generate texts from the prompts.
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
if __name__ == "__main__":
main()
```
- vLLM version: v0.10.1.1
- vLLM main:
0235103cbb
---------
Signed-off-by: Shanshan Shen <87969357+shen-shanshan@users.noreply.github.com>
Signed-off-by: shen-shanshan <467638484@qq.com>
This commit is contained in:
@@ -64,6 +64,29 @@ def _rope_forward_oot(
|
||||
raise NotImplementedError(
|
||||
"Batched rotary embedding is currently not supported on NPU.")
|
||||
else:
|
||||
if self.rotary_dim < self.head_size:
|
||||
num_tokens = query.shape[0]
|
||||
query = query.view(num_tokens, -1, self.head_size)
|
||||
key = key.view(num_tokens, -1, self.head_size)
|
||||
q_rot = query[..., :self.rotary_dim]
|
||||
q_pass = query[..., self.rotary_dim:]
|
||||
k_rot = key[..., :self.rotary_dim]
|
||||
k_pass = key[..., self.rotary_dim:]
|
||||
q_rot = q_rot.contiguous().view(num_tokens, -1)
|
||||
k_rot = k_rot.contiguous().view(num_tokens, -1)
|
||||
torch_npu._npu_rotary_embedding(
|
||||
positions,
|
||||
q_rot,
|
||||
k_rot,
|
||||
self.head_size,
|
||||
self.cos_sin_cache,
|
||||
neox_style,
|
||||
)
|
||||
q_rot = q_rot.view(num_tokens, -1, self.rotary_dim)
|
||||
k_rot = k_rot.view(num_tokens, -1, self.rotary_dim)
|
||||
q = torch.cat((q_rot, q_pass), dim=-1).reshape(query_shape)
|
||||
k = torch.cat((k_rot, k_pass), dim=-1).reshape(key_shape)
|
||||
return q, k
|
||||
# TODO: Remove the contiguous in the future.
|
||||
query = query.contiguous().view(query.shape[0], -1)
|
||||
key = key.contiguous().view(key.shape[0], -1)
|
||||
|
||||
Reference in New Issue
Block a user