fix some typos (#6209)
Co-authored-by: Brayden Zhong <b8zhong@uwaterloo.ca>
This commit is contained in:
@@ -20,7 +20,7 @@ class AttentionBackend(ABC):
|
||||
raise NotImplementedError()
|
||||
|
||||
def init_cuda_graph_state(self, max_bs: int):
|
||||
"""Init the global shared states for cuda graph."""
|
||||
"""Init the global shared states for CUDA graph."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def init_forward_metadata_capture_cuda_graph(
|
||||
@@ -33,7 +33,7 @@ class AttentionBackend(ABC):
|
||||
forward_mode: ForwardMode,
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
):
|
||||
"""Init the metadata for a forward pass for capturing a cuda graph."""
|
||||
"""Init the metadata for a forward pass for capturing a CUDA graph."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def init_forward_metadata_replay_cuda_graph(
|
||||
@@ -47,7 +47,7 @@ class AttentionBackend(ABC):
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
seq_lens_cpu: Optional[torch.Tensor],
|
||||
):
|
||||
"""Init the metadata for a forward pass for replaying a cuda graph."""
|
||||
"""Init the metadata for a forward pass for replaying a CUDA graph."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_cuda_graph_seq_len_fill_value(self):
|
||||
|
||||
@@ -15,7 +15,7 @@ if TYPE_CHECKING:
|
||||
|
||||
class DoubleSparseAttnBackend(AttentionBackend):
|
||||
def __init__(self, model_runner: ModelRunner):
|
||||
# Lazy import to avoid the initialization of cuda context
|
||||
# Lazy import to avoid the initialization of CUDA context
|
||||
from sglang.srt.layers.attention.triton_ops.double_sparsity_attention import (
|
||||
extend_attention_fwd,
|
||||
flash_decode_attention_fwd,
|
||||
|
||||
@@ -664,7 +664,7 @@ class FlashInferIndicesUpdaterDecode:
|
||||
kv_indptr = kv_indptr[: bs + 1]
|
||||
|
||||
if wrapper.is_cuda_graph_enabled:
|
||||
# Directly write to the cuda graph input buffer
|
||||
# Directly write to the CUDA graph input buffer
|
||||
kv_indices = wrapper._paged_kv_indices_buf
|
||||
else:
|
||||
kv_indices = torch.empty(
|
||||
@@ -1173,7 +1173,7 @@ def fast_decode_plan(
|
||||
"""
|
||||
A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for FlashInferMultiStepDraftBackend.
|
||||
Modifications:
|
||||
- Remove unnecessary device-to-device copy for the cuda graph buffers.
|
||||
- Remove unnecessary device-to-device copy for the CUDA graph buffers.
|
||||
- Remove unnecessary host-to-device copy for the metadata buffers.
|
||||
"""
|
||||
batch_size = len(last_page_len)
|
||||
|
||||
@@ -874,7 +874,7 @@ def fast_mla_decode_plan(
|
||||
) -> None:
|
||||
"""A faster version of BatchMLAPagedAttentionWrapper::plan,
|
||||
for skipping the stream synchronization in original plan function during
|
||||
cuda graph replaying.
|
||||
CUDA graph replaying.
|
||||
"""
|
||||
self._causal = causal
|
||||
self._page_size = page_size
|
||||
|
||||
@@ -92,7 +92,7 @@ class TritonAttnBackend(AttentionBackend):
|
||||
skip_prefill: bool = False,
|
||||
kv_indptr_buf: Optional[torch.Tensor] = None,
|
||||
):
|
||||
# Lazy import to avoid the initialization of cuda context
|
||||
# Lazy import to avoid the initialization of CUDA context
|
||||
from sglang.srt.layers.attention.triton_ops.decode_attention import (
|
||||
decode_attention_fwd,
|
||||
)
|
||||
|
||||
@@ -257,7 +257,7 @@ class VisionFlash3Attention(nn.Module):
|
||||
**kwargs,
|
||||
):
|
||||
if not _is_cuda:
|
||||
raise Exception("VisionFlash3Attention is only available for cuda")
|
||||
raise Exception("VisionFlash3Attention is only available for CUDA")
|
||||
super().__init__()
|
||||
|
||||
def forward(
|
||||
|
||||
Reference in New Issue
Block a user