Fix FusedSetKVBufferArg in RotaryEmbedding (#11003)

This commit is contained in:
Lianmin Zheng
2025-09-28 11:17:27 -07:00
committed by GitHub
parent 9816989bff
commit 07440f5f34
2 changed files with 32 additions and 9 deletions

View File

@@ -1,6 +1,5 @@
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Optional, Tuple, Union
import pytest
import torch
from sgl_kernel import FusedSetKVBufferArg, apply_rope_with_cos_sin_cache_inplace
@@ -84,8 +83,13 @@ class RotaryEmbedding(torch.nn.Module):
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""A PyTorch-native implementation of forward()."""
assert (
fused_set_kv_buffer_arg is None
), "fused_set_kv_buffer_arg is not supported for native implementation"
if offsets is not None:
positions = positions + offsets
@@ -125,8 +129,8 @@ class FlashInferRotaryEmbedding(RotaryEmbedding):
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
offsets: Optional[torch.Tensor] = None,
fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
apply_rope_with_cos_sin_cache_inplace(