init v0.11.0rc0
This commit is contained in:
@@ -20,6 +20,7 @@ from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.rotary_embedding import (
|
||||
DeepseekScalingRotaryEmbedding, RotaryEmbedding)
|
||||
|
||||
@@ -37,34 +38,39 @@ def _rope_forward_oot(
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
is_neox_style_override: Optional[bool] = None,
|
||||
is_neox_style: bool,
|
||||
offsets: Optional[torch.Tensor] = None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
query_shape, key_shape = query.shape, key.shape
|
||||
if self.cos_sin_cache.device != query.device:
|
||||
self.cos_sin_cache = self.cos_sin_cache.to(query.device)
|
||||
if self.cos_sin_cache.dtype != query.dtype:
|
||||
self.cos_sin_cache = self.cos_sin_cache.to(query.dtype)
|
||||
neox_style = self.is_neox_style
|
||||
if is_neox_style_override is not None:
|
||||
neox_style = is_neox_style_override
|
||||
# adopt custom kernel path for rotary_embedding
|
||||
if _custom_rotary_embedding_enabled(query, neox_style,
|
||||
if _custom_rotary_embedding_enabled(query, is_neox_style,
|
||||
self.head_size) and not is_310p():
|
||||
query, key = torch.ops._C.rotary_embedding(
|
||||
query, key = torch.ops._C_ascend.rotary_embedding(
|
||||
positions,
|
||||
query,
|
||||
key,
|
||||
self.head_size,
|
||||
self.cos_sin_cache,
|
||||
neox_style,
|
||||
is_neox_style,
|
||||
)
|
||||
return query.view(query_shape), key.view(key_shape)
|
||||
if offsets is not None:
|
||||
raise NotImplementedError(
|
||||
"Batched rotary embedding is currently not supported on NPU.")
|
||||
else:
|
||||
if self.rotary_dim < self.head_size:
|
||||
if self.cos is not None and \
|
||||
self.sin is not None:
|
||||
# If cos and sin are generated outside, use npu_apply_rotary_pos_emb to avoid redundant calculation.
|
||||
# This method requires head_size and rotary_dim equal 128 and neox_style is True
|
||||
query = query.contiguous().view(1, query.shape[0], -1,
|
||||
self.head_size)
|
||||
key = key.contiguous().view(1, key.shape[0], -1, self.head_size)
|
||||
torch_npu.npu_apply_rotary_pos_emb(query, key, self.cos, self.sin)
|
||||
elif self.rotary_dim < self.head_size:
|
||||
num_tokens = query.shape[0]
|
||||
query = query.view(num_tokens, -1, self.head_size)
|
||||
key = key.view(num_tokens, -1, self.head_size)
|
||||
@@ -80,25 +86,26 @@ def _rope_forward_oot(
|
||||
k_rot,
|
||||
self.head_size,
|
||||
self.cos_sin_cache,
|
||||
neox_style,
|
||||
is_neox_style,
|
||||
)
|
||||
q_rot = q_rot.view(num_tokens, -1, self.rotary_dim)
|
||||
k_rot = k_rot.view(num_tokens, -1, self.rotary_dim)
|
||||
q = torch.cat((q_rot, q_pass), dim=-1).reshape(query_shape)
|
||||
k = torch.cat((k_rot, k_pass), dim=-1).reshape(key_shape)
|
||||
return q, k
|
||||
# TODO: Remove the contiguous in the future.
|
||||
query = query.contiguous().view(query.shape[0], -1)
|
||||
key = key.contiguous().view(key.shape[0], -1)
|
||||
torch_npu._npu_rotary_embedding(
|
||||
positions,
|
||||
query,
|
||||
key,
|
||||
self.head_size,
|
||||
self.cos_sin_cache,
|
||||
neox_style,
|
||||
)
|
||||
return query.view(query_shape), key.view(key_shape)
|
||||
else:
|
||||
# TODO: Remove the contiguous in the future.
|
||||
query = query.contiguous().view(query.shape[0], -1)
|
||||
key = key.contiguous().view(key.shape[0], -1)
|
||||
torch_npu._npu_rotary_embedding(
|
||||
positions,
|
||||
query,
|
||||
key,
|
||||
self.head_size,
|
||||
self.cos_sin_cache,
|
||||
is_neox_style,
|
||||
)
|
||||
return query.view(query_shape), key.view(key_shape)
|
||||
|
||||
|
||||
class AscendRotaryEmbedding(RotaryEmbedding):
|
||||
@@ -112,6 +119,8 @@ class AscendRotaryEmbedding(RotaryEmbedding):
|
||||
is_neox_style: bool,
|
||||
dtype: torch.dtype,
|
||||
) -> None:
|
||||
self.cos = None
|
||||
self.sin = None
|
||||
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
|
||||
is_neox_style, dtype)
|
||||
|
||||
@@ -123,14 +132,25 @@ class AscendRotaryEmbedding(RotaryEmbedding):
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
is_neox_style_override: Optional[bool] = None,
|
||||
):
|
||||
return _rope_forward_oot(
|
||||
self,
|
||||
positions,
|
||||
query,
|
||||
key,
|
||||
offsets,
|
||||
is_neox_style_override,
|
||||
)
|
||||
is_neox_style = self.is_neox_style
|
||||
if is_neox_style_override is not None:
|
||||
is_neox_style = is_neox_style_override
|
||||
forward_context = get_forward_context()
|
||||
is_first_layer = forward_context.is_first_layer
|
||||
# Generate cos and sin outside layers to avoid repeated calculation.
|
||||
if is_neox_style and self.head_size == 128 and self.cos_sin_cache.shape[
|
||||
-1] == 128:
|
||||
if is_first_layer:
|
||||
cos_sin = self.cos_sin_cache.index_select(0, positions)
|
||||
last_dim = cos_sin.size()[-1]
|
||||
cos, sin = cos_sin.reshape(-1, 2, last_dim // 2).repeat(
|
||||
1, 1, 2).chunk(2, dim=-2)
|
||||
# BSNH
|
||||
self.cos = cos.view(1, -1, 1, last_dim).contiguous()
|
||||
self.sin = sin.view(1, -1, 1, last_dim).contiguous()
|
||||
forward_context.is_first_layer = False
|
||||
return _rope_forward_oot(self, positions, query, key, is_neox_style,
|
||||
offsets)
|
||||
|
||||
|
||||
class AscendDeepseekScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding):
|
||||
@@ -168,8 +188,10 @@ class AscendDeepseekScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding):
|
||||
super(DeepseekScalingRotaryEmbedding,
|
||||
self).__init__(head_size, rotary_dim, max_position_embeddings,
|
||||
base, is_neox_style, dtype)
|
||||
self.max_seq_len = max_position_embeddings
|
||||
self._set_cos_sin_cache(seq_len=max_position_embeddings,
|
||||
|
||||
# NOTE: For ascend friendly computing, reorder sin and cos cache
|
||||
self.max_seq_len = math.ceil(max_position_embeddings * scaling_factor)
|
||||
self._set_cos_sin_cache(self.max_seq_len,
|
||||
device=NPUPlatform.device_type,
|
||||
dtype=dtype)
|
||||
|
||||
@@ -275,8 +297,7 @@ class AscendDeepseekScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding):
|
||||
|
||||
return q_embed, k_embed
|
||||
|
||||
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
||||
self.max_seq_len_cached = seq_len
|
||||
def _set_cos_sin_cache(self, max_seq_len, device, dtype):
|
||||
dim = self.rotary_dim
|
||||
|
||||
freq_extra = 1.0 / (self.base**(
|
||||
@@ -297,9 +318,7 @@ class AscendDeepseekScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding):
|
||||
inv_freq_mask) + freq_extra * inv_freq_mask
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
|
||||
t = torch.arange(seq_len * self.scaling_factor,
|
||||
device=device,
|
||||
dtype=torch.float32)
|
||||
t = torch.arange(max_seq_len, device=device, dtype=torch.float32)
|
||||
|
||||
freqs = torch.outer(t, inv_freq)
|
||||
cos_cached = torch.cat([freqs, freqs], dim=-1).cos() * self.mscale
|
||||
@@ -317,16 +336,13 @@ class AscendDeepseekScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding):
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
max_seq_len: Optional[int] = None):
|
||||
if max_seq_len is not None and max_seq_len > self.max_seq_len:
|
||||
self._set_cos_sin_cache(max_seq_len, query.device, query.dtype)
|
||||
offsets: Optional[torch.Tensor] = None):
|
||||
if len(key.shape) == 2:
|
||||
key = key[:, None, :]
|
||||
# Note: we implement the non neox_style method with shuffle the last dim and neox style
|
||||
# calculation method which is also more compute friendly to the ascend machine
|
||||
# https://huggingface.co/deepseek-ai/DeepSeek-V3-0324/blob/main/modeling_deepseek.py
|
||||
neox_style = True
|
||||
is_neox_style = True
|
||||
if self.is_neox_style is False:
|
||||
b, h_q, d = query.shape
|
||||
query = query.view(b, h_q, d // 2,
|
||||
@@ -334,6 +350,6 @@ class AscendDeepseekScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding):
|
||||
b, h_k, d = key.shape
|
||||
key = key.view(b, h_k, d // 2, 2).transpose(3,
|
||||
2).reshape(b, h_k, d)
|
||||
q_pe, k_pe = _rope_forward_oot(self, positions, query, key, offsets,
|
||||
neox_style)
|
||||
q_pe, k_pe = _rope_forward_oot(self, positions, query, key,
|
||||
is_neox_style, offsets)
|
||||
return q_pe, k_pe
|
||||
|
||||
Reference in New Issue
Block a user