[feat][torchair] support super kernel feat for quantized dsr1 (#3485)
### What this PR does / why we need it? Port #1916 and #2157 to master branch to fuse operators in deepseek moe layers, which can reduce scheduling overhead on devices. Note that this feature is valid only when `tp_size = 1` and `multistream_overlap_shared_expert` is enabled with torchair graph mode. ### Does this PR introduce _any_ user-facing change? Users can enable this feature with `--additional-config '{"torchair_graph_config":{"enabled":true, "enable_super_kernel":true}, "multistream_overlap_shared_expert":true}'`. ### How was this patch tested? E2E deepseek serving with 2P1D disaggregated prefill scenarios. - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: linfeng-yuan <1102311262@qq.com>
This commit is contained in:
@@ -37,7 +37,8 @@ class AscendConfig:
|
||||
|
||||
torchair_graph_config = additional_config.get("torchair_graph_config",
|
||||
{})
|
||||
self.torchair_graph_config = TorchairGraphConfig(torchair_graph_config)
|
||||
self.torchair_graph_config = TorchairGraphConfig(
|
||||
torchair_graph_config, vllm_config, additional_config)
|
||||
|
||||
ascend_scheduler_config = additional_config.get(
|
||||
"ascend_scheduler_config", {})
|
||||
@@ -133,7 +134,7 @@ class TorchairGraphConfig:
|
||||
Configuration Object for torchair_graph_config from additional_config
|
||||
"""
|
||||
|
||||
def __init__(self, torchair_graph_config):
|
||||
def __init__(self, torchair_graph_config, vllm_config, additional_config):
|
||||
self.enabled = torchair_graph_config.get("enabled", False)
|
||||
self.mode = torchair_graph_config.get("mode", '')
|
||||
self.use_cached_graph = torchair_graph_config.get(
|
||||
@@ -151,6 +152,8 @@ class TorchairGraphConfig:
|
||||
self.enable_frozen_parameter = torchair_graph_config.get(
|
||||
"enable_frozen_parameter", True)
|
||||
self.enable_kv_nz = torchair_graph_config.get("enable_kv_nz", False)
|
||||
self.enable_super_kernel = torchair_graph_config.get(
|
||||
"enable_super_kernel", False)
|
||||
|
||||
if not isinstance(self.graph_batch_sizes, list):
|
||||
raise TypeError("graph_batch_sizes must be list[int]")
|
||||
@@ -186,6 +189,20 @@ class TorchairGraphConfig:
|
||||
raise RuntimeError(
|
||||
"enable_kv_nz is valid only when Torchair graph mode is enabled"
|
||||
)
|
||||
if self.enable_super_kernel:
|
||||
raise RuntimeError(
|
||||
"enable_super_kernel is valid only when Torchair graph mode is enabled"
|
||||
)
|
||||
if self.enable_super_kernel:
|
||||
if vllm_config.parallel_config.tensor_parallel_size != 1:
|
||||
raise RuntimeError(
|
||||
"enable_super_kernel is valid only when tensor_parallel_size is 1"
|
||||
)
|
||||
if not additional_config.get("multistream_overlap_shared_expert",
|
||||
False):
|
||||
raise RuntimeError(
|
||||
"enable_super_kernel is valid only when multistream_overlap_shared_expert is enabled"
|
||||
)
|
||||
if self.use_cached_kv_cache_bytes and not self.use_cached_graph:
|
||||
raise RuntimeError(
|
||||
"use_cached_kv_cache_bytes is valid only when Torchair graph mode and use_cached_graph are enabled"
|
||||
|
||||
Reference in New Issue
Block a user