Sync from v0.13
This commit is contained in:
57
tests/v1/worker/test_utils.py
Normal file
57
tests/v1/worker/test_utils.py
Normal file
@@ -0,0 +1,57 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.v1.worker.utils import bind_kv_cache
|
||||
|
||||
|
||||
def test_bind_kv_cache():
|
||||
from vllm.attention.layer import Attention
|
||||
|
||||
ctx = {
|
||||
"layers.0.self_attn": Attention(32, 128, 0.1),
|
||||
"layers.1.self_attn": Attention(32, 128, 0.1),
|
||||
"layers.2.self_attn": Attention(32, 128, 0.1),
|
||||
"layers.3.self_attn": Attention(32, 128, 0.1),
|
||||
}
|
||||
kv_cache = {
|
||||
"layers.0.self_attn": torch.zeros((1,)),
|
||||
"layers.1.self_attn": torch.zeros((1,)),
|
||||
"layers.2.self_attn": torch.zeros((1,)),
|
||||
"layers.3.self_attn": torch.zeros((1,)),
|
||||
}
|
||||
runner_kv_caches: list[torch.Tensor] = []
|
||||
bind_kv_cache(kv_cache, ctx, runner_kv_caches)
|
||||
assert ctx["layers.0.self_attn"].kv_cache[0] is kv_cache["layers.0.self_attn"]
|
||||
assert ctx["layers.1.self_attn"].kv_cache[0] is kv_cache["layers.1.self_attn"]
|
||||
assert ctx["layers.2.self_attn"].kv_cache[0] is kv_cache["layers.2.self_attn"]
|
||||
assert ctx["layers.3.self_attn"].kv_cache[0] is kv_cache["layers.3.self_attn"]
|
||||
|
||||
assert runner_kv_caches[0] is kv_cache["layers.0.self_attn"]
|
||||
assert runner_kv_caches[1] is kv_cache["layers.1.self_attn"]
|
||||
assert runner_kv_caches[2] is kv_cache["layers.2.self_attn"]
|
||||
assert runner_kv_caches[3] is kv_cache["layers.3.self_attn"]
|
||||
|
||||
|
||||
def test_bind_kv_cache_non_attention():
|
||||
from vllm.attention.layer import Attention
|
||||
|
||||
# example from Jamba PP=2
|
||||
ctx = {
|
||||
"model.layers.20.attn": Attention(32, 128, 0.1),
|
||||
"model.layers.28.attn": Attention(32, 128, 0.1),
|
||||
}
|
||||
kv_cache = {
|
||||
"model.layers.20.attn": torch.zeros((1,)),
|
||||
"model.layers.28.attn": torch.zeros((1,)),
|
||||
}
|
||||
|
||||
runner_kv_caches: list[torch.Tensor] = []
|
||||
bind_kv_cache(kv_cache, ctx, runner_kv_caches)
|
||||
|
||||
assert ctx["model.layers.20.attn"].kv_cache[0] is kv_cache["model.layers.20.attn"]
|
||||
assert ctx["model.layers.28.attn"].kv_cache[0] is kv_cache["model.layers.28.attn"]
|
||||
|
||||
assert runner_kv_caches[0] is kv_cache["model.layers.20.attn"]
|
||||
assert runner_kv_caches[1] is kv_cache["model.layers.28.attn"]
|
||||
Reference in New Issue
Block a user