feat: integrate sampling kernels into sgl-kernel (#3086)

Co-authored-by: Zihao Ye <expye@outlook.com>
This commit is contained in:
Yineng Zhang
2025-01-24 01:54:47 +08:00
committed by GitHub
parent e0cd65c2b6
commit 5de4051bcf
6 changed files with 419 additions and 3 deletions

View File

@@ -17,3 +17,10 @@ def _get_cache_buf(name: str, bytes: int, device: torch.device) -> torch.Tensor:
buf = torch.empty(bytes, dtype=torch.uint8, device=device)
_cache_buf[key] = buf
return buf
def _to_tensor_scalar_tuple(x):
if isinstance(x, torch.Tensor):
return (x, 0)
else:
return (None, x)