upgrade main to 0212 (#6712)
### What this PR does / why we need it?
Fixes `transformers_utils/processors/__init__` import error, due to
https://github.com/vllm-project/vllm/pull/33247
Fixes Fused MoE break introduced by `MoERunner abstraction,` due to
https://github.com/vllm-project/vllm/pull/32344
> delete AscendMoERunnere when
https://github.com/vllm-project/vllm/pull/35178 is merged
Fixes `Make Qwen3VL compatible with Transformers v5`, due to
https://github.com/vllm-project/vllm/pull/34262
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.15.0
- vLLM main:
9562912cea
---------
Signed-off-by: wxsIcey <1790571317@qq.com>
This commit is contained in:
@@ -28,6 +28,13 @@ from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
|
||||
from vllm.model_executor.layers.fused_moe.layer import FusedMoE, UnquantizedFusedMoEMethod, get_compressed_expert_map
|
||||
from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE
|
||||
|
||||
from vllm_ascend.utils import vllm_version_is
|
||||
|
||||
if not vllm_version_is("0.15.0"):
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe_method_base import FusedMoEMethodBase # type: ignore
|
||||
from vllm.model_executor.layers.fused_moe.router.fused_moe_router import FusedMoERouter # type: ignore
|
||||
from vllm.model_executor.layers.fused_moe.runner.default_moe_runner import DefaultMoERunner # type: ignore
|
||||
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.ascend_forward_context import MoECommType
|
||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||
@@ -154,6 +161,77 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
if not vllm_version_is("0.15.0"):
|
||||
# Please remove this inheritance after extending vllm, todo(wxs)
|
||||
class AscendMoERunner(DefaultMoERunner):
|
||||
"""
|
||||
Default implementation of the MoE runner for executing Mixture of Experts layers.
|
||||
|
||||
This class provides a comprehensive implementation for running MoE computations
|
||||
with support for:
|
||||
- Expert routing and token dispatching
|
||||
- Shared experts computation with optional parallel execution using CUDA streams
|
||||
- Data parallel (DP) chunking for large batch processing
|
||||
- Tensor model parallel and expert parallel operations
|
||||
- Various quantization methods and custom operators
|
||||
- Both monolithic and decomposed expert execution paths
|
||||
|
||||
The runner handles the complete MoE forward pass including routing tokens to
|
||||
experts, executing expert computations, and combining results. It supports
|
||||
advanced features like overlapped execution of shared experts and optimized
|
||||
kernels for different parallel execution modes.
|
||||
|
||||
Eventually, this class will be split up and specialized for different
|
||||
configurations, e.g. the presence or absence of shared experts, a gate, etc.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
moe_config: FusedMoEConfig,
|
||||
router: FusedMoERouter,
|
||||
routed_input_transform: torch.nn.Module | None,
|
||||
gate: torch.nn.Module | None,
|
||||
shared_experts: torch.nn.Module | None,
|
||||
quant_method: FusedMoEMethodBase,
|
||||
reduce_results: bool,
|
||||
enable_dbo: bool,
|
||||
):
|
||||
super().__init__(
|
||||
layer,
|
||||
moe_config,
|
||||
router,
|
||||
routed_input_transform,
|
||||
gate,
|
||||
shared_experts,
|
||||
quant_method,
|
||||
reduce_results,
|
||||
enable_dbo,
|
||||
)
|
||||
if self.shared_experts is None:
|
||||
self.moe_forward = torch.ops.vllm.moe_forward
|
||||
else:
|
||||
self.moe_forward = torch.ops.vllm.moe_forward_shared
|
||||
|
||||
def forward_impl(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
shared_input: torch.Tensor | None,
|
||||
):
|
||||
"""
|
||||
Override the default forward_impl to use Ascend-specific implementation.
|
||||
This delegates to the layer's forward_impl method which contains the
|
||||
Ascend-specific MoE computation logic.
|
||||
"""
|
||||
result = layer.forward_impl(hidden_states, router_logits)
|
||||
# If the layer has shared experts, forward_impl returns a tuple (shared_out, routed_out)
|
||||
# Otherwise, it returns just routed_out
|
||||
# The torch op expects the same return type based on whether it's moe_forward or moe_forward_shared
|
||||
return result
|
||||
|
||||
|
||||
class AscendFusedMoE(FusedMoE):
|
||||
moe_counter = -1
|
||||
gate_stream: torch.npu.Stream | None = None
|
||||
@@ -237,6 +315,26 @@ class AscendFusedMoE(FusedMoE):
|
||||
|
||||
setup_moe_comm_method(self.moe_config)
|
||||
self.quant_type = self._get_quant_type()
|
||||
if not vllm_version_is("0.15.0"):
|
||||
self.runner = self._init_runner()
|
||||
|
||||
if not vllm_version_is("0.15.0"):
|
||||
|
||||
def _init_runner(self):
|
||||
# Storing the runner in the FusedMoE is an intermediate state, eventually
|
||||
# the runner will own the FusedMoE layer and provide the execution interface
|
||||
# for MoE ops.
|
||||
return AscendMoERunner(
|
||||
layer=self,
|
||||
moe_config=self.moe_config,
|
||||
router=self.router,
|
||||
routed_input_transform=self._routed_input_transform,
|
||||
gate=self.gate,
|
||||
shared_experts=self.shared_experts,
|
||||
quant_method=self.quant_method,
|
||||
reduce_results=self.reduce_results,
|
||||
enable_dbo=self.vllm_config.parallel_config.enable_dbo,
|
||||
)
|
||||
|
||||
def _get_quant_type(self) -> QuantType:
|
||||
quant_type = QuantType.NONE
|
||||
@@ -266,6 +364,19 @@ class AscendFusedMoE(FusedMoE):
|
||||
"""
|
||||
return torch.ops.vllm.maybe_all_reduce_tensor_model_parallel(final_hidden_states)
|
||||
|
||||
if not vllm_version_is("0.15.0"):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
self.ensure_moe_quant_config_init()
|
||||
return self.runner.forward(
|
||||
hidden_states,
|
||||
router_logits,
|
||||
)
|
||||
|
||||
def forward_impl( # type: ignore[override]
|
||||
self, hidden_states: torch.Tensor, router_logits: torch.Tensor, return_with_event: bool = False
|
||||
) -> torch.Tensor | FusedMoEResult:
|
||||
@@ -414,6 +525,10 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
|
||||
logger.info_once("Sequence parallelism is enabled, shared experts are replicated for best performance.")
|
||||
|
||||
self._gate = gate
|
||||
if not vllm_version_is("0.15.0"):
|
||||
# Recreate the runner with the correct shared_experts parameter
|
||||
# The parent class created the runner before self._shared_experts was set
|
||||
self.runner = self._init_runner()
|
||||
|
||||
if self.multistream_overlap_shared_expert:
|
||||
# Wrap the quant_method's process_weights_after_loading to validate that
|
||||
|
||||
Reference in New Issue
Block a user