fix sgl-kernel unit tests (#5666)

This commit is contained in:
Yineng Zhang
2025-04-23 01:18:30 -07:00
committed by GitHub
parent e62c49557d
commit 15fabcc07f
9 changed files with 313 additions and 0 deletions

View File

@@ -41,6 +41,7 @@ from sgl_kernel.gemm import (
sgl_per_token_group_quant_int8,
sgl_per_token_quant_fp8,
)
from sgl_kernel.grammar import apply_token_bitmask_inplace_cuda
from sgl_kernel.moe import (
fp8_blockwise_scaled_grouped_mm,
moe_align_block_size,

View File

@@ -0,0 +1,15 @@
from typing import List, Optional, Union
import torch
def apply_token_bitmask_inplace_cuda(
logits: torch.Tensor,
bitmask: torch.Tensor,
indices: Optional[Union[List[int], torch.Tensor]] = None,
) -> None:
if isinstance(indices, list):
indices = torch.tensor(indices, dtype=torch.int32, device=logits.device)
if indices is not None:
indices = indices.to(logits.device)
torch.ops.sgl_kernel.apply_token_bitmask_inplace_cuda(logits, bitmask, indices)