Add minimal vLLM 0.16.1 build repo for BI-V150
This commit is contained in:
230
vllm/compilation/passes/fusion/rope_kvcache_fusion.py
Normal file
230
vllm/compilation/passes/fusion/rope_kvcache_fusion.py
Normal file
@@ -0,0 +1,230 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
import torch._inductor.pattern_matcher as pm
|
||||
from torch import fx
|
||||
from torch._higher_order_ops import auto_functionalized
|
||||
from torch._inductor.fx_passes.post_grad import view_to_reshape
|
||||
from torch._inductor.pattern_matcher import PatternMatcherPass
|
||||
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.config.utils import Range
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention.attention import (
|
||||
Attention,
|
||||
get_attention_context,
|
||||
)
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
from ..inductor_pass import enable_fake_mode
|
||||
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
from .matcher_utils import (
|
||||
MatcherRotaryEmbedding,
|
||||
)
|
||||
from .rms_quant_fusion import (
|
||||
empty_bf16,
|
||||
empty_i64,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def fused_rope_and_unified_kv_cache_update_impl(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
is_neox: bool,
|
||||
layer_name: str = "",
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
This impl fetches the KV cache and slot mapping from the forward context,
|
||||
then calls the layer impl's `AttentionImpl.do_rope_and_kv_cache_update` method.
|
||||
It also returns a dummy tensor, similar to `Attention.unified_kv_cache_update`,
|
||||
that is passed to unified_attention to signal a side effect and
|
||||
the data dependency between them to ensure torch.compile preserves ordering.
|
||||
"""
|
||||
_, attn_layer, kv_cache, layer_slot_mapping = get_attention_context(layer_name)
|
||||
if layer_slot_mapping is not None:
|
||||
attn_layer.impl.do_rope_and_kv_cache_update(
|
||||
attn_layer,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
positions,
|
||||
cos_sin_cache,
|
||||
is_neox,
|
||||
kv_cache,
|
||||
layer_slot_mapping,
|
||||
)
|
||||
|
||||
return torch.empty(0, device=kv_cache.device, dtype=kv_cache.dtype)
|
||||
|
||||
|
||||
def fused_rope_and_unified_kv_cache_update_fake(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
is_neox: bool,
|
||||
layer_name: str = "",
|
||||
) -> torch.Tensor:
|
||||
return torch.empty(0, device=query.device, dtype=query.dtype)
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="fused_rope_and_unified_kv_cache_update",
|
||||
op_func=fused_rope_and_unified_kv_cache_update_impl,
|
||||
mutates_args=["query", "key"],
|
||||
fake_impl=fused_rope_and_unified_kv_cache_update_fake,
|
||||
)
|
||||
|
||||
|
||||
class RopeReshapeKVCachePattern:
|
||||
"""
|
||||
This pattern matches the following unfused inplace ops:
|
||||
q, k = rotary_embedding(positions, q, k, head_size, cos_sin_cache, is_neox)
|
||||
kv_cache_dummy = unified_kv_cache_update(k, v, layer_name)
|
||||
|
||||
and replaces it with the fused inplace op:
|
||||
kv_cache_dummy = fused_rope_and_unified_kv_cache_update(
|
||||
q, k, v, positions, cos_sin_cache, is_neox, layer_name
|
||||
)
|
||||
"""
|
||||
|
||||
FUSED_OP = torch.ops.vllm.fused_rope_and_unified_kv_cache_update.default
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer: Attention,
|
||||
is_neox: bool,
|
||||
) -> None:
|
||||
self.layer_name = layer.layer_name
|
||||
self.num_heads = layer.num_heads
|
||||
self.num_kv_heads = layer.num_kv_heads
|
||||
self.head_size = layer.head_size
|
||||
self.head_size_v = layer.head_size_v
|
||||
self.is_neox = is_neox
|
||||
|
||||
self.q_size = self.num_heads * self.head_size
|
||||
self.k_size = self.num_kv_heads * self.head_size
|
||||
self.v_size = self.num_kv_heads * self.head_size_v
|
||||
|
||||
self.rope_matcher = MatcherRotaryEmbedding(
|
||||
is_neox=self.is_neox,
|
||||
head_size=self.head_size,
|
||||
num_heads=self.num_heads,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
)
|
||||
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
# Sample inputs to help pattern tracing
|
||||
T = 5
|
||||
L = 4096
|
||||
qkv = empty_bf16(T, self.q_size + self.k_size + self.v_size)
|
||||
positions = empty_i64(T)
|
||||
cos_sin_cache = empty_bf16(L, self.head_size)
|
||||
return [
|
||||
qkv,
|
||||
positions,
|
||||
cos_sin_cache,
|
||||
]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
qkv: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
q, k, v = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1)
|
||||
q, k = self.rope_matcher(positions, q, k, cos_sin_cache)
|
||||
q = q.view(-1, self.num_heads, self.head_size)
|
||||
k = k.view(-1, self.num_kv_heads, self.head_size)
|
||||
v = v.view(-1, self.num_kv_heads, self.head_size_v)
|
||||
dummy = torch.ops.vllm.unified_kv_cache_update(k, v, self.layer_name)
|
||||
return dummy, q, k, v
|
||||
|
||||
def replacement(
|
||||
qkv: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
q, k, v = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1)
|
||||
q = q.view(-1, self.num_heads, self.head_size)
|
||||
k = k.view(-1, self.num_kv_heads, self.head_size)
|
||||
v = v.view(-1, self.num_kv_heads, self.head_size_v)
|
||||
results = auto_functionalized(
|
||||
self.FUSED_OP,
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
positions=positions,
|
||||
cos_sin_cache=cos_sin_cache,
|
||||
is_neox=self.is_neox,
|
||||
layer_name=self.layer_name,
|
||||
)
|
||||
return results[0], results[1], results[2], v
|
||||
|
||||
# NOTE: use view_to_reshape to unify view/reshape to simplify
|
||||
# pattern and increase matching opportunities
|
||||
def fwd_and_view_to_reshape(*args, **kwargs) -> fx.GraphModule:
|
||||
gm = pm.fwd_only(*args, **kwargs)
|
||||
view_to_reshape(gm)
|
||||
return gm
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), fwd_and_view_to_reshape, pm_pass
|
||||
)
|
||||
|
||||
|
||||
class RopeKVCacheFusionPass(VllmPatternMatcherPass):
|
||||
"""
|
||||
This pass fuses the rotary embedding and KV cache update operations
|
||||
into a single fused kernel if available.
|
||||
|
||||
It uses the pattern matcher and matches each layer manually, as strings
|
||||
cannot be wildcarded. This also lets us check support on attention layers
|
||||
upon registration instead of during pattern matching.
|
||||
|
||||
This fusion eliminates the need for separate kernel launches and
|
||||
intermediate memory operations between the RoPE and cache update steps.
|
||||
"""
|
||||
|
||||
@enable_fake_mode
|
||||
def __init__(self, config: VllmConfig) -> None:
|
||||
super().__init__(config)
|
||||
|
||||
self.patterns: PatternMatcherPass = PatternMatcherPass(
|
||||
pass_name="rope_kv_cache_fusion_pass"
|
||||
)
|
||||
|
||||
cc = config.compilation_config
|
||||
self.max_token_num = cc.pass_config.rope_kvcache_fusion_max_token_num
|
||||
|
||||
attn_layers = get_layers_from_vllm_config(config, Attention)
|
||||
for _, layer in attn_layers.items():
|
||||
if layer.impl.fused_rope_kvcache_supported():
|
||||
for is_neox in [True, False]:
|
||||
RopeReshapeKVCachePattern(
|
||||
layer=layer,
|
||||
is_neox=is_neox,
|
||||
).register(self.patterns)
|
||||
|
||||
self.dump_patterns(config, self.patterns)
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: fx.Graph) -> None:
|
||||
self.matched_count = self.patterns.apply(graph)
|
||||
logger.debug("Replaced %s patterns", self.matched_count)
|
||||
|
||||
def is_applicable_for_range(self, compile_range: Range) -> bool:
|
||||
# This pass works best for the small-batch decode setting.
|
||||
# For large-batch e.g. prefill, it is better to use two separate kernels
|
||||
# since they are compute bound and the fused kernels require further tuning.
|
||||
return compile_range.end <= self.max_token_num
|
||||
|
||||
def uuid(self) -> str:
|
||||
return VllmInductorPass.hash_source(self, RopeReshapeKVCachePattern)
|
||||
Reference in New Issue
Block a user