From 320877d48888ef8ac2cb3d1079518d12d4a626d8 Mon Sep 17 00:00:00 2001 From: XiaoxinWang <963372609@qq.com> Date: Fri, 26 Dec 2025 09:19:47 +0800 Subject: [PATCH] move contiguous in fused_sigmoid_gating_delta_rule_update to model_runner_v1 (#5274) ### What this PR does / why we need it? The contiguous() operation temporarily increases memory usage, leading to higher peak GPU memory, which necessitates reducing gpu_memory_utilization. However, making tensors contiguous in modelrunnerv1 significantly enhances operator performance, resulting in greater end-to-end model benefits despite the memory overhead. - vLLM version: release/v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 Signed-off-by: wangxiaoxin-sherie Co-authored-by: wangxiaoxin-sherie --- vllm_ascend/ops/triton/fla/sigmoid_gating.py | 6 ++-- vllm_ascend/worker/model_runner_v1.py | 31 ++++++++------------ 2 files changed, 15 insertions(+), 22 deletions(-) diff --git a/vllm_ascend/ops/triton/fla/sigmoid_gating.py b/vllm_ascend/ops/triton/fla/sigmoid_gating.py index dd481ec4..b4c063d2 100644 --- a/vllm_ascend/ops/triton/fla/sigmoid_gating.py +++ b/vllm_ascend/ops/triton/fla/sigmoid_gating.py @@ -360,7 +360,7 @@ def fused_sigmoid_gating_delta_rule_update( if not initial_state_indices.is_contiguous(): initial_state_indices = initial_state_indices.contiguous() if not initial_state_source.is_contiguous(): - initial_state_source_contiguous = initial_state_source.contiguous() + initial_state_source = initial_state_source.contiguous() if not cu_seqlens.is_contiguous(): cu_seqlens = cu_seqlens.contiguous() @@ -375,7 +375,7 @@ def fused_sigmoid_gating_delta_rule_update( v=v, b=b, o=o, - h0_source=initial_state_source_contiguous, + h0_source=initial_state_source, h0_indices=initial_state_indices, cu_seqlens=cu_seqlens, scale=scale, @@ -391,7 +391,5 @@ def fused_sigmoid_gating_delta_rule_update( num_warps=num_warps, num_stages=num_stages, ) - initial_state_source.copy_( - initial_state_source_contiguous.view_as(initial_state_source)) o = o.squeeze(0) return o diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index b6adeba0..c2b11d92 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -57,7 +57,6 @@ from vllm.sequence import IntermediateTensors from vllm.utils.import_utils import LazyLoader from vllm.utils.math_utils import cdiv from vllm.utils.mem_utils import DeviceMemoryProfiler -from vllm.utils.torch_utils import get_dtype_size from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder from vllm.v1.attention.backends.utils import (AttentionCGSupport, CommonAttentionMetadata) @@ -2541,6 +2540,7 @@ class NPUModelRunner(GPUModelRunner): ) % kv_cache_spec.page_size_bytes == 0 num_blocks = raw_tensor.numel( ) // kv_cache_spec.page_size_bytes + assert num_blocks >= kv_cache_config.num_blocks # `num_blocks` is the number of blocks the model runner can use. # `kv_cache_config.num_blocks` is the number of blocks that @@ -2549,27 +2549,22 @@ class NPUModelRunner(GPUModelRunner): # different memory capacities, `num_blocks` can be different on # different GPUs, and `kv_cache_config.num_blocks` is set to # the min of all `num_blocks`. Verify it here. - assert num_blocks >= kv_cache_config.num_blocks state_tensors = [] - storage_offset_bytes = 0 - for (shape, dtype) in zip(kv_cache_spec.shapes, - kv_cache_spec.dtypes): - dtype_size = get_dtype_size(dtype) - num_element_per_page = ( - kv_cache_spec.page_size_bytes // dtype_size) + target_idx = 0 + start_idx = 0 + for shape, dtype in zip(kv_cache_spec.shapes, + kv_cache_spec.dtypes): + # normally, there is conv state and ssm state in this loop. And there is only + # a conv state in some special models. target_shape = (num_blocks, *shape) - stride = torch.empty(target_shape).stride() - target_stride = (num_element_per_page, *stride[1:]) - assert storage_offset_bytes % dtype_size == 0 - tensor = torch.as_strided( - raw_tensor.view(dtype), - size=target_shape, - stride=target_stride, - storage_offset=storage_offset_bytes // dtype_size, - ) + + target_idx += torch.prod( + torch.tensor(target_shape)).item() + tensor = raw_tensor.view( + dtype)[start_idx:target_idx].view(target_shape) + start_idx = target_idx state_tensors.append(tensor) - storage_offset_bytes += stride[0] * dtype_size kv_caches[layer_name] = state_tensors else: raise ValueError("Unknown KV cache spec type.")