update sgl_moe_align_block_size usage (#2617)
This commit is contained in:
@@ -23,7 +23,7 @@ runtime_common = ["aiohttp", "decord", "fastapi",
|
||||
"psutil", "pydantic", "python-multipart",
|
||||
"pyzmq>=25.1.2", "torchao>=0.7.0", "uvicorn", "uvloop",
|
||||
"xgrammar>=0.1.6"]
|
||||
srt = ["sglang[runtime_common]", "torch", "vllm>=0.6.3.post1,<=0.6.4.post1", "cuda-python", "flashinfer==0.1.6", "sgl-kernel>=0.0.2.post9"]
|
||||
srt = ["sglang[runtime_common]", "torch", "vllm>=0.6.3.post1,<=0.6.4.post1", "cuda-python", "flashinfer==0.1.6", "sgl-kernel>=0.0.2.post10"]
|
||||
|
||||
# HIP (Heterogeneous-computing Interface for Portability) for AMD
|
||||
# => base docker rocm/vllm-dev:20241022, not from public vllm whl
|
||||
|
||||
@@ -272,8 +272,14 @@ def moe_align_block_size(
|
||||
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
|
||||
)
|
||||
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
|
||||
# FIXME(zhyncs)
|
||||
if not_hip and num_experts >= 256:
|
||||
if not_hip and num_experts >= 224:
|
||||
token_cnts_buffer = torch.empty(
|
||||
(num_experts + 1) * num_experts, dtype=torch.int32, device=topk_ids.device
|
||||
)
|
||||
cumsum_buffer = torch.empty(
|
||||
num_experts + 1, dtype=torch.int32, device=topk_ids.device
|
||||
)
|
||||
|
||||
sgl_moe_align_block_size(
|
||||
topk_ids,
|
||||
num_experts,
|
||||
@@ -281,6 +287,8 @@ def moe_align_block_size(
|
||||
sorted_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_pad,
|
||||
token_cnts_buffer,
|
||||
cumsum_buffer,
|
||||
)
|
||||
else:
|
||||
ops.moe_align_block_size(
|
||||
|
||||
@@ -95,12 +95,6 @@ class ModelRunner:
|
||||
):
|
||||
logger.info("MLA optimization is turned on. Use triton backend.")
|
||||
self.server_args.attention_backend = "triton"
|
||||
# FIXME(HandH1998)
|
||||
if (
|
||||
"DeepseekV3ForCausalLM" in self.model_config.hf_config.architectures
|
||||
and not self.server_args.disable_cuda_graph
|
||||
):
|
||||
self.server_args.disable_cuda_graph = True
|
||||
|
||||
if self.server_args.enable_double_sparsity:
|
||||
logger.info(
|
||||
|
||||
Reference in New Issue
Block a user