[Ascend]optimize Qwen3 on Ascend (#10574)
Co-authored-by: c30031083 <chenxu140@huawei.com>
This commit is contained in:
@@ -50,6 +50,7 @@ from sglang.srt.utils import (
|
|||||||
is_hip,
|
is_hip,
|
||||||
is_sm90_supported,
|
is_sm90_supported,
|
||||||
is_sm100_supported,
|
is_sm100_supported,
|
||||||
|
prepare_weight_cache,
|
||||||
)
|
)
|
||||||
|
|
||||||
_is_flashinfer_available = is_flashinfer_available()
|
_is_flashinfer_available = is_flashinfer_available()
|
||||||
@@ -275,7 +276,11 @@ class LayerCommunicator:
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
residual: torch.Tensor,
|
residual: torch.Tensor,
|
||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
|
cache=None,
|
||||||
):
|
):
|
||||||
|
if cache is not None:
|
||||||
|
self._context.cache = cache
|
||||||
|
|
||||||
return self._communicate_with_all_reduce_and_layer_norm_fn(
|
return self._communicate_with_all_reduce_and_layer_norm_fn(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
residual=residual,
|
residual=residual,
|
||||||
@@ -349,6 +354,7 @@ class CommunicateContext:
|
|||||||
attn_tp_size: int
|
attn_tp_size: int
|
||||||
attn_dp_size: int
|
attn_dp_size: int
|
||||||
tp_size: int
|
tp_size: int
|
||||||
|
cache = None
|
||||||
|
|
||||||
def is_same_group_size(self, a: ScatterMode, b: ScatterMode):
|
def is_same_group_size(self, a: ScatterMode, b: ScatterMode):
|
||||||
return self.process_group_sizes[a] == self.process_group_sizes[b]
|
return self.process_group_sizes[a] == self.process_group_sizes[b]
|
||||||
@@ -533,6 +539,8 @@ class CommunicateWithAllReduceAndLayerNormFn:
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
|
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
|
||||||
|
if context.cache is not None:
|
||||||
|
_ = prepare_weight_cache(hidden_states, context.cache)
|
||||||
hidden_states, residual = layernorm(hidden_states, residual)
|
hidden_states, residual = layernorm(hidden_states, residual)
|
||||||
return hidden_states, residual
|
return hidden_states, residual
|
||||||
|
|
||||||
|
|||||||
@@ -638,6 +638,7 @@ class NPU_W8A8LinearMethodImpl:
|
|||||||
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
|
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
|
||||||
layer.weight_scale.data = torch.flatten(layer.weight_scale.data)
|
layer.weight_scale.data = torch.flatten(layer.weight_scale.data)
|
||||||
layer.weight_offset.data = torch.flatten(layer.weight_offset.data)
|
layer.weight_offset.data = torch.flatten(layer.weight_offset.data)
|
||||||
|
layer.weight.data = torch_npu.npu_format_cast(layer.weight.data, 29)
|
||||||
|
|
||||||
|
|
||||||
class NPU_W8A8LinearMethodMTImpl:
|
class NPU_W8A8LinearMethodMTImpl:
|
||||||
@@ -830,6 +831,7 @@ class NPU_W8A8DynamicLinearMethodImpl:
|
|||||||
layer.weight_scale.data = layer.weight_scale.data.flatten()
|
layer.weight_scale.data = layer.weight_scale.data.flatten()
|
||||||
layer.weight_scale_fp32 = layer.weight_scale.data.to(torch.float32)
|
layer.weight_scale_fp32 = layer.weight_scale.data.to(torch.float32)
|
||||||
layer.weight_offset.data = layer.weight_offset.data.flatten()
|
layer.weight_offset.data = layer.weight_offset.data.flatten()
|
||||||
|
layer.weight.data = torch_npu.npu_format_cast(layer.weight.data, 29)
|
||||||
|
|
||||||
|
|
||||||
class NPU_W8A8DynamicLinearMethod(LinearMethodBase):
|
class NPU_W8A8DynamicLinearMethod(LinearMethodBase):
|
||||||
|
|||||||
@@ -179,6 +179,13 @@ UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
if _is_npu:
|
||||||
|
import torch_npu
|
||||||
|
|
||||||
|
torch.npu.config.allow_internal_format = True
|
||||||
|
torch_npu.npu.set_compile_mode(jit_compile=False)
|
||||||
|
|
||||||
|
|
||||||
class RankZeroFilter(logging.Filter):
|
class RankZeroFilter(logging.Filter):
|
||||||
"""Filter that only allows INFO level logs from rank 0, but allows all other levels from any rank."""
|
"""Filter that only allows INFO level logs from rank 0, but allows all other levels from any rank."""
|
||||||
|
|
||||||
|
|||||||
@@ -19,8 +19,10 @@ import logging
|
|||||||
import threading
|
import threading
|
||||||
from typing import TYPE_CHECKING, Optional, Union
|
from typing import TYPE_CHECKING, Optional, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from sglang.srt.configs.model_config import AttentionArch
|
||||||
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|||||||
@@ -30,12 +30,19 @@ from sglang.srt.model_loader.weight_utils import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.models.qwen2 import Qwen2MLP as Qwen3MLP
|
from sglang.srt.models.qwen2 import Qwen2MLP as Qwen3MLP
|
||||||
from sglang.srt.models.qwen2 import Qwen2Model
|
from sglang.srt.models.qwen2 import Qwen2Model
|
||||||
from sglang.srt.utils import add_prefix, is_cuda
|
from sglang.srt.utils import (
|
||||||
|
add_prefix,
|
||||||
|
get_cmo_stream,
|
||||||
|
is_cuda,
|
||||||
|
is_npu,
|
||||||
|
wait_cmo_stream,
|
||||||
|
)
|
||||||
|
|
||||||
Qwen3Config = None
|
Qwen3Config = None
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
_is_cuda = is_cuda()
|
_is_cuda = is_cuda()
|
||||||
|
_is_npu = is_npu()
|
||||||
|
|
||||||
|
|
||||||
class Qwen3Attention(nn.Module):
|
class Qwen3Attention(nn.Module):
|
||||||
@@ -235,9 +242,18 @@ class Qwen3DecoderLayer(nn.Module):
|
|||||||
|
|
||||||
# Fully Connected
|
# Fully Connected
|
||||||
hidden_states, residual = self.layer_communicator.prepare_mlp(
|
hidden_states, residual = self.layer_communicator.prepare_mlp(
|
||||||
hidden_states, residual, forward_batch
|
hidden_states,
|
||||||
|
residual,
|
||||||
|
forward_batch,
|
||||||
|
cache=(
|
||||||
|
[self.mlp.gate_up_proj.weight, self.mlp.down_proj.weight]
|
||||||
|
if _is_npu
|
||||||
|
else None
|
||||||
|
),
|
||||||
)
|
)
|
||||||
hidden_states = self.mlp(hidden_states)
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
if _is_npu and get_cmo_stream():
|
||||||
|
wait_cmo_stream()
|
||||||
hidden_states, residual = self.layer_communicator.postprocess_layer(
|
hidden_states, residual = self.layer_communicator.postprocess_layer(
|
||||||
hidden_states, residual, forward_batch
|
hidden_states, residual, forward_batch
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -517,6 +517,50 @@ def make_layers(
|
|||||||
return modules, start_layer, end_layer
|
return modules, start_layer, end_layer
|
||||||
|
|
||||||
|
|
||||||
|
cmo_stream = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_cmo_stream():
|
||||||
|
"""
|
||||||
|
Cache Management Operation(CMO).
|
||||||
|
Launch a new stream to prefetch the weight of matmul when running other
|
||||||
|
AIV or communication kernels, aiming to overlap the memory access time.
|
||||||
|
"""
|
||||||
|
global cmo_stream
|
||||||
|
if cmo_stream is None:
|
||||||
|
cmo_stream = torch.get_device_module().Stream()
|
||||||
|
return cmo_stream
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_weight_cache(handle, cache):
|
||||||
|
import torch_npu
|
||||||
|
|
||||||
|
NPU_PREFETCH_MAX_SIZE_BYTES = (
|
||||||
|
1000000000 # 1GB, a large value to prefetch entire weight
|
||||||
|
)
|
||||||
|
stream = get_cmo_stream()
|
||||||
|
stream.wait_stream(torch.npu.current_stream())
|
||||||
|
with torch.npu.stream(stream):
|
||||||
|
if isinstance(cache, list):
|
||||||
|
for weight in cache:
|
||||||
|
torch_npu.npu_prefetch(
|
||||||
|
weight,
|
||||||
|
handle,
|
||||||
|
NPU_PREFETCH_MAX_SIZE_BYTES,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
torch_npu.npu_prefetch(
|
||||||
|
cache,
|
||||||
|
handle,
|
||||||
|
NPU_PREFETCH_MAX_SIZE_BYTES,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def wait_cmo_stream():
|
||||||
|
cur_stream = torch.get_device_module().current_stream()
|
||||||
|
cur_stream.wait_stream(get_cmo_stream())
|
||||||
|
|
||||||
|
|
||||||
def set_random_seed(seed: int) -> None:
|
def set_random_seed(seed: int) -> None:
|
||||||
"""Set the random seed for all libraries."""
|
"""Set the random seed for all libraries."""
|
||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
|
|||||||
Reference in New Issue
Block a user