[Ascend]optimize Qwen3 on Ascend (#10574)
Co-authored-by: c30031083 <chenxu140@huawei.com>
This commit is contained in:
@@ -50,6 +50,7 @@ from sglang.srt.utils import (
|
||||
is_hip,
|
||||
is_sm90_supported,
|
||||
is_sm100_supported,
|
||||
prepare_weight_cache,
|
||||
)
|
||||
|
||||
_is_flashinfer_available = is_flashinfer_available()
|
||||
@@ -275,7 +276,11 @@ class LayerCommunicator:
|
||||
hidden_states: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
cache=None,
|
||||
):
|
||||
if cache is not None:
|
||||
self._context.cache = cache
|
||||
|
||||
return self._communicate_with_all_reduce_and_layer_norm_fn(
|
||||
hidden_states=hidden_states,
|
||||
residual=residual,
|
||||
@@ -349,6 +354,7 @@ class CommunicateContext:
|
||||
attn_tp_size: int
|
||||
attn_dp_size: int
|
||||
tp_size: int
|
||||
cache = None
|
||||
|
||||
def is_same_group_size(self, a: ScatterMode, b: ScatterMode):
|
||||
return self.process_group_sizes[a] == self.process_group_sizes[b]
|
||||
@@ -533,6 +539,8 @@ class CommunicateWithAllReduceAndLayerNormFn:
|
||||
)
|
||||
else:
|
||||
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
|
||||
if context.cache is not None:
|
||||
_ = prepare_weight_cache(hidden_states, context.cache)
|
||||
hidden_states, residual = layernorm(hidden_states, residual)
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
@@ -638,6 +638,7 @@ class NPU_W8A8LinearMethodImpl:
|
||||
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
|
||||
layer.weight_scale.data = torch.flatten(layer.weight_scale.data)
|
||||
layer.weight_offset.data = torch.flatten(layer.weight_offset.data)
|
||||
layer.weight.data = torch_npu.npu_format_cast(layer.weight.data, 29)
|
||||
|
||||
|
||||
class NPU_W8A8LinearMethodMTImpl:
|
||||
@@ -830,6 +831,7 @@ class NPU_W8A8DynamicLinearMethodImpl:
|
||||
layer.weight_scale.data = layer.weight_scale.data.flatten()
|
||||
layer.weight_scale_fp32 = layer.weight_scale.data.to(torch.float32)
|
||||
layer.weight_offset.data = layer.weight_offset.data.flatten()
|
||||
layer.weight.data = torch_npu.npu_format_cast(layer.weight.data, 29)
|
||||
|
||||
|
||||
class NPU_W8A8DynamicLinearMethod(LinearMethodBase):
|
||||
|
||||
@@ -179,6 +179,13 @@ UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
if _is_npu:
|
||||
import torch_npu
|
||||
|
||||
torch.npu.config.allow_internal_format = True
|
||||
torch_npu.npu.set_compile_mode(jit_compile=False)
|
||||
|
||||
|
||||
class RankZeroFilter(logging.Filter):
|
||||
"""Filter that only allows INFO level logs from rank 0, but allows all other levels from any rank."""
|
||||
|
||||
|
||||
@@ -19,8 +19,10 @@ import logging
|
||||
import threading
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from sglang.srt.configs.model_config import AttentionArch
|
||||
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -30,12 +30,19 @@ from sglang.srt.model_loader.weight_utils import (
|
||||
)
|
||||
from sglang.srt.models.qwen2 import Qwen2MLP as Qwen3MLP
|
||||
from sglang.srt.models.qwen2 import Qwen2Model
|
||||
from sglang.srt.utils import add_prefix, is_cuda
|
||||
from sglang.srt.utils import (
|
||||
add_prefix,
|
||||
get_cmo_stream,
|
||||
is_cuda,
|
||||
is_npu,
|
||||
wait_cmo_stream,
|
||||
)
|
||||
|
||||
Qwen3Config = None
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_is_cuda = is_cuda()
|
||||
_is_npu = is_npu()
|
||||
|
||||
|
||||
class Qwen3Attention(nn.Module):
|
||||
@@ -235,9 +242,18 @@ class Qwen3DecoderLayer(nn.Module):
|
||||
|
||||
# Fully Connected
|
||||
hidden_states, residual = self.layer_communicator.prepare_mlp(
|
||||
hidden_states, residual, forward_batch
|
||||
hidden_states,
|
||||
residual,
|
||||
forward_batch,
|
||||
cache=(
|
||||
[self.mlp.gate_up_proj.weight, self.mlp.down_proj.weight]
|
||||
if _is_npu
|
||||
else None
|
||||
),
|
||||
)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
if _is_npu and get_cmo_stream():
|
||||
wait_cmo_stream()
|
||||
hidden_states, residual = self.layer_communicator.postprocess_layer(
|
||||
hidden_states, residual, forward_batch
|
||||
)
|
||||
|
||||
@@ -517,6 +517,50 @@ def make_layers(
|
||||
return modules, start_layer, end_layer
|
||||
|
||||
|
||||
cmo_stream = None
|
||||
|
||||
|
||||
def get_cmo_stream():
|
||||
"""
|
||||
Cache Management Operation(CMO).
|
||||
Launch a new stream to prefetch the weight of matmul when running other
|
||||
AIV or communication kernels, aiming to overlap the memory access time.
|
||||
"""
|
||||
global cmo_stream
|
||||
if cmo_stream is None:
|
||||
cmo_stream = torch.get_device_module().Stream()
|
||||
return cmo_stream
|
||||
|
||||
|
||||
def prepare_weight_cache(handle, cache):
|
||||
import torch_npu
|
||||
|
||||
NPU_PREFETCH_MAX_SIZE_BYTES = (
|
||||
1000000000 # 1GB, a large value to prefetch entire weight
|
||||
)
|
||||
stream = get_cmo_stream()
|
||||
stream.wait_stream(torch.npu.current_stream())
|
||||
with torch.npu.stream(stream):
|
||||
if isinstance(cache, list):
|
||||
for weight in cache:
|
||||
torch_npu.npu_prefetch(
|
||||
weight,
|
||||
handle,
|
||||
NPU_PREFETCH_MAX_SIZE_BYTES,
|
||||
)
|
||||
else:
|
||||
torch_npu.npu_prefetch(
|
||||
cache,
|
||||
handle,
|
||||
NPU_PREFETCH_MAX_SIZE_BYTES,
|
||||
)
|
||||
|
||||
|
||||
def wait_cmo_stream():
|
||||
cur_stream = torch.get_device_module().current_stream()
|
||||
cur_stream.wait_stream(get_cmo_stream())
|
||||
|
||||
|
||||
def set_random_seed(seed: int) -> None:
|
||||
"""Set the random seed for all libraries."""
|
||||
random.seed(seed)
|
||||
|
||||
Reference in New Issue
Block a user