feat: support sm75 with FlashInfer v0.1.6 (#1233)

This commit is contained in:
Yineng Zhang
2024-08-28 18:39:12 +10:00
committed by GitHub
parent 6cc38b2bf3
commit 198974cd1a
5 changed files with 4 additions and 12 deletions

View File

@@ -135,7 +135,7 @@ sky status --endpoint 30000 sglang
### Common Notes
- [FlashInfer](https://github.com/flashinfer-ai/flashinfer) is currently one of the dependencies that must be installed for SGLang. If you are using NVIDIA GPU devices below sm80, such as T4, you can't use SGLang for the time being. We expect to resolve this issue soon, so please stay tuned. If you encounter any FlashInfer-related issues on sm80+ devices (e.g., A100, L40S, H100), consider using Triton's kernel by `--disable-flashinfer --disable-flashinfer-sampling` and raise a issue.
- [FlashInfer](https://github.com/flashinfer-ai/flashinfer) is currently one of the dependencies that must be installed for SGLang. It only supports sm75 and above. If you encounter any FlashInfer-related issues on sm75+ devices (e.g., T4, A10, A100, L4, L40S, H100), consider using Triton's kernel by `--disable-flashinfer --disable-flashinfer-sampling` and raise an issue.
- If you only need to use the OpenAI backend, you can avoid installing other dependencies by using `pip install "sglang[openai]"`.
## Backend: SGLang Runtime (SRT)

View File

@@ -30,18 +30,11 @@ from vllm.model_executor.utils import set_weight_attrs
class SiluAndMul(CustomOp):
def __init__(self, **kwargs):
super().__init__()
self.is_lower_sm80 = torch.cuda.get_device_capability()[0] < 8
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
return F.silu(x[..., :d]) * x[..., d:]
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
if self.is_lower_sm80:
return self.forward_native(x)
d = x.shape[-1] // 2
output_shape = x.shape[:-1] + (d,)
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)

View File

@@ -32,15 +32,12 @@ class RMSNorm(CustomOp):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
self.is_lower_sm80 = torch.cuda.get_device_capability()[0] < 8
def forward_cuda(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if self.is_lower_sm80:
return self.forward_native(x, residual)
if residual is not None:
fused_add_rmsnorm(x, residual, self.weight.data, self.variance_epsilon)

View File

@@ -161,6 +161,8 @@ class ModelRunner:
"Compute capability below sm80. Use float16 due to lack of bfloat16 support."
)
self.server_args.dtype = "float16"
if torch.cuda.get_device_capability()[1] < 5:
raise RuntimeError("SGLang only supports sm75 and above.")
monkey_patch_vllm_dummy_weight_loader()
self.device_config = DeviceConfig()

View File

@@ -421,7 +421,7 @@ def _set_envs_and_config(server_args: ServerArgs):
if not server_args.disable_flashinfer:
assert_pkg_version(
"flashinfer",
"0.1.5",
"0.1.6",
"Please uninstall the old version and "
"reinstall the latest version by following the instructions "
"at https://docs.flashinfer.ai/installation.html.",