Add PDL support for quant kernel and rope kernel (#9106)
This commit is contained in:
@@ -271,6 +271,7 @@ def apply_rope_with_cos_sin_cache_inplace(
|
||||
cos_sin_cache: torch.Tensor,
|
||||
is_neox: bool = True,
|
||||
fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
|
||||
enable_pdl: Optional[bool] = None,
|
||||
) -> None:
|
||||
r"""
|
||||
Apply rotary embedding to keys and queries with precomputed cos/sin values.
|
||||
@@ -307,6 +308,10 @@ def apply_rope_with_cos_sin_cache_inplace(
|
||||
if cos_sin_cache.dtype != torch.float32:
|
||||
raise ValueError("cos_sin_cache should be float32")
|
||||
|
||||
if enable_pdl is None:
|
||||
# the non-fused branch does not yet support PDL, but after we switch to our impl for that branch it will
|
||||
enable_pdl = is_arch_support_pdl() and (fused_set_kv_buffer_arg is not None)
|
||||
|
||||
if (a := fused_set_kv_buffer_arg) is not None:
|
||||
assert a.k_scale is None, "k_scale is not yet supported"
|
||||
assert a.v_scale is None, "v_scale is not yet supported"
|
||||
@@ -323,6 +328,7 @@ def apply_rope_with_cos_sin_cache_inplace(
|
||||
cos_sin_cache,
|
||||
positions.long(),
|
||||
(not is_neox),
|
||||
enable_pdl,
|
||||
get_cuda_stream(),
|
||||
(
|
||||
_view_3d(fused_set_kv_buffer_arg.value)
|
||||
|
||||
Reference in New Issue
Block a user