[Model][MiniCPM] support MiniCPM (#645)
### What this PR does / why we need it? This pr support minicpm in branch main. see https://github.com/vllm-project/vllm-ascend/pull/164 ### How was this patch tested? test locally with minicpm --------- Signed-off-by: MengqingCao <cmq0113@163.com>
This commit is contained in:
@@ -96,6 +96,8 @@
|
|||||||
# Related PR (if no, explain why): no related PR, we want add this ability into vllm
|
# Related PR (if no, explain why): no related PR, we want add this ability into vllm
|
||||||
# Future Plan:
|
# Future Plan:
|
||||||
# Remove those patch when vllm merged them
|
# Remove those patch when vllm merged them
|
||||||
|
#
|
||||||
|
#
|
||||||
# * Worker Patch:
|
# * Worker Patch:
|
||||||
# ===============
|
# ===============
|
||||||
# ** File: worker/patch_0_8_4/patch_metrics.py **
|
# ** File: worker/patch_0_8_4/patch_metrics.py **
|
||||||
@@ -125,6 +127,20 @@
|
|||||||
# Future Plan:
|
# Future Plan:
|
||||||
# Revert it when the related pr is merged in vllm.
|
# Revert it when the related pr is merged in vllm.
|
||||||
#
|
#
|
||||||
|
# ** File: worker/patch_common/patch_minicpm.py **
|
||||||
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
# 1. `vllm.model_executor.models.minicpm.MiniCPMAttention.forward`
|
||||||
|
# Why:
|
||||||
|
# The forward func of MiniCPMAttention in vllm do a datatype convert
|
||||||
|
# (original datatype --> float32) to ensure the precision on cuda.
|
||||||
|
# However float32 is not supported in cann rope op, thus we keep this patch
|
||||||
|
# How:
|
||||||
|
# Removed the dtype convert operations in forward
|
||||||
|
# Related PR (if no, explain why): 1. refused by vllm. 2. vllm doesn't support 3. prepare to submit....
|
||||||
|
# NO, only for npu due to rope op.
|
||||||
|
# Future Plan:
|
||||||
|
# Keep this patch in vllm-ascend.
|
||||||
|
#
|
||||||
# ** File: worker/patch_common/patch_multi_step_worker.py **
|
# ** File: worker/patch_common/patch_multi_step_worker.py **
|
||||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
# 1. `vllm.spec_decode.multi_step_worker.MultiStepWorker.sampler_output`
|
# 1. `vllm.spec_decode.multi_step_worker.MultiStepWorker.sampler_output`
|
||||||
@@ -156,3 +172,15 @@
|
|||||||
# - https://github.com/vllm-project/vllm-ascend/pull/395
|
# - https://github.com/vllm-project/vllm-ascend/pull/395
|
||||||
# Future Plan:
|
# Future Plan:
|
||||||
# Revert it when the related pr is merged in vllm and vllm-ascend.
|
# Revert it when the related pr is merged in vllm and vllm-ascend.
|
||||||
|
#
|
||||||
|
# ** File: worker/patch_0_8_4/patch_tritonplaceholder.py **
|
||||||
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
# 1. `triton` Module
|
||||||
|
# Why:
|
||||||
|
# Triton is not supported on npu currently, importing triton will break vllm-ascend
|
||||||
|
# How:
|
||||||
|
# ditto
|
||||||
|
# Related PR (if no, explain why): 1. refused by vllm. 2. vllm doesn't support 3. prepare to submit....
|
||||||
|
# TritonPlaceholder is only available in vllm>0.8.4
|
||||||
|
# Future Plan:
|
||||||
|
# Revert it when branch main doesn't maintain v0.8.4.
|
||||||
|
|||||||
@@ -16,3 +16,4 @@
|
|||||||
#
|
#
|
||||||
|
|
||||||
import vllm_ascend.patch.worker.patch_0_8_4.patch_metrics # noqa
|
import vllm_ascend.patch.worker.patch_0_8_4.patch_metrics # noqa
|
||||||
|
import vllm_ascend.patch.worker.patch_0_8_4.patch_tritonplaceholder # noqa
|
||||||
|
|||||||
@@ -0,0 +1,68 @@
|
|||||||
|
#
|
||||||
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||||
|
# Copyright 2023 The vLLM team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# This file is a part of the vllm-ascend project.
|
||||||
|
# Adapted from vllm/triton_utils/importing.py
|
||||||
|
#
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import types
|
||||||
|
from importlib.util import find_spec
|
||||||
|
|
||||||
|
from vllm.logger import logger
|
||||||
|
|
||||||
|
HAS_TRITON = (
|
||||||
|
find_spec("triton") is not None
|
||||||
|
or find_spec("pytorch-triton-xpu") is not None # Not compatible
|
||||||
|
)
|
||||||
|
|
||||||
|
if not HAS_TRITON:
|
||||||
|
logger.info("Triton not installed or not compatible; certain GPU-related"
|
||||||
|
" functions will not be available.")
|
||||||
|
|
||||||
|
class TritonPlaceholder(types.ModuleType):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__("triton")
|
||||||
|
self.jit = self._dummy_decorator("jit")
|
||||||
|
self.autotune = self._dummy_decorator("autotune")
|
||||||
|
self.heuristics = self._dummy_decorator("heuristics")
|
||||||
|
self.language = TritonLanguagePlaceholder()
|
||||||
|
logger.warning_once(
|
||||||
|
"Triton is not installed. Using dummy decorators. "
|
||||||
|
"Install it via `pip install triton` to enable kernel"
|
||||||
|
" compilation.")
|
||||||
|
|
||||||
|
def _dummy_decorator(self, name):
|
||||||
|
|
||||||
|
def decorator(func=None, **kwargs):
|
||||||
|
if func is None:
|
||||||
|
return lambda f: f
|
||||||
|
return func
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
class TritonLanguagePlaceholder(types.ModuleType):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__("triton.language")
|
||||||
|
self.constexpr = None
|
||||||
|
self.dtype = None
|
||||||
|
|
||||||
|
sys.modules['triton'] = TritonPlaceholder()
|
||||||
|
sys.modules['triton.language'] = TritonLanguagePlaceholder()
|
||||||
|
|
||||||
|
if 'triton' in sys.modules:
|
||||||
|
logger.info("Triton module has been replaced with a placeholder.")
|
||||||
@@ -16,5 +16,6 @@
|
|||||||
#
|
#
|
||||||
|
|
||||||
import vllm_ascend.patch.worker.patch_common.patch_metrics # noqa
|
import vllm_ascend.patch.worker.patch_common.patch_metrics # noqa
|
||||||
|
import vllm_ascend.patch.worker.patch_common.patch_minicpm # noqa
|
||||||
import vllm_ascend.patch.worker.patch_common.patch_multi_step_worker # noqa
|
import vllm_ascend.patch.worker.patch_common.patch_multi_step_worker # noqa
|
||||||
import vllm_ascend.patch.worker.patch_common.patch_spec_decode_worker # noqa
|
import vllm_ascend.patch.worker.patch_common.patch_spec_decode_worker # noqa
|
||||||
|
|||||||
36
vllm_ascend/patch/worker/patch_common/patch_minicpm.py
Normal file
36
vllm_ascend/patch/worker/patch_common/patch_minicpm.py
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
#
|
||||||
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||||
|
# This file is a part of the vllm-ascend project.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from vllm.model_executor.models.minicpm import MiniCPMAttention
|
||||||
|
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
qkv, _ = self.qkv_proj(hidden_states)
|
||||||
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
|
q, k = self.rotary_emb(positions, q, k)
|
||||||
|
attn_output = self.attn(q, k, v)
|
||||||
|
output, _ = self.o_proj(attn_output)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
# The type conversion in the forward function is deleted to support the rope operator.
|
||||||
|
MiniCPMAttention.forward = forward
|
||||||
Reference in New Issue
Block a user