diff --git a/python/pyproject.toml b/python/pyproject.toml index 44be1c4fc..423c43e4d 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -23,7 +23,7 @@ runtime_common = ["aiohttp", "decord", "fastapi", "psutil", "pydantic", "python-multipart", "pyzmq>=25.1.2", "torchao>=0.7.0", "uvicorn", "uvloop", "xgrammar>=0.1.6"] -srt = ["sglang[runtime_common]", "torch", "vllm>=0.6.3.post1,<=0.6.4.post1", "cuda-python", "flashinfer==0.1.6", "sgl-kernel>=0.0.2.post9"] +srt = ["sglang[runtime_common]", "torch", "vllm>=0.6.3.post1,<=0.6.4.post1", "cuda-python", "flashinfer==0.1.6", "sgl-kernel>=0.0.2.post10"] # HIP (Heterogeneous-computing Interface for Portability) for AMD # => base docker rocm/vllm-dev:20241022, not from public vllm whl diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py index aa649254d..ec2e06942 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py @@ -272,8 +272,14 @@ def moe_align_block_size( (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device ) num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device) - # FIXME(zhyncs) - if not_hip and num_experts >= 256: + if not_hip and num_experts >= 224: + token_cnts_buffer = torch.empty( + (num_experts + 1) * num_experts, dtype=torch.int32, device=topk_ids.device + ) + cumsum_buffer = torch.empty( + num_experts + 1, dtype=torch.int32, device=topk_ids.device + ) + sgl_moe_align_block_size( topk_ids, num_experts, @@ -281,6 +287,8 @@ def moe_align_block_size( sorted_ids, expert_ids, num_tokens_post_pad, + token_cnts_buffer, + cumsum_buffer, ) else: ops.moe_align_block_size( diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index f98cc14fb..2612f8840 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -95,12 +95,6 @@ class ModelRunner: ): logger.info("MLA optimization is turned on. Use triton backend.") self.server_args.attention_backend = "triton" - # FIXME(HandH1998) - if ( - "DeepseekV3ForCausalLM" in self.model_config.hf_config.architectures - and not self.server_args.disable_cuda_graph - ): - self.server_args.disable_cuda_graph = True if self.server_args.enable_double_sparsity: logger.info(