[Refact.]: Refactor some leftover implementations of 300I DUO in the main branch. (#6425)
### What this PR does / why we need it?
- Replace the RoPE operator implementation.
- Refactor some leftover implementations of 300I DUO in the main branch.
### Does this PR introduce _any_ user-facing change?
NA
### How was this patch tested?
- vLLM version: v0.14.1
- vLLM main:
dc917cceb8
---------
Signed-off-by: Tflowers-0129 <2906339855@qq.com>
This commit is contained in:
@@ -18,13 +18,13 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from vllm_ascend.ops.activation import AscendSiluAndMul as _Base
|
from vllm_ascend.ops.activation import AscendSiluAndMul
|
||||||
|
|
||||||
|
|
||||||
class AscendSiluAndMul310(_Base):
|
class AscendSiluAndMul310(AscendSiluAndMul):
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
torch.ops.vllm.maybe_prefetch_mlp_down_proj(x)
|
torch.ops.vllm.maybe_prefetch_mlp_down_proj(x)
|
||||||
h = x.shape[-1] // 2
|
h = x.shape[-1] // 2
|
||||||
out = (F.silu(x[..., :h].to(torch.float32)) * x[..., h:].to(torch.float32)).to(torch.float16)
|
out = F.silu(x[..., :h]) * x[..., h:]
|
||||||
torch.ops.vllm.maybe_wait_prefetch_done(out)
|
torch.ops.vllm.maybe_wait_prefetch_done(out)
|
||||||
return out
|
return out
|
||||||
|
|||||||
@@ -11,17 +11,16 @@ class AscendRMSNorm310(AscendRMSNorm):
|
|||||||
residual: torch.Tensor | None = None,
|
residual: torch.Tensor | None = None,
|
||||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||||
if residual is not None:
|
if residual is not None:
|
||||||
orig_dtype = residual.dtype
|
|
||||||
if x is None or x.numel() == 0 or x.shape[-1] == 0:
|
if x is None or x.numel() == 0 or x.shape[-1] == 0:
|
||||||
x = residual.to(dtype=residual.dtype)
|
x = residual
|
||||||
else:
|
else:
|
||||||
x = x + residual.to(x.dtype)
|
x = x + residual
|
||||||
|
|
||||||
residual = x.to(orig_dtype)
|
residual = x
|
||||||
x, _ = torch_npu.npu_rms_norm(x, self.weight, self.variance_epsilon)
|
x, _ = torch_npu.npu_rms_norm(x, self.weight, self.variance_epsilon)
|
||||||
return x, residual
|
return x, residual
|
||||||
|
|
||||||
x, residual = torch_npu.npu_rms_norm(x, self.weight, self.variance_epsilon)
|
x, _ = torch_npu.npu_rms_norm(x, self.weight, self.variance_epsilon)
|
||||||
if self.bias is not None:
|
if self.bias is not None:
|
||||||
x.add_(self.bias)
|
x.add_(self.bias)
|
||||||
return x
|
return x
|
||||||
|
|||||||
@@ -19,10 +19,10 @@ import einops
|
|||||||
import torch
|
import torch
|
||||||
import torch_npu
|
import torch_npu
|
||||||
|
|
||||||
from vllm_ascend.ops.mm_encoder_attention import AscendMMEncoderAttention as _Base
|
from vllm_ascend.ops.mm_encoder_attention import AscendMMEncoderAttention
|
||||||
|
|
||||||
|
|
||||||
class AscendMMEncoderAttention310(_Base):
|
class AscendMMEncoderAttention310(AscendMMEncoderAttention):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
|||||||
@@ -15,9 +15,85 @@
|
|||||||
# This file is a part of the vllm-ascend project.
|
# This file is a part of the vllm-ascend project.
|
||||||
#
|
#
|
||||||
|
|
||||||
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
|
||||||
|
import torch
|
||||||
|
import torch_npu
|
||||||
|
|
||||||
|
from vllm_ascend.ops.rotary_embedding import AscendRotaryEmbedding, get_cos_and_sin_slice
|
||||||
|
|
||||||
|
|
||||||
class AscendMRotaryEmbedding310(MRotaryEmbedding):
|
def _rope_forward_oot(
|
||||||
def forward_oot(self, positions, query, key):
|
self,
|
||||||
return super().forward_oot(positions, query, key)
|
positions: torch.Tensor,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
is_neox_style: bool,
|
||||||
|
offsets: torch.Tensor | None = None,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
query_shape, key_shape = query.shape, key.shape
|
||||||
|
if self.cos_sin_cache.device != query.device:
|
||||||
|
self.cos_sin_cache = self.cos_sin_cache.to(query.device)
|
||||||
|
if self.cos_sin_cache.dtype != query.dtype:
|
||||||
|
self.cos_sin_cache = self.cos_sin_cache.to(query.dtype)
|
||||||
|
cos, sin = get_cos_and_sin_slice()
|
||||||
|
if offsets is not None:
|
||||||
|
raise NotImplementedError("Batched rotary embedding is currently not supported on NPU.")
|
||||||
|
rotary_mode = "half" if is_neox_style else "interleave"
|
||||||
|
if self.head_size == 128 and self.cos_sin_cache.shape[-1] == 128:
|
||||||
|
query = query.contiguous().view(1, query.shape[0], -1, self.head_size)
|
||||||
|
key = key.contiguous().view(1, key.shape[0], -1, self.head_size)
|
||||||
|
query, key = torch_npu.npu_apply_rotary_pos_emb(query, key, cos, sin, rotary_mode=rotary_mode)
|
||||||
|
elif self.rotary_dim < self.head_size:
|
||||||
|
num_tokens = query.shape[0]
|
||||||
|
query = query.view(num_tokens, -1, self.head_size)
|
||||||
|
key = key.view(num_tokens, -1, self.head_size)
|
||||||
|
q_rot = query[..., : self.rotary_dim]
|
||||||
|
q_pass = query[..., self.rotary_dim :]
|
||||||
|
k_rot = key[..., : self.rotary_dim]
|
||||||
|
k_pass = key[..., self.rotary_dim :]
|
||||||
|
if self.rotary_dim == 64:
|
||||||
|
q_rot = q_rot.contiguous().view(1, num_tokens, -1, self.rotary_dim)
|
||||||
|
k_rot = k_rot.contiguous().view(1, num_tokens, -1, self.rotary_dim)
|
||||||
|
q_rot, k_rot = torch_npu.npu_apply_rotary_pos_emb(q_rot, k_rot, cos, sin, rotary_mode=rotary_mode)
|
||||||
|
else:
|
||||||
|
q_rot = q_rot.contiguous().view(num_tokens, -1)
|
||||||
|
k_rot = k_rot.contiguous().view(num_tokens, -1)
|
||||||
|
torch_npu._npu_rotary_embedding(
|
||||||
|
positions,
|
||||||
|
q_rot,
|
||||||
|
k_rot,
|
||||||
|
self.rotary_dim,
|
||||||
|
self.cos_sin_cache,
|
||||||
|
is_neox_style,
|
||||||
|
)
|
||||||
|
q_rot = q_rot.view(num_tokens, -1, self.rotary_dim)
|
||||||
|
k_rot = k_rot.view(num_tokens, -1, self.rotary_dim)
|
||||||
|
query = torch.cat((q_rot, q_pass), dim=-1).reshape(query_shape)
|
||||||
|
key = torch.cat((k_rot, k_pass), dim=-1).reshape(key_shape)
|
||||||
|
else:
|
||||||
|
query = query.contiguous().view(query.shape[0], -1)
|
||||||
|
key = key.contiguous().view(key.shape[0], -1)
|
||||||
|
torch_npu._npu_rotary_embedding(
|
||||||
|
positions,
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
self.head_size,
|
||||||
|
self.cos_sin_cache,
|
||||||
|
is_neox_style,
|
||||||
|
)
|
||||||
|
return query.view(query_shape), key.view(key_shape)
|
||||||
|
|
||||||
|
|
||||||
|
class AscendRotaryEmbedding310(AscendRotaryEmbedding):
|
||||||
|
def forward_oot(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
offsets: torch.Tensor | None = None,
|
||||||
|
is_neox_style_override: bool | None = None,
|
||||||
|
):
|
||||||
|
is_neox_style = self.is_neox_style
|
||||||
|
if is_neox_style_override is not None:
|
||||||
|
is_neox_style = is_neox_style_override
|
||||||
|
return _rope_forward_oot(self, positions, query, key, is_neox_style, offsets)
|
||||||
|
|||||||
@@ -26,7 +26,8 @@ class NPUWorker310(NPUWorker):
|
|||||||
def init_device(self):
|
def init_device(self):
|
||||||
self.device = self._init_device()
|
self.device = self._init_device()
|
||||||
|
|
||||||
torch_npu.npu.set_compile_mode(jit_compile=False)
|
# TODO: There is accuracy issue when jit_compile is disabled currently.
|
||||||
|
torch_npu.npu.set_compile_mode(jit_compile=True)
|
||||||
|
|
||||||
init_workspace_manager(self.device, num_ubatches=1)
|
init_workspace_manager(self.device, num_ubatches=1)
|
||||||
|
|
||||||
|
|||||||
@@ -628,13 +628,13 @@ def register_ascend_customop(vllm_config: VllmConfig | None = None):
|
|||||||
from vllm_ascend._310p.ops.activation import AscendSiluAndMul310
|
from vllm_ascend._310p.ops.activation import AscendSiluAndMul310
|
||||||
from vllm_ascend._310p.ops.layernorm import AscendGemmaRMSNorm310, AscendRMSNorm310
|
from vllm_ascend._310p.ops.layernorm import AscendGemmaRMSNorm310, AscendRMSNorm310
|
||||||
from vllm_ascend._310p.ops.mm_encoder_attention import AscendMMEncoderAttention310
|
from vllm_ascend._310p.ops.mm_encoder_attention import AscendMMEncoderAttention310
|
||||||
from vllm_ascend._310p.ops.rotary_embedding import AscendMRotaryEmbedding310
|
from vllm_ascend._310p.ops.rotary_embedding import AscendRotaryEmbedding310
|
||||||
|
|
||||||
REGISTERED_ASCEND_OPS.update(
|
REGISTERED_ASCEND_OPS.update(
|
||||||
{
|
{
|
||||||
"SiluAndMul": AscendSiluAndMul310,
|
"SiluAndMul": AscendSiluAndMul310,
|
||||||
"MMEncoderAttention": AscendMMEncoderAttention310,
|
"MMEncoderAttention": AscendMMEncoderAttention310,
|
||||||
"MRotaryEmbedding": AscendMRotaryEmbedding310,
|
"RotaryEmbedding": AscendRotaryEmbedding310,
|
||||||
"RMSNorm": AscendRMSNorm310,
|
"RMSNorm": AscendRMSNorm310,
|
||||||
"GemmaRMSNorm": AscendGemmaRMSNorm310,
|
"GemmaRMSNorm": AscendGemmaRMSNorm310,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -149,9 +149,6 @@ AttnMetadataDict: TypeAlias = dict[str, AttentionMetadata]
|
|||||||
# list when ubatching is enabled
|
# list when ubatching is enabled
|
||||||
PerLayerAttnMetadata: TypeAlias = list[AttnMetadataDict] | AttnMetadataDict
|
PerLayerAttnMetadata: TypeAlias = list[AttnMetadataDict] | AttnMetadataDict
|
||||||
|
|
||||||
if get_ascend_device_type() == AscendDeviceType._310P:
|
|
||||||
torch_npu.npu.set_compile_mode(jit_compile=False)
|
|
||||||
|
|
||||||
|
|
||||||
SEQ_LEN_WITH_MAX_PA_WORKSPACE = 6144
|
SEQ_LEN_WITH_MAX_PA_WORKSPACE = 6144
|
||||||
|
|
||||||
@@ -2527,9 +2524,7 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
]
|
]
|
||||||
k_cache = raw_k_tensor.view(dtype).view(k_shape)
|
k_cache = raw_k_tensor.view(dtype).view(k_shape)
|
||||||
v_cache = raw_v_tensor.view(dtype).view(v_shape)
|
v_cache = raw_v_tensor.view(dtype).view(v_shape)
|
||||||
if get_ascend_device_type() == AscendDeviceType._310P:
|
|
||||||
k_cache = maybe_trans_nz(k_cache)
|
|
||||||
v_cache = maybe_trans_nz(v_cache)
|
|
||||||
if self.use_sparse and raw_dsa_k_tensor is not None:
|
if self.use_sparse and raw_dsa_k_tensor is not None:
|
||||||
dsa_k_cache_shape = (num_blocks, kv_cache_spec.block_size, 1, 128)
|
dsa_k_cache_shape = (num_blocks, kv_cache_spec.block_size, 1, 128)
|
||||||
dsa_k_cache_size = (num_blocks) * kv_cache_spec.block_size * 128 * dtype.itemsize
|
dsa_k_cache_size = (num_blocks) * kv_cache_spec.block_size * 128 * dtype.itemsize
|
||||||
|
|||||||
Reference in New Issue
Block a user