From 93754d80616830a5bc068c51d3493b84f679750d Mon Sep 17 00:00:00 2001 From: Shanshan Shen <467638484@qq.com> Date: Wed, 3 Sep 2025 09:18:44 +0800 Subject: [PATCH] [Bugfix] Fix long context seq accuracy problem for `GLM4.5` (#2601) ### What this PR does / why we need it? Fix long context seq accuracy problem for `GLM4.5`. When `max_tokens=1000`, there is cyclic output problem like: ```bash 00 00 00 00 00 00 00 00 00 00 00 00 00 00 ``` ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? ```python import os os.environ["VLLM_USE_MODELSCOPE"] = "True" os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" from vllm import LLM, SamplingParams def main(): prompts = [ "Hello, my name is", "The president of the United States is", "The capital of France is", "The future of AI is", ] # Create a sampling params object. sampling_params = SamplingParams(max_tokens=1000, temperature=0.0) # Create an LLM. llm = LLM(model="/root/.cache/modelscope/hub/models/ZhipuAI/GLM-4___5", tensor_parallel_size=8, enforce_eager=True, trust_remote_code=True, max_model_len=1024) # Generate texts from the prompts. outputs = llm.generate(prompts, sampling_params) for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") if __name__ == "__main__": main() ``` - vLLM version: v0.10.1.1 - vLLM main: https://github.com/vllm-project/vllm/commit/0235103cbbdb511e6708aae600f759060a797c16 --------- Signed-off-by: Shanshan Shen <87969357+shen-shanshan@users.noreply.github.com> Signed-off-by: shen-shanshan <467638484@qq.com> --- vllm_ascend/ops/rotary_embedding.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/vllm_ascend/ops/rotary_embedding.py b/vllm_ascend/ops/rotary_embedding.py index 5b0daa3..89e2bc7 100644 --- a/vllm_ascend/ops/rotary_embedding.py +++ b/vllm_ascend/ops/rotary_embedding.py @@ -64,6 +64,29 @@ def _rope_forward_oot( raise NotImplementedError( "Batched rotary embedding is currently not supported on NPU.") else: + if self.rotary_dim < self.head_size: + num_tokens = query.shape[0] + query = query.view(num_tokens, -1, self.head_size) + key = key.view(num_tokens, -1, self.head_size) + q_rot = query[..., :self.rotary_dim] + q_pass = query[..., self.rotary_dim:] + k_rot = key[..., :self.rotary_dim] + k_pass = key[..., self.rotary_dim:] + q_rot = q_rot.contiguous().view(num_tokens, -1) + k_rot = k_rot.contiguous().view(num_tokens, -1) + torch_npu._npu_rotary_embedding( + positions, + q_rot, + k_rot, + self.head_size, + self.cos_sin_cache, + neox_style, + ) + q_rot = q_rot.view(num_tokens, -1, self.rotary_dim) + k_rot = k_rot.view(num_tokens, -1, self.rotary_dim) + q = torch.cat((q_rot, q_pass), dim=-1).reshape(query_shape) + k = torch.cat((k_rot, k_pass), dim=-1).reshape(key_shape) + return q, k # TODO: Remove the contiguous in the future. query = query.contiguous().view(query.shape[0], -1) key = key.contiguous().view(key.shape[0], -1)