diff --git a/python/sglang/srt/layers/communicator.py b/python/sglang/srt/layers/communicator.py index fba8d8f18..e050da91d 100644 --- a/python/sglang/srt/layers/communicator.py +++ b/python/sglang/srt/layers/communicator.py @@ -50,6 +50,7 @@ from sglang.srt.utils import ( is_hip, is_sm90_supported, is_sm100_supported, + prepare_weight_cache, ) _is_flashinfer_available = is_flashinfer_available() @@ -275,7 +276,11 @@ class LayerCommunicator: hidden_states: torch.Tensor, residual: torch.Tensor, forward_batch: ForwardBatch, + cache=None, ): + if cache is not None: + self._context.cache = cache + return self._communicate_with_all_reduce_and_layer_norm_fn( hidden_states=hidden_states, residual=residual, @@ -349,6 +354,7 @@ class CommunicateContext: attn_tp_size: int attn_dp_size: int tp_size: int + cache = None def is_same_group_size(self, a: ScatterMode, b: ScatterMode): return self.process_group_sizes[a] == self.process_group_sizes[b] @@ -533,6 +539,8 @@ class CommunicateWithAllReduceAndLayerNormFn: ) else: hidden_states = tensor_model_parallel_all_reduce(hidden_states) + if context.cache is not None: + _ = prepare_weight_cache(hidden_states, context.cache) hidden_states, residual = layernorm(hidden_states, residual) return hidden_states, residual diff --git a/python/sglang/srt/layers/quantization/w8a8_int8.py b/python/sglang/srt/layers/quantization/w8a8_int8.py index 5ccb0259d..cab505a50 100644 --- a/python/sglang/srt/layers/quantization/w8a8_int8.py +++ b/python/sglang/srt/layers/quantization/w8a8_int8.py @@ -638,6 +638,7 @@ class NPU_W8A8LinearMethodImpl: layer.weight.data = layer.weight.data.transpose(0, 1).contiguous() layer.weight_scale.data = torch.flatten(layer.weight_scale.data) layer.weight_offset.data = torch.flatten(layer.weight_offset.data) + layer.weight.data = torch_npu.npu_format_cast(layer.weight.data, 29) class NPU_W8A8LinearMethodMTImpl: @@ -830,6 +831,7 @@ class NPU_W8A8DynamicLinearMethodImpl: layer.weight_scale.data = layer.weight_scale.data.flatten() layer.weight_scale_fp32 = layer.weight_scale.data.to(torch.float32) layer.weight_offset.data = layer.weight_offset.data.flatten() + layer.weight.data = torch_npu.npu_format_cast(layer.weight.data, 29) class NPU_W8A8DynamicLinearMethod(LinearMethodBase): diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 1f08e43a1..d053e2bb8 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -179,6 +179,13 @@ UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300 logger = logging.getLogger(__name__) +if _is_npu: + import torch_npu + + torch.npu.config.allow_internal_format = True + torch_npu.npu.set_compile_mode(jit_compile=False) + + class RankZeroFilter(logging.Filter): """Filter that only allows INFO level logs from rank 0, but allows all other levels from any rank.""" diff --git a/python/sglang/srt/model_executor/npu_graph_runner.py b/python/sglang/srt/model_executor/npu_graph_runner.py index 0ff19d582..d7619b2d7 100644 --- a/python/sglang/srt/model_executor/npu_graph_runner.py +++ b/python/sglang/srt/model_executor/npu_graph_runner.py @@ -19,8 +19,10 @@ import logging import threading from typing import TYPE_CHECKING, Optional, Union +import numpy as np import torch +from sglang.srt.configs.model_config import AttentionArch from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner logger = logging.getLogger(__name__) diff --git a/python/sglang/srt/models/qwen3.py b/python/sglang/srt/models/qwen3.py index bc5f054d7..a7551bb82 100644 --- a/python/sglang/srt/models/qwen3.py +++ b/python/sglang/srt/models/qwen3.py @@ -30,12 +30,19 @@ from sglang.srt.model_loader.weight_utils import ( ) from sglang.srt.models.qwen2 import Qwen2MLP as Qwen3MLP from sglang.srt.models.qwen2 import Qwen2Model -from sglang.srt.utils import add_prefix, is_cuda +from sglang.srt.utils import ( + add_prefix, + get_cmo_stream, + is_cuda, + is_npu, + wait_cmo_stream, +) Qwen3Config = None logger = logging.getLogger(__name__) _is_cuda = is_cuda() +_is_npu = is_npu() class Qwen3Attention(nn.Module): @@ -235,9 +242,18 @@ class Qwen3DecoderLayer(nn.Module): # Fully Connected hidden_states, residual = self.layer_communicator.prepare_mlp( - hidden_states, residual, forward_batch + hidden_states, + residual, + forward_batch, + cache=( + [self.mlp.gate_up_proj.weight, self.mlp.down_proj.weight] + if _is_npu + else None + ), ) hidden_states = self.mlp(hidden_states) + if _is_npu and get_cmo_stream(): + wait_cmo_stream() hidden_states, residual = self.layer_communicator.postprocess_layer( hidden_states, residual, forward_batch ) diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 0681bdfe2..812d72a08 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -517,6 +517,50 @@ def make_layers( return modules, start_layer, end_layer +cmo_stream = None + + +def get_cmo_stream(): + """ + Cache Management Operation(CMO). + Launch a new stream to prefetch the weight of matmul when running other + AIV or communication kernels, aiming to overlap the memory access time. + """ + global cmo_stream + if cmo_stream is None: + cmo_stream = torch.get_device_module().Stream() + return cmo_stream + + +def prepare_weight_cache(handle, cache): + import torch_npu + + NPU_PREFETCH_MAX_SIZE_BYTES = ( + 1000000000 # 1GB, a large value to prefetch entire weight + ) + stream = get_cmo_stream() + stream.wait_stream(torch.npu.current_stream()) + with torch.npu.stream(stream): + if isinstance(cache, list): + for weight in cache: + torch_npu.npu_prefetch( + weight, + handle, + NPU_PREFETCH_MAX_SIZE_BYTES, + ) + else: + torch_npu.npu_prefetch( + cache, + handle, + NPU_PREFETCH_MAX_SIZE_BYTES, + ) + + +def wait_cmo_stream(): + cur_stream = torch.get_device_module().current_stream() + cur_stream.wait_stream(get_cmo_stream()) + + def set_random_seed(seed: int) -> None: """Set the random seed for all libraries.""" random.seed(seed)