[main] flashcomm_v1 optim in Qwen Dense Models (#2802)
### What this PR does / why we need it?
Flashcomm_v1 optim in Qwen Dense Models.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
CI passed with new added/existing test.
- vLLM version: v0.10.1.1
- vLLM main:
5e537f45b4
Co-authored-by: 1024daniel <xxltju324@gmail.com>
This commit is contained in:
@@ -44,6 +44,13 @@ class AddRMSNormW8A8Quant(RMSNorm):
|
||||
import torch_npu
|
||||
|
||||
if residual is not None:
|
||||
# FIXME(rjg-lyh): This is a hacky way to chunk residuals when the flashcomm_v1 feature
|
||||
# is enabled, without interfering with the normal operation of components like torchair.
|
||||
# The final solution should be to move this check into the operator and support
|
||||
# integration with torchair.
|
||||
if x.size(0) != residual.size(0):
|
||||
residual = torch.ops.vllm.maybe_chunk_residual(x, residual)
|
||||
assert x.size(0) == residual.size(0)
|
||||
x, _, residual = torch_npu.npu_add_rms_norm_quant(
|
||||
x,
|
||||
residual,
|
||||
@@ -69,6 +76,13 @@ class AscendRMSNorm(RMSNorm):
|
||||
|
||||
from vllm_ascend.utils import is_310p
|
||||
if residual is not None:
|
||||
# FIXME(rjg-lyh): This is a hacky way to chunk residuals when the flashcomm_v1 feature
|
||||
# is enabled, without interfering with the normal operation of components like torchair.
|
||||
# The final solution should be to move this check into the operator and support
|
||||
# integration with torchair.
|
||||
if x.size(0) != residual.size(0):
|
||||
residual = torch.ops.vllm.maybe_chunk_residual(x, residual)
|
||||
assert x.size(0) == residual.size(0)
|
||||
if is_310p():
|
||||
orig_dtype = residual.dtype
|
||||
x = x + residual.to(x.dtype)
|
||||
|
||||
Reference in New Issue
Block a user