[310P]: refactoring for 310p kvcache and some ops class (#6117)
### What this PR does / why we need it?
* Refactor the LayerNorm and activation operator classes to decouple the
310P device implementation from the main branch.
* Refactor `mm_encoder_attention` on 310P to use the
`torch_npu._npu_flash_attention_unpad` operator.
* Refactor the QKV inputs in the prefill stage of `attention_v1` on 310P
so they are no longer padded to 16× alignment.
* Refactor `model_runner` on 310P to align the KV-cache initialization
logic with the mainline implementation.
### Does this PR introduce _any_ user-facing change?
NO
### How was this patch tested?
use the e2e tests.
- vLLM version: v0.13.0
- vLLM main:
d68209402d
---------
Signed-off-by: Tflowers-0129 <2906339855@qq.com>
This commit is contained in:
44
vllm_ascend/_310p/ops/layernorm.py
Normal file
44
vllm_ascend/_310p/ops/layernorm.py
Normal file
@@ -0,0 +1,44 @@
|
||||
import torch
|
||||
import torch_npu
|
||||
|
||||
from vllm_ascend.ops.layernorm import AscendGemmaRMSNorm, AscendRMSNorm
|
||||
|
||||
|
||||
class AscendRMSNorm310(AscendRMSNorm):
|
||||
def forward_oot(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
if residual is not None:
|
||||
orig_dtype = residual.dtype
|
||||
if x is None or x.numel() == 0 or x.shape[-1] == 0:
|
||||
x = residual.to(dtype=residual.dtype)
|
||||
else:
|
||||
x = x + residual.to(x.dtype)
|
||||
|
||||
residual = x.to(orig_dtype)
|
||||
x, _ = torch_npu.npu_rms_norm(x, self.weight, self.variance_epsilon)
|
||||
return x, residual
|
||||
|
||||
x, residual = torch_npu.npu_rms_norm(x, self.weight, self.variance_epsilon)
|
||||
if self.bias is not None:
|
||||
x.add_(self.bias)
|
||||
return x
|
||||
|
||||
|
||||
class AscendGemmaRMSNorm310(AscendGemmaRMSNorm):
|
||||
def forward_oot(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
if residual is not None:
|
||||
orig_dtype = residual.dtype
|
||||
x = x + residual.to(x.dtype)
|
||||
residual = x.to(orig_dtype)
|
||||
x, _ = torch_npu.npu_rms_norm(x, 1.0 + self.weight, self.variance_epsilon)
|
||||
return x, residual
|
||||
|
||||
x, _ = torch_npu.npu_rms_norm(x, 1.0 + self.weight, self.variance_epsilon)
|
||||
return x
|
||||
@@ -17,11 +17,8 @@
|
||||
|
||||
import einops
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch_npu
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.ops.mm_encoder_attention import MAX_PAD_SIZE, MIN_PAD_SIZE
|
||||
from vllm_ascend.ops.mm_encoder_attention import AscendMMEncoderAttention as _Base
|
||||
|
||||
|
||||
@@ -43,23 +40,6 @@ class AscendMMEncoderAttention310(_Base):
|
||||
|
||||
q, k, v = self.reshape_qkv_to_3d(query, key, value, bsz, q_len, kv_len)
|
||||
|
||||
enable_pad = envs_ascend.USE_OPTIMIZED_MODEL and self.head_size > MIN_PAD_SIZE and self.head_size < MAX_PAD_SIZE
|
||||
|
||||
origin_shape = q.shape[-1]
|
||||
if enable_pad:
|
||||
pad_len = MAX_PAD_SIZE - origin_shape
|
||||
q = F.pad(q, (0, pad_len), mode="constant", value=0)
|
||||
k = F.pad(k, (0, pad_len), mode="constant", value=0)
|
||||
v = F.pad(v, (0, pad_len), mode="constant", value=0)
|
||||
|
||||
origin_dim = origin_shape
|
||||
cur_dim = q.shape[-1]
|
||||
pad16 = (16 - cur_dim % 16) % 16
|
||||
if pad16:
|
||||
q = F.pad(q, (0, pad16), mode="constant", value=0)
|
||||
k = F.pad(k, (0, pad16), mode="constant", value=0)
|
||||
v = F.pad(v, (0, pad16), mode="constant", value=0)
|
||||
|
||||
if cu_seqlens is None:
|
||||
cu_seqlens = torch.arange(
|
||||
0,
|
||||
@@ -69,36 +49,19 @@ class AscendMMEncoderAttention310(_Base):
|
||||
device=query.device,
|
||||
)
|
||||
|
||||
total_q_tokens = bsz * q_len
|
||||
context_flat = q.new_empty((total_q_tokens, self.num_heads, q.shape[-1]))
|
||||
seq_len = torch.diff(cu_seqlens).to("cpu", dtype=torch.int32)
|
||||
|
||||
st = 0
|
||||
seg_lens = torch.diff(cu_seqlens).to("cpu", dtype=torch.int64).tolist()
|
||||
for seg_len in seg_lens:
|
||||
seg_len = int(seg_len)
|
||||
ed = st + seg_len
|
||||
context_layer = torch.empty_like(q)
|
||||
torch_npu._npu_flash_attention_unpad(
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
seq_len=seq_len,
|
||||
scale_value=self.head_size**-0.5,
|
||||
num_heads=self.num_heads,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
out=context_layer,
|
||||
)
|
||||
|
||||
q_i = q[st:ed].unsqueeze(0) # [1, S, H, D]
|
||||
k_i = k[st:ed].unsqueeze(0)
|
||||
v_i = v[st:ed].unsqueeze(0)
|
||||
|
||||
qs = int(q_i.shape[1])
|
||||
kvs = int(k_i.shape[1])
|
||||
|
||||
out_i = torch_npu.npu_prompt_flash_attention(
|
||||
q_i,
|
||||
k_i,
|
||||
v_i,
|
||||
input_layout="BSND",
|
||||
num_heads=self.num_heads,
|
||||
num_key_value_heads=self.num_kv_heads,
|
||||
scale_value=self.head_size**-0.5,
|
||||
pre_tokens=qs,
|
||||
next_tokens=kvs,
|
||||
)
|
||||
context_flat[st:ed] = out_i[0]
|
||||
st = ed
|
||||
|
||||
context_flat = context_flat[..., :origin_dim]
|
||||
context_layer = einops.rearrange(context_flat, "(b s) h d -> b s h d", b=bsz).contiguous()
|
||||
context_layer = einops.rearrange(context_layer, "(b s) h d -> b s h d", b=bsz).contiguous()
|
||||
return context_layer
|
||||
|
||||
Reference in New Issue
Block a user