diff --git a/docs/backend/pd_disaggregation.md b/docs/backend/pd_disaggregation.md index 43ad0547e..9284dc048 100644 --- a/docs/backend/pd_disaggregation.md +++ b/docs/backend/pd_disaggregation.md @@ -111,3 +111,36 @@ $ python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 ---di # decode 1 $ python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 ---disaggregation-transfer-backend nixl --disaggregation-mode decode --host ${local_ip} --port 30001 --trust-remote-code --dist-init-addr ${decode_master_ip}:5000 --nnodes 2 --node-rank 1 --tp-size 16 --dp-size 8 --enable-dp-attention --enable-deepep-moe --deepep-mode low_latency --mem-fraction-static 0.8 --max-running-requests 128 ``` + +## ASCEND + +### Usage + +Use ascend backend with [mf_adapter(download link)](https://sglang-ascend.obs.cn-east-3.myhuaweicloud.com:443/sglang/mf_adapter-1.0.0-cp311-cp311-linux_aarch64.whl?AccessKeyId=HPUAXT4YM0U8JNTERLST&Expires=1783151861&Signature=3j10QDUjqk70enaq8lostYV2bEA%3D) and ASCEND_MF_STORE_URL being set + +```bash +pip install mf_adapter-1.0.0-cp311-cp311-linux_aarch64.whl --force-reinstall +export ASCEND_MF_STORE_URL="tcp://xxx.xx.xxx.xxx:xxxx" +``` +Use mooncake backend, more details can be found in mooncake section. +```bash +export ENABLE_ASCEND_TRANSFER_WITH_MOONCAKE=true +``` + + +### Llama Single Node + +```bash +$ python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disaggregation-mode prefill --disaggregation-transfer-backend ascend +$ python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disaggregation-mode decode --port 30001 --base-gpu-id 1 --disaggregation-transfer-backend ascend +$ python -m sglang.srt.disaggregation.mini_lb --prefill http://127.0.0.1:30000 --decode http://127.0.0.1:30001 --host 0.0.0.0 --port 8000 +``` + +### DeepSeek Multi-Node + +```bash +# prefill 0 +$ python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 ---disaggregation-transfer-backend ascend --disaggregation-mode prefill --host ${local_ip} --port 30000 --trust-remote-code --dist-init-addr ${prefill_master_ip}:5000 --nnodes 1 --node-rank 0 --tp-size 16 +# decode 0 +$ python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 ---disaggregation-transfer-backend ascend --disaggregation-mode decode --host ${local_ip} --port 30001 --trust-remote-code --dist-init-addr ${decode_master_ip}:5000 --nnodes 1 --node-rank 0 --tp-size 16 +``` diff --git a/python/sglang/srt/disaggregation/ascend/__init__.py b/python/sglang/srt/disaggregation/ascend/__init__.py new file mode 100644 index 000000000..2550f91a4 --- /dev/null +++ b/python/sglang/srt/disaggregation/ascend/__init__.py @@ -0,0 +1,6 @@ +from sglang.srt.disaggregation.ascend.conn import ( + AscendKVBootstrapServer, + AscendKVManager, + AscendKVReceiver, + AscendKVSender, +) diff --git a/python/sglang/srt/disaggregation/ascend/conn.py b/python/sglang/srt/disaggregation/ascend/conn.py new file mode 100644 index 000000000..504212e0a --- /dev/null +++ b/python/sglang/srt/disaggregation/ascend/conn.py @@ -0,0 +1,44 @@ +import logging + +from sglang.srt.disaggregation.ascend.transfer_engine import AscendTransferEngine +from sglang.srt.disaggregation.mooncake.conn import ( + MooncakeKVBootstrapServer, + MooncakeKVManager, + MooncakeKVReceiver, + MooncakeKVSender, +) +from sglang.srt.utils import get_local_ip_by_remote + +logger = logging.getLogger(__name__) + + +class AscendKVManager(MooncakeKVManager): + def init_engine(self): + # TransferEngine initialized on ascend. + local_ip = get_local_ip_by_remote() + self.engine = AscendTransferEngine( + hostname=local_ip, + npu_id=self.kv_args.gpu_id, + disaggregation_mode=self.disaggregation_mode, + ) + + def register_buffer_to_engine(self): + self.engine.register( + self.kv_args.kv_data_ptrs[0], sum(self.kv_args.kv_data_lens) + ) + # The Ascend backend optimize batch registration for small memory blocks. + self.engine.batch_register( + self.kv_args.aux_data_ptrs, self.kv_args.aux_data_lens + ) + + +class AscendKVSender(MooncakeKVSender): + pass + + +class AscendKVReceiver(MooncakeKVReceiver): + pass + + +class AscendKVBootstrapServer(MooncakeKVBootstrapServer): + pass diff --git a/python/sglang/srt/disaggregation/ascend/transfer_engine.py b/python/sglang/srt/disaggregation/ascend/transfer_engine.py new file mode 100644 index 000000000..0ccffffd6 --- /dev/null +++ b/python/sglang/srt/disaggregation/ascend/transfer_engine.py @@ -0,0 +1,58 @@ +import logging +import os +from typing import List, Optional + +from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine +from sglang.srt.disaggregation.utils import DisaggregationMode + +logger = logging.getLogger(__name__) + + +class AscendTransferEngine(MooncakeTransferEngine): + + def __init__( + self, hostname: str, npu_id: int, disaggregation_mode: DisaggregationMode + ): + try: + from mf_adapter import TransferEngine + except ImportError as e: + raise ImportError( + "Please install mf_adapter, for details, see docs/backend/pd_disaggregation.md" + ) from e + + self.engine = TransferEngine() + self.hostname = hostname + self.npu_id = npu_id + + # Centralized storage address of the AscendTransferEngine + self.store_url = os.getenv("ASCEND_MF_STORE_URL") + if disaggregation_mode == DisaggregationMode.PREFILL: + self.role = "Prefill" + elif disaggregation_mode == DisaggregationMode.DECODE: + self.role = "Decode" + else: + logger.error(f"Unsupported DisaggregationMode: {disaggregation_mode}") + raise ValueError(f"Unsupported DisaggregationMode: {disaggregation_mode}") + self.session_id = f"{self.hostname}:{self.engine.get_rpc_port()}" + self.initialize() + + def initialize(self) -> None: + """Initialize the ascend transfer instance.""" + ret_value = self.engine.initialize( + self.store_url, + self.session_id, + self.role, + self.npu_id, + ) + if ret_value != 0: + logger.error("Ascend Transfer Engine initialization failed.") + raise RuntimeError("Ascend Transfer Engine initialization failed.") + + def batch_register(self, ptrs: List[int], lengths: List[int]): + try: + ret_value = self.engine.batch_register_memory(ptrs, lengths) + except Exception: + # Mark register as failed + ret_value = -1 + if ret_value != 0: + logger.debug(f"Ascend memory registration for ptr {ptrs} failed.") diff --git a/python/sglang/srt/disaggregation/mooncake/conn.py b/python/sglang/srt/disaggregation/mooncake/conn.py index 13458ba64..db30d8c0d 100644 --- a/python/sglang/srt/disaggregation/mooncake/conn.py +++ b/python/sglang/srt/disaggregation/mooncake/conn.py @@ -132,13 +132,9 @@ class MooncakeKVManager(BaseKVManager): ): self.kv_args = args self.local_ip = get_local_ip_auto() - self.engine = MooncakeTransferEngine( - hostname=self.local_ip, - gpu_id=self.kv_args.gpu_id, - ib_device=self.kv_args.ib_device, - ) self.is_mla_backend = is_mla_backend self.disaggregation_mode = disaggregation_mode + self.init_engine() # for p/d multi node infer self.bootstrap_port = server_args.disaggregation_bootstrap_port self.dist_init_addr = server_args.dist_init_addr @@ -225,6 +221,13 @@ class MooncakeKVManager(BaseKVManager): self.failure_records: Dict[int, str] = {} self.failure_lock = threading.Lock() + def init_engine(self): + self.engine = MooncakeTransferEngine( + hostname=self.local_ip, + gpu_id=self.kv_args.gpu_id, + ib_device=self.kv_args.ib_device, + ) + def register_buffer_to_engine(self): for kv_data_ptr, kv_data_len in zip( self.kv_args.kv_data_ptrs, self.kv_args.kv_data_lens diff --git a/python/sglang/srt/disaggregation/mooncake/transfer_engine.py b/python/sglang/srt/disaggregation/mooncake/transfer_engine.py index 4fb108d0f..8c7ea0108 100644 --- a/python/sglang/srt/disaggregation/mooncake/transfer_engine.py +++ b/python/sglang/srt/disaggregation/mooncake/transfer_engine.py @@ -1,6 +1,8 @@ import logging from typing import List, Optional +from sglang.srt.utils import get_bool_env_var, get_free_port + logger = logging.getLogger(__name__) @@ -53,12 +55,21 @@ class MooncakeTransferEngine: device_name: Optional[str], ) -> None: """Initialize the mooncake instance.""" - ret_value = self.engine.initialize( - hostname, - "P2PHANDSHAKE", - "rdma", - device_name if device_name is not None else "", - ) + if get_bool_env_var("ENABLE_ASCEND_TRANSFER_WITH_MOONCAKE", "false"): + hostname += f":{get_free_port()}:npu_{self.gpu_id}" + ret_value = self.engine.initialize( + hostname, + "P2PHANDSHAKE", + "ascend", + device_name if device_name is not None else "", + ) + else: + ret_value = self.engine.initialize( + hostname, + "P2PHANDSHAKE", + "rdma", + device_name if device_name is not None else "", + ) if ret_value != 0: logger.error("Mooncake Transfer Engine initialization failed.") raise RuntimeError("Mooncake Transfer Engine initialization failed.") diff --git a/python/sglang/srt/disaggregation/utils.py b/python/sglang/srt/disaggregation/utils.py index d2ff845c4..720c9d5a5 100644 --- a/python/sglang/srt/disaggregation/utils.py +++ b/python/sglang/srt/disaggregation/utils.py @@ -15,7 +15,7 @@ import requests import torch import torch.distributed as dist -from sglang.srt.utils import get_ip +from sglang.srt.utils import get_ip, is_npu if TYPE_CHECKING: from sglang.srt.managers.schedule_batch import Req @@ -94,8 +94,12 @@ class MetadataBuffers: custom_mem_pool: torch.cuda.MemPool = None, ): self.custom_mem_pool = custom_mem_pool - device = "cuda" if self.custom_mem_pool else "cpu" - + device = "cpu" + if is_npu(): + # For ascend backend, output tokens are placed in the NPU and will be transferred by D2D channel. + device = "npu" + elif self.custom_mem_pool: + device = "cuda" with ( torch.cuda.use_mem_pool(self.custom_mem_pool) if self.custom_mem_pool @@ -200,6 +204,7 @@ class MetadataBuffers: class TransferBackend(Enum): MOONCAKE = "mooncake" NIXL = "nixl" + ASCEND = "ascend" FAKE = "fake" @@ -231,6 +236,23 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType): KVClassType.BOOTSTRAP_SERVER: MooncakeKVBootstrapServer, } return class_mapping.get(class_type) + elif transfer_backend == TransferBackend.ASCEND: + from sglang.srt.disaggregation.ascend import ( + AscendKVBootstrapServer, + AscendKVManager, + AscendKVReceiver, + AscendKVSender, + ) + from sglang.srt.disaggregation.base import KVArgs + + class_mapping = { + KVClassType.KVARGS: KVArgs, + KVClassType.MANAGER: AscendKVManager, + KVClassType.SENDER: AscendKVSender, + KVClassType.RECEIVER: (AscendKVReceiver), + KVClassType.BOOTSTRAP_SERVER: AscendKVBootstrapServer, + } + return class_mapping.get(class_type) elif transfer_backend == TransferBackend.NIXL: from sglang.srt.disaggregation.base import KVArgs from sglang.srt.disaggregation.nixl import ( diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 3c9256e3f..81f36faa6 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -285,6 +285,20 @@ class TokenizerManager: self.bootstrap_server = kv_bootstrap_server_class( self.server_args.disaggregation_bootstrap_port ) + is_create_store = ( + self.server_args.node_rank == 0 + and self.server_args.disaggregation_transfer_backend == "ascend" + ) + if is_create_store: + try: + from mf_adapter import create_config_store + + ascend_url = os.getenv("ASCEND_MF_STORE_URL") + create_config_store(ascend_url) + except Exception as e: + error_message = f"Failed create mf store, invalid ascend_url." + error_message += f" With exception {e}" + raise error_message # For load balancing self.current_load = 0 diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index cc2d8e20c..2e0766222 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -604,32 +604,49 @@ class AscendTokenToKVPool(MHATokenToKVPool): with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE): # [size, head_num, head_dim] for each layer # The padded slot 0 is used for writing dummy outputs from padded tokens. - self.k_buffer = [ - torch.zeros( - ( - self.size // self.page_size + 1, - self.page_size, - self.head_num, - self.head_dim, - ), - dtype=self.store_dtype, - device=self.device, - ) - for _ in range(self.layer_num) - ] - self.v_buffer = [ - torch.zeros( - ( - self.size // self.page_size + 1, - self.page_size, - self.head_num, - self.head_dim, - ), - dtype=self.store_dtype, - device=self.device, - ) - for _ in range(self.layer_num) - ] + # Continuous memory improves the efficiency of Ascend`s transmission backend, + # while other backends remain unchanged. + self.kv_buffer = torch.zeros( + ( + 2, + self.layer_num, + self.size // self.page_size + 1, + self.page_size, + self.head_num, + self.head_dim, + ), + dtype=self.store_dtype, + device=self.device, + ) + self.k_buffer = self.kv_buffer[0] + self.v_buffer = self.kv_buffer[1] + + # for disagg + def get_contiguous_buf_infos(self): + # layer_num x [seq_len, head_num, head_dim] + # layer_num x [page_num, page_size, head_num, head_dim] + kv_data_ptrs = [ + self.get_key_buffer(i).data_ptr() + for i in range(self.start_layer, self.start_layer + self.layer_num) + ] + [ + self.get_value_buffer(i).data_ptr() + for i in range(self.start_layer, self.start_layer + self.layer_num) + ] + kv_data_lens = [ + self.get_key_buffer(i).nbytes + for i in range(self.start_layer, self.start_layer + self.layer_num) + ] + [ + self.get_value_buffer(i).nbytes + for i in range(self.start_layer, self.start_layer + self.layer_num) + ] + kv_item_lens = [ + self.get_key_buffer(i)[0].nbytes + for i in range(self.start_layer, self.start_layer + self.layer_num) + ] + [ + self.get_value_buffer(i)[0].nbytes + for i in range(self.start_layer, self.start_layer + self.layer_num) + ] + return kv_data_ptrs, kv_data_lens, kv_item_lens def set_kv_buffer( self, @@ -969,18 +986,16 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool): with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE): # The padded slot 0 is used for writing dummy outputs from padded tokens. - self.kv_buffer = [ - torch.zeros( - ( - self.size // self.page_size + 1, - self.page_size, - self.kv_lora_rank + self.qk_rope_head_dim, - ), - dtype=self.store_dtype, - device=self.device, - ) - for _ in range(layer_num) - ] + self.kv_buffer = torch.zeros( + ( + layer_num, + self.size // self.page_size + 1, + self.page_size, + self.kv_lora_rank + self.qk_rope_head_dim, + ), + dtype=self.store_dtype, + device=self.device, + ) self.layer_transfer_counter = None @@ -990,6 +1005,14 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool): ) self.mem_usage = kv_size / GB + # for disagg + def get_contiguous_buf_infos(self): + # MLA has only one kv_buffer, so only the information of this buffer needs to be returned. + kv_data_ptrs = [self.kv_buffer[i].data_ptr() for i in range(self.layer_num)] + kv_data_lens = [self.kv_buffer[i].nbytes for i in range(self.layer_num)] + kv_item_lens = [self.kv_buffer[i][0].nbytes for i in range(self.layer_num)] + return kv_data_ptrs, kv_data_lens, kv_item_lens + def set_kv_buffer( self, layer: RadixAttention, diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 935c1b89d..c44f53f7e 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -1621,7 +1621,7 @@ class ServerArgs: "--disaggregation-transfer-backend", type=str, default=ServerArgs.disaggregation_transfer_backend, - choices=["mooncake", "nixl"], + choices=["mooncake", "nixl", "ascend"], help="The backend for disaggregation transfer. Default is mooncake.", ) parser.add_argument(