feat: integrate sampling kernels into sgl-kernel (#3086)
Co-authored-by: Zihao Ye <expye@outlook.com>
This commit is contained in:
@@ -17,3 +17,10 @@ def _get_cache_buf(name: str, bytes: int, device: torch.device) -> torch.Tensor:
|
||||
buf = torch.empty(bytes, dtype=torch.uint8, device=device)
|
||||
_cache_buf[key] = buf
|
||||
return buf
|
||||
|
||||
|
||||
def _to_tensor_scalar_tuple(x):
|
||||
if isinstance(x, torch.Tensor):
|
||||
return (x, 0)
|
||||
else:
|
||||
return (None, x)
|
||||
|
||||
Reference in New Issue
Block a user