[Main2Main] Upgrade vLLM to 0226 (#6813)
### What this PR does / why we need it?
Breaking:
1. https://github.com/vllm-project/vllm/pull/33452
2. https://github.com/vllm-project/vllm/pull/33451
3. https://github.com/vllm-project/vllm/pull/32567
4. https://github.com/vllm-project/vllm/pull/32344
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.15.0
- vLLM main:
83b47f67b1
---------
Signed-off-by: MrZ20 <2609716663@qq.com>
Signed-off-by: gcanlin <canlinguosdu@gmail.com>
Co-authored-by: MrZ20 <2609716663@qq.com>
This commit is contained in:
@@ -25,18 +25,13 @@ from torch import nn
|
||||
from vllm.config import CacheConfig, get_current_vllm_config
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.model_executor.layers.attention import MLAAttention
|
||||
from vllm.model_executor.layers.mla import MLAModules, MultiHeadLatentAttentionWrapper
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
from vllm.v1.attention.backend import AttentionMetadata # type: ignore
|
||||
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.utils import vllm_version_is
|
||||
|
||||
if vllm_version_is("v0.15.0"):
|
||||
from vllm.attention.layer import MLAAttention # type: ignore
|
||||
else:
|
||||
from vllm.model_executor.layers.attention import MLAAttention
|
||||
|
||||
|
||||
class IndexerWrapper(nn.Module):
|
||||
@@ -126,17 +121,16 @@ class AscendMultiHeadLatentAttention(MultiHeadLatentAttentionWrapper):
|
||||
o_proj=mla_modules.o_proj,
|
||||
)
|
||||
|
||||
if not vllm_version_is("v0.15.0"):
|
||||
original_process_weights = self.mla_attn.process_weights_after_loading
|
||||
original_process_weights = self.mla_attn.process_weights_after_loading
|
||||
|
||||
def wrapped_process_weights(act_dtype: torch.dtype):
|
||||
from vllm_ascend.attention.sfa_v1 import AscendSFAImpl
|
||||
def wrapped_process_weights(act_dtype: torch.dtype):
|
||||
from vllm_ascend.attention.sfa_v1 import AscendSFAImpl
|
||||
|
||||
if not isinstance(self.mla_attn.impl, AscendSFAImpl):
|
||||
original_process_weights(act_dtype)
|
||||
self.mla_attn.impl.process_weights_after_loading(act_dtype)
|
||||
if not isinstance(self.mla_attn.impl, AscendSFAImpl):
|
||||
original_process_weights(act_dtype)
|
||||
self.mla_attn.impl.process_weights_after_loading(act_dtype)
|
||||
|
||||
self.mla_attn.process_weights_after_loading = wrapped_process_weights
|
||||
self.mla_attn.process_weights_after_loading = wrapped_process_weights
|
||||
|
||||
compilation_config = get_current_vllm_config().compilation_config
|
||||
if prefix in compilation_config.static_forward_context:
|
||||
|
||||
Reference in New Issue
Block a user