[AMD] Expand test coverage for AMD CI and enable apply_token_bitmask_inplace_cuda in sgl-kernel (#8268)
This commit is contained in:
@@ -32,10 +32,15 @@ from sglang.srt.constrained.base_grammar_backend import (
|
||||
BaseGrammarBackend,
|
||||
BaseGrammarObject,
|
||||
)
|
||||
from sglang.srt.constrained.triton_ops.bitmask_ops import (
|
||||
apply_token_bitmask_inplace_triton,
|
||||
)
|
||||
from sglang.srt.utils import is_hip
|
||||
|
||||
_is_hip = is_hip()
|
||||
if _is_hip:
|
||||
from sgl_kernel import apply_token_bitmask_inplace_cuda
|
||||
else:
|
||||
from sglang.srt.constrained.triton_ops.bitmask_ops import (
|
||||
apply_token_bitmask_inplace_triton,
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -94,7 +99,10 @@ class XGrammarGrammar(BaseGrammarObject):
|
||||
|
||||
def apply_vocab_mask(self, logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
|
||||
if logits.device.type == "cuda":
|
||||
apply_token_bitmask_inplace_triton(logits, vocab_mask)
|
||||
if _is_hip:
|
||||
apply_token_bitmask_inplace_cuda(logits, vocab_mask)
|
||||
else:
|
||||
apply_token_bitmask_inplace_triton(logits, vocab_mask)
|
||||
elif logits.device.type == "cpu" and self.apply_vocab_mask_cpu:
|
||||
self.apply_vocab_mask_cpu(logits, vocab_mask)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user