[Ascend]optimize Qwen3 on Ascend (#10574)

Co-authored-by: c30031083 <chenxu140@huawei.com>
This commit is contained in:
ronnie_zheng
2025-09-23 03:18:36 +03:00
committed by GitHub
parent 095093ee5a
commit e22f3a5ec9
6 changed files with 81 additions and 2 deletions

View File

@@ -50,6 +50,7 @@ from sglang.srt.utils import (
is_hip, is_hip,
is_sm90_supported, is_sm90_supported,
is_sm100_supported, is_sm100_supported,
prepare_weight_cache,
) )
_is_flashinfer_available = is_flashinfer_available() _is_flashinfer_available = is_flashinfer_available()
@@ -275,7 +276,11 @@ class LayerCommunicator:
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
residual: torch.Tensor, residual: torch.Tensor,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
cache=None,
): ):
if cache is not None:
self._context.cache = cache
return self._communicate_with_all_reduce_and_layer_norm_fn( return self._communicate_with_all_reduce_and_layer_norm_fn(
hidden_states=hidden_states, hidden_states=hidden_states,
residual=residual, residual=residual,
@@ -349,6 +354,7 @@ class CommunicateContext:
attn_tp_size: int attn_tp_size: int
attn_dp_size: int attn_dp_size: int
tp_size: int tp_size: int
cache = None
def is_same_group_size(self, a: ScatterMode, b: ScatterMode): def is_same_group_size(self, a: ScatterMode, b: ScatterMode):
return self.process_group_sizes[a] == self.process_group_sizes[b] return self.process_group_sizes[a] == self.process_group_sizes[b]
@@ -533,6 +539,8 @@ class CommunicateWithAllReduceAndLayerNormFn:
) )
else: else:
hidden_states = tensor_model_parallel_all_reduce(hidden_states) hidden_states = tensor_model_parallel_all_reduce(hidden_states)
if context.cache is not None:
_ = prepare_weight_cache(hidden_states, context.cache)
hidden_states, residual = layernorm(hidden_states, residual) hidden_states, residual = layernorm(hidden_states, residual)
return hidden_states, residual return hidden_states, residual

View File

@@ -638,6 +638,7 @@ class NPU_W8A8LinearMethodImpl:
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous() layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
layer.weight_scale.data = torch.flatten(layer.weight_scale.data) layer.weight_scale.data = torch.flatten(layer.weight_scale.data)
layer.weight_offset.data = torch.flatten(layer.weight_offset.data) layer.weight_offset.data = torch.flatten(layer.weight_offset.data)
layer.weight.data = torch_npu.npu_format_cast(layer.weight.data, 29)
class NPU_W8A8LinearMethodMTImpl: class NPU_W8A8LinearMethodMTImpl:
@@ -830,6 +831,7 @@ class NPU_W8A8DynamicLinearMethodImpl:
layer.weight_scale.data = layer.weight_scale.data.flatten() layer.weight_scale.data = layer.weight_scale.data.flatten()
layer.weight_scale_fp32 = layer.weight_scale.data.to(torch.float32) layer.weight_scale_fp32 = layer.weight_scale.data.to(torch.float32)
layer.weight_offset.data = layer.weight_offset.data.flatten() layer.weight_offset.data = layer.weight_offset.data.flatten()
layer.weight.data = torch_npu.npu_format_cast(layer.weight.data, 29)
class NPU_W8A8DynamicLinearMethod(LinearMethodBase): class NPU_W8A8DynamicLinearMethod(LinearMethodBase):

View File

@@ -179,6 +179,13 @@ UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
if _is_npu:
import torch_npu
torch.npu.config.allow_internal_format = True
torch_npu.npu.set_compile_mode(jit_compile=False)
class RankZeroFilter(logging.Filter): class RankZeroFilter(logging.Filter):
"""Filter that only allows INFO level logs from rank 0, but allows all other levels from any rank.""" """Filter that only allows INFO level logs from rank 0, but allows all other levels from any rank."""

View File

@@ -19,8 +19,10 @@ import logging
import threading import threading
from typing import TYPE_CHECKING, Optional, Union from typing import TYPE_CHECKING, Optional, Union
import numpy as np
import torch import torch
from sglang.srt.configs.model_config import AttentionArch
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@@ -30,12 +30,19 @@ from sglang.srt.model_loader.weight_utils import (
) )
from sglang.srt.models.qwen2 import Qwen2MLP as Qwen3MLP from sglang.srt.models.qwen2 import Qwen2MLP as Qwen3MLP
from sglang.srt.models.qwen2 import Qwen2Model from sglang.srt.models.qwen2 import Qwen2Model
from sglang.srt.utils import add_prefix, is_cuda from sglang.srt.utils import (
add_prefix,
get_cmo_stream,
is_cuda,
is_npu,
wait_cmo_stream,
)
Qwen3Config = None Qwen3Config = None
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_is_cuda = is_cuda() _is_cuda = is_cuda()
_is_npu = is_npu()
class Qwen3Attention(nn.Module): class Qwen3Attention(nn.Module):
@@ -235,9 +242,18 @@ class Qwen3DecoderLayer(nn.Module):
# Fully Connected # Fully Connected
hidden_states, residual = self.layer_communicator.prepare_mlp( hidden_states, residual = self.layer_communicator.prepare_mlp(
hidden_states, residual, forward_batch hidden_states,
residual,
forward_batch,
cache=(
[self.mlp.gate_up_proj.weight, self.mlp.down_proj.weight]
if _is_npu
else None
),
) )
hidden_states = self.mlp(hidden_states) hidden_states = self.mlp(hidden_states)
if _is_npu and get_cmo_stream():
wait_cmo_stream()
hidden_states, residual = self.layer_communicator.postprocess_layer( hidden_states, residual = self.layer_communicator.postprocess_layer(
hidden_states, residual, forward_batch hidden_states, residual, forward_batch
) )

View File

@@ -517,6 +517,50 @@ def make_layers(
return modules, start_layer, end_layer return modules, start_layer, end_layer
cmo_stream = None
def get_cmo_stream():
"""
Cache Management Operation(CMO).
Launch a new stream to prefetch the weight of matmul when running other
AIV or communication kernels, aiming to overlap the memory access time.
"""
global cmo_stream
if cmo_stream is None:
cmo_stream = torch.get_device_module().Stream()
return cmo_stream
def prepare_weight_cache(handle, cache):
import torch_npu
NPU_PREFETCH_MAX_SIZE_BYTES = (
1000000000 # 1GB, a large value to prefetch entire weight
)
stream = get_cmo_stream()
stream.wait_stream(torch.npu.current_stream())
with torch.npu.stream(stream):
if isinstance(cache, list):
for weight in cache:
torch_npu.npu_prefetch(
weight,
handle,
NPU_PREFETCH_MAX_SIZE_BYTES,
)
else:
torch_npu.npu_prefetch(
cache,
handle,
NPU_PREFETCH_MAX_SIZE_BYTES,
)
def wait_cmo_stream():
cur_stream = torch.get_device_module().current_stream()
cur_stream.wait_stream(get_cmo_stream())
def set_random_seed(seed: int) -> None: def set_random_seed(seed: int) -> None:
"""Set the random seed for all libraries.""" """Set the random seed for all libraries."""
random.seed(seed) random.seed(seed)