fix sgl-kernel unit tests (#5666)
This commit is contained in:
@@ -41,6 +41,7 @@ from sgl_kernel.gemm import (
|
||||
sgl_per_token_group_quant_int8,
|
||||
sgl_per_token_quant_fp8,
|
||||
)
|
||||
from sgl_kernel.grammar import apply_token_bitmask_inplace_cuda
|
||||
from sgl_kernel.moe import (
|
||||
fp8_blockwise_scaled_grouped_mm,
|
||||
moe_align_block_size,
|
||||
|
||||
15
sgl-kernel/python/sgl_kernel/grammar.py
Normal file
15
sgl-kernel/python/sgl_kernel/grammar.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def apply_token_bitmask_inplace_cuda(
|
||||
logits: torch.Tensor,
|
||||
bitmask: torch.Tensor,
|
||||
indices: Optional[Union[List[int], torch.Tensor]] = None,
|
||||
) -> None:
|
||||
if isinstance(indices, list):
|
||||
indices = torch.tensor(indices, dtype=torch.int32, device=logits.device)
|
||||
if indices is not None:
|
||||
indices = indices.to(logits.device)
|
||||
torch.ops.sgl_kernel.apply_token_bitmask_inplace_cuda(logits, bitmask, indices)
|
||||
Reference in New Issue
Block a user