231 lines
8.1 KiB
Python
231 lines
8.1 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
import torch
|
|
import torch._inductor.pattern_matcher as pm
|
|
from torch import fx
|
|
from torch._higher_order_ops import auto_functionalized
|
|
from torch._inductor.fx_passes.post_grad import view_to_reshape
|
|
from torch._inductor.pattern_matcher import PatternMatcherPass
|
|
|
|
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
|
from vllm.config.utils import Range
|
|
from vllm.logger import init_logger
|
|
from vllm.model_executor.layers.attention.attention import (
|
|
Attention,
|
|
get_attention_context,
|
|
)
|
|
from vllm.utils.torch_utils import direct_register_custom_op
|
|
|
|
from ..inductor_pass import enable_fake_mode
|
|
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
|
from .matcher_utils import (
|
|
MatcherRotaryEmbedding,
|
|
)
|
|
from .rms_quant_fusion import (
|
|
empty_bf16,
|
|
empty_i64,
|
|
)
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
def fused_rope_and_unified_kv_cache_update_impl(
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
cos_sin_cache: torch.Tensor,
|
|
is_neox: bool,
|
|
layer_name: str = "",
|
|
) -> torch.Tensor:
|
|
"""
|
|
This impl fetches the KV cache and slot mapping from the forward context,
|
|
then calls the layer impl's `AttentionImpl.do_rope_and_kv_cache_update` method.
|
|
It also returns a dummy tensor, similar to `Attention.unified_kv_cache_update`,
|
|
that is passed to unified_attention to signal a side effect and
|
|
the data dependency between them to ensure torch.compile preserves ordering.
|
|
"""
|
|
_, attn_layer, kv_cache, layer_slot_mapping = get_attention_context(layer_name)
|
|
if layer_slot_mapping is not None:
|
|
attn_layer.impl.do_rope_and_kv_cache_update(
|
|
attn_layer,
|
|
query,
|
|
key,
|
|
value,
|
|
positions,
|
|
cos_sin_cache,
|
|
is_neox,
|
|
kv_cache,
|
|
layer_slot_mapping,
|
|
)
|
|
|
|
return torch.empty(0, device=kv_cache.device, dtype=kv_cache.dtype)
|
|
|
|
|
|
def fused_rope_and_unified_kv_cache_update_fake(
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
cos_sin_cache: torch.Tensor,
|
|
is_neox: bool,
|
|
layer_name: str = "",
|
|
) -> torch.Tensor:
|
|
return torch.empty(0, device=query.device, dtype=query.dtype)
|
|
|
|
|
|
direct_register_custom_op(
|
|
op_name="fused_rope_and_unified_kv_cache_update",
|
|
op_func=fused_rope_and_unified_kv_cache_update_impl,
|
|
mutates_args=["query", "key"],
|
|
fake_impl=fused_rope_and_unified_kv_cache_update_fake,
|
|
)
|
|
|
|
|
|
class RopeReshapeKVCachePattern:
|
|
"""
|
|
This pattern matches the following unfused inplace ops:
|
|
q, k = rotary_embedding(positions, q, k, head_size, cos_sin_cache, is_neox)
|
|
kv_cache_dummy = unified_kv_cache_update(k, v, layer_name)
|
|
|
|
and replaces it with the fused inplace op:
|
|
kv_cache_dummy = fused_rope_and_unified_kv_cache_update(
|
|
q, k, v, positions, cos_sin_cache, is_neox, layer_name
|
|
)
|
|
"""
|
|
|
|
FUSED_OP = torch.ops.vllm.fused_rope_and_unified_kv_cache_update.default
|
|
|
|
def __init__(
|
|
self,
|
|
layer: Attention,
|
|
is_neox: bool,
|
|
) -> None:
|
|
self.layer_name = layer.layer_name
|
|
self.num_heads = layer.num_heads
|
|
self.num_kv_heads = layer.num_kv_heads
|
|
self.head_size = layer.head_size
|
|
self.head_size_v = layer.head_size_v
|
|
self.is_neox = is_neox
|
|
|
|
self.q_size = self.num_heads * self.head_size
|
|
self.k_size = self.num_kv_heads * self.head_size
|
|
self.v_size = self.num_kv_heads * self.head_size_v
|
|
|
|
self.rope_matcher = MatcherRotaryEmbedding(
|
|
is_neox=self.is_neox,
|
|
head_size=self.head_size,
|
|
num_heads=self.num_heads,
|
|
num_kv_heads=self.num_kv_heads,
|
|
)
|
|
|
|
def get_inputs(self) -> list[torch.Tensor]:
|
|
# Sample inputs to help pattern tracing
|
|
T = 5
|
|
L = 4096
|
|
qkv = empty_bf16(T, self.q_size + self.k_size + self.v_size)
|
|
positions = empty_i64(T)
|
|
cos_sin_cache = empty_bf16(L, self.head_size)
|
|
return [
|
|
qkv,
|
|
positions,
|
|
cos_sin_cache,
|
|
]
|
|
|
|
def register(self, pm_pass: PatternMatcherPass) -> None:
|
|
def pattern(
|
|
qkv: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
cos_sin_cache: torch.Tensor,
|
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
q, k, v = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1)
|
|
q, k = self.rope_matcher(positions, q, k, cos_sin_cache)
|
|
q = q.view(-1, self.num_heads, self.head_size)
|
|
k = k.view(-1, self.num_kv_heads, self.head_size)
|
|
v = v.view(-1, self.num_kv_heads, self.head_size_v)
|
|
dummy = torch.ops.vllm.unified_kv_cache_update(k, v, self.layer_name)
|
|
return dummy, q, k, v
|
|
|
|
def replacement(
|
|
qkv: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
cos_sin_cache: torch.Tensor,
|
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
q, k, v = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1)
|
|
q = q.view(-1, self.num_heads, self.head_size)
|
|
k = k.view(-1, self.num_kv_heads, self.head_size)
|
|
v = v.view(-1, self.num_kv_heads, self.head_size_v)
|
|
results = auto_functionalized(
|
|
self.FUSED_OP,
|
|
query=q,
|
|
key=k,
|
|
value=v,
|
|
positions=positions,
|
|
cos_sin_cache=cos_sin_cache,
|
|
is_neox=self.is_neox,
|
|
layer_name=self.layer_name,
|
|
)
|
|
return results[0], results[1], results[2], v
|
|
|
|
# NOTE: use view_to_reshape to unify view/reshape to simplify
|
|
# pattern and increase matching opportunities
|
|
def fwd_and_view_to_reshape(*args, **kwargs) -> fx.GraphModule:
|
|
gm = pm.fwd_only(*args, **kwargs)
|
|
view_to_reshape(gm)
|
|
return gm
|
|
|
|
pm.register_replacement(
|
|
pattern, replacement, self.get_inputs(), fwd_and_view_to_reshape, pm_pass
|
|
)
|
|
|
|
|
|
class RopeKVCacheFusionPass(VllmPatternMatcherPass):
|
|
"""
|
|
This pass fuses the rotary embedding and KV cache update operations
|
|
into a single fused kernel if available.
|
|
|
|
It uses the pattern matcher and matches each layer manually, as strings
|
|
cannot be wildcarded. This also lets us check support on attention layers
|
|
upon registration instead of during pattern matching.
|
|
|
|
This fusion eliminates the need for separate kernel launches and
|
|
intermediate memory operations between the RoPE and cache update steps.
|
|
"""
|
|
|
|
@enable_fake_mode
|
|
def __init__(self, config: VllmConfig) -> None:
|
|
super().__init__(config)
|
|
|
|
self.patterns: PatternMatcherPass = PatternMatcherPass(
|
|
pass_name="rope_kv_cache_fusion_pass"
|
|
)
|
|
|
|
cc = config.compilation_config
|
|
self.max_token_num = cc.pass_config.rope_kvcache_fusion_max_token_num
|
|
|
|
attn_layers = get_layers_from_vllm_config(config, Attention)
|
|
for _, layer in attn_layers.items():
|
|
if layer.impl.fused_rope_kvcache_supported():
|
|
for is_neox in [True, False]:
|
|
RopeReshapeKVCachePattern(
|
|
layer=layer,
|
|
is_neox=is_neox,
|
|
).register(self.patterns)
|
|
|
|
self.dump_patterns(config, self.patterns)
|
|
|
|
@VllmInductorPass.time_and_log
|
|
def __call__(self, graph: fx.Graph) -> None:
|
|
self.matched_count = self.patterns.apply(graph)
|
|
logger.debug("Replaced %s patterns", self.matched_count)
|
|
|
|
def is_applicable_for_range(self, compile_range: Range) -> bool:
|
|
# This pass works best for the small-batch decode setting.
|
|
# For large-batch e.g. prefill, it is better to use two separate kernels
|
|
# since they are compute bound and the fused kernels require further tuning.
|
|
return compile_range.end <= self.max_token_num
|
|
|
|
def uuid(self) -> str:
|
|
return VllmInductorPass.hash_source(self, RopeReshapeKVCachePattern)
|