move contiguous in fused_sigmoid_gating_delta_rule_update to model_runner_v1 (#5274)
### What this PR does / why we need it?
The contiguous() operation temporarily increases memory usage, leading
to higher peak GPU memory, which necessitates reducing
gpu_memory_utilization. However, making tensors contiguous in
modelrunnerv1 significantly enhances operator performance, resulting in
greater end-to-end model benefits despite the memory overhead.
- vLLM version: release/v0.13.0
- vLLM main:
ad32e3e19c
Signed-off-by: wangxiaoxin-sherie <wangxiaoxin7@huawei.com>
Co-authored-by: wangxiaoxin-sherie <wangxiaoxin7@huawei.com>
This commit is contained in:
@@ -360,7 +360,7 @@ def fused_sigmoid_gating_delta_rule_update(
|
||||
if not initial_state_indices.is_contiguous():
|
||||
initial_state_indices = initial_state_indices.contiguous()
|
||||
if not initial_state_source.is_contiguous():
|
||||
initial_state_source_contiguous = initial_state_source.contiguous()
|
||||
initial_state_source = initial_state_source.contiguous()
|
||||
if not cu_seqlens.is_contiguous():
|
||||
cu_seqlens = cu_seqlens.contiguous()
|
||||
|
||||
@@ -375,7 +375,7 @@ def fused_sigmoid_gating_delta_rule_update(
|
||||
v=v,
|
||||
b=b,
|
||||
o=o,
|
||||
h0_source=initial_state_source_contiguous,
|
||||
h0_source=initial_state_source,
|
||||
h0_indices=initial_state_indices,
|
||||
cu_seqlens=cu_seqlens,
|
||||
scale=scale,
|
||||
@@ -391,7 +391,5 @@ def fused_sigmoid_gating_delta_rule_update(
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages,
|
||||
)
|
||||
initial_state_source.copy_(
|
||||
initial_state_source_contiguous.view_as(initial_state_source))
|
||||
o = o.squeeze(0)
|
||||
return o
|
||||
|
||||
@@ -57,7 +57,6 @@ from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils.import_utils import LazyLoader
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.utils.mem_utils import DeviceMemoryProfiler
|
||||
from vllm.utils.torch_utils import get_dtype_size
|
||||
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder
|
||||
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
|
||||
CommonAttentionMetadata)
|
||||
@@ -2541,6 +2540,7 @@ class NPUModelRunner(GPUModelRunner):
|
||||
) % kv_cache_spec.page_size_bytes == 0
|
||||
num_blocks = raw_tensor.numel(
|
||||
) // kv_cache_spec.page_size_bytes
|
||||
assert num_blocks >= kv_cache_config.num_blocks
|
||||
|
||||
# `num_blocks` is the number of blocks the model runner can use.
|
||||
# `kv_cache_config.num_blocks` is the number of blocks that
|
||||
@@ -2549,27 +2549,22 @@ class NPUModelRunner(GPUModelRunner):
|
||||
# different memory capacities, `num_blocks` can be different on
|
||||
# different GPUs, and `kv_cache_config.num_blocks` is set to
|
||||
# the min of all `num_blocks`. Verify it here.
|
||||
assert num_blocks >= kv_cache_config.num_blocks
|
||||
|
||||
state_tensors = []
|
||||
storage_offset_bytes = 0
|
||||
for (shape, dtype) in zip(kv_cache_spec.shapes,
|
||||
kv_cache_spec.dtypes):
|
||||
dtype_size = get_dtype_size(dtype)
|
||||
num_element_per_page = (
|
||||
kv_cache_spec.page_size_bytes // dtype_size)
|
||||
target_idx = 0
|
||||
start_idx = 0
|
||||
for shape, dtype in zip(kv_cache_spec.shapes,
|
||||
kv_cache_spec.dtypes):
|
||||
# normally, there is conv state and ssm state in this loop. And there is only
|
||||
# a conv state in some special models.
|
||||
target_shape = (num_blocks, *shape)
|
||||
stride = torch.empty(target_shape).stride()
|
||||
target_stride = (num_element_per_page, *stride[1:])
|
||||
assert storage_offset_bytes % dtype_size == 0
|
||||
tensor = torch.as_strided(
|
||||
raw_tensor.view(dtype),
|
||||
size=target_shape,
|
||||
stride=target_stride,
|
||||
storage_offset=storage_offset_bytes // dtype_size,
|
||||
)
|
||||
|
||||
target_idx += torch.prod(
|
||||
torch.tensor(target_shape)).item()
|
||||
tensor = raw_tensor.view(
|
||||
dtype)[start_idx:target_idx].view(target_shape)
|
||||
start_idx = target_idx
|
||||
state_tensors.append(tensor)
|
||||
storage_offset_bytes += stride[0] * dtype_size
|
||||
kv_caches[layer_name] = state_tensors
|
||||
else:
|
||||
raise ValueError("Unknown KV cache spec type.")
|
||||
|
||||
Reference in New Issue
Block a user