Add PDL support for quant kernel and rope kernel (#9106)

This commit is contained in:
fzyzcjy
2025-08-20 16:56:29 +08:00
committed by GitHub
parent c9bf3877a0
commit 42c8704560
7 changed files with 80 additions and 33 deletions

View File

@@ -271,6 +271,7 @@ def apply_rope_with_cos_sin_cache_inplace(
cos_sin_cache: torch.Tensor,
is_neox: bool = True,
fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
enable_pdl: Optional[bool] = None,
) -> None:
r"""
Apply rotary embedding to keys and queries with precomputed cos/sin values.
@@ -307,6 +308,10 @@ def apply_rope_with_cos_sin_cache_inplace(
if cos_sin_cache.dtype != torch.float32:
raise ValueError("cos_sin_cache should be float32")
if enable_pdl is None:
# the non-fused branch does not yet support PDL, but after we switch to our impl for that branch it will
enable_pdl = is_arch_support_pdl() and (fused_set_kv_buffer_arg is not None)
if (a := fused_set_kv_buffer_arg) is not None:
assert a.k_scale is None, "k_scale is not yet supported"
assert a.v_scale is None, "v_scale is not yet supported"
@@ -323,6 +328,7 @@ def apply_rope_with_cos_sin_cache_inplace(
cos_sin_cache,
positions.long(),
(not is_neox),
enable_pdl,
get_cuda_stream(),
(
_view_3d(fused_set_kv_buffer_arg.value)