first commit

This commit is contained in:
2026-03-10 13:31:25 +08:00
parent ba974cecfa
commit b62b889355
2604 changed files with 438977 additions and 0 deletions

49
vllm_br/__init__.py Normal file
View File

@@ -0,0 +1,49 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
import os
import torch # noqa F401
import torch_br # noqa F401
from torch_br.contrib import transfer_to_supa # noqa F401
from torch_br.supa import _debug as supa_debug
# patches
from . import utils # noqa: F401
# bypass memset
supa_debug.set_disable_zero_ws(True)
supa_debug.set_disable_zero_output_uma(True)
supa_debug.set_disable_zero_output_numa(True)
supa_debug.set_disable_reorder_zero(True)
os.environ["BRTB_ENABLE_NUMA_SPLIT"] = "1"
os.environ["BRTB_ENABLE_NUMA_ALIGN_4K"] = "1"
def register():
"""Register the SUPA platform."""
return "vllm_br.platform.SUPAPlatform"
def register_model():
from . import attention # noqa: F401
from . import config # noqa: F401
from . import distributed # noqa: F401
from . import v1 # noqa: F401
from .model_executor import register_model
register_model()

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -0,0 +1,16 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from . import layer # noqa: F401

Binary file not shown.

Binary file not shown.

130
vllm_br/attention/layer.py Normal file
View File

@@ -0,0 +1,130 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from typing import Optional
import torch
import vllm.attention.layer
from vllm.attention.layer import (maybe_save_kv_layer_to_connector,
wait_for_kv_layer_from_connector)
from vllm.forward_context import ForwardContext, get_forward_context
#direct_register_custom_op(
# op_name="unified_attention",
# op_func=unified_attention,
# mutates_args=[],
# fake_impl=unified_attention_fake,
# dispatch_key=current_platform.dispatch_key,
#)
#direct_register_custom_op(
# op_name="unified_attention_with_output",
# op_func=unified_attention_with_output,
# mutates_args=["output"],
# fake_impl=unified_attention_with_output_fake,
# dispatch_key=current_platform.dispatch_key,
#)
def forward_(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
# For some alternate attention backends like MLA the attention output
# shape does not match the query shape, so we optionally let the model
# definition specify the output tensor shape.
output_shape: Optional[torch.Size] = None,
) -> torch.Tensor:
"""
The KV cache is stored inside this class and is accessed via
`self.kv_cache`.
Attention metadata (`attn_metadata`) is set using a context manager in
the model runner's `execute_model` method. It is accessed via forward
context using
`vllm.forward_context.get_forward_context().attn_metadata`.
"""
if self.calculate_kv_scales:
attn_metadata = get_forward_context().attn_metadata
if attn_metadata.enable_kv_scales_calculation:
self.calc_kv_scales(query, key, value)
if self.use_output:
output_shape = (output_shape
if output_shape is not None else query.shape)
output = torch.empty(output_shape,
dtype=query.dtype,
device=query.device)
hidden_size = output_shape[-1]
# We skip reshaping query, key and value tensors for the MLA
# backend since these tensors have different semantics and are
# processed differently.
if not self.use_mla:
# Reshape the query, key, and value tensors.
# NOTE(woosuk): We do this outside the custom op to minimize the
# CPU overheads from the non-CUDA-graph regions.
query = query.view(-1, self.num_heads, self.head_size)
output = output.view(-1, self.num_heads, self.head_size)
if key is not None:
key = key.view(-1, self.num_kv_heads, self.head_size)
if value is not None:
value = value.view(-1, self.num_kv_heads, self.head_size)
if self.use_direct_call:
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[self.layer_name]
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
self.impl.forward(self,
query,
key,
value,
self_kv_cache,
attn_metadata,
output=output)
else:
torch.ops.vllm.unified_attention_with_output(
query, key, value, output, self.layer_name)
return output.view(-1, hidden_size)
else:
if self.use_direct_call:
wait_for_kv_layer_from_connector(self.layer_name)
forward_context = get_forward_context()
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[self.layer_name]
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
output = self.impl.forward(self, query, key, value, self_kv_cache,
attn_metadata)
maybe_save_kv_layer_to_connector(self.layer_name, self_kv_cache)
return output
else:
# return torch.ops.vllm.unified_attention(
# query, key, value, self.layer_name)
wait_for_kv_layer_from_connector(self.layer_name)
forward_context = get_forward_context()
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[self.layer_name]
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
output = self.impl.forward(self, query, key, value, self_kv_cache,
attn_metadata)
maybe_save_kv_layer_to_connector(self.layer_name, self_kv_cache)
return output
vllm.attention.layer.Attention.forward = forward_

View File

View File

@@ -0,0 +1,70 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
import os
import time
from vllm.config import CompilationConfig, CompilationLevel, VllmConfig
from vllm.logger import init_logger
logger = init_logger(__name__)
context_manager = None
torch_compile_start_time: float = 0.0
def start_monitoring_torch_compile(vllm_config: VllmConfig):
global torch_compile_start_time
torch_compile_start_time = time.time()
compilation_config: CompilationConfig = vllm_config.compilation_config
if compilation_config.level == CompilationLevel.PIECEWISE and \
compilation_config.debug_dump_path:
import depyf
path = os.path.join(compilation_config.debug_dump_path,
f"rank_{vllm_config.parallel_config.rank}")
global context_manager
context_manager = depyf.prepare_debug(path)
context_manager.__enter__()
def end_monitoring_torch_compile(vllm_config: VllmConfig):
compilation_config: CompilationConfig = vllm_config.compilation_config
if compilation_config.level == CompilationLevel.PIECEWISE:
logger.info("torch.compile takes %.2f s in total",
compilation_config.compilation_time)
global context_manager
if context_manager is not None:
context_manager.__exit__(None, None, None)
context_manager = None
supagraph_capturing_enabled: bool = True
def validate_supagraph_capturing_enabled():
# used to monitor whether a supagraph capturing is legal at runtime.
# should be called before any supagraph capturing.
# if an illegal supagraph capturing happens, raise an error.
global supagraph_capturing_enabled
if not supagraph_capturing_enabled:
raise RuntimeError("CUDA graph capturing detected at an inappropriate "
"time. This operation is currently disabled.")
def set_supagraph_capturing_enabled(enabled: bool):
global supagraph_capturing_enabled
supagraph_capturing_enabled = enabled

View File

@@ -0,0 +1,239 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
import dataclasses
from contextlib import ExitStack
from typing import Any, Callable, Optional
from unittest.mock import patch
import torch
import torch_br
import vllm.envs as envs
from vllm.compilation.counter import compilation_counter
from vllm.config import VllmConfig
from vllm.distributed.device_communicators.pynccl_allocator import (
set_graph_pool_id)
from vllm.distributed.parallel_state import get_world_group
from vllm.forward_context import BatchDescriptor, get_forward_context
from vllm.logger import init_logger, logger
from vllm.platforms import current_platform
from vllm_br.compilation.monitor import validate_supagraph_capturing_enabled
from vllm_br.config.compilation import SUPAGraphMode
from vllm_br.forward_context import BatchDescriptor
logger = init_logger(__name__)
@dataclasses.dataclass
class SUPAGraphEntry:
batch_descriptor: BatchDescriptor
supagraph: Optional[torch.supa.SUPAGraph] = None
output: Optional[Any] = None
# for supagraph debugging, track the input addresses
# during capture, and check if they are the same during replay
input_addresses: Optional[list[int]] = None
@dataclasses.dataclass
class SUPAGraphOptions:
debug_log_enable: bool = True
gc_disable: bool = False
weak_ref_output: bool = True
class SUPAGraphWrapper:
"""Wraps a runnable to add SUPA graph capturing and replaying ability. And
provide attribute access to the underlying `runnable` via `__getattr__`.
The workflow of this wrapper in the supagraph dispatching is as follows:
1. At initialization, a runtime mode is assigned to the wrapper (FULL or
PIECEWISE).
2. At runtime, the wrapper receives a runtime_mode and a
batch_descriptor(key) from the forward context and blindly trust them
for supagraph dispatching.
3. If runtime_mode is NONE or runtime_mode does not match the mode of the
wrapper, just call the runnable directly.
4. Otherwise, i.e., the runtime_mode matches the mode of the wrapper,
the wrapper will perform supagraph capture(if key does not exist, create
a new entry and cache it) or replay (if key exists in the cache).
Note: SUPAGraphWrapper does not store persistent buffers or copy any
runtime inputs into that buffers for replay. We assume implementing them
is done outside of the wrapper. That is because we do not make any
assumption on the dynamic shape (batch size) of the runtime inputs, as a
trade-off for staying orthogonal to compilation logic. Nevertheless,
tracing and checking the input addresses to be consistent during replay is
guaranteed when VLLM_LOGGING_LEVEL == "DEBUG".
"""
def __init__(self,
runnable: Callable,
vllm_config: VllmConfig,
runtime_mode: SUPAGraphMode,
supagraph_options: Optional[SUPAGraphOptions] = None):
self.runnable = runnable
self.vllm_config = vllm_config
self.runtime_mode = runtime_mode
self.compilation_config = vllm_config.compilation_config
self.first_run_finished = False
self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG"
# assert runtime_mode is not NONE(no supagraph), otherwise, we don't
# need to initialize a SUPAGraphWrapper.
assert self.runtime_mode != SUPAGraphMode.NONE
# TODO: in the future, if we want to use multiple
# streams, it might not be safe to share a global pool.
# only investigate this when we use multiple streams
self.graph_pool = current_platform.get_global_graph_pool()
if supagraph_options is None:
supagraph_options = SUPAGraphOptions()
self.supagraph_options = supagraph_options
# the entries for different batch descriptors that we need to capture
# supagraphs for.
self.concrete_supagraph_entries: dict[BatchDescriptor, SUPAGraphEntry]\
= {}
def __getattr__(self, key: str):
# allow accessing the attributes of the runnable.
if hasattr(self.runnable, key):
return getattr(self.runnable, key)
raise AttributeError(f"Attribute {key} not exists in the runnable of "
f"supagraph wrapper: {self.runnable}")
def unwrap(self) -> Callable:
# in case we need to access the original runnable.
return self.runnable
def __call__(self, *args, **kwargs):
forward_context = get_forward_context()
batch_descriptor = forward_context.batch_descriptor
supagraph_runtime_mode = forward_context.cudagraph_runtime_mode
#if supagraph_runtime_mode == SUPAGraphMode.NONE or \
# supagraph_runtime_mode != self.runtime_mode:
if supagraph_runtime_mode == SUPAGraphMode.NONE:
# SUPAGraphMode.NONE could mean the profile run, a warmup run, or
# running without supagraphs.
# We do not trigger capture/replay if the runtime mode is not
# matches. This enables properly dispatching to the correct
# SUPAGraphWrapper when nesting multiple instances with different
# runtime modes.
return self.runnable(*args, **kwargs)
if batch_descriptor not in self.concrete_supagraph_entries:
# create a new entry for this batch descriptor
self.concrete_supagraph_entries[batch_descriptor] = \
SUPAGraphEntry(batch_descriptor=batch_descriptor)
entry = self.concrete_supagraph_entries[batch_descriptor]
if entry.supagraph is None:
if self.supagraph_options.debug_log_enable:
# Since we capture supagraph for many different shapes and
# capturing is fast, we don't need to log it for every
# shape. E.g. we only log it for the first subgraph in
# piecewise mode.
logger.debug("Capturing a supagraph on (%s,%s)",
self.runtime_mode.name, entry.batch_descriptor)
# validate that supagraph capturing is legal at this point.
validate_supagraph_capturing_enabled()
input_addresses = [
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
] + [
x.data_ptr()
for x in kwargs.values() if isinstance(x, torch.Tensor)
]
entry.input_addresses = input_addresses
supagraph = torch.supa.SUPAGraph()
with ExitStack() as stack:
if self.supagraph_options.gc_disable:
# during every model forward for piecewise supagraph
# mode, we will capture many pieces of supagraphs
# (roughly one per layer). running gc again and again
# across layers will make the supagraph capture very slow.
# therefore, we only run gc for the first graph,
# and disable gc for the rest of the graphs.
stack.enter_context(patch("gc.collect", lambda: None))
stack.enter_context(
patch("torch.supa.empty_cache", lambda: None))
if self.graph_pool is not None:
set_graph_pool_id(self.graph_pool)
else:
set_graph_pool_id(current_platform.graph_pool_handle())
# mind-exploding: carefully manage the reference and memory.
with torch.supa.graph(supagraph, pool=self.graph_pool):
# `output` is managed by pytorch's supagraph pool
output = self.runnable(*args, **kwargs)
# (FIXME): torch.ops._C.weak_ref_tensor is not supported
# if self.supagraph_options.weak_ref_output:
# # by converting it to weak ref,
# # the original `output` will immediately be released
# # to save memory. It is only safe to do this for
# # the last graph in piecewise cuadgraph mode, because
# # the output of the last graph will not be used by
# # any other supa graph.
# output = weak_ref_tensors(output)
# here we always use weak ref for the output
# entry.output = weak_ref_tensors(output)
entry.output = output
entry.supagraph = supagraph
compilation_counter.num_cudagraph_captured += 1
# important: we need to return the output, rather than
# the weak ref of the output, so that pytorch can correctly
# manage the memory during supa graph capture
return output
if self.is_debugging_mode:
# check if the input addresses are the same
new_input_addresses = [
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
] + [
x.data_ptr()
for x in kwargs.values() if isinstance(x, torch.Tensor)
]
assert new_input_addresses == entry.input_addresses, (
f"Input addresses for supagraphs are different "
f"during replay. Expected {entry.input_addresses}, "
f"got {new_input_addresses}")
if self.vllm_config.parallel_config.world_size != 1:
# prevent SCCL capturing by using the same stream with SCCL
stream = torch.distributed.get_group_stream(
get_world_group().device_group)
else:
stream = torch_br.supa.Stream()
current_stream = torch.supa.current_stream()
with torch_br.supa.stream(stream):
entry.supagraph.replay()
event = torch.supa.Event()
stream.record_event(event)
current_stream.wait_event(event)
logger.debug(" ========Supa graph reply======== ")
logger.debug(" padded num_tokens size = %s",
batch_descriptor.num_tokens)
return entry.output

256
vllm_br/config/__init__.py Normal file
View File

@@ -0,0 +1,256 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
import os
import torch
import vllm
import vllm.envs as envs
from vllm.config import VllmConfig, logger
from vllm.config.compilation import CompilationLevel
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.utils import random_uuid
from .compilation import SUPAGraphMode
def supa_post_init(self):
"""Verify configs are valid & consistent with each other.
"""
self.try_verify_and_update_config()
if self.model_config is not None:
self.model_config.verify_with_parallel_config(self.parallel_config)
self.model_config.verify_dual_chunk_attention_config(self.load_config)
self.cache_config.verify_with_parallel_config(self.parallel_config)
if self.lora_config is not None:
self.lora_config.verify_with_cache_config(self.cache_config)
self.lora_config.verify_with_model_config(self.model_config)
if self.quant_config is None and self.model_config is not None:
self.quant_config = VllmConfig._get_quantization_config(
self.model_config, self.load_config)
from vllm.platforms import current_platform
if self.model_config is not None and \
self.scheduler_config.chunked_prefill_enabled and \
self.model_config.dtype == torch.float32 and \
current_platform.get_device_capability() == (7, 5):
logger.warning_once(
"Turing devices tensor cores do not support float32 matmul. "
"To workaround this limitation, vLLM will set 'ieee' input "
"precision for chunked prefill triton kernels.")
# If the user does not explicitly set a compilation level, then
# we use the default level. The default level depends on other
# settings (see the below code).
if self.compilation_config.level is None:
if envs.VLLM_USE_V1:
if (self.model_config is not None
and not self.model_config.enforce_eager):
self.compilation_config.level = CompilationLevel.PIECEWISE
else:
self.compilation_config.level = \
CompilationLevel.NO_COMPILATION
else:
# NB: Passing both --enforce-eager and a compilation level
# in V0 means the compilation level wins out.
self.compilation_config.level = CompilationLevel.NO_COMPILATION
# async tp is built on top of sequence parallelism
# and requires it to be enabled.
if self.compilation_config.pass_config.enable_async_tp:
self.compilation_config.pass_config.enable_sequence_parallelism = \
True
if self.compilation_config.pass_config.enable_sequence_parallelism:
self.compilation_config.custom_ops.append("+rms_norm")
if current_platform.support_static_graph_mode():
# if cudagraph_mode is not explicitly set by users, set default
# value
if self.compilation_config.cudagraph_mode is None:
if envs.VLLM_USE_V1 and self.compilation_config.level \
== CompilationLevel.PIECEWISE:
# default to full and piecewise for most models
self.compilation_config.cudagraph_mode = \
SUPAGraphMode.FULL_AND_PIECEWISE
# pooling models and encoder-decoder models
# do not support full cudagraphs
if self.model_config is not None and \
(self.model_config.pooler_config is not None
or self.model_config.is_encoder_decoder):
self.compilation_config.cudagraph_mode = \
SUPAGraphMode.PIECEWISE
else:
self.compilation_config.cudagraph_mode = SUPAGraphMode.NONE
# disable cudagraph when enforce eager execution
if self.model_config is not None and \
self.model_config.enforce_eager:
logger.info("Cudagraph is disabled under eager mode")
self.compilation_config.cudagraph_mode = SUPAGraphMode.NONE
elif envs.VLLM_USE_V1:
self.compilation_config.cudagraph_num_of_warmups = 1
self._set_cudagraph_sizes()
else:
self.compilation_config.cudagraph_mode = SUPAGraphMode.NONE
if self.cache_config.kv_sharing_fast_prefill:
if self.speculative_config is not None and \
self.speculative_config.use_eagle():
raise NotImplementedError(
"Fast prefill optimization for KV sharing is not "
"compatible with EAGLE as EAGLE requires correct logits "
"for all tokens while fast prefill gives incorrect logits "
"for prompt tokens.")
logger.warning_once(
"--kv-sharing-fast-prefill requires changes on model side for "
"correctness and to realize prefill savings. ")
disable_chunked_prefill_reasons: list[str] = []
if self.model_config:
if self.model_config.pooler_config:
pooling_type = self.model_config.pooler_config.pooling_type
if pooling_type is None or pooling_type.lower() != "last":
disable_chunked_prefill_reasons.append(
"Only \"last\" pooling supports chunked "
"prefill and prefix caching; disabling both.")
if not getattr(self.model_config.hf_config, "is_causal", True):
disable_chunked_prefill_reasons.append(
"Only models using causal attention supports chunked "
"prefill and prefix caching; disabling both.")
elif self.model_config.is_encoder_decoder:
self.scheduler_config.max_num_encoder_input_tokens = \
MULTIMODAL_REGISTRY.get_encdec_max_encoder_len(self.model_config)
logger.debug(
"Encoder-decoder model detected: setting "
"`max_num_encoder_input_tokens` to encoder length (%s)",
self.scheduler_config.max_num_encoder_input_tokens)
self.scheduler_config.disable_chunked_mm_input = True
disable_chunked_prefill_reasons.append(
"Encoder-decoder models do not support chunked prefill nor"
" prefix caching; disabling both.")
if (self.model_config.architecture
== "WhisperForConditionalGeneration" and
os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn"):
logger.warning("Whisper is known to have issues with "
"forked workers. If startup is hanging, "
"try setting 'VLLM_WORKER_MULTIPROC_METHOD' "
"to 'spawn'.")
if disable_chunked_prefill_reasons:
for reason in disable_chunked_prefill_reasons:
logger.info(reason)
self.scheduler_config.chunked_prefill_enabled = False
self.scheduler_config.long_prefill_token_threshold = 0
if self.cache_config is not None:
self.cache_config.enable_prefix_caching = False
if (self.kv_events_config is not None
and self.kv_events_config.enable_kv_cache_events
and not self.cache_config.enable_prefix_caching):
logger.warning(
"KV cache events are on, but prefix caching is not enabled."
"Use --enable-prefix-caching to enable.")
if (self.kv_events_config is not None
and self.kv_events_config.publisher != "null"
and not self.kv_events_config.enable_kv_cache_events):
logger.warning("KV cache events are disabled,"
"but the scheduler is configured to publish them."
"Modify KVEventsConfig.enable_kv_cache_events"
"to True to enable.")
current_platform.check_and_update_config(self)
# final check of cudagraph mode after platform-specific update
if envs.VLLM_USE_V1 and current_platform.is_cuda_alike():
if self.compilation_config.cudagraph_mode == SUPAGraphMode.FULL \
and self.model_config is not None and \
not self.model_config.disable_cascade_attn:
logger.info("SUPAGraphMode.FULL is not supported with "
"cascade attention currently. Disabling cascade"
"attention.")
self.model_config.disable_cascade_attn = True
if self.compilation_config.cudagraph_mode\
.requires_piecewise_compilation():
assert self.compilation_config.level == \
CompilationLevel.PIECEWISE, \
"Compilation level should be CompilationLevel.PIECEWISE "\
"when cudagraph_mode piecewise cudagraphs is used, "\
f"cudagraph_mode={self.compilation_config.cudagraph_mode}"
if self.parallel_config.enable_dbo:
a2a_backend = envs.VLLM_ALL2ALL_BACKEND
assert a2a_backend in \
["deepep_low_latency", "deepep_high_throughput"], \
"Microbatching currently only supports the deepep_low_latency and "\
f"deepep_high_throughput all2all backend. {a2a_backend} is not "\
"supported. To fix set the VLLM_ALL2ALL_BACKEND environment "\
"variable to deepep_low_latency or deepep_high_throughput and "\
"install the DeepEP kernels."
if not self.instance_id:
self.instance_id = random_uuid()[:5]
# Do this after all the updates to compilation_config.level
if envs.VLLM_USE_V1 and \
self.compilation_config.level == CompilationLevel.PIECEWISE:
self.compilation_config.set_splitting_ops_for_v1()
if (envs.VLLM_USE_V1
and not self.scheduler_config.disable_hybrid_kv_cache_manager):
# logger should only print warning message for hybrid models. As we
# can't know whether the model is hybrid or not now, so we don't log
# warning message here and will log it later.
if not current_platform.support_hybrid_kv_cache():
# Hybrid KV cache manager is not supported on non-GPU platforms.
self.scheduler_config.disable_hybrid_kv_cache_manager = True
if self.kv_transfer_config is not None:
# Hybrid KV cache manager is not compatible with KV transfer.
self.scheduler_config.disable_hybrid_kv_cache_manager = True
if self.kv_events_config is not None:
# Hybrid KV cache manager is not compatible with KV events.
self.scheduler_config.disable_hybrid_kv_cache_manager = True
if self.model_config is not None and \
self.model_config.attention_chunk_size is not None:
if self.speculative_config is not None and \
self.speculative_config.use_eagle():
# Hybrid KV cache manager is not yet supported with chunked
# local attention + eagle.
self.scheduler_config.disable_hybrid_kv_cache_manager = True
elif \
not envs.VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE:
logger.warning(
"There is a latency regression when using chunked local"
" attention with the hybrid KV cache manager. Disabling"
" it, by default. To enable it, set the environment "
"VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE=1.")
# Hybrid KV cache manager is not yet supported with chunked
# local attention.
self.scheduler_config.disable_hybrid_kv_cache_manager = True
vllm.config.VllmConfig.__post_init__ = supa_post_init

Binary file not shown.

Binary file not shown.

View File

@@ -0,0 +1,67 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
import enum
from vllm.config.compilation import CUDAGraphMode
class CompilationLevel:
# constants for the levels of the compilation process
NO_COMPILATION = 0
DYNAMO_AS_IS = 1
DYNAMO_ONCE = 2
PIECEWISE = 3
class SUPAGraphMode(enum.Enum):
""" Constants for the supagraph mode in CompilationConfig.
Meanwhile, the subset enum `NONE`, `PIECEWISE` and `FULL` are also
treated as concrete runtime mode for supagraph runtime dispatching.
"""
NONE = 0
PIECEWISE = 1
FULL = 2
FULL_DECODE_ONLY = (FULL, NONE)
FULL_AND_PIECEWISE = (FULL, PIECEWISE)
def decode_mode(self) -> 'SUPAGraphMode':
return SUPAGraphMode(self.value[0]) if \
self.separate_routine() else self
def mixed_mode(self) -> 'SUPAGraphMode':
return SUPAGraphMode(self.value[1]) if \
self.separate_routine() else self
def requires_piecewise_compilation(self) -> bool:
return (self.decode_mode() == SUPAGraphMode.PIECEWISE
or self.mixed_mode() == SUPAGraphMode.PIECEWISE)
def max_supagraph_mode(self) -> 'SUPAGraphMode':
return SUPAGraphMode(max(
self.value)) if self.separate_routine() else self
def has_full_supagraphs(self) -> bool:
return self.max_supagraph_mode() == SUPAGraphMode.FULL
# ychun, trick for CUDAGraphMode
def has_full_cudagraphs(self) -> bool:
cuda_graph_mode = CUDAGraphMode(max(
self.value)) if self.separate_routine() else self
return cuda_graph_mode == CUDAGraphMode.FULL
def separate_routine(self) -> bool:
return isinstance(self.value, tuple)

View File

@@ -0,0 +1,17 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from . import device_communicators # noqa: F401
from . import kv_transfer # noqa: F401

View File

@@ -0,0 +1,60 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from typing import Optional
import torch
import torch.distributed as dist
from vllm.distributed.device_communicators.base_device_communicator import (
DeviceCommunicatorBase)
from vllm.logger import logger
from vllm_br import envs
class SUPACommunicator(DeviceCommunicatorBase):
def __init__(self,
cpu_group: dist.ProcessGroup,
device: Optional[torch.device] = None,
device_group: Optional[dist.ProcessGroup] = None,
unique_name: str = ""):
super().__init__(cpu_group, device, device_group, unique_name)
self.device = torch.supa.current_device()
# TODO: Deprecate this method in the future if torch_br support gather
def gather(self,
input_: torch.Tensor,
dst: int = 0,
dim: int = -1) -> torch.Tensor:
"""All gather as gather"""
output_tensor = self.all_gather(input_, dim)
if self.rank_in_group == dst:
return output_tensor
return None
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
if envs.VLLM_BR_USE_FP32_ALL_REDUCE and input_ is not None and input_.dtype == torch.bfloat16:
logger.debug(
'[Patch] patch all_reduce: use fp32 all_reduce when env VLLM_BR_USE_FP32_ALL_REDUCE is set'
)
input_ = input_.to(torch.float32)
dist.all_reduce(input_, group=self.device_group)
input_ = input_.to(torch.bfloat16)
else:
dist.all_reduce(input_, group=self.device_group)
return input_

View File

@@ -0,0 +1,18 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from . import base_device_communicator # noqa: F401
from . import pysccl_wrapper # noqa: F401

View File

@@ -0,0 +1,44 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
import torch
import vllm
def supa_prepare_communication_buffer_for_model(
self, model: torch.nn.Module) -> None:
"""
Prepare the communication buffer for the model.
"""
if not self.use_all2all:
return
if not self.is_ep_communicator:
return
moe_modules = [
module for module in model.modules()
# TODO(bnell): Should use isinstance but can't. Maybe search for
# presence of quant_method.init_prepare_finalize?
if (module.__class__.__name__ == "FusedMoE"
or module.__class__.__name__ == "SharedFusedMoE")
]
for module in moe_modules:
module.quant_method.init_prepare_finalize(module)
vllm.distributed.device_communicators.base_device_communicator.DeviceCommunicatorBase.prepare_communication_buffer_for_model = supa_prepare_communication_buffer_for_model

View File

@@ -0,0 +1,420 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# This file is a pure Python wrapper for the SCCL library.
# The main purpose is to use SCCL combined with CUDA graph.
# Before writing this script, we tried the following approach:
# 1. We tried to use `cupy`, it calls SCCL correctly, but `cupy` itself
# often gets stuck when initializing the SCCL communicator.
# 2. We tried to use `torch.distributed`, but `torch.distributed.all_reduce`
# contains many other potential cuda APIs, that are not allowed during
# capturing the CUDA graph. For further details, please check
# https://discuss.pytorch.org/t/pytorch-cudagraph-with-nccl-operation-failed/ .
#
# Another rejected idea is to write a C/C++ binding for SCCL. It is usually
# doable, but we often encounter issues related with succl versions, and need
# to switch between different versions of SCCL. See
# https://github.com/NVIDIA/nccl/issues/1234 for more details.
# A C/C++ binding is not flexible enough to handle this. It requires
# recompilation of the code every time we want to switch between different
# versions. This current implementation, with a **pure** Python wrapper, is
# more flexible. We can easily switch between different versions of SCCL by
# changing the environment variable `VLLM_SCCL_SO_PATH`, or the `so_file`
# variable in the code.
import ctypes
import platform
from dataclasses import dataclass
from typing import Any, Optional
import torch
from torch.distributed import ReduceOp
from vllm.logger import logger
from vllm_br import envs
# === export types and functions from nccl to Python ===
# for the original nccl definition, please check
# https://github.com/NVIDIA/nccl/blob/master/src/nccl.h.in
succlResult_t = ctypes.c_int
succlComm_t = ctypes.c_void_p
class succlUniqueId(ctypes.Structure):
_fields_ = [("internal", ctypes.c_byte * 128)]
suStream_t = ctypes.c_void_p
buffer_type = ctypes.c_void_p
succlDataType_t = ctypes.c_int
class succlDataTypeEnum:
succlInt8 = 0
succlChar = 0
succlUint8 = 1
succlInt16 = 2
succlUint16 = 3
succlInt32 = 4
succlInt = 4
succlUint32 = 5
succlInt64 = 6
succlUint64 = 7
succlBfloat16 = 8
succlFloat32 = 9
succlFloat = 9
succlFloat64 = 10
succlDouble = 10
succlNumTypes = 11
@classmethod
def from_torch(cls, dtype: torch.dtype) -> int:
if dtype == torch.int8:
return cls.succlInt8
if dtype == torch.uint8:
return cls.succlUint8
if dtype == torch.int32:
return cls.succlInt32
if dtype == torch.int64:
return cls.succlInt64
if dtype == torch.float16:
return cls.succlBfloat16
if dtype == torch.float32:
return cls.succlFloat32
if dtype == torch.float64:
return cls.succlFloat64
if dtype == torch.bfloat16:
return cls.succlBfloat16
raise ValueError(f"Unsupported dtype: {dtype}")
succlRedOp_t = ctypes.c_int
class succlRedOpTypeEnum:
succlSum = 0
succlProd = 1
succlMax = 2
succlMin = 3
succlAvg = 4
succlNumOps = 5
@classmethod
def from_torch(cls, op: ReduceOp) -> int:
if op == ReduceOp.SUM:
return cls.succlSum
if op == ReduceOp.PRODUCT:
return cls.succlProd
if op == ReduceOp.MAX:
return cls.succlMax
if op == ReduceOp.MIN:
return cls.succlMin
if op == ReduceOp.AVG:
return cls.succlAvg
raise ValueError(f"Unsupported op: {op}")
@dataclass
class Function:
name: str
restype: Any
argtypes: list[Any]
class SCCLLibrary:
exported_functions = [
# const char* succlGetErrorString(succlResult_t result)
Function("succlGetErrorString", ctypes.c_char_p, [succlResult_t]),
# succlResult_t succlGetVersion(int *version);
Function("succlGetVersion", succlResult_t,
[ctypes.POINTER(ctypes.c_int)]),
# succlResult_t succlGetUniqueId(succlUniqueId* uniqueId);
Function("succlGetUniqueId", succlResult_t,
[ctypes.POINTER(succlUniqueId)]),
# succlResult_t succlCommInitRank(
# succlComm_t* comm, int nranks, succlUniqueId commId, int rank);
# note that succlComm_t is a pointer type, so the first argument
# is a pointer to a pointer
Function("succlCommInitRank", succlResult_t, [
ctypes.POINTER(succlComm_t), ctypes.c_int, succlUniqueId,
ctypes.c_int, ctypes.c_void_p
]),
# succlResult_t succlAllReduce(
# const void* sendbuff, void* recvbuff, size_t count,
# succlDataType_t datatype, succlRedOp_t op, succlComm_t comm,
# suStream_t stream);
# note that suStream_t is a pointer type, so the last argument
# is a pointer
Function("succlAllReduce", succlResult_t, [
buffer_type, buffer_type, ctypes.c_size_t, succlDataType_t,
succlRedOp_t, succlComm_t, suStream_t, ctypes.c_void_p
]),
# succlResult_t succlReduce(
# const void* sendbuff, void* recvbuff, size_t count,
# succlDataType_t datatype, succlRedOp_t op, int root,
# succlComm_t comm, suStream_t stream);
# note that suStream_t is a pointer type, so the last argument
# is a pointer
Function("succlReduce", succlResult_t, [
buffer_type, buffer_type, ctypes.c_size_t, succlDataType_t,
succlRedOp_t, ctypes.c_int, succlComm_t, suStream_t,
ctypes.c_void_p
]),
# succlResult_t succlAllGather(
# const void* sendbuff, void* recvbuff, size_t count,
# succlDataType_t datatype, succlComm_t comm,
# suStream_t stream);
# note that suStream_t is a pointer type, so the last argument
# is a pointer
Function("succlAllGather", succlResult_t, [
buffer_type, buffer_type, ctypes.c_size_t, succlDataType_t,
succlComm_t, suStream_t, ctypes.c_void_p
]),
# succlResult_t succlReduceScatter(
# const void* sendbuff, void* recvbuff, size_t count,
# succlDataType_t datatype, succlRedOp_t op, succlComm_t comm,
# suStream_t stream);
# note that suStream_t is a pointer type, so the last argument
# is a pointer
Function("succlReduceScatter", succlResult_t, [
buffer_type, buffer_type, ctypes.c_size_t, succlDataType_t,
succlRedOp_t, succlComm_t, suStream_t, ctypes.c_void_p
]),
# succlResult_t succlSend(
# const void* sendbuff, size_t count, succlDataType_t datatype,
# int dest, succlComm_t comm, suStream_t stream);
Function("succlSend", succlResult_t, [
buffer_type, ctypes.c_size_t, succlDataType_t, ctypes.c_int,
succlComm_t, suStream_t, ctypes.c_void_p
]),
# succlResult_t succlRecv(
# void* recvbuff, size_t count, succlDataType_t datatype,
# int src, succlComm_t comm, suStream_t stream);
Function("succlRecv", succlResult_t, [
buffer_type, ctypes.c_size_t, succlDataType_t, ctypes.c_int,
succlComm_t, suStream_t, ctypes.c_void_p
]),
# succlResult_t succlBroadcast(
# const void* sendbuff, void* recvbuff, size_t count,
# succlDataType_t datatype, int root, succlComm_t comm,
# suStream_t stream);
Function("succlBroadcast", succlResult_t, [
buffer_type, buffer_type, ctypes.c_size_t, succlDataType_t,
ctypes.c_int, succlComm_t, suStream_t, ctypes.c_void_p
]),
# be cautious! this is a collective call, it will block until all
# processes in the communicator have called this function.
# because Python object destruction can happen in random order,
# it is better not to call it at all.
# succlResult_t succlCommDestroy(succlComm_t comm);
Function("succlCommDestroy", succlResult_t, [succlComm_t]),
# succlResult_t succlGroupStart();
Function("succlGroupStart", succlResult_t, []),
# succlResult_t succlGroupEnd();
Function("succlGroupEnd", succlResult_t, []),
# Function("succldemoSetdevice", succlResult_t, [ctypes.c_int]),
]
# class attribute to store the mapping from the path to the library
# to avoid loading the same library multiple times
path_to_library_cache: dict[str, Any] = {}
# class attribute to store the mapping from library path
# to the corresponding dictionary
path_to_dict_mapping: dict[str, dict[str, Any]] = {}
def __init__(self, so_file: Optional[str] = None):
so_file = so_file or find_sccl_library()
try:
if so_file not in SCCLLibrary.path_to_dict_mapping:
lib = ctypes.CDLL(so_file)
SCCLLibrary.path_to_library_cache[so_file] = lib
self.lib = SCCLLibrary.path_to_library_cache[so_file]
except Exception as e:
logger.error(
"Failed to load SCCL library from %s. "
"It is expected if you are not running on NVIDIA/AMD GPUs."
"Otherwise, the sccl library might not exist, be corrupted "
"or it does not support the current platform %s. "
"If you already have the library, please set the "
"environment variable VLLM_SCCL_SO_PATH"
" to point to the correct sccl library path.", so_file,
platform.platform())
raise e
if so_file not in SCCLLibrary.path_to_dict_mapping:
_funcs: dict[str, Any] = {}
for func in SCCLLibrary.exported_functions:
f = getattr(self.lib, func.name)
f.restype = func.restype
f.argtypes = func.argtypes
_funcs[func.name] = f
SCCLLibrary.path_to_dict_mapping[so_file] = _funcs
self._funcs = SCCLLibrary.path_to_dict_mapping[so_file]
def succlGetErrorString(self, result: succlResult_t) -> str:
return self._funcs["succlGetErrorString"](result).decode("utf-8")
def SUCCL_CHECK(self, result: succlResult_t) -> None:
if result != 0:
error_str = self.succlGetErrorString(result)
raise RuntimeError(f"SCCL error: {error_str}")
def succlGetVersion(self) -> str:
version = ctypes.c_int()
self.SUCCL_CHECK(self._funcs["succlGetVersion"](ctypes.byref(version)))
version_str = str(version.value)
# something like 21903 --> "2.19.3"
major = version_str[0].lstrip("0")
minor = version_str[1:3].lstrip("0")
patch = version_str[3:].lstrip("0")
return f"{major}.{minor}.{patch}"
def succlGetUniqueId(self) -> succlUniqueId:
unique_id = succlUniqueId()
self.SUCCL_CHECK(self._funcs["succlGetUniqueId"](
ctypes.byref(unique_id)))
return unique_id
def unique_id_from_bytes(self, data: bytes) -> succlUniqueId:
if len(data) != 128:
raise ValueError(
f"Expected 128 bytes for succlUniqueId, got {len(data)} bytes")
unique_id = succlUniqueId()
ctypes.memmove(ctypes.addressof(unique_id.internal), data, 128)
return unique_id
def succlCommInitRank(self, world_size: int, unique_id: succlUniqueId,
rank: int) -> succlComm_t:
comm = succlComm_t()
result = self._funcs["succlCommInitRank"](ctypes.byref(comm),
world_size, unique_id, rank,
None)
self.SUCCL_CHECK(result)
return comm
# def succldemoSetdevice(self, deviceid:int):
# self._funcs["succldemoSetdevice"](deviceid)
def succlAllReduce(self, sendbuff: buffer_type, recvbuff: buffer_type,
count: int, datatype: int, op: int, comm: succlComm_t,
stream: suStream_t) -> None:
# `datatype` actually should be `succlDataType_t`
# and `op` should be `succlRedOp_t`
# both are aliases of `ctypes.c_int`
# when we pass int to a function, it will be converted to `ctypes.c_int`
# by ctypes automatically
self.SUCCL_CHECK(self._funcs["succlAllReduce"](sendbuff, recvbuff,
count, datatype, op,
comm, stream, None))
def succlReduce(self, sendbuff: buffer_type, recvbuff: buffer_type,
count: int, datatype: int, op: int, root: int,
comm: succlComm_t, stream: suStream_t) -> None:
# `datatype` actually should be `succlDataType_t`
# and `op` should be `succlRedOp_t`
# both are aliases of `ctypes.c_int`
# when we pass int to a function, it will be converted to `ctypes.c_int`
# by ctypes automatically
self.SUCCL_CHECK(self._funcs["succlReduce"](sendbuff, recvbuff, count,
datatype, op, root, comm,
stream, None))
def succlReduceScatter(self, sendbuff: buffer_type, recvbuff: buffer_type,
count: int, datatype: int, op: int,
comm: succlComm_t, stream: suStream_t) -> None:
# `datatype` actually should be `succlDataType_t`
# and `op` should be `succlRedOp_t`
# both are aliases of `ctypes.c_int`
# when we pass int to a function, it will be converted to `ctypes.c_int`
# by ctypes automatically
self.SUCCL_CHECK(self._funcs["succlReduceScatter"](sendbuff, recvbuff,
count, datatype, op,
comm, stream, None))
def succlAllGather(self, sendbuff: buffer_type, recvbuff: buffer_type,
count: int, datatype: int, comm: succlComm_t,
stream: suStream_t) -> None:
# `datatype` actually should be `succlDataType_t`
# which is an aliases of `ctypes.c_int`
# when we pass int to a function, it will be converted to `ctypes.c_int`
# by ctypes automatically
self.SUCCL_CHECK(self._funcs["succlAllGather"](sendbuff, recvbuff,
count, datatype, comm,
stream, None))
def succlSend(self, sendbuff: buffer_type, count: int, datatype: int,
dest: int, comm: succlComm_t, stream: suStream_t) -> None:
self.SUCCL_CHECK(self._funcs["succlSend"](sendbuff, count, datatype,
dest, comm, stream, None))
def succlRecv(self, recvbuff: buffer_type, count: int, datatype: int,
src: int, comm: succlComm_t, stream: suStream_t) -> None:
self.SUCCL_CHECK(self._funcs["succlRecv"](recvbuff, count, datatype,
src, comm, stream, None))
def succlBroadcast(self, sendbuff: buffer_type, recvbuff: buffer_type,
count: int, datatype: int, root: int, comm: succlComm_t,
stream: suStream_t) -> None:
self.SUCCL_CHECK(self._funcs["succlBroadcast"](sendbuff, recvbuff,
count, datatype, root,
comm, stream, None))
def succlCommDestroy(self, comm: succlComm_t) -> None:
self.SUCCL_CHECK(self._funcs["succlCommDestroy"](comm))
def succlGroupStart(self) -> None:
self.SUCCL_CHECK(self._funcs["succlGroupStart"]())
def succlGroupEnd(self) -> None:
self.SUCCL_CHECK(self._funcs["succlGroupEnd"]())
def find_sccl_library() -> str:
"""
We either use the library file specified by the `VLLM_SCCL_SO_PATH`
environment variable, or we find the library file brought by PyTorch.
After importing `torch`, `libsuccl.so.2` or `librccl.so.1` can be
found by `ctypes` automatically.
"""
so_file = envs.VLLM_SCCL_SO_PATH
# manually load the sccl library
if so_file:
logger.info(
"Found sccl from environment variable VLLM_SCCL_SO_PATH=%s",
so_file)
else:
raise ValueError("SCCL lib file not found.")
return so_file
__all__ = [
"SCCLLibrary", "succlDataTypeEnum", "succlRedOpTypeEnum", "succlUniqueId",
"succlComm_t", "suStream_t", "buffer_type"
]

View File

@@ -0,0 +1,17 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from . import kv_connector # noqa: F401

View File

@@ -0,0 +1,17 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from . import v1 # noqa: F401

View File

@@ -0,0 +1,17 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from . import base, p2p # noqa: F401

View File

@@ -0,0 +1,28 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
# from vllm.logger import logger
# from vllm.v1.core.sched.output import SchedulerOutput
# from vllm.v1.outputs import KVConnectorOutput
# class KVConnectorRole(enum.Enum):
# # Connector running in the scheduler process
# SCHEDULER = 0
# # Connector running in the worker process
# WORKER = 1
# vllm.distributed.kv_transfer.kv_connector.v1.base.KVConnectorRole=KVConnectorRole

View File

@@ -0,0 +1,17 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from . import p2p_succl_engine # noqa: F401
from . import p2p_succl_connector, tensor_memory_pool # noqa: F401

View File

@@ -0,0 +1,535 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional
import regex as re
import torch
import torch_br
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.factory import (
KVConnectorFactory)
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
from vllm.distributed.parallel_state import get_world_group
from vllm.logger import logger
from vllm.v1.attention.backends.mla.common import MLACommonMetadata
from vllm.v1.core.sched.output import SchedulerOutput
from vllm_br.distributed.kv_transfer.kv_connector.v1.p2p.p2p_succl_engine import (
P2pSucclEngine)
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.request import Request
@dataclass
class ReqMeta:
# Request Id
request_id: str
# Request block ids
block_ids: torch.Tensor
# Request num tokens
num_tokens: int
@staticmethod
def make_meta(request_id: str, token_ids: list[int], block_ids: list[int],
block_size: int) -> "ReqMeta":
block_ids_tensor = torch.tensor(block_ids)
return ReqMeta(
request_id=request_id,
block_ids=block_ids_tensor,
num_tokens=len(token_ids),
)
@dataclass
class P2pSucclConnectorMetadata(KVConnectorMetadata):
requests: list[ReqMeta]
def __init__(self):
self.requests = []
def add_request(
self,
request_id: str,
token_ids: list[int],
block_ids: list[int],
block_size: int,
) -> None:
self.requests.append(
ReqMeta.make_meta(request_id, token_ids, block_ids, block_size))
class P2pSucclConnector(KVConnectorBase_V1):
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
super().__init__(vllm_config=vllm_config, role=role)
self._block_size = vllm_config.cache_config.block_size
self._requests_need_load: dict[str, Any] = {}
self.config = vllm_config.kv_transfer_config
self.is_producer = self.config.is_kv_producer
self.chunked_prefill: dict[str, Any] = {}
self._rank = get_world_group().rank \
if role == KVConnectorRole.WORKER else 0
self._local_rank = get_world_group().local_rank \
if role == KVConnectorRole.WORKER else 0
self.p2p_nccl_engine = P2pSucclEngine(
local_rank=self._local_rank,
config=self.config,
hostname="",
port_offset=self._rank,
) if role == KVConnectorRole.WORKER else None
# ==============================
# Worker-side methods
# ==============================
def start_load_kv(self, forward_context: "ForwardContext",
**kwargs) -> None:
"""Start loading the KV cache from the connector buffer to vLLM's
paged KV buffer.
Args:
forward_context (ForwardContext): the forward context.
**kwargs: additional arguments for the load operation
Note:
The number of elements in kv_caches and layer_names should be
the same.
"""
# Only consumer/decode loads KV Cache
if self.is_producer:
return
assert self.p2p_nccl_engine is not None
attn_metadata = forward_context.attn_metadata
if attn_metadata is None:
return
def inject_kv_into_layer(
layer: torch.Tensor,
kv_cache: torch.Tensor,
block_ids: torch.Tensor,
request_id: str,
) -> None:
"""
Inject KV cache data into a given attention layer tensor.
This function updates `layer` in-place with values from `kv_cache`,
handling different backend layouts:
- MLA (Multi-Linear Attention) or FlashInfer: KV tensors are
indexed along the first dimension.
- FlashAttention: KV tensors are indexed along the second
dimension.
If the number of provided block IDs does not match the number of KV
blocks, only the overlapping portion is updated, and a warning is
logged.
Args:
layer (torch.Tensor): The attention layer KV tensor to update.
kv_cache (torch.Tensor): The KV cache tensor to inject.
block_ids (torch.Tensor): Indices of the blocks to update.
request_id (str): Request identifier used for logging.
Returns:
None. The function modifies `layer` in-place.
"""
if (isinstance(attn_metadata, MLACommonMetadata)
or layer.shape[1] == 2): # MLA or FlashInfer
num_block = kv_cache.shape[1]
block_len = min(len(block_ids), num_block)
block_ids = block_ids[:block_len]
th_gran = layer.shape[2] // self._block_size
for i, block_index in enumerate(block_ids.tolist()):
dst_0 = block_index // th_gran
dst_1 = (block_index % th_gran) * self._block_size
layer[0][dst_0][dst_1:dst_1 +
self._block_size] = kv_cache[0][i]
elif layer.shape[0] == 2: # FlashAttention
num_block = kv_cache.shape[1]
block_len = min(len(block_ids), num_block)
block_ids = block_ids[:block_len]
th_gran = layer.shape[2] // self._block_size
for i, block_index in enumerate(block_ids.tolist()):
dst_0 = block_index // th_gran
dst_1 = (block_index % th_gran) * self._block_size
layer[0][dst_0][dst_1:dst_1 +
self._block_size] = kv_cache[0][i]
layer[1][dst_0][dst_1:dst_1 +
self._block_size] = kv_cache[1][i]
# Get the metadata
metadata: KVConnectorMetadata = \
self._get_connector_metadata()
assert isinstance(metadata, P2pSucclConnectorMetadata)
if metadata is None:
return
# Load the KV for each request each layer
for request in metadata.requests:
request_id = request.request_id
ip, port = self.parse_request_id(request_id, False)
remote_address = ip + ":" + str(port + self._rank)
for layer_name in forward_context.no_compile_layers:
layer = forward_context.no_compile_layers[layer_name]
# Only process layers that have kv_cache
# attribute (attention layers) Skip non-attention
# layers like FusedMoE
kv_cache = getattr(layer, 'kv_cache', None)
if kv_cache is None:
continue
layer = kv_cache[forward_context.virtual_engine]
kv_cache = self.p2p_nccl_engine.recv_tensor(
request.request_id + "#" + layer_name, remote_address)
if kv_cache is None:
logger.warning("🚧kv_cache is None, %s", request.request_id)
continue
inject_kv_into_layer(layer, kv_cache, request.block_ids,
request.request_id)
def wait_for_layer_load(self, layer_name: str) -> None:
"""Blocking until the KV for a specific layer is loaded into vLLM's
paged buffer.
This interface will be useful for layer-by-layer pipelining.
Args:
layer_name: the name of that layer
"""
return
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata", **kwargs) -> None:
"""Start saving the KV cache of the layer from vLLM's paged buffer
to the connector.
Args:
layer_name (str): the name of the layer.
kv_layer (torch.Tensor): the paged KV buffer of the current
layer in vLLM.
attn_metadata (AttentionMetadata): the attention metadata.
**kwargs: additional arguments for the save operation.
"""
# Only producer/prefill saves KV Cache
if not self.is_producer:
return
assert self.p2p_nccl_engine is not None
def extract_kv_from_layer(
layer: torch.Tensor,
block_ids: torch.Tensor,
) -> torch.Tensor:
"""
Extract KV cache slices from a given attention layer tensor.
This function handles multiple backend layouts:
- MLA (Multi-Linear Attention) or FlashInfer: KV tensors are
indexed along the first dimension.
- FlashAttention: KV tensors are indexed along the second
dimension.
Args:
layer (torch.Tensor): The KV cache from the attention layer.
block_ids (torch.Tensor): Indices of blocks to extract.
Returns:
torch.Tensor: A tensor containing the extracted KV slices.
Returns None if the layout is unsupported.
"""
if (isinstance(attn_metadata, MLACommonMetadata)
or layer.shape[1] == 2): # MLA or FlashInfer
origin_shape = layer.shape
shape = [
origin_shape[0],
len(block_ids), self._block_size, origin_shape[3]
]
layer_send = torch_br._empty_ut_only(shape,
dtype=layer.dtype,
tensor_type='BUFFER_ANY',
device=layer.device)
th_gran = origin_shape[2] // self._block_size
for i, block_index in enumerate(block_ids.tolist()):
dst_0 = block_index // th_gran
dst_1 = (block_index % th_gran) * self._block_size
layer_send[0][i] = layer[0][dst_0][dst_1:dst_1 +
self._block_size]
return layer_send
if layer.shape[0] == 2: # FlashAttention
origin_shape = layer.shape
shape = [
origin_shape[0],
len(block_ids), self._block_size, origin_shape[3]
]
layer_send = torch_br._empty_ut_only(shape,
dtype=layer.dtype,
tensor_type='BUFFER_ANY',
device=layer.device)
th_gran = origin_shape[2] // self._block_size
for i, block_index in enumerate(block_ids.tolist()):
dst_0 = block_index // th_gran
dst_1 = (block_index % th_gran) * self._block_size
layer_send[0][i] = layer[0][dst_0][dst_1:dst_1 +
self._block_size]
layer_send[1][i] = layer[1][dst_0][dst_1:dst_1 +
self._block_size]
return layer_send
return None
connector_metadata = self._get_connector_metadata()
assert isinstance(connector_metadata, P2pSucclConnectorMetadata)
for request in connector_metadata.requests:
request_id = request.request_id
ip, port = self.parse_request_id(request_id, True)
remote_address = ip + ":" + str(port + self._rank)
kv_cache = extract_kv_from_layer(kv_layer, request.block_ids)
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, remote_address)
def wait_for_save(self):
if self.is_producer:
assert self.p2p_nccl_engine is not None
self.p2p_nccl_engine.wait_for_sent()
def get_finished(
self, finished_req_ids: set[str],
**kwargs) -> tuple[Optional[set[str]], Optional[set[str]]]:
"""
Notifies worker-side connector ids of requests that have
finished generating tokens.
Returns:
ids of requests that have finished asynchronous transfer,
tuple of (sending/saving ids, recving/loading ids).
The finished saves/sends req ids must belong to a set provided in a
call to this method (this call or a prior one).
"""
assert self.p2p_nccl_engine is not None
no_compile_layers = (
self._vllm_config.compilation_config.static_forward_context)
return self.p2p_nccl_engine.get_finished(finished_req_ids,
no_compile_layers)
# ==============================
# Scheduler-side methods
# ==============================
def get_num_new_matched_tokens(
self,
request: "Request",
num_computed_tokens: int,
) -> tuple[int, bool]:
"""
Get number of new tokens that can be loaded from the
external KV cache beyond the num_computed_tokens.
Args:
request (Request): the request object.
num_computed_tokens (int): the number of locally
computed tokens for this request
Returns:
the number of tokens that can be loaded from the
external KV cache beyond what is already computed.
"""
if self.is_producer:
return 0, False
num_external_tokens = (len(request.prompt_token_ids) - 1 -
num_computed_tokens)
if num_external_tokens < 0:
num_external_tokens = 0
return num_external_tokens, False
def update_state_after_alloc(self, request: "Request",
blocks: "KVCacheBlocks",
num_external_tokens: int):
"""
Update KVConnector state after block allocation.
"""
if not self.is_producer and num_external_tokens > 0:
self._requests_need_load[request.request_id] = (
request, blocks.get_block_ids()[0])
def build_connector_meta(
self,
scheduler_output: SchedulerOutput,
) -> KVConnectorMetadata:
"""Build the connector metadata for this step.
This function should NOT modify any fields in the scheduler_output.
Also, calling this function will reset the state of the connector.
Args:
scheduler_output (SchedulerOutput): the scheduler output object.
"""
meta = P2pSucclConnectorMetadata()
for new_req in scheduler_output.scheduled_new_reqs:
if self.is_producer:
num_scheduled_tokens = (
scheduler_output.num_scheduled_tokens)[new_req.req_id]
num_tokens = num_scheduled_tokens + new_req.num_computed_tokens
# the request's prompt is chunked prefill
if num_tokens < len(new_req.prompt_token_ids):
# 'CachedRequestData' has no attribute 'prompt_token_ids'
self.chunked_prefill[new_req.req_id] = (
new_req.block_ids[0], new_req.prompt_token_ids)
continue
# the request's prompt is not chunked prefill
meta.add_request(request_id=new_req.req_id,
token_ids=new_req.prompt_token_ids,
block_ids=new_req.block_ids[0],
block_size=self._block_size)
continue
if new_req.req_id in self._requests_need_load:
meta.add_request(request_id=new_req.req_id,
token_ids=new_req.prompt_token_ids,
block_ids=new_req.block_ids[0],
block_size=self._block_size)
self._requests_need_load.pop(new_req.req_id)
cached_reqs = scheduler_output.scheduled_cached_reqs
for i, req_id in enumerate(cached_reqs.req_ids):
num_computed_tokens = cached_reqs.num_computed_tokens[i]
new_block_ids = cached_reqs.new_block_ids[i]
resumed_from_preemption = cached_reqs.resumed_from_preemption[i]
if self.is_producer:
num_scheduled_tokens = (
scheduler_output.num_scheduled_tokens)[req_id]
num_tokens = (num_scheduled_tokens + num_computed_tokens)
assert req_id in self.chunked_prefill
block_ids = new_block_ids[0]
if not resumed_from_preemption:
block_ids = (self.chunked_prefill[req_id][0] + block_ids)
prompt_token_ids = self.chunked_prefill[req_id][1]
# the request's prompt is chunked prefill again
if num_tokens < len(prompt_token_ids):
self.chunked_prefill[req_id] = (block_ids,
prompt_token_ids)
continue
# the request's prompt is all prefilled finally
meta.add_request(request_id=req_id,
token_ids=prompt_token_ids,
block_ids=block_ids,
block_size=self._block_size)
self.chunked_prefill.pop(req_id, None)
continue
# NOTE(rob): here we rely on the resumed requests being
# the first N requests in the list scheduled_cache_reqs.
if not resumed_from_preemption:
break
if req_id in self._requests_need_load:
request, _ = self._requests_need_load.pop(req_id)
total_tokens = num_computed_tokens + 1
token_ids = request.all_token_ids[:total_tokens]
# NOTE(rob): For resumed req, new_block_ids is all
# of the block_ids for the request.
block_ids = new_block_ids[0]
meta.add_request(request_id=req_id,
token_ids=token_ids,
block_ids=block_ids,
block_size=self._block_size)
self._requests_need_load.clear()
return meta
def request_finished(
self,
request: "Request",
block_ids: list[int],
) -> tuple[bool, Optional[dict[str, Any]]]:
"""
Called when a request has finished, before its blocks are freed.
Returns:
True if the request is being saved/sent asynchronously and blocks
should not be freed until the request_id is returned from
get_finished().
Optional KVTransferParams to be included in the request outputs
returned by the engine.
"""
self.chunked_prefill.pop(request.request_id, None)
return False, None
# ==============================
# Static methods
# ==============================
@staticmethod
def parse_request_id(request_id: str, is_prefill=True) -> tuple[str, int]:
# Regular expression to match the string hostname and integer port
if is_prefill:
pattern = r"___decode_addr_(.*):(\d+)"
else:
pattern = r"___prefill_addr_(.*):(\d+)___"
# Use re.search to find the pattern in the request_id
match = re.search(pattern, request_id)
if match:
# Extract the ranks
ip = match.group(1)
port = int(match.group(2))
return ip, port
raise ValueError(
f"Request id {request_id} does not contain hostname and port")
@staticmethod
def check_tensors_except_dim(tensor1, tensor2, dim):
shape1 = tensor1.size()
shape2 = tensor2.size()
if len(shape1) != len(shape2) or not all(
s1 == s2
for i, (s1, s2) in enumerate(zip(shape1, shape2)) if i != dim):
raise NotImplementedError(
"Currently, only symmetric TP is supported. Asymmetric TP, PP,"
"and others will be supported in future PRs.")
KVConnectorFactory.register_connector(
"P2pSucclConnector",
"vllm_br.distributed.kv_transfer.kv_connector.v1.p2p.p2p_succl_connector",
"P2pSucclConnector")

View File

@@ -0,0 +1,572 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import logging
import os
import threading
import time
import typing
from collections import deque
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any, Optional
import msgpack
import torch
import torch_br
import zmq
from torch_br.supa._internal import get_tensor_info
from vllm.config import KVTransferConfig
# import vllm.distributed.kv_transfer.kv_connector.v1.p2p.p2p_nccl_engine
from vllm.utils import get_ip
from vllm_br.distributed.device_communicators.pysccl_wrapper import (
SCCLLibrary, buffer_type, succlComm_t, succlDataTypeEnum, suStream_t)
from vllm_br.distributed.kv_transfer.kv_connector.v1.p2p.tensor_memory_pool import ( # noqa: E501
TensorMemoryPool)
from vllm_br.platform import SUPAPlatform
logger = logging.getLogger(__name__)
DEFAULT_MEM_POOL_SIZE_GB = 1
@contextmanager
def set_p2p_succl_context(num_channels: str):
original_values: dict[str, Any] = {}
env_vars = [
'SUCCL_MAX_NCHANNELS',
'SUCCL_MIN_NCHANNELS',
'SUCCL_CUMEM_ENABLE',
'SUCCL_BUFFSIZE',
'SUCCL_PROTO', # LL,LL128,SIMPLE
'SUCCL_ALGO', # RING,TREE
]
for var in env_vars:
original_values[var] = os.environ.get(var)
logger.info("set_p2p_succl_context, original_values: %s", original_values)
try:
os.environ['SUCCL_MAX_NCHANNELS'] = num_channels
os.environ['SUCCL_MIN_NCHANNELS'] = num_channels
os.environ['SUCCL_CUMEM_ENABLE'] = '1'
yield
finally:
for var in env_vars:
if original_values[var] is not None:
os.environ[var] = original_values[var]
else:
os.environ.pop(var, None)
@dataclass
class SendQueueItem:
tensor_id: str
remote_address: str
tensor: torch.Tensor
class P2pSucclEngine:
def __init__(self,
local_rank: int,
config: KVTransferConfig,
hostname: str = "",
port_offset: int = 0,
library_path: Optional[str] = None) -> None:
self.config = config
self.rank = port_offset
self.local_rank = local_rank
self.device = torch.device(f"supa:{self.local_rank}")
if config is not None:
device_cursor = self.config.get_from_extra_config(
"device_cursor", 0)
self.device = torch.device(
f"supa:{self.local_rank + int(device_cursor)}")
SUPAPlatform.set_device(self.device)
self.succl = SCCLLibrary(library_path)
if not hostname:
hostname = get_ip()
port = int(self.config.kv_port) + port_offset
if port == 0:
raise ValueError("Port cannot be 0")
self._hostname = hostname
self._port = port
# Each card corresponds to a ZMQ address.
self.zmq_address = f"{self._hostname}:{self._port}"
# The `http_port` must be consistent with the port of OpenAI.
self.http_address = (
f"{self._hostname}:"
f"{self.config.kv_connector_extra_config['http_port']}")
# If `proxy_ip` or `proxy_port` is `""`,
# then the ping thread will not be enabled.
proxy_ip = self.config.get_from_extra_config("proxy_ip", "")
proxy_port = self.config.get_from_extra_config("proxy_port", "")
if proxy_ip == "" or proxy_port == "":
self.proxy_address = ""
else:
self.proxy_address = proxy_ip + ":" + proxy_port
self.context = zmq.Context()
self.router_socket = self.context.socket(zmq.ROUTER)
self.router_socket.bind(f"tcp://{self.zmq_address}")
self.poller = zmq.Poller()
self.poller.register(self.router_socket, zmq.POLLIN)
self.send_store_cv = threading.Condition()
self.send_queue_cv = threading.Condition()
self.recv_store_cv = threading.Condition()
self.send_stream = torch_br.supa.Stream()
self.recv_stream = self.send_stream
mem_pool_size_gb = float(
self.config.get_from_extra_config("mem_pool_size_gb",
DEFAULT_MEM_POOL_SIZE_GB))
self.pool = TensorMemoryPool(max_block_size=int(mem_pool_size_gb *
1024**3)) # GB
# The sending type includes tree mutually exclusive options:
# PUT, GET, PUT_ASYNC.
self.send_type = self.config.get_from_extra_config(
"send_type", "PUT_ASYNC")
if self.send_type == "GET":
# tensor_id: torch.Tensor
self.send_store: dict[str, torch.Tensor] = {}
else:
# PUT or PUT_ASYNC
# tensor_id: torch.Tensor
self.send_queue: deque[SendQueueItem] = deque()
if self.send_type == "PUT_ASYNC":
self._send_thread = threading.Thread(target=self.send_async,
daemon=True)
self._send_thread.start()
# tensor_id: torch.Tensor/(addr, dtype, shape)
self.recv_store: dict[str, Any] = {}
self.recv_request_id_to_tensor_ids: dict[str, set[str]] = {}
self.send_request_id_to_tensor_ids: dict[str, set[str]] = {}
self.socks: dict[str, Any] = {} # remote_address: client socket
self.comms: dict[str, Any] = {} # remote_address: (succlComm_t, rank)
self.buffer_size = 0
self.buffer_size_threshold = float(self.config.kv_buffer_size)
self.succl_num_channels = self.config.get_from_extra_config(
"nccl_num_channels", "8")
self._listener_thread = threading.Thread(
target=self.listen_for_requests, daemon=True)
self._listener_thread.start()
self._ping_thread = None
if port_offset == 0 and self.proxy_address != "":
self._ping_thread = threading.Thread(target=self.ping, daemon=True)
self._ping_thread.start()
logger.info(
"💯P2pSucclEngine init, rank:%d, local_rank:%d, http_address:%s, "
"zmq_address:%s, proxy_address:%s, send_type:%s, buffer_size_"
"threshold:%.2f, nccl_num_channels:%s", self.rank, self.local_rank,
self.http_address, self.zmq_address, self.proxy_address,
self.send_type, self.buffer_size_threshold,
self.succl_num_channels)
def create_connect(self, remote_address: typing.Optional[str] = None):
assert remote_address is not None
if remote_address not in self.socks:
sock = self.context.socket(zmq.DEALER)
sock.setsockopt_string(zmq.IDENTITY, self.zmq_address)
sock.connect(f"tcp://{remote_address}")
self.socks[remote_address] = sock
if remote_address in self.comms:
logger.info("👋comm exists, remote_address:%s, comms:%s",
remote_address, self.comms)
return sock, self.comms[remote_address]
unique_id = self.succl.succlGetUniqueId()
data = {"cmd": "NEW", "unique_id": bytes(unique_id.internal)}
sock.send(msgpack.dumps(data))
rank = 0
SUPAPlatform.set_device(self.device)
comm: succlComm_t = self.succl.succlCommInitRank(
2, unique_id, rank)
self.comms[remote_address] = (comm, rank)
logger.info("🤝succlCommInitRank Success, %s👉%s, MyRank:%s",
self.zmq_address, remote_address, rank)
return self.socks[remote_address], self.comms[remote_address]
def send_tensor(
self,
tensor_id: str,
tensor: torch.Tensor,
remote_address: typing.Optional[str] = None,
) -> bool:
if remote_address is None:
with self.recv_store_cv:
self.recv_store[tensor_id] = tensor
self.recv_store_cv.notify()
return True
item = SendQueueItem(tensor_id=tensor_id,
remote_address=remote_address,
tensor=tensor)
if self.send_type == "PUT":
return self.send_sync(item)
if self.send_type == "PUT_ASYNC":
with self.send_queue_cv:
self.send_queue.append(item)
self.send_queue_cv.notify()
return True
# GET
with self.send_store_cv:
tensor_size = tensor.element_size() * tensor.numel()
if tensor_size > self.buffer_size_threshold:
logger.warning(
"❗[GET]tensor_id:%s, tensor_size:%d, is greater than"
"buffer size threshold :%d, skip send to %s, rank:%d",
tensor_id, tensor_size, self.buffer_size_threshold,
remote_address, self.rank)
return False
while (self.buffer_size + tensor_size
> self.buffer_size_threshold):
assert len(self.send_store) > 0
oldest_tensor_id = next(iter(self.send_store))
oldest_tensor = self.send_store.pop(oldest_tensor_id)
oldest_tensor_size = oldest_tensor.element_size(
) * oldest_tensor.numel()
self.buffer_size -= oldest_tensor_size
logger.debug(
"⛔[GET]Send to %s, tensor_id:%s, tensor_size:%d,"
" buffer_size:%d, oldest_tensor_size:%d, rank:%d",
remote_address, tensor_id, tensor_size, self.buffer_size,
oldest_tensor_size, self.rank)
self.send_store[tensor_id] = tensor
self.buffer_size += tensor_size
logger.debug(
"🔵[GET]Send to %s, tensor_id:%s, tensor_size:%d, "
"shape:%s, rank:%d, buffer_size:%d(%.2f%%)", remote_address,
tensor_id, tensor_size, tensor.shape, self.rank,
self.buffer_size,
self.buffer_size / self.buffer_size_threshold * 100)
return True
def recv_tensor(
self,
tensor_id: str,
remote_address: typing.Optional[str] = None,
) -> torch.Tensor:
if self.send_type == "PUT" or self.send_type == "PUT_ASYNC":
start_time = time.time()
with self.recv_store_cv:
while tensor_id not in self.recv_store:
self.recv_store_cv.wait()
tensor = self.recv_store[tensor_id]
if tensor is not None:
if isinstance(tensor, tuple):
addr, dtype, shape = tensor
tensor = self.pool.load_tensor(addr, dtype, shape,
self.device)
else:
self.buffer_size -= (tensor.element_size() *
tensor.numel())
else:
duration = time.time() - start_time
logger.warning(
"🔴[PUT]Recv From %s, tensor_id:%s, duration:%.3fms, "
"rank:%d", remote_address, tensor_id, duration * 1000,
self.rank)
return tensor
# GET
if remote_address is None:
return None
if remote_address not in self.socks:
self.create_connect(remote_address)
sock = self.socks[remote_address]
comm, rank = self.comms[remote_address]
data = {"cmd": "GET", "tensor_id": tensor_id}
sock.send(msgpack.dumps(data))
message = sock.recv()
data = msgpack.loads(message)
if data["ret"] != 0:
logger.warning("🔴[GET]Recv From %s, tensor_id: %s, ret: %d",
remote_address, tensor_id, data["ret"])
return None
with torch_br.supa.stream(self.recv_stream):
tensor = torch_br._empty_ut_only(data["shape"],
dtype=getattr(
torch, data["dtype"]),
tensor_type='BUFFER_ANY',
device=self.device)
self.recv(comm, tensor, rank ^ 1, self.recv_stream)
return tensor
def listen_for_requests(self):
while True:
socks = dict(self.poller.poll())
if self.router_socket not in socks:
continue
remote_address, message = self.router_socket.recv_multipart()
data = msgpack.loads(message)
if data["cmd"] == "NEW":
unique_id = self.succl.unique_id_from_bytes(
bytes(data["unique_id"]))
rank = 1
SUPAPlatform.set_device(self.device)
comm: succlComm_t = self.succl.succlCommInitRank(
2, unique_id, rank)
self.comms[remote_address.decode()] = (comm, rank)
logger.info("🤝suclCommInitRank Success, %s👈%s, MyRank:%s",
self.zmq_address, remote_address.decode(), rank)
elif data["cmd"] == "PUT":
tensor_id = data["tensor_id"]
try:
with torch_br.supa.stream(self.recv_stream):
tensor = torch_br._empty_ut_only(
data["shape"],
dtype=getattr(torch, data["dtype"]),
tensor_type='BUFFER_ANY',
device=self.device)
self.router_socket.send_multipart([remote_address, b"0"])
comm, rank = self.comms[remote_address.decode()]
self.recv(comm, tensor, rank ^ 1, self.recv_stream)
tensor_size = tensor.element_size() * tensor.numel()
if (self.buffer_size + tensor_size
> self.buffer_size_threshold):
# Store Tensor in memory pool
addr = self.pool.store_tensor(tensor)
tensor = (addr, tensor.dtype, tensor.shape)
logger.warning(
"🔴[PUT]Recv Tensor, Out Of Threshold, "
"%s👈%s, data:%s, addr:%d", self.zmq_address,
remote_address.decode(), data, addr)
else:
self.buffer_size += tensor_size
except torch.cuda.OutOfMemoryError:
self.router_socket.send_multipart([remote_address, b"1"])
tensor = None
logger.warning(
"🔴[PUT]Recv Tensor, Out Of Memory, %s👈%s, "
"data:%s", self.zmq_address, remote_address.decode(),
data)
with self.recv_store_cv:
self.recv_store[tensor_id] = tensor
self.have_received_tensor_id(tensor_id)
self.recv_store_cv.notify()
elif data["cmd"] == "GET":
tensor_id = data["tensor_id"]
with self.send_store_cv:
tensor = self.send_store.pop(tensor_id, None)
if tensor is not None:
data = {
"ret": 0,
"shape": tensor.shape,
"dtype": str(tensor.dtype).replace("torch.", ""),
"tensor_type": get_tensor_info(tensor)[0]['layout']
}
# LRU
self.send_store[tensor_id] = tensor
self.have_sent_tensor_id(tensor_id)
else:
data = {"ret": 1}
self.router_socket.send_multipart(
[remote_address, msgpack.dumps(data)])
if data["ret"] == 0:
comm, rank = self.comms[remote_address.decode()]
self.send(comm, tensor.to(self.device), rank ^ 1,
self.send_stream)
else:
logger.warning(
"🚧Unexpected, Received message from %s, data:%s",
remote_address, data)
def have_sent_tensor_id(self, tensor_id: str):
request_id = tensor_id.split('#')[0]
if request_id not in self.send_request_id_to_tensor_ids:
self.send_request_id_to_tensor_ids[request_id] = set()
self.send_request_id_to_tensor_ids[request_id].add(tensor_id)
def have_received_tensor_id(self, tensor_id: str):
request_id = tensor_id.split('#')[0]
if request_id not in self.recv_request_id_to_tensor_ids:
self.recv_request_id_to_tensor_ids[request_id] = set()
self.recv_request_id_to_tensor_ids[request_id].add(tensor_id)
def send_async(self):
while True:
with self.send_queue_cv:
while not self.send_queue:
self.send_queue_cv.wait()
item = self.send_queue.popleft()
if not self.send_queue:
self.send_queue_cv.notify()
self.send_sync(item)
def wait_for_sent(self):
if self.send_type == "PUT_ASYNC":
start_time = time.time()
with self.send_queue_cv:
while self.send_queue:
self.send_queue_cv.wait()
duration = time.time() - start_time
logger.debug(
"🚧[PUT_ASYNC]It took %.3fms to wait for the send_queue"
" to be empty, rank:%d", duration * 1000, self.rank)
def send_sync(self, item: SendQueueItem) -> bool:
if item.remote_address is None:
return False
if item.remote_address not in self.socks:
self.create_connect(item.remote_address)
tensor = item.tensor
sock = self.socks[item.remote_address]
comm, rank = self.comms[item.remote_address]
data = {
"cmd": "PUT",
"tensor_id": item.tensor_id,
"shape": tensor.shape,
"dtype": str(tensor.dtype).replace("torch.", ""),
"tensor_type": get_tensor_info(tensor)[0]['layout']
}
sock.send(msgpack.dumps(data))
response = sock.recv()
if response != b"0":
logger.error(
"🔴Send Tensor, Peer Out Of Memory/Threshold, %s 👉 %s, "
"MyRank:%s, data:%s, tensor:%s, size:%fGB, response:%s",
self.zmq_address, item.remote_address, rank, data,
tensor.shape,
tensor.element_size() * tensor.numel() / 1024**3,
response.decode())
return False
self.send(comm, tensor.to(self.device), rank ^ 1, self.send_stream)
if self.send_type == "PUT_ASYNC":
self.have_sent_tensor_id(item.tensor_id)
return True
def get_finished(
self, finished_req_ids: set[str], no_compile_layers
) -> tuple[Optional[set[str]], Optional[set[str]]]:
"""
Notifies worker-side connector ids of requests that have
finished generating tokens.
Returns:
ids of requests that have finished asynchronous transfer,
tuple of (sending/saving ids, recving/loading ids).
The finished saves/sends req ids must belong to a set provided in a
call to this method (this call or a prior one).
"""
# Clear the buffer upon request completion.
for request_id in finished_req_ids:
for layer_name in no_compile_layers:
tensor_id = request_id + "#" + layer_name
if tensor_id in self.recv_store:
with self.recv_store_cv:
tensor = self.recv_store.pop(tensor_id, None)
self.send_request_id_to_tensor_ids.pop(
request_id, None)
self.recv_request_id_to_tensor_ids.pop(
request_id, None)
if isinstance(tensor, tuple):
addr, _, _ = tensor
self.pool.free(addr)
# TODO:Retrieve requests that have already sent the KV cache.
finished_sending: set[str] = set()
# TODO:Retrieve requests that have already received the KV cache.
finished_recving: set[str] = set()
return finished_sending or None, finished_recving or None
def ping(self):
sock = self.context.socket(zmq.DEALER)
sock.setsockopt_string(zmq.IDENTITY, self.zmq_address)
logger.debug("ping start, zmq_address:%s", self.zmq_address)
sock.connect(f"tcp://{self.proxy_address}")
data = {
"type": "P" if self.config.is_kv_producer else "D",
"http_address": self.http_address,
"zmq_address": self.zmq_address
}
while True:
sock.send(msgpack.dumps(data))
time.sleep(3)
def send(self, comm, tensor: torch.Tensor, dst: int, stream=None):
assert tensor.device == self.device, (
f"this succl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}")
if stream is None:
stream = torch_br.supa.Stream()
with torch_br.supa.stream(stream):
self.succl.succlSend(buffer_type(tensor.data_ptr()),
tensor.numel(),
succlDataTypeEnum.from_torch(tensor.dtype),
dst, comm, suStream_t(stream.supa_stream))
stream.synchronize()
def recv(self, comm, tensor: torch.Tensor, src: int, stream=None):
assert tensor.device == self.device, (
f"this succl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}")
if stream is None:
stream = torch_br.supa.Stream()
with torch_br.supa.stream(stream):
self.succl.succlRecv(buffer_type(tensor.data_ptr()),
tensor.numel(),
succlDataTypeEnum.from_torch(tensor.dtype),
src, comm, suStream_t(stream.supa_stream))
stream.synchronize()
def close(self) -> None:
self._listener_thread.join()
if self.send_type == "PUT_ASYNC":
self._send_thread.join()
if self._ping_thread is not None:
self._ping_thread.join()

View File

@@ -0,0 +1,280 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import atexit
import ctypes
import math
from dataclasses import dataclass
import torch
from vllm.logger import logger
@dataclass
class MemoryBlock:
size: int
addr: int
"""A memory pool for managing pinned host memory allocations for tensors.
This class implements a buddy allocation system to efficiently manage pinned
host memory for tensor storage. It supports allocation, deallocation, and
tensor storage/retrieval operations.
Key Features:
- Uses power-of-two block sizes for efficient buddy allocation
- Supports splitting and merging of memory blocks
- Provides methods to store CUDA tensors in pinned host memory
- Allows loading tensors from pinned memory back to device
- Automatically cleans up memory on destruction
Attributes:
max_block_size (int): Maximum block size (rounded to nearest power of two)
min_block_size (int): Minimum block size (rounded to nearest power of two)
free_lists (dict): Dictionary of free memory blocks by size
allocated_blocks (dict): Dictionary of currently allocated blocks
base_tensor (torch.Tensor): Base pinned memory tensor
base_address (int): Base memory address of the pinned memory region
Example:
>>> pool = TensorMemoryPool(max_block_size=1024*1024)
>>> tensor = torch.randn(100, device='cuda')
>>> addr = pool.store_tensor(tensor)
>>> loaded_tensor = pool.load_tensor(addr, tensor.dtype,
... tensor.shape, 'cuda')
>>> pool.free(addr)
"""
class TensorMemoryPool:
"""Initializes the memory pool with given size constraints.
Args:
max_block_size (int): Maximum size of memory blocks to manage
min_block_size (int, optional): Minimum size of memory blocks
to manage. Defaults to 512.
Raises:
ValueError: If block sizes are invalid or max_block_size is less
than min_block_size
"""
def __init__(self, max_block_size: int, min_block_size: int = 512):
if max_block_size <= 0 or min_block_size <= 0:
raise ValueError("Block sizes must be positive")
if max_block_size < min_block_size:
raise ValueError(
"Max block size must be greater than min block size")
self.max_block_size = self._round_to_power_of_two(max_block_size)
self.min_block_size = self._round_to_power_of_two(min_block_size)
self.free_lists: dict[int, dict[int, MemoryBlock]] = {}
self.allocated_blocks: dict[int, MemoryBlock] = {}
self._initialize_free_lists()
self._allocate_pinned_memory()
atexit.register(self.cleanup)
def _round_to_power_of_two(self, size: int) -> int:
return 1 << (size - 1).bit_length()
def _initialize_free_lists(self):
size = self.max_block_size
while size >= self.min_block_size:
self.free_lists[size] = {}
size //= 2
def _allocate_pinned_memory(self):
self.base_tensor = torch.empty(self.max_block_size // 4,
dtype=torch.float32,
pin_memory=True)
self.base_address = self.base_tensor.data_ptr()
initial_block = MemoryBlock(size=self.max_block_size,
addr=self.base_address)
self.free_lists[self.max_block_size][
initial_block.addr] = initial_block
logger.debug("TensorMemoryPool, base_address:%d, max_block_size:%d",
self.base_address, self.max_block_size)
def allocate(self, size: int) -> int:
"""Allocates a memory block of at least the requested size.
Args:
size (int): Minimum size of memory to allocate
Returns:
int: Address of the allocated memory block
Raises:
ValueError: If size is invalid or insufficient memory is available
"""
if size <= 0:
raise ValueError("Allocation size must be positive")
required_size = self._round_to_power_of_two(
max(size, self.min_block_size))
if required_size > self.max_block_size:
raise ValueError("Requested size exceeds maximum block size")
current_size = required_size
while current_size <= self.max_block_size:
if self.free_lists[current_size]:
_, block = self.free_lists[current_size].popitem()
self._split_block(block, required_size)
self.allocated_blocks[block.addr] = block
return block.addr
current_size *= 2
raise ValueError("Insufficient memory")
def _split_block(self, block: MemoryBlock, required_size: int):
while (block.size > required_size
and block.size // 2 >= self.min_block_size):
buddy_size = block.size // 2
buddy_addr = block.addr + buddy_size
buddy = MemoryBlock(size=buddy_size, addr=buddy_addr)
block.size = buddy_size
self.free_lists[buddy_size][buddy.addr] = buddy
def free(self, addr: int):
"""Frees an allocated memory block.
Args:
addr (int): Address of the block to free
Raises:
ValueError: If address is invalid or not allocated
"""
if addr not in self.allocated_blocks:
raise ValueError("Invalid address to free")
block = self.allocated_blocks.pop(addr)
self._merge_buddies(block)
def _merge_buddies(self, block: MemoryBlock):
MAX_MERGE_DEPTH = 30
depth = 0
while depth < MAX_MERGE_DEPTH:
buddy_offset = block.size if (block.addr - self.base_address) % (
2 * block.size) == 0 else -block.size
buddy_addr = block.addr + buddy_offset
buddy = self.free_lists[block.size].get(buddy_addr)
if buddy:
del self.free_lists[buddy.size][buddy.addr]
merged_addr = min(block.addr, buddy.addr)
merged_size = block.size * 2
block = MemoryBlock(size=merged_size, addr=merged_addr)
depth += 1
else:
break
self.free_lists[block.size][block.addr] = block
def store_tensor(self, tensor: torch.Tensor) -> int:
"""Stores a CUDA tensor in pinned host memory.
Args:
tensor (torch.Tensor): CUDA tensor to store
Returns:
int: Address where the tensor is stored
Raises:
ValueError: If tensor is not on CUDA or allocation fails
"""
if not tensor.is_cuda:
raise ValueError("Only CUDA tensors can be stored")
size = tensor.element_size() * tensor.numel()
addr = self.allocate(size)
block = self.allocated_blocks[addr]
if block.size < size:
self.free(addr)
raise ValueError(
f"Allocated block size {block.size} is smaller than "
f"required size {size}")
try:
buffer = (ctypes.c_byte * block.size).from_address(block.addr)
cpu_tensor = torch.frombuffer(buffer,
dtype=tensor.dtype,
count=tensor.numel()).reshape(
tensor.shape)
except ValueError as err:
self.free(addr)
raise ValueError(f"Failed to create tensor view: {err}") from err
cpu_tensor.copy_(tensor)
return addr
def load_tensor(self, addr: int, dtype: torch.dtype,
shape: tuple[int, ...], device) -> torch.Tensor:
"""Loads a tensor from pinned host memory to the specified device.
Args:
addr (int): Address where tensor is stored
dtype (torch.dtype): Data type of the tensor
shape (tuple[int, ...]): Shape of the tensor
device: Target device for the loaded tensor
Returns:
torch.Tensor: The loaded tensor on the specified device
Raises:
ValueError: If address is invalid or sizes don't match
"""
if addr not in self.allocated_blocks:
raise ValueError("Invalid address to load")
block = self.allocated_blocks[addr]
num_elements = math.prod(shape)
dtype_size = torch.tensor([], dtype=dtype).element_size()
required_size = num_elements * dtype_size
if required_size > block.size:
raise ValueError("Requested tensor size exceeds block size")
buffer = (ctypes.c_byte * block.size).from_address(block.addr)
cpu_tensor = torch.frombuffer(buffer, dtype=dtype,
count=num_elements).reshape(shape)
cuda_tensor = torch.empty(shape, dtype=dtype, device=device)
cuda_tensor.copy_(cpu_tensor)
return cuda_tensor
def cleanup(self):
"""Cleans up all memory resources and resets the pool state."""
self.free_lists.clear()
self.allocated_blocks.clear()
if hasattr(self, 'base_tensor'):
del self.base_tensor
def __del__(self):
self.cleanup()

View File

@@ -0,0 +1,473 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any, Optional, Union
import torch
import torch.distributed
import torch_br
import vllm
import vllm.distributed.parallel_state
from vllm.distributed import GroupCoordinator
from vllm.distributed.parallel_state import (_WORLD, TensorMetadata,
_split_tensor_dict, get_pp_group,
get_tp_group, get_world_group,
init_model_parallel_group, logger)
from vllm_br import envs
@dataclass
class GraphCaptureContext:
stream: torch_br.supa.Stream
@contextmanager
#@patch_to(GroupCoordinator.graph_capture)
def graph_capture_(self,
graph_capture_context: Optional[GraphCaptureContext] = None
):
if graph_capture_context is None:
stream = torch_br.supa.Stream()
graph_capture_context = GraphCaptureContext(stream)
else:
stream = graph_capture_context.stream
# only supa uses this function,
# so we don't abstract it into the base class
#maybe_ca_context = nullcontext()
#from vllm_br.distributed.communicator import SUPACommunicator
#if self.device_communicator is not None:
# assert isinstance(self.device_communicator, SUPACommunicator)
# ca_comm = self.device_communicator.ca_comm
# if ca_comm is not None:
# maybe_ca_context = ca_comm.capture() # type: ignore
# ensure all initialization operations complete before attempting to
# capture the graph on another stream
curr_stream = torch_br.supa.current_stream()
if curr_stream != stream:
stream.wait_stream(curr_stream)
with torch_br.supa.stream(stream):
yield graph_capture_context
vllm.distributed.parallel_state.GroupCoordinator.graph_capture = graph_capture_
@contextmanager
#@patch_to(graph_capture)
def graph_capture_supa(device: torch.device):
"""
`graph_capture` is a context manager which should surround the code that
is capturing the SUPA graph. Its main purpose is to ensure that the
some operations will be run after the graph is captured, before the graph
is replayed. It returns a `GraphCaptureContext` object which contains the
necessary data for the graph capture. Currently, it only contains the
stream that the graph capture is running on. This stream is set to the
current SUPA stream when the context manager is entered and reset to the
default stream when the context manager is exited. This is to ensure that
the graph capture is running on a separate stream from the default stream,
in order to explicitly distinguish the kernels to capture
from other kernels possibly launched on background in the default stream.
"""
context = GraphCaptureContext(torch_br.supa.Stream(device=device))
with get_tp_group().graph_capture(context), get_pp_group().graph_capture(
context):
yield context
vllm.distributed.parallel_state.graph_capture = graph_capture_supa
def is_global_first_rank() -> bool:
"""
Check if the current process is the first rank globally across all
parallelism strategies (PP, TP, DP, EP, etc.).
Unlike group-specific checks like `get_tensor_model_parallel_rank() == 0`
or `get_pp_group().is_first_rank`, this function checks the global rank
across all parallelism dimensions.
Returns:
bool: True if this is the global first rank (rank 0), False otherwise.
Returns True if distributed is not initialized (single process).
"""
try:
# If world group is available, use it for the most accurate check
if _WORLD is not None:
return _WORLD.is_first_rank
# If torch distributed is not initialized, assume single process
if not torch.distributed.is_initialized():
return True
# Fallback to torch's global rank
return torch.distributed.get_rank() == 0
except Exception:
# If anything goes wrong, assume this is the first rank
return True
def generate_multi_node_parallel_groups(
total_procs: int,
tp_size: int,
pp_size: int,
dp_size: int,
) -> dict:
if total_procs == 16 and tp_size == 8 and pp_size == 2 and dp_size == 1:
tp_groups = [[0, 1, 2, 3, 8, 9, 10, 11], [4, 5, 6, 7, 12, 13, 14, 15]]
pp_groups = [[0, 4], [1, 5], [2, 6], [3, 7], [8, 12], [9, 13],
[10, 14], [11, 15]]
dp_groups = [[0], [1], [2], [3], [4], [5], [6], [7], [8], [9], [10],
[11], [12], [13], [14], [15]]
ep_groups = [[0, 1, 2, 3, 8, 9, 10, 11], [4, 5, 6, 7, 12, 13, 14, 15]]
else:
raise ValueError(
"Unsupported VLLM_BR_ENABLE_TP_GROUPS_IN_SUPERNODE parallel config of"
" tp_size: {tp_size} pp_size: {pp_size} dp_size: {dp_size}"
"Currently only 'tp8pp2dp1' is allowed.")
return {
"tp_groups": tp_groups,
"pp_groups": pp_groups,
"dp_groups": dp_groups,
"ep_groups": ep_groups,
}
# sync v0.11 api update, while code logic possibly need sync with vllm original code implementation
def initialize_model_parallel_cross_tp(
tensor_model_parallel_size: int = 1,
pipeline_model_parallel_size: int = 1,
decode_context_model_parallel_size: Optional[int] = 1,
backend: Optional[str] = None,
) -> None:
"""
Initialize model parallel groups.
Arguments:
tensor_model_parallel_size: number of GPUs used for tensor model
parallelism.
pipeline_model_parallel_size: number of GPUs used for pipeline model
parallelism.
backend: name of torch distributed communication backend.
Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we
use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
the model pipeline. The present function will
create 4 tensor model-parallel groups and 2 pipeline model-parallel groups:
4 tensor model-parallel groups:
[g0, g1], [g2, g3], [g4, g5], [g6, g7]
2 pipeline model-parallel groups:
[g0, g2, g4, g6], [g1, g3, g5, g7]
Note that for efficiency, the caller should make sure adjacent ranks
are on the same DGX box. For example if we are using 2 DGX-1 boxes
with a total of 16 GPUs, rank 0 to 7 belong to the first box and
ranks 8 to 15 belong to the second box.
"""
# Get world size and rank. Ensure some consistencies.
assert torch.distributed.is_initialized()
world_size: int = torch.distributed.get_world_size()
rank = torch.distributed.get_rank()
backend = backend or torch.distributed.get_backend(
get_world_group().device_group)
data_parallel_size = 1
from vllm.config import get_current_vllm_config
config = get_current_vllm_config()
if config is not None:
data_parallel_size = config.parallel_config.data_parallel_size
# the layout order is: ExternalDP x DP x PP x TP
# ExternalDP is the data parallel group that is not part of the model,
# every dp rank can generate independently (in verl integration).
# DP is the data parallel group that is part of the model,
# all the ranks in the same DP group should generate simultaneously,
# i.e. the `generate` call in the same DP group should be called together,
# otherwise it will cause deadlock.
# to get group_ranks for each dimension, transpose that dimension to the
# last dimension, then reshape to 2D, then unbind the last dimension
all_ranks = torch.arange(world_size).reshape(
-1, data_parallel_size, pipeline_model_parallel_size,
tensor_model_parallel_size) # noqa
if envs.VLLM_BR_ENABLE_TP_GROUPS_IN_SUPERNODE:
groups = generate_multi_node_parallel_groups(
world_size, tensor_model_parallel_size,
pipeline_model_parallel_size, data_parallel_size)
logger.info("supernode reorganized groups: %s", groups)
# Build the tensor model-parallel groups.
assert vllm.distributed.parallel_state._TP is None, (
"tensor model parallel group is already initialized")
if envs.VLLM_BR_ENABLE_TP_GROUPS_IN_SUPERNODE:
group_ranks = groups['tp_groups']
else:
group_ranks = all_ranks.view(-1, tensor_model_parallel_size).unbind(0)
group_ranks = [x.tolist() for x in group_ranks]
# message queue broadcaster is only used in tensor model parallel group
vllm.distributed.parallel_state._TP = init_model_parallel_group(
group_ranks,
get_world_group().local_rank,
backend,
use_message_queue_broadcaster=True,
group_name="tp")
# Build the DCP model-parallel groups.
# global _DCP
assert vllm.distributed.parallel_state._DCP is None, (
"decode context model parallel group is already initialized")
# Note(hc): In the current implementation of decode context parallel,
# dcp_size must not exceed tp_size, because the world size does not
# change by DCP, it simply reuses the GPUs of TP group, and split one
# TP group into tp_size//dcp_size DCP groups.
group_ranks = all_ranks.reshape(
-1, decode_context_model_parallel_size).unbind(0)
group_ranks = [x.tolist() for x in group_ranks]
vllm.distributed.parallel_state._DCP = init_model_parallel_group(
group_ranks,
get_world_group().local_rank,
backend,
use_message_queue_broadcaster=True,
group_name="dcp")
# Build the pipeline model-parallel groups.
assert vllm.distributed.parallel_state._PP is None, (
"pipeline model parallel group is already initialized")
if envs.VLLM_BR_ENABLE_TP_GROUPS_IN_SUPERNODE:
group_ranks = groups['pp_groups']
else:
group_ranks = all_ranks.transpose(2, 3).reshape(
-1, pipeline_model_parallel_size).unbind(0)
group_ranks = [x.tolist() for x in group_ranks]
vllm.distributed.parallel_state._PP = init_model_parallel_group(
group_ranks, get_world_group().local_rank, backend, group_name="pp")
assert vllm.distributed.parallel_state._DP is None, (
"data parallel group is already initialized")
if envs.VLLM_BR_ENABLE_TP_GROUPS_IN_SUPERNODE:
group_ranks = groups['dp_groups']
else:
group_ranks = all_ranks.transpose(1, 3).reshape(
-1, data_parallel_size).unbind(0)
group_ranks = [x.tolist() for x in group_ranks]
vllm.distributed.parallel_state._DP = init_model_parallel_group(
group_ranks, get_world_group().local_rank, backend, group_name="dp")
assert vllm.distributed.parallel_state._EP is None, (
"expert parallel group is already initialized")
if envs.VLLM_BR_ENABLE_TP_GROUPS_IN_SUPERNODE:
group_ranks = groups['ep_groups']
else:
group_ranks = all_ranks.transpose(1, 2).reshape(
-1, data_parallel_size * tensor_model_parallel_size).unbind(0)
group_ranks = [x.tolist() for x in group_ranks]
vllm.distributed.parallel_state._EP = init_model_parallel_group(
group_ranks, get_world_group().local_rank, backend, group_name="ep")
logger.info(
"rank %s in world size %s is assigned as (br) "
"DP rank %s, PP rank %s, TP rank %s, EP rank %s", rank, world_size,
vllm.distributed.parallel_state._DP.rank_in_group,
vllm.distributed.parallel_state._PP.rank_in_group,
vllm.distributed.parallel_state._TP.rank_in_group,
vllm.distributed.parallel_state._EP.rank_in_group)
vllm.distributed.parallel_state.initialize_model_parallel = initialize_model_parallel_cross_tp
def send_tensor_dict(
self,
tensor_dict: dict[str, Union[torch.Tensor, Any]],
dst: Optional[int] = None,
all_gather_group: Optional["GroupCoordinator"] = None,
all_gather_tensors: Optional[dict[str, bool]] = None,
) -> Optional[dict[str, Union[torch.Tensor, Any]]]:
"""Send the input tensor dictionary.
NOTE: `dst` is the local rank of the source rank.
all_gather_group: The group for the all-gather operation. If provided,
an optimization is enabled where each rank in the group sends a
slice of a tensor and the receiver reconstructs it using an
all-gather, which can improve performance. This is typically the
tensor-parallel group.
all_gather_tensors: A dictionary to specify which tensors should use
the all-gather optimization, which is only effective when
`all_gather_group` is provided. By default, this optimization is
on for any tensor whose size is divisible by the
`all_gather_group`'s world size. However, it should be disabled
for tensors that are not fully replicated across the group (e.g.,
the residual tensor when sequence parallelism is enabled). This
dictionary allows overriding the default behavior on a per-tensor
basis.
"""
# Bypass the function if we are using only 1 GPU.
if not torch.distributed.is_initialized() or self.world_size == 1:
return tensor_dict
all_gather_size = (1 if all_gather_group is None else
all_gather_group.world_size)
all_gather_rank = (0 if all_gather_group is None else
all_gather_group.rank_in_group)
group = self.device_group
metadata_group = self.cpu_group
if dst is None:
dst = (self.rank_in_group + 1) % self.world_size
assert dst < self.world_size, f"Invalid dst rank ({dst})"
if self.use_cpu_custom_send_recv:
if self.device_communicator is None:
raise ValueError("No device communicator found")
self.device_communicator.send_tensor_dict( # type: ignore
tensor_dict, dst)
return None
metadata_list: list[tuple[Any, Any]] = []
assert isinstance(tensor_dict,
dict), f"Expecting a dictionary, got {type(tensor_dict)}"
metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
# `metadata_list` lives in CPU memory.
# `send_object_list` has serialization & deserialization,
# all happening on CPU. Therefore, we can use the CPU group.
self.send_object(metadata_list, dst=dst)
tensor_keys = [
k for k, v in tensor_dict.items() if isinstance(v, torch.Tensor)
]
assert len(tensor_keys) == len(tensor_list)
for key, tensor in zip(tensor_keys, tensor_list):
if tensor.numel() == 0:
# Skip sending empty tensors.
continue
# send-allgather: send only a slice, then do allgather.
use_all_gather = (all_gather_group is not None
and tensor.numel() % all_gather_size == 0)
use_all_gather = all_gather_tensors.get(key, use_all_gather) \
if all_gather_tensors else use_all_gather
if use_all_gather:
tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank]
if tensor.is_cpu:
# use metadata_group for CPU tensors
torch.distributed.send(tensor,
dst=self.ranks[dst],
group=metadata_group)
else:
# ensure tensor is ready
torch.supa.synchronize()
# use group for GPU tensors
torch.distributed.send(tensor, dst=self.ranks[dst], group=group)
return None
def recv_tensor_dict(
self,
src: Optional[int] = None,
all_gather_group: Optional["GroupCoordinator"] = None,
all_gather_tensors: Optional[dict[str, bool]] = None,
) -> Optional[dict[str, Union[torch.Tensor, Any]]]:
"""Recv the input tensor dictionary.
NOTE: `src` is the local rank of the source rank.
all_gather_group: The group for the all-gather operation. If provided,
an optimization is enabled where each rank in the group sends a
slice of a tensor and the receiver reconstructs it using an
all-gather, which can improve performance. This is typically the
tensor-parallel group.
all_gather_tensors: A dictionary to specify which tensors should use
the all-gather optimization, which is only effective when
`all_gather_group` is provided. By default, this optimization is
on for any tensor whose size is divisible by the
`all_gather_group`'s world size. However, it should be disabled
for tensors that are not fully replicated across the group (e.g.,
the residual tensor when sequence parallelism is enabled). This
dictionary allows overriding the default behavior on a per-tensor
basis.
"""
# Bypass the function if we are using only 1 GPU.
if not torch.distributed.is_initialized() or self.world_size == 1:
return None
all_gather_size = (1 if all_gather_group is None else
all_gather_group.world_size)
all_gather_rank = (0 if all_gather_group is None else
all_gather_group.rank_in_group)
group = self.device_group
metadata_group = self.cpu_group
if src is None:
src = (self.rank_in_group - 1) % self.world_size
assert src < self.world_size, f"Invalid src rank ({src})"
if self.use_cpu_custom_send_recv:
if self.device_communicator is None:
raise ValueError("No device communicator found")
return self.device_communicator.recv_tensor_dict( # type: ignore
src)
recv_metadata_list = self.recv_object(src=src)
tensor_dict: dict[str, Any] = {}
for key, value in recv_metadata_list:
if isinstance(value, TensorMetadata):
tensor = torch.empty(value.size,
dtype=value.dtype,
device=value.device)
if tensor.numel() == 0:
# Skip broadcasting empty tensors.
tensor_dict[key] = tensor
continue
# send-allgather: send only a slice, then do allgather.
use_all_gather = (all_gather_group is not None
and tensor.numel() % all_gather_size == 0)
use_all_gather = all_gather_tensors.get(key, use_all_gather) \
if all_gather_tensors else use_all_gather
if use_all_gather:
orig_shape = tensor.shape
tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank]
if tensor.is_cpu:
# use metadata_group for CPU tensors
torch.distributed.recv(tensor,
src=self.ranks[src],
group=metadata_group)
else:
# use group for GPU tensors
torch.distributed.recv(tensor,
src=self.ranks[src],
group=group)
# ensure recv is done
torch.supa.synchronize()
if use_all_gather:
# do the allgather
tensor = all_gather_group.all_gather( # type: ignore
tensor, dim=0)
tensor = tensor.reshape(orig_shape)
tensor_dict[key] = tensor
else:
tensor_dict[key] = value
return tensor_dict
vllm.distributed.GroupCoordinator.send_tensor_dict = send_tensor_dict
vllm.distributed.GroupCoordinator.recv_tensor_dict = recv_tensor_dict

120
vllm_br/envs.py Normal file
View File

@@ -0,0 +1,120 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
import os
from typing import Any, Callable, Dict
import pybrml
import torch
import torch_br
# The begin-* and end* here are used by the documentation generator
# to extract the used env vars.
# begin-env-vars-definition
def check_allreduce_available():
P2P_DIRECT_LINK_TYPE = 2
pybrml.brmlInit()
device_count = pybrml.brmlDeviceGetCount()
def is_p2p_direct_link(dev0, dev1):
return pybrml.brmlDeviceGetP2PStatus_v3(
dev0, dev1).type == P2P_DIRECT_LINK_TYPE
def get_p2p_link_info(device_count):
p2p_link_info = []
for i in range(device_count):
current_link_info = []
current_dev = pybrml.brmlDeviceGetHandleByIndex(i)
for j in range(device_count):
other_dev = pybrml.brmlDeviceGetHandleByIndex(j)
current_link_info.append(
is_p2p_direct_link(current_dev, other_dev))
p2p_link_info.append(current_link_info)
return p2p_link_info
p2p_link_info = get_p2p_link_info(device_count)
all_reduce_count = sum(p2p_link_info[0])
all_reduce = 1
if all_reduce_count == 3:
all_reduce = 4
elif all_reduce_count == 4:
all_reduce = 8
pybrml.brmlShutdown()
return all_reduce
_VLLM_BR_USE_FUSED_ALLREDUCE_CACHE = check_allreduce_available()
env_variables: Dict[str, Callable[[], Any]] = {
"VLLM_VERSION":
lambda: os.getenv("VLLM_VERSION", None),
"VLLM_BR_USE_PAGED_ATTN":
lambda: os.getenv("VLLM_BR_USE_PAGED_ATTN", False),
"VLLM_BR_WEIGHT_TYPE":
lambda: os.getenv("VLLM_BR_WEIGHT_TYPE", "NUMA"),
"VLLM_BR_QUANT_METHOD":
lambda: os.getenv("VLLM_BR_QUANT_METHOD", "INT8"),
"VLLM_BR_USE_FUSED_ALLREDUCE":
lambda: int(
os.getenv("VLLM_BR_USE_FUSED_ALLREDUCE",
_VLLM_BR_USE_FUSED_ALLREDUCE_CACHE)),
"VLLM_BR_EMBEDDING_S0B":
lambda: bool(int(os.getenv("VLLM_BR_EMBEDDING_S0B", False))),
# MoE (DeepSeek)
"VLLM_BR_STATIC_MOE_DECODER_MAX_LEN":
lambda: int(os.getenv("VLLM_BR_STATIC_MOE_DECODER_MAX_LEN", "256")),
# NOTE: following are device properties
"VLLM_BR_DEVICE_SPC_NUM":
lambda: int(
os.getenv(
"VLLM_BR_DEVICE_SPC_NUM",
torch_br.supa.get_device_properties(torch.device("supa")).
max_compute_units)),
"VLLM_BR_DEVICE_WARP_SIZE":
lambda: int(os.getenv("VLLM_BR_DEVICE_WARP_SIZE", 32)),
"VLLM_BR_USE_CPU_ALL_REDUCE":
lambda: int(os.getenv("VLLM_BR_USE_CPU_ALL_REDUCE", 0)),
"VLLM_SCCL_SO_PATH":
lambda: os.getenv(
"VLLM_SCCL_SO_PATH",
"/usr/local/birensupa/base/latest/succl/lib/x86_64-linux-gnu/libsuccl.so"
),
"VLLM_RANDOMIZE_DP_DUMMY_INPUTS":
lambda: bool(int(os.getenv("VLLM_RANDOMIZE_DP_DUMMY_INPUTS", False))),
"VLLM_PP_CPU_SEND_RECV":
lambda: bool(int(os.getenv("VLLM_PP_CPU_SEND_RECV", False))),
"VLLM_BR_USE_FP32_ALL_REDUCE":
lambda: int(os.getenv("VLLM_BR_USE_FP32_ALL_REDUCE", 0)),
"VLLM_BR_USE_MROPE_0_9_2":
lambda: bool(os.getenv("VLLM_BR_USE_MROPE_0_9_2", False)),
"VLLM_BR_ENABLE_TP_GROUPS_IN_SUPERNODE":
lambda: bool(int(os.getenv("VLLM_BR_ENABLE_TP_GROUPS_IN_SUPERNODE", "0"))),
}
# end-env-vars-definition
def __getattr__(name: str):
# lazy evaluation of environment variables
if name in env_variables:
return env_variables[name]()
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
def __dir__():
return list(env_variables.keys())

View File

@@ -0,0 +1,15 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################

Binary file not shown.

View File

@@ -0,0 +1,356 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
import os
from collections import defaultdict
from typing import TYPE_CHECKING, Dict, List
import vllm.envs as envs
from vllm.executor.ray_distributed_executor import RayDistributedExecutor
from vllm.executor.ray_utils import RayWorkerWrapper, ray
# from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.ray.ray_env import get_env_vars_to_copy
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
from vllm_br import envs as envs_br
if ray is not None:
from ray.actor import ActorHandle
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
else:
ActorHandle = None
if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup
from vllm.executor.ray_distributed_executor import RayWorkerMetaData, logger
def get_supernode_pp_tp_global_rank_map(tp_size, pp_size):
rank_map = {}
tp_driver_rank = []
for pp_rank in range(pp_size):
for tp_rank in range(tp_size):
# PP=2, TP=8
# pp_tp_workers = [[0, 1, 2, 3, 8, 9, 10, 11], [4, 5, 6, 7, 12, 13, 14, 15]]
if tp_rank < 4 and pp_rank < 1:
rank = (pp_rank * pp_size) + tp_rank
elif tp_rank >= 4 and pp_rank < 1:
rank = (pp_rank * pp_size) + tp_rank + 4
elif tp_rank < 4 and pp_rank >= 1:
rank = (pp_rank * pp_size) + tp_rank + 2
elif tp_rank >= 4 and pp_rank >= 1:
rank = (pp_rank * pp_size) + tp_rank + 6
rank_map[(pp_rank, tp_rank)] = rank
if tp_rank == 0:
tp_driver_rank.append(rank)
return rank_map, tp_driver_rank
def _init_workers_ray_br(self, placement_group: "PlacementGroup",
**ray_remote_kwargs):
num_gpus = envs.VLLM_RAY_PER_WORKER_GPUS
if envs_br.VLLM_BR_ENABLE_TP_GROUPS_IN_SUPERNODE:
rank_map, tp_driver_rank = get_supernode_pp_tp_global_rank_map(
self.parallel_config.tensor_parallel_size,
self.parallel_config.pipeline_parallel_size)
# The driver dummy worker does not actually use any resources.
# It holds the resource for the driver worker.
self.driver_dummy_worker = None
# The remaining workers are the actual ray actors.
self.workers = []
# Used in ray compiled DAG: indexed first by PP rank,
# and then TP rank. In other words, the inner list is
# the TP group of workers for a PP rank.
self.pp_tp_workers = []
if self.parallel_config.ray_workers_use_nsight:
ray_remote_kwargs = self._configure_ray_workers_use_nsight(
ray_remote_kwargs)
logger.info("use_ray_spmd_worker: %s", self.use_ray_spmd_worker)
# Create the workers.
bundle_indices: List[int]
if envs.VLLM_RAY_BUNDLE_INDICES:
# Use the bundle indices specified by the user.
bundle_indices = list(map(int,
envs.VLLM_RAY_BUNDLE_INDICES.split(",")))
assert len(bundle_indices) == self.parallel_config.world_size, \
("VLLM_RAY_BUNDLE_INDICES must have the same size"
f" as the world size, but got {bundle_indices=} "
f"and {self.parallel_config.world_size=}")
assert len(set(bundle_indices)) == len(bundle_indices), \
("VLLM_RAY_BUNDLE_INDICES cannot have duplicate values,"
f" but got {bundle_indices=}")
else:
# use the first N bundles that have GPU resources.
bundle_indices = []
for bundle_id, bundle in enumerate(placement_group.bundle_specs):
if bundle.get(current_platform.ray_device_key, 0):
bundle_indices.append(bundle_id)
bundle_indices = bundle_indices[:self.parallel_config.world_size]
worker_metadata: List[RayWorkerMetaData] = []
driver_ip = get_ip()
for rank, bundle_id in enumerate(bundle_indices):
scheduling_strategy = PlacementGroupSchedulingStrategy(
placement_group=placement_group,
placement_group_capture_child_tasks=True,
placement_group_bundle_index=bundle_id,
)
if current_platform.ray_device_key == "GPU":
# NV+AMD GPUs, and Intel XPUs
worker = ray.remote(
num_cpus=0,
num_gpus=num_gpus,
scheduling_strategy=scheduling_strategy,
**ray_remote_kwargs,
)(RayWorkerWrapper).remote(vllm_config=self.vllm_config,
rpc_rank=rank)
else:
worker = ray.remote(
num_cpus=0,
num_gpus=0,
resources={current_platform.ray_device_key: num_gpus},
scheduling_strategy=scheduling_strategy,
**ray_remote_kwargs,
)(RayWorkerWrapper).remote(vllm_config=self.vllm_config,
rpc_rank=rank)
worker_metadata.append(
RayWorkerMetaData(worker=worker, created_rank=rank))
worker_ips = ray.get([
each.worker.get_node_ip.remote() # type: ignore[attr-defined]
for each in worker_metadata
])
for each, ip in zip(worker_metadata, worker_ips):
each.ip = ip
if not self.use_ray_spmd_worker:
for i, each in enumerate(worker_metadata):
# find and remove the dummy worker from the list
worker = each.worker
worker_ip = each.ip
if self.driver_dummy_worker is None and worker_ip == driver_ip:
# If the worker is on the same node as the driver, we use it
# as the resource holder for the driver process.
self.driver_dummy_worker = worker
self.driver_worker = RayWorkerWrapper(
vllm_config=self.vllm_config, rpc_rank=0)
worker_metadata.pop(i)
break
logger.debug("workers: %s", worker_metadata)
logger.debug("driver_dummy_worker: %s", self.driver_dummy_worker)
if not self.use_ray_spmd_worker and self.driver_dummy_worker is None:
raise ValueError(
"Ray does not allocate any GPUs on the driver node."
f"Driver IP: {driver_ip}, worker IPs: {worker_ips}."
"Consider adjusting the Ray placement group or running "
"the driver on a GPU node.")
ip_counts: Dict[str, int] = {}
for ip in worker_ips:
ip_counts[ip] = ip_counts.get(ip, 0) + 1
def sort_by_driver_then_worker_ip(item: RayWorkerMetaData):
"""
Sort the workers based on 3 properties:
1. If the worker is on the same node as the driver (vllm engine),
it should be placed first.
2. Then, if the worker is on a node with fewer workers, it should
be placed first.
3. Finally, if the work is on a node with smaller IP address, it
should be placed first.
"""
ip = item.ip
return (0 if ip == driver_ip else 1, ip_counts[ip], ip)
# After sorting, the workers on the same node will be
# close to each other, and the workers on the driver
# node will be placed first.
sorted_worker_metadata = sorted(worker_metadata,
key=sort_by_driver_then_worker_ip)
start_rank = 0 if self.use_ray_spmd_worker else 1
for i, item in enumerate(sorted_worker_metadata):
item.adjusted_rank = i + start_rank
self.workers = [item.worker for item in sorted_worker_metadata]
rerank_mapping = {
item.created_rank: item.adjusted_rank
for item in sorted_worker_metadata
}
self._run_workers("adjust_rank", rerank_mapping)
# Get the set of GPU IDs used on each node.
worker_node_and_gpu_ids = []
for worker in [self.driver_dummy_worker] + self.workers:
if worker is None:
# driver_dummy_worker can be None when using ray spmd worker.
continue
worker_node_and_gpu_ids.append(
ray.get(worker.get_node_and_gpu_ids.remote()) \
) # type: ignore
node_workers = defaultdict(list) # node id -> list of worker ranks
node_gpus = defaultdict(list) # node id -> list of gpu ids
for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids):
node_workers[node_id].append(i)
# `gpu_ids` can be a list of strings or integers.
# convert them to integers for consistency.
# NOTE: gpu_ids can be larger than 9 (e.g. 16 GPUs),
# string sorting is not sufficient.
# see https://github.com/vllm-project/vllm/issues/5590
gpu_ids = [int(x) for x in gpu_ids]
node_gpus[node_id].extend(gpu_ids)
for node_id, gpu_ids in node_gpus.items():
node_gpus[node_id] = sorted(gpu_ids)
all_ips = set(worker_ips + [driver_ip])
n_ips = len(all_ips)
n_nodes = len(node_workers)
if n_nodes != n_ips:
raise RuntimeError(
f"Every node should have a unique IP address. Got {n_nodes}"
f" nodes with node ids {list(node_workers.keys())} and "
f"{n_ips} unique IP addresses {all_ips}. Please check your"
" network configuration. If you set `VLLM_HOST_IP`"
" environment variable, make sure it is unique for"
" each node.")
# Set environment variables for the driver and workers.
all_args_to_update_environment_variables = [{
current_platform.device_control_env_var:
",".join(map(str, node_gpus[node_id])),
} for (node_id, _) in worker_node_and_gpu_ids]
# Environment variables to copy from driver to workers
env_vars_to_copy = get_env_vars_to_copy(
exclude_vars=self.WORKER_SPECIFIC_ENV_VARS,
additional_vars=set(current_platform.additional_env_vars).union(
self.ADDITIONAL_ENV_VARS),
destination="workers")
# Copy existing env vars to each worker's args
for args in all_args_to_update_environment_variables:
# TODO: refactor platform-specific env vars
for name in env_vars_to_copy:
if name in os.environ:
args[name] = os.environ[name]
self._env_vars_for_all_workers = (all_args_to_update_environment_variables)
self._run_workers("update_environment_variables",
self._get_env_vars_to_be_updated())
if len(node_gpus) == 1:
# in single node case, we don't need to get the IP address.
# the loopback address is sufficient
# NOTE: a node may have several IP addresses, one for each
# network interface. `get_ip()` might return any of them,
# while they might not work for communication inside the node
# if the network setup is complicated. Using the loopback address
# solves this issue, as it always works for communication inside
# the node.
driver_ip = "127.0.0.1"
distributed_init_method = get_distributed_init_method(
driver_ip, get_open_port())
# Initialize the actual workers inside worker wrapper.
all_kwargs = []
for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids):
local_rank = node_workers[node_id].index(rank)
if envs_br.VLLM_BR_ENABLE_TP_GROUPS_IN_SUPERNODE:
kwargs = dict(
vllm_config=self.vllm_config,
local_rank=local_rank,
rank=rank,
distributed_init_method=distributed_init_method,
is_driver_worker=(not self.parallel_config)
or (rank in tp_driver_rank),
)
else:
kwargs = dict(
vllm_config=self.vllm_config,
local_rank=local_rank,
rank=rank,
distributed_init_method=distributed_init_method,
is_driver_worker=(not self.parallel_config)
or (rank % self.parallel_config.tensor_parallel_size == 0),
)
all_kwargs.append(kwargs)
self._run_workers("init_worker", all_kwargs)
self._run_workers("init_device")
self._run_workers("load_model",
max_concurrent_workers=self.parallel_config.
max_parallel_loading_workers)
if self.use_ray_spmd_worker:
if envs_br.VLLM_BR_ENABLE_TP_GROUPS_IN_SUPERNODE:
for pp_rank in range(self.parallel_config.pipeline_parallel_size):
self.pp_tp_workers.append([])
for tp_rank in range(
self.parallel_config.tensor_parallel_size):
# PP=8, TP=2
# pp_tp_workers = [[0, 1, 2, 3, 8, 9, 10, 11], [4, 5, 6, 7, 12, 13, 14, 15]]
rank = rank_map[(pp_rank, tp_rank)]
self.pp_tp_workers[pp_rank].append(self.workers[rank])
else:
for pp_rank in range(self.parallel_config.pipeline_parallel_size):
self.pp_tp_workers.append([])
for tp_rank in range(
self.parallel_config.tensor_parallel_size):
# PP=2, TP=4
# pp_tp_workers = [[0, 1, 2, 3], [4, 5, 6, 7]]
rank = (pp_rank * self.parallel_config.tensor_parallel_size
) + tp_rank
assert len(self.pp_tp_workers[pp_rank]) == tp_rank
assert pp_rank < len(self.pp_tp_workers)
self.pp_tp_workers[pp_rank].append(self.workers[rank])
# This is the list of workers that are rank 0 of each TP group EXCEPT
# global rank 0. These are the workers that will broadcast to the
# rest of the workers.
self.tp_driver_workers = []
# This is the list of workers that are not drivers and not the first
# worker in a TP group. These are the workers that will be
# broadcasted to.
self.non_driver_workers = []
# Enforce rank order for correct rank to return final output.
for index, worker in enumerate(self.workers):
# The driver worker is rank 0 and not in self.workers.
rank = index + 1
if envs_br.VLLM_BR_ENABLE_TP_GROUPS_IN_SUPERNODE:
if rank in tp_driver_rank:
self.tp_driver_workers.append(worker)
else:
self.non_driver_workers.append(worker)
else:
if rank % self.parallel_config.tensor_parallel_size == 0:
self.tp_driver_workers.append(worker)
else:
self.non_driver_workers.append(worker)
RayDistributedExecutor._init_workers_ray = _init_workers_ray_br # noqa: E501

418
vllm_br/forward_context.py Normal file
View File

@@ -0,0 +1,418 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
import time
from collections import defaultdict
from contextlib import contextmanager
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Union
import torch
import torch.distributed as dist
import vllm
import vllm.envs as envs
from vllm.config import ParallelConfig, VllmConfig
from vllm.forward_context import ForwardContext
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.v1.worker.ubatch_utils import UBatchSlices, is_second_ubatch_empty
from vllm_br.config.compilation import SUPAGraphMode
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
logger = init_logger(__name__)
track_batchsize: bool = envs.VLLM_LOG_BATCHSIZE_INTERVAL >= 0
last_logging_time: float = 0
forward_start_time: float = 0
batchsize_logging_interval: float = envs.VLLM_LOG_BATCHSIZE_INTERVAL
batchsize_forward_time: defaultdict = defaultdict(list)
class BatchDescriptor(NamedTuple):
"""
Batch descriptor for supagraph dispatching. We should keep the num of
items as minimal as possible to properly and uniquely describe the padded
batch for supagraph.
"""
num_tokens: int
uniform_decode: bool
"""
False can also be used for an uniform decode batch to dispatch to the
supagraph supporting non-uniform batches.
"""
@property
def non_uniform(self) -> "BatchDescriptor":
"""
Return a non-uniform version of current batch descriptor.
"""
return BatchDescriptor(self.num_tokens, self.uniform_decode)
def _compute_sp_num_tokens(num_tokens_across_dp_cpu: torch.Tensor,
sequence_parallel_size: int) -> list[int]:
sp_tokens = ((num_tokens_across_dp_cpu + sequence_parallel_size - 1) //
sequence_parallel_size)
sp_tokens = sp_tokens.repeat_interleave(sequence_parallel_size)
return sp_tokens.tolist()
def _compute_chunked_local_num_tokens(num_tokens_across_dp_cpu: torch.Tensor,
sequence_parallel_size: int,
max_num_tokens: int,
chunk_idx: int) -> list[int]:
sp_tokens = _compute_sp_num_tokens(num_tokens_across_dp_cpu,
sequence_parallel_size)
sp_size = len(sp_tokens)
local_size = [-1] * sp_size
for i in range(sp_size):
# Take into account sharding if MoE activation is sequence parallel.
local_size[i] = min(max_num_tokens,
sp_tokens[i] - (max_num_tokens * chunk_idx))
if local_size[i] <= 0:
local_size[i] = 1 # ensure lockstep even if done
return local_size
@dataclass
class DPMetadata:
max_tokens_across_dp_cpu: torch.Tensor
num_tokens_across_dp_cpu: torch.Tensor
# NOTE: local_sizes should only be set by the chunked_sizes context manager
local_sizes: Optional[list[int]] = None
@staticmethod
def num_tokens_across_dp(num_tokens: int, dp_size: int,
dp_rank: int) -> torch.Tensor:
"""
Gather the num_tokens across all DP ranks and return results in a
CPU tensor of size dp_size.
"""
from vllm.distributed.parallel_state import get_dp_group
device = current_platform.device_type
group = get_dp_group().device_group
# Transferring this tensor from GPU to CPU will introduce a GPU sync
# point that could adversely affect performance of vllm with asynch
# scheduling. This environment variable exists to quickly disable
# this optimization if we run into this case.
if envs.VLLM_DISABLE_NCCL_FOR_DP_SYNCHRONIZATION:
logger.info_once(
"Using CPU all reduce to synchronize DP padding between ranks."
)
device = "cpu"
group = get_dp_group().cpu_group
num_tokens_across_dp = [0] * dp_size
num_tokens_across_dp[dp_rank] = num_tokens
num_tokens_tensor = torch.tensor(num_tokens_across_dp,
device=device,
dtype=torch.int32)
dist.all_reduce(num_tokens_tensor, group=group)
return num_tokens_tensor.cpu()
# Get the cumulative tokens across sequence parallel ranks.
# In this case the input to the MoEs will be distributed w.r.t both
# DP and TP rank.
# When sp_size==1, this is just the cumulative num tokens across DP.
def cu_tokens_across_sp(self, sp_size: int) -> torch.Tensor:
num_tokens_across_sp_cpu = (
(self.num_tokens_across_dp_cpu - 1 + sp_size) // sp_size)
num_tokens_across_sp_cpu = (
num_tokens_across_sp_cpu.repeat_interleave(sp_size))
return torch.cumsum(num_tokens_across_sp_cpu, dim=0)
@staticmethod
def should_ubatch_across_dp(
should_ubatch: bool, orig_num_tokens_per_ubatch: int,
padded_num_tokens_per_ubatch: int, dp_size: int,
dp_rank: int) -> tuple[bool, Optional[torch.Tensor]]:
"""
1. Decides if each DP rank is going to microbatch. Either all ranks
run with microbatching or none of them do. If this function decides
not to run with microbatching. It will "abort" meaning that no padding
information will be returned to the caller. It will return (False, None)
2. Determines the total number of tokens that each rank will run.
All ranks will be padded out so that the run with the same number
of tokens
Returns: tuple[
should_ubatch: Are all DP ranks going to microbatch
num_tokens_after_padding: A tensor containing the total number of
tokens per-microbatch for each DP rank including padding. Will be
None if should_ubatch if False
]
"""
device = current_platform.device_type
tensor = torch.zeros(3, dp_size, device=device, dtype=torch.int32)
tensor[0][dp_rank] = orig_num_tokens_per_ubatch
tensor[1][dp_rank] = padded_num_tokens_per_ubatch
tensor[2][dp_rank] = 1 if should_ubatch else 0
from vllm.distributed.parallel_state import get_dp_group
dist.all_reduce(tensor, group=get_dp_group().device_group)
result: bool = bool(torch.all(tensor[2] == 1).item())
if not result:
return result, None
orig_num_tokens_tensor = tensor[0, :]
padded_num_tokens_tensor = tensor[1, :]
orig_min_num_tokens = int(orig_num_tokens_tensor.min().item())
padded_max_num_tokens = int(padded_num_tokens_tensor.max().item())
if is_second_ubatch_empty(orig_min_num_tokens, padded_max_num_tokens):
logger.debug("Aborting ubatching %s %s", orig_min_num_tokens,
padded_max_num_tokens)
return False, None
return result, padded_num_tokens_tensor.cpu()
@staticmethod
def make(
parallel_config: ParallelConfig,
attn_metadata: Any,
num_tokens: int,
num_tokens_across_dp_cpu: Optional[torch.Tensor] = None
) -> "DPMetadata":
assert parallel_config.data_parallel_size > 1
dp_size = parallel_config.data_parallel_size
dp_rank = parallel_config.data_parallel_rank
if attn_metadata is not None and hasattr(attn_metadata,
"num_prefill_tokens"):
# for v0 attention backends
batchsize = attn_metadata.num_prefill_tokens + \
attn_metadata.num_decode_tokens
else:
# for v1 attention backends or no attn_metadata
batchsize = num_tokens
# If num_tokens_across_dp is None, it will be computed by all_reduce
# Otherwise, num_tokens_across_dp[dp_rank] should be equal to batchsize
assert (num_tokens_across_dp_cpu is None
or num_tokens_across_dp_cpu[dp_rank] == batchsize
), f"{num_tokens_across_dp_cpu[dp_rank]} {batchsize}"
if num_tokens_across_dp_cpu is None:
num_tokens_across_dp_cpu = DPMetadata.num_tokens_across_dp(
batchsize, dp_size, dp_rank)
max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp_cpu)
return DPMetadata(max_tokens_across_dp_cpu, num_tokens_across_dp_cpu)
@contextmanager
def chunked_sizes(self, sequence_parallel_size: int,
max_chunk_size_per_rank: int, chunk_idx: int):
"""
Context manager to compute and temporarily set the per-rank local token
sizes for a specific chunk during chunked forward execution.
This is necessary to ensure each DP (data parallel) rank processes its
designated portion of tokens in lockstep with others, even when the
token counts are uneven or some ranks have completed their input early.
For chunked execution, we break up the total tokens on each rank into
multiple chunks (of at most `max_chunk_size_per_rank`), and for a given
`chunk_idx`, this context manager sets `self.local_sizes` to the number
of tokens to process in that chunk on each rank.
`self.local_sizes` is only valid inside the context.
Args:
sequence_parallel_size: When Attn is TP and MoE layers are EP,
we use SP between the layers to avoid
redundant ops. We need this value to
compute the chunked sizes.
max_chunk_size_per_rank: The max number of tokens each rank is
allowed to process in this chunk.
chunk_idx: The index of the chunk to compute sizes for.
"""
self.local_sizes = _compute_chunked_local_num_tokens(
self.num_tokens_across_dp_cpu, sequence_parallel_size,
max_chunk_size_per_rank, chunk_idx)
try:
yield self.local_sizes
finally:
self.local_sizes = None
@contextmanager
def sp_local_sizes(self, sequence_parallel_size: int):
"""
Context manager for setting self.local_sizes. Same as self.chunked_sizes
but without any chunking.
"""
self.local_sizes = _compute_sp_num_tokens(
self.num_tokens_across_dp_cpu, sequence_parallel_size)
try:
yield self.local_sizes
finally:
self.local_sizes = None
def get_chunk_sizes_across_dp_rank(self) -> Optional[list[int]]:
assert self.local_sizes is not None
return self.local_sizes
@dataclass
class ForwardContext:
# copy from vllm_config.compilation_config.static_forward_context
no_compile_layers: dict[str, Any]
"""
Type AttentionMetadata for v0,
Type Dict[str, AttentionMetadata] for v1, map from layer_name of each
attention layer to its attention metadata
Type List[Dict[str, AttentionMetadata]] for DBO. List of size two, one
for each microbatch.
Set dynamically for each forward pass
"""
attn_metadata: Union["AttentionMetadata", dict[str, "AttentionMetadata"],
list[dict[str, "AttentionMetadata"]]]
# TODO: remove after making all virtual_engines share the same kv cache
virtual_engine: int # set dynamically for each forward pass
# set dynamically for each forward pass
dp_metadata: Optional[DPMetadata] = None
# determine the cudagraph style at runtime to be FULL, PIECEWISE, or NONE.
# by default NONE, no cudagraph is used.
cudagraph_runtime_mode: SUPAGraphMode = SUPAGraphMode.NONE
batch_descriptor: Optional[BatchDescriptor] = None
ubatch_slices: Optional[UBatchSlices] = None
def __post_init__(self):
assert self.cudagraph_runtime_mode in [
SUPAGraphMode.NONE, SUPAGraphMode.PIECEWISE, SUPAGraphMode.FULL, SUPAGraphMode.FULL_DECODE_ONLY], \
f"Invalid cudagraph runtime mode: {self.cudagraph_runtime_mode}"
# _forward_context: Optional[ForwardContext] = None
def get_forward_context() -> ForwardContext:
"""Get the current forward context."""
assert vllm.forward_context._forward_context is not None, (
"Forward context is not set. "
"Please use `set_forward_context` to set the forward context.")
return vllm.forward_context._forward_context
def create_forward_context(
attn_metadata: Any,
vllm_config: VllmConfig,
virtual_engine: int = 0,
dp_metadata: Optional[DPMetadata] = None,
cudagraph_runtime_mode: SUPAGraphMode = SUPAGraphMode.NONE,
batch_descriptor: Optional[BatchDescriptor] = None,
ubatch_slices: Optional[UBatchSlices] = None):
return ForwardContext(no_compile_layers=vllm_config.compilation_config.
static_forward_context,
virtual_engine=virtual_engine,
attn_metadata=attn_metadata,
dp_metadata=dp_metadata,
cudagraph_runtime_mode=cudagraph_runtime_mode,
batch_descriptor=batch_descriptor,
ubatch_slices=ubatch_slices)
@contextmanager
def override_forward_context(forward_context: Optional[ForwardContext]):
"""A context manager that overrides the current forward context.
This is used to override the forward context for a specific
forward pass.
"""
prev_context = vllm.forward_context._forward_context
vllm.forward_context._forward_context = forward_context
try:
yield
finally:
vllm.forward_context._forward_context = prev_context
@contextmanager
def set_forward_context(
attn_metadata: Any,
vllm_config: VllmConfig,
virtual_engine: int = 0,
num_tokens: Optional[int] = None,
num_tokens_across_dp: Optional[torch.Tensor] = None,
cudagraph_runtime_mode: SUPAGraphMode = SUPAGraphMode.NONE,
batch_descriptor: Optional[BatchDescriptor] = None,
ubatch_slices: Optional[UBatchSlices] = None):
"""A context manager that stores the current forward context,
can be attention metadata, etc.
Here we can inject common logic for every model forward pass.
"""
global forward_start_time
need_to_track_batchsize = track_batchsize and attn_metadata is not None
if need_to_track_batchsize:
forward_start_time = time.perf_counter()
dp_metadata: Optional[DPMetadata] = None
if vllm_config.parallel_config.data_parallel_size > 1 and (
attn_metadata is not None or num_tokens is not None):
dp_metadata = DPMetadata.make(vllm_config.parallel_config,
attn_metadata, num_tokens or 0,
num_tokens_across_dp)
forward_context = create_forward_context(attn_metadata, vllm_config,
virtual_engine, dp_metadata,
cudagraph_runtime_mode,
batch_descriptor, ubatch_slices)
try:
with override_forward_context(forward_context):
yield
finally:
global last_logging_time, batchsize_logging_interval
if need_to_track_batchsize:
if hasattr(attn_metadata, "num_prefill_tokens"):
# for v0 attention backends
batchsize = attn_metadata.num_prefill_tokens + \
attn_metadata.num_decode_tokens
else:
# for v1 attention backends
batchsize = num_tokens
# we use synchronous scheduling right now,
# adding a sync point here should not affect
# scheduling of the next batch
from vllm.platforms import current_platform
synchronize = current_platform.synchronize
if synchronize is not None:
synchronize()
now = time.perf_counter()
# time measurement is in milliseconds
batchsize_forward_time[batchsize].append(
(now - forward_start_time) * 1000)
if now - last_logging_time > batchsize_logging_interval:
last_logging_time = now
forward_stats = []
for bs, times in batchsize_forward_time.items():
if len(times) <= 1:
# can be cudagraph / profiling run
continue
medium = torch.quantile(torch.tensor(times), q=0.5).item()
medium = round(medium, 2)
forward_stats.append((bs, len(times), medium))
forward_stats.sort(key=lambda x: x[1], reverse=True)
if forward_stats:
logger.info(("Batchsize forward time stats "
"(batchsize, count, median_time(ms)): %s"),
forward_stats)

View File

@@ -0,0 +1,39 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from vllm import ModelRegistry # noqa: F401
from . import parameter
from .layers import *
from .model_loader import *
from .models import *
__all__ = [
"parameter",
]
def register_model():
"""Register Biren modified models"""
'''
ModelRegistry.register_model(
"GptOssForCausalLM",
"vllm_br.model_executor.models.gpt_oss:GptOssForCausalLM")
ModelRegistry.register_model(
"Glm4MoeForCausalLM",
"vllm_br.model_executor.models.glm4_moe:Glm4MoeForCausalLM")
'''
pass

View File

@@ -0,0 +1,25 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
import vllm_br.model_executor.layers.activation
import vllm_br.model_executor.layers.fused_moe
import vllm_br.model_executor.layers.layernorm
import vllm_br.model_executor.layers.linear
import vllm_br.model_executor.layers.logits_processor
import vllm_br.model_executor.layers.quantization
import vllm_br.model_executor.layers.rotary_embedding
import vllm_br.model_executor.layers.utils
import vllm_br.model_executor.layers.vocab_parallel_embedding # noqa: F401

View File

@@ -0,0 +1,31 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
import torch
import torch_br
from fastcore.basics import patch_to
from vllm.model_executor.layers.activation import QuickGELU, SiluAndMul
@patch_to(SiluAndMul)
def silu_and_mul_forward_oot(self, x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
return torch_br.supa_silumul(x[..., :d], x[..., d:]) # type: ignore
@patch_to(QuickGELU)
def quick_gelu_forward_oot(self, x: torch.Tensor) -> torch.Tensor: # noqa:F811
return self.forward_native(x)

View File

@@ -0,0 +1,619 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
import torch
import torch_br
import torch_br.supa._debug as supa_debug
from vllm_br import envs
def align_n(n, align_size, spc_num=envs.VLLM_BR_DEVICE_SPC_NUM) -> int:
n_block = (n + spc_num - 1) // spc_num
n_block = (n_block + align_size - 1) // align_size * align_size
return n_block
def _br_qweight_cvt(quant_method,
qweight,
qzeros,
size_k,
size_n,
override_group_size=None):
group_size = override_group_size or quant_method.quant_config.group_size
curr_dev = qweight.device
group_num = size_k // group_size if group_size > 0 else 1
qweight = qweight.cpu().view(torch.int8).reshape(
size_k // 4, size_n,
4).permute(0, 2, 1).contiguous().reshape(group_num,
size_k // group_num, size_n)
if qzeros is not None and not torch.all(qzeros == 0):
qzeros = qzeros.cpu().view(torch.int8).to(torch.int32) + 1
qweight = (qweight.to(torch.int32) - qzeros.unsqueeze(1)).to(
torch.int8)
qwei_int8 = qweight.reshape(size_k, size_n).to(curr_dev)
return qwei_int8
def _numa_scales_cvt(scales, wn, spc_num):
align_size = 32
wn_block = (wn + spc_num - 1) // spc_num
wn_block = (wn_block + align_size - 1) // align_size * align_size
cvt_scales = torch.nn.functional.pad(scales, (0, spc_num * wn_block - wn),
mode='constant',
value=0)
cvt_scales = cvt_scales.reshape(spc_num, wn_block).contiguous()
return cvt_scales
def cross_weight_32(t1, t2, spc_num, dim=1, need_pad=True):
width = t1.shape[dim]
# NOTE: br166 must ensure dual-dies width are 32-aligned
if spc_num > 16:
assert width % 2 == 0
half_width = width // 2
half_width_ = (half_width + 32 - 1) // 32 * 32
half_pad = half_width_ - half_width
if half_pad > 0:
t10, t11 = torch.chunk(t1, 2, dim=-1)
t10 = torch.nn.functional.pad(t10, (0, half_pad), "constant", 0)
t11 = torch.nn.functional.pad(t11, (0, half_pad), "constant", 0)
t1 = torch.cat([t10, t11], dim=-1)
t20, t21 = torch.chunk(t2, 2, dim=-1)
t20 = torch.nn.functional.pad(t20, (0, half_pad), "constant", 0)
t21 = torch.nn.functional.pad(t21, (0, half_pad), "constant", 0)
t2 = torch.cat([t20, t21], dim=-1)
width = half_width_ * 2
else:
width_ = (width + 32 - 1) // 32 * 32
t1 = torch.nn.functional.pad(t1, (0, width_ - width), "constant", 0)
t2 = torch.nn.functional.pad(t2, (0, width_ - width), "constant", 0)
width = width_
cnt = width // 32
t1_list = torch.chunk(t1, cnt, dim)
t2_list = torch.chunk(t2, cnt, dim)
tt = []
for i in range(cnt):
tt.append(t1_list[i])
tt.append(t2_list[i])
no_pad = torch.cat(tt, dim=dim)
if not need_pad:
return no_pad
if spc_num > 16:
align = (spc_num // 2) * 32 * 2
width_align = (width + align - 1) // align * align
pad_size = width_align - width
out0, out1 = torch.chunk(no_pad, 2, dim=-1)
out0 = torch.nn.functional.pad(out0, (0, pad_size), "constant", 0)
out1 = torch.nn.functional.pad(out1, (0, pad_size), "constant", 0)
out = torch.cat([out0, out1], dim=-1)
else:
align = spc_num * 32 * 2 # 768
width_align = (width * 2 + align - 1) // align * align
pad_size = width_align - width * 2
out = torch.nn.functional.pad(no_pad, (0, pad_size), "constant", 0)
return out
# # NOTE: bias and scales will not be converted to numa arch, still use uma weight for support fused_linear
def _convert_to_uma_tensor(tensor,
align_size,
layout,
dtype,
do_transpose=False,
wk=None,
wn=None,
parallel_type="col_parallel"):
assert parallel_type in ("col_parallel", "row_parallel")
layout = layout.lower()
if layout == "colmajor":
wk = wk or tensor.shape[1]
wn = wn or tensor.shape[0]
d_shape = (wn, wk)
if do_transpose:
data = tensor.cpu().permute(1, 0).contiguous()
d_shape = (wk, wn)
else:
data = tensor.cpu().contiguous()
if parallel_type == "col_parallel":
uma_tensor = torch_br._empty_ut_only(
size=d_shape,
dtype=dtype,
is_numa=False,
device=torch.supa.current_device(),
tensor_type=layout,
sbp="SB",
axis=0)
else:
uma_tensor = torch_br._empty_ut_only(
size=d_shape,
dtype=dtype,
is_numa=False,
device=torch.supa.current_device(),
tensor_type=layout,
axis=1,
sbp="SB")
torch.supa.synchronize()
uma_tensor.copy_(data.to(torch.supa.current_device()))
elif layout == "linear_bias":
axis = 0
wn = wn or tensor.shape[-1]
wk = 1
data = tensor
if len(data.shape) == 2 and data.shape[0] == 1:
data = tensor.cpu().reshape(-1).contiguous()
elif len(data.shape) == 2:
axis = 1
wk = data.shape[0]
elif len(data.shape) == 3 and data.shape[1] == 1:
data = tensor.cpu().reshape(
(data.shape[0], data.shape[2])).contiguous()
axis = 1
wk = data.shape[0]
d_shape = (wn, ) if axis == 0 else (wk, wn)
if parallel_type == "row_parallel":
uma_tensor = torch_br._empty_ut_only(
size=d_shape,
dtype=dtype,
device=torch.supa.current_device(),
tensor_type=layout)
elif parallel_type == "col_parallel":
uma_tensor = torch_br._empty_ut_only(
size=d_shape,
dtype=dtype,
device=torch.supa.current_device(),
tensor_type=layout,
axis=axis,
sbp="SB")
torch.supa.synchronize()
uma_tensor.copy_(data.to(torch.supa.current_device()))
else:
raise ValueError("uma tensor only support colmajor and linear_bias")
return uma_tensor
def _convert_to_numa_tensor_vit(tensor,
align_size,
layout,
dtype,
do_transpose=False,
wk=None,
wn=None,
parallel_type="col_parallel",
pad_zeros=False):
assert parallel_type in ("col_parallel", "row_parallel")
enable_force_uma = supa_debug.is_enable_force_uma()
supa_debug.set_enable_force_uma(False)
spc_num = envs.VLLM_BR_DEVICE_SPC_NUM
layout = layout.lower()
die_num = 1
if spc_num > 16:
spc_num = spc_num // 2
die_num = 2
die_spc_num = die_num * spc_num
if layout == "colmajor":
wk = wk or tensor.shape[0]
wn = wn or tensor.shape[1]
if die_num == 1:
wn_block = (wn + spc_num - 1) // spc_num
wn_block = (wn_block + align_size - 1) // align_size * align_size
numa_tensor = torch_br._empty_ut_only(
size=(spc_num, wk, wn_block),
dtype=dtype,
is_numa=True,
device=torch.supa.current_device(),
tensor_type=layout)
if do_transpose:
data = tensor.cpu().permute(1, 0).contiguous()
else:
data = tensor.cpu().contiguous()
data = torch.nn.functional.pad(data,
(0, spc_num * wn_block - wn, 0, 0),
mode='constant',
value=0)
data = data.reshape(wk, spc_num, wn_block).permute(1, 0,
2).contiguous()
torch.supa.synchronize()
numa_tensor.copy_(data.to(torch.supa.current_device()))
else:
if parallel_type == "col_parallel":
wn_block = (wn // die_num + spc_num - 1) // spc_num
wn_block = (wn_block + align_size -
1) // align_size * align_size
numa_tensor = torch_br._empty_ut_only(
size=(die_spc_num, wk, wn_block),
dtype=dtype,
is_numa=True,
device=torch.supa.current_device(),
tensor_type=layout)
if do_transpose:
weight = tensor.cpu().permute(1, 0).contiguous().reshape(
wk, die_num, wn // die_num)
else:
weight = tensor.cpu().contiguous().reshape(
wk, die_num, wn // die_num)
weight = torch.nn.functional.pad(
weight,
(0, spc_num * wn_block - wn // die_num, 0, 0, 0, 0),
mode='constant',
value=0)
weight = weight.reshape(wk, die_spc_num,
wn_block).permute(1, 0,
2).contiguous()
numa_tensor.copy_(weight)
else:
wn_block = (wn + spc_num - 1) // spc_num
# w_block must align with 32 (warp_size)
wn_block = (wn_block + align_size -
1) // align_size * align_size
numa_tensor = torch_br._empty_ut_only(
size=(die_spc_num, wk // die_num, wn_block),
dtype=dtype,
is_numa=True,
device=torch.supa.current_device(),
tensor_type=layout)
if do_transpose:
weight = tensor.cpu().permute(1, 0).contiguous()
else:
weight = tensor.cpu().contiguous()
weight = torch.nn.functional.pad(
weight, (0, spc_num * wn_block - wn, 0, 0),
mode='constant',
value=0)
weight = weight.reshape(
die_num, wk // die_num, spc_num,
wn_block).permute(0, 2, 1, 3).contiguous().reshape(
die_spc_num, wk // die_num, wn_block)
numa_tensor.copy_(weight)
elif layout == "linear_bias":
# NOTE: bias and scales will not be converted to numa arch, still use uma weight for support fused_linear
# NOTE: index -1 for both scales and bias
wn = tensor.shape[-1] if wn is None else wn
group_num = tensor.shape[-2] if len(tensor.shape) > 1 else 1
if die_num == 1:
wn_block = (wn + spc_num - 1) // spc_num
wn_block = (wn_block + align_size - 1) // align_size * align_size
numa_tensor = torch_br._empty_ut_only(
size=(spc_num * group_num, wn_block),
dtype=dtype,
is_numa=True,
device=torch.supa.current_device(),
tensor_type=layout)
data = torch.nn.functional.pad(tensor.cpu(),
(0, spc_num * wn_block - wn),
mode='constant',
value=0)
if group_num > 1:
data = data.type(dtype).reshape(
group_num, spc_num,
wn_block).permute(1, 0, 2).contiguous().reshape(
spc_num * group_num, wn_block)
else:
data = data.type(dtype).reshape(spc_num, wn_block).contiguous()
torch.supa.synchronize()
numa_tensor.copy_(data.to(torch.supa.current_device()))
else:
if parallel_type == "col_parallel":
wn_block = (wn // die_num + spc_num - 1) // spc_num
wn_block = (wn_block + align_size -
1) // align_size * align_size
numa_tensor = torch_br._empty_ut_only(
size=(die_spc_num, wn_block),
dtype=dtype,
is_numa=True,
device=torch.supa.current_device(),
tensor_type="linear_bias")
bias = tensor.cpu().reshape(die_num, wn // die_num)
bias = torch.nn.functional.pad(
bias, (0, spc_num * wn_block - wn // die_num, 0, 0),
mode='constant',
value=0)
bias = bias.type(torch.float32).reshape(die_spc_num,
wn_block).contiguous()
numa_tensor.copy_(bias)
else:
wn_block = (wn + spc_num - 1) // spc_num
# w_block must align with 32 (warp_size)
wn_block = (wn_block + align_size -
1) // align_size * align_size
numa_tensor = torch_br._empty_ut_only(
size=(die_spc_num, wn_block),
dtype=dtype,
is_numa=True,
device=torch.supa.current_device(),
tensor_type="linear_bias")
bias = torch.nn.functional.pad(tensor.cpu(),
(0, spc_num * wn_block - wn),
mode='constant',
value=0)
bias = bias.type(torch.float32).reshape(spc_num,
wn_block).contiguous()
if pad_zeros:
bias_zeros_die2 = torch.zeros((spc_num, wn_block),
dtype=bias.dtype)
bias = torch.concat([bias, bias_zeros_die2], dim=0)
else:
bias = torch.concat([bias, bias], dim=0)
numa_tensor.copy_(bias)
else:
raise ValueError(f"Unsupported tensor_type: {layout}")
torch.supa.synchronize()
supa_debug.set_enable_force_uma(enable_force_uma)
return numa_tensor
def _convert_to_numa_tensor(tensor,
align_size,
layout,
dtype,
do_transpose=False,
wk=None,
wn=None,
parallel_type="col_parallel",
pad_zeros=False):
assert parallel_type in ("col_parallel", "row_parallel")
enable_force_uma = supa_debug.is_enable_force_uma()
supa_debug.set_enable_force_uma(False)
spc_num = envs.VLLM_BR_DEVICE_SPC_NUM
layout = layout.lower()
die_num = 1
if spc_num > 16:
spc_num = spc_num // 2
die_num = 2
die_spc_num = die_num * spc_num
if layout == "colmajor":
wk = wk or tensor.shape[0]
wn = wn or tensor.shape[1]
if die_num == 1:
wn_block = (wn + spc_num - 1) // spc_num
wn_block = (wn_block + align_size - 1) // align_size * align_size
numa_tensor = torch_br._empty_ut_only(
size=(spc_num, wk, wn_block),
dtype=dtype,
is_numa=True,
device=torch.supa.current_device(),
tensor_type=layout)
if do_transpose:
data = tensor.cpu().permute(1, 0).contiguous()
else:
data = tensor.cpu().contiguous()
data = torch.nn.functional.pad(data,
(0, spc_num * wn_block - wn, 0, 0),
mode='constant',
value=0)
data = data.reshape(wk, spc_num, wn_block).permute(1, 0,
2).contiguous()
torch.supa.synchronize()
numa_tensor.copy_(data.to(torch.supa.current_device()))
else:
if parallel_type == "col_parallel":
wn_block = (wn // die_num + spc_num - 1) // spc_num
wn_block = (wn_block + align_size -
1) // align_size * align_size
numa_tensor = torch_br._empty_ut_only(
size=(die_spc_num, wk, wn_block),
dtype=dtype,
is_numa=True,
device=torch.supa.current_device(),
tensor_type=layout,
axis=0,
sbp="SS")
if do_transpose:
weight = tensor.cpu().permute(1, 0).contiguous().reshape(
wk, die_num, wn // die_num)
else:
weight = tensor.cpu().contiguous().reshape(
wk, die_num, wn // die_num)
weight = torch.nn.functional.pad(
weight,
(0, spc_num * wn_block - wn // die_num, 0, 0, 0, 0),
mode='constant',
value=0)
weight = weight.reshape(wk, die_spc_num,
wn_block).permute(1, 0,
2).contiguous()
numa_tensor.copy_(weight)
else:
wn_block = (wn + spc_num - 1) // spc_num
# w_block must align with 32 (warp_size)
wn_block = (wn_block + align_size -
1) // align_size * align_size
numa_tensor = torch_br._empty_ut_only(
size=(die_spc_num, wk // die_num, wn_block),
dtype=dtype,
is_numa=True,
device=torch.supa.current_device(),
tensor_type=layout,
axis=0,
sbp="SS")
if do_transpose:
weight = tensor.cpu().permute(1, 0).contiguous()
else:
weight = tensor.cpu().contiguous()
weight = torch.nn.functional.pad(
weight, (0, spc_num * wn_block - wn, 0, 0),
mode='constant',
value=0)
weight = weight.reshape(
die_num, wk // die_num, spc_num,
wn_block).permute(0, 2, 1, 3).contiguous().reshape(
die_spc_num, wk // die_num, wn_block)
numa_tensor.copy_(weight)
elif layout == "linear_bias":
# NOTE: bias and scales will not be converted to numa arch, still use uma weight for support fused_linear
# NOTE: index -1 for both scales and bias
wn = tensor.shape[-1] if wn is None else wn
expert_num = tensor.shape[-2] if len(tensor.shape) > 1 else 1
bias_shape = (expert_num, wn) if expert_num > 1 else (wn, )
if die_num == 1:
numa_tensor = torch_br._empty_ut_only(size=bias_shape,
dtype=dtype,
is_numa=False,
device=tensor.device,
tensor_type=layout)
data = tensor.cpu().type(dtype)
if expert_num > 1:
data = data.reshape(expert_num, wn)
else:
data = data.reshape(wn).type(dtype)
torch.supa.synchronize()
numa_tensor.copy_(data.to(tensor.device))
else:
if parallel_type == "col_parallel":
axis = 1 if expert_num > 1 else 0
numa_tensor = torch_br._empty_ut_only(size=bias_shape,
dtype=dtype,
is_numa=False,
device=tensor.device,
tensor_type="buffer_any",
axis=axis,
sbp="SB")
if expert_num == 1:
tensor = tensor.reshape(-1)
numa_tensor.copy_(tensor.to(torch.supa.current_device()))
else:
numa_tensor = torch_br._empty_ut_only(
size=bias_shape,
dtype=dtype,
is_numa=False,
device=tensor.device,
tensor_type="linear_bias",
sbp="BB")
bias = tensor.reshape(expert_num, wn).cpu().type(dtype)
if expert_num == 1:
bias = bias.reshape(-1)
numa_tensor.copy_(bias.to(torch.supa.current_device()))
else:
raise ValueError(f"Unsupported tensor_type: {layout}")
torch.supa.synchronize()
supa_debug.set_enable_force_uma(enable_force_uma)
return numa_tensor
def _convert_to_crossed_numa_tensor(t1,
t2,
spc_num,
dim=1,
need_pad=True,
layout="colmajor",
do_transpose=False):
"""Equals to V0: cross_weight_32 + numa_weight_convert/_numa_weight_cvt
"""
uma_weight = cross_weight_32(t1, t2, spc_num, dim, need_pad)
numa_weight = _convert_to_numa_tensor(uma_weight, 32, layout,
uma_weight.dtype, do_transpose)
return numa_weight
def _convert_to_numa_tensor_moe(tensor,
align_size,
layout,
dtype,
do_transpose=False,
wb=None,
wk=None,
wn=None,
parallel_type="col_parallel",
pad_zeros=False):
assert parallel_type in ("col_parallel", "row_parallel")
enable_force_uma = supa_debug.is_enable_force_uma()
supa_debug.set_enable_force_uma(False)
spc_num = envs.VLLM_BR_DEVICE_SPC_NUM
layout = layout.lower()
die_num = 1
if spc_num > 16:
spc_num = spc_num // 2
die_num = 2
die_spc_num = die_num * spc_num
assert die_num == 2
if layout == "colmajor":
wb = wb or tensor.shape[0]
wk = wk or tensor.shape[1]
wn = wn or tensor.shape[2]
if parallel_type == "col_parallel":
wn_block = (wn // die_num + spc_num - 1) // spc_num
wn_block = (wn_block + align_size - 1) // align_size * align_size
numa_tensor = torch_br._empty_ut_only(
size=(die_spc_num * wb, wk, wn_block),
dtype=dtype,
is_numa=True,
device=torch.supa.current_device(),
tensor_type=layout)
if do_transpose:
weight = tensor.cpu().permute(0, 2, 1).contiguous().reshape(
wb, wk, die_num, wn // die_num)
else:
weight = tensor.cpu().contiguous().reshape(
wb, wk, die_num, wn // die_num)
weight = torch.nn.functional.pad(
weight,
(0, spc_num * wn_block - wn // die_num, 0, 0, 0, 0, 0, 0),
mode='constant',
value=0)
weight = weight.reshape(wb, wk, die_spc_num, wn_block).permute(
2, 0, 1, 3).reshape(wb * die_spc_num, wk,
wn_block).contiguous()
numa_tensor.copy_(weight)
elif parallel_type == "row_parallel":
wn_block = (wn + spc_num - 1) // spc_num
wn_block = (wn_block + align_size - 1) // align_size * align_size
numa_tensor = torch_br._empty_ut_only(
size=(die_spc_num * wb, wk // die_num, wn_block),
dtype=dtype,
is_numa=True,
device=torch.supa.current_device(),
tensor_type=layout)
if do_transpose:
weight = tensor.cpu().permute(0, 2, 1).contiguous()
else:
weight = tensor.cpu().contiguous()
weight = torch.nn.functional.pad(
weight, (0, spc_num * wn_block - wn, 0, 0, 0, 0),
mode='constant',
value=0)
weight = weight.reshape(wb, die_num, wk // die_num, spc_num,
wn_block).permute(1, 3, 0, 2,
4).contiguous().reshape(
die_spc_num * wb,
wk // die_num,
wn_block)
numa_tensor.copy_(weight)
else:
raise ValueError(f"Unsupported tensor_type: {layout}")
torch.supa.synchronize()
supa_debug.set_enable_force_uma(enable_force_uma)
return numa_tensor, (die_spc_num, wk, wn_block)
def is_br166_device():
spc_num = torch_br.supa.get_device_properties(
torch.device("supa")).max_compute_units
return bool(spc_num > 16 and spc_num <= 32)

View File

@@ -0,0 +1,23 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from . import layer, supa_moe # noqa: E402
from .layer import * # noqa: E402
__all__ = [
"layer",
"supa_moe",
]

View File

@@ -0,0 +1,413 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from functools import wraps
from typing import Callable, Optional
import torch
import torch_br
from fastcore.basics import patch_to
from torch_br.utils.tensor_methods import Sbp
from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, UnquantizedFusedMoEMethod)
from vllm.model_executor.utils import set_weight_attrs
from vllm_br import envs
from ..br_utils import (_convert_to_crossed_numa_tensor,
_convert_to_numa_tensor, align_n, cross_weight_32)
from .supa_moe import (fused_moe_quant_device, fused_moe_quant_dyn,
fused_oss_moe_dyn)
@patch_to(UnquantizedFusedMoEMethod)
def forward_oot(
self,
layer: FusedMoE,
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
):
"""Forward for UnquantizedFusedMoEMethod with SUPA out-of-tree support.
"""
if activation == "swigluoai":
return fused_oss_moe_dyn(
x,
layer.w13_weight,
layer.w13_bias,
layer.w2_weight,
layer.w2_bias,
router_logits,
top_k,
layer.intermediate_size_per_partition,
renormalize=renormalize,
inplace=True,
use_grouped_topk=use_grouped_topk,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
ep_rank=layer.ep_rank,
ep_size=layer.ep_size)
b_seq = x.shape[0]
gating_weight, shared_gate_up_weight, shared_down_weight = router_logits
if b_seq > envs.VLLM_BR_STATIC_MOE_DECODER_MAX_LEN:
# prefill
return fused_moe_quant_dyn(
x,
shared_gate_up_weight,
shared_down_weight,
layer.w13_weight,
layer.w2_weight,
None,
None,
gating_weight,
top_k,
layer.intermediate_size_per_partition,
renormalize=renormalize,
inplace=True,
use_grouped_topk=use_grouped_topk,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
tp_rank=get_tp_group().rank_in_group,
global_rank=get_tp_group().rank,
tp_size=get_tensor_model_parallel_world_size(),
ep_rank=layer.ep_rank,
ep_size=layer.ep_size)
else:
# decoder
return fused_moe_quant_device(
x,
shared_gate_up_weight,
shared_down_weight,
layer.w13_weight,
layer.w2_weight,
None,
None,
gating_weight,
top_k,
layer.intermediate_size_per_partition,
renormalize=renormalize,
inplace=True,
use_grouped_topk=use_grouped_topk,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
tp_rank=get_tp_group().rank_in_group,
global_rank=get_tp_group().rank,
tp_size=get_tensor_model_parallel_world_size(),
ep_rank=layer.ep_rank,
ep_size=layer.ep_size)
@patch_to(UnquantizedFusedMoEMethod)
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
# Fused gate_up_proj (column parallel)
w13_weight = torch.nn.Parameter(torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_size,
device="cpu",
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
if self.moe.has_bias:
w13_bias = torch.nn.Parameter(torch.zeros(
num_experts,
2 * intermediate_size_per_partition,
device="cpu",
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w13_bias", w13_bias)
set_weight_attrs(w13_bias, extra_weight_attrs)
# down_proj (row parallel)
w2_weight = torch.nn.Parameter(torch.empty(num_experts,
hidden_size,
intermediate_size_per_partition,
device="cpu",
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
if self.moe.has_bias:
w2_bias = torch.nn.Parameter(torch.zeros(num_experts,
hidden_size,
device="cpu",
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w2_bias", w2_bias)
set_weight_attrs(w2_bias, extra_weight_attrs)
@patch_to(UnquantizedFusedMoEMethod)
def process_weights_after_loading(self: UnquantizedFusedMoEMethod,
layer: FusedMoE) -> None:
cur_device = torch.supa.current_device()
die_spc_num = envs.VLLM_BR_DEVICE_SPC_NUM
die_num = 1 if die_spc_num <= 16 else 2
spc_num = die_spc_num // die_num
align_size = 32 if layer.activation == "swigluoai" else 64
is_dual_die = (die_spc_num > 16)
# NOTE: w13_weight
# after _load_w13, w13_weight is a colparallel weight, shape
# [num_experts, 2 * intermediate_size_per_partition, hidden_size]
# for SUPA, transform it to a NUMA colmajor weight, shape
# [spc_num * num_experts, wk, wn_block] (wn = aligned(2 * intermediate_size_per_partition, align_size=64))
wk = layer.hidden_size
wn_block = align_n((layer.intermediate_size_per_partition * 2) // die_num,
align_size=align_size,
spc_num=spc_num)
supa_w13_weight = torch_br._empty_ut_only(
size=(die_spc_num * layer.local_num_experts, wk, wn_block),
dtype=torch.bfloat16,
is_numa=True,
device=cur_device,
tensor_type="colmajor",
axis=0,
sbp="SS" if is_dual_die else None)
for expert_id in range(layer.local_num_experts):
expert_w13 = layer.w13_weight[expert_id].transpose(0, 1).contiguous()
# swigluoai activation, no need do interweave
if layer.activation and layer.activation == "swigluoai":
pad_expert_w13 = _convert_to_numa_tensor(expert_w13, align_size,
'COLMAJOR',
expert_w13.dtype)
pad_expert_w13_shape = pad_expert_w13.shape
hw_size = pad_expert_w13_shape[-2] * pad_expert_w13_shape[-1]
narrow_data = supa_w13_weight.view_as_usharp(
"COLMAJOR", pad_expert_w13_shape, Sbp.ss(0),
expert_id * hw_size)
narrow_data.copy_(pad_expert_w13)
else:
expert_1, expert_3 = expert_w13.chunk(2, dim=1)
pad_expert_w13 = _convert_to_crossed_numa_tensor(expert_1,
expert_3,
die_spc_num,
dim=1,
need_pad=True,
layout='COLMAJOR')
hw_size = pad_expert_w13.shape[-2] * pad_expert_w13.shape[-1]
narrow_data = supa_w13_weight.view_as_usharp(
"COLMAJOR", pad_expert_w13.shape, Sbp.ss(0),
expert_id * hw_size)
narrow_data.copy_(pad_expert_w13)
layer.w13_weight.data = supa_w13_weight
# NOTE: w13_bias
if hasattr(layer, "w13_bias") and layer.w13_bias is not None:
wn = layer.intermediate_size_per_partition * 2
supa_w13_bias = torch_br._empty_ut_only(
size=(layer.local_num_experts, wn),
dtype=torch.float32,
is_numa=False,
device=cur_device,
tensor_type="linear_bias",
sbp="BB" if is_dual_die else None)
for expert_id in range(layer.local_num_experts):
expert_w13_bias = layer.w13_bias[expert_id]
# swigluoai activation, no need do interweave
if layer.activation and layer.activation == "swigluoai":
narrow_data = supa_w13_bias[expert_id]
narrow_data.copy_(expert_w13_bias)
else:
expert_1_bias, expert_3_bias = expert_w13_bias.chunk(2, dim=-1)
crossed_expert_w13_bias = cross_weight_32(
expert_1_bias,
expert_3_bias,
die_spc_num,
dim=0,
need_pad=False,
)
narrow_data = supa_w13_bias[expert_id]
narrow_data.copy_(crossed_expert_w13_bias)
layer.w13_bias.data = supa_w13_bias
# NOTE: w2_weight
# after _load_w2, w2_weight is a rowparallel weight, shape
# [num_experts, hidden_size, intermediate_size_per_partition]
# for SUPA, transform it to a NUMA colmajor weight, shape
# [spc_num * num_experts, wk, wn_block]
align_size = 32
wk = layer.intermediate_size_per_partition
wn_block = align_n(layer.hidden_size,
align_size=align_size,
spc_num=spc_num)
supa_w2_weight = torch_br._empty_ut_only(
size=(die_spc_num * layer.local_num_experts, wk // die_num, wn_block),
dtype=torch.bfloat16,
is_numa=True,
device=cur_device,
tensor_type="colmajor",
axis=0,
sbp="SS" if is_dual_die else None)
for expert_id in range(layer.local_num_experts):
expert_w2 = layer.w2_weight[expert_id].transpose(0, 1).contiguous()
pad_expert_w2 = _convert_to_numa_tensor(expert_w2,
align_size,
'COLMAJOR',
expert_w2.dtype,
parallel_type="row_parallel")
pad_expert_w2_shape = pad_expert_w2.shape
hw_size = pad_expert_w2_shape[-2] * pad_expert_w2_shape[-1]
narrow_data = supa_w2_weight.view_as_usharp("COLMAJOR",
pad_expert_w2_shape,
Sbp.ss(0),
expert_id * hw_size)
narrow_data.copy_(pad_expert_w2)
layer.w2_weight.data = supa_w2_weight
# NOTE: w2_bias
if hasattr(layer, "w2_bias") and layer.w2_bias is not None:
wn = layer.hidden_size
supa_w2_bias = torch.zeros((layer.local_num_experts, wn),
dtype=torch.float32,
device=cur_device)
for expert_id in range(layer.local_num_experts):
expert_w2 = layer.w2_bias[expert_id]
narrow_data = supa_w2_bias[expert_id]
narrow_data.copy_(expert_w2)
layer.w2_bias.data = supa_w2_bias
@patch_to(FusedMoE)
def forward(self: FusedMoE, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
"""
! router_logits is a tuple of gate, shared_experts.gate_up_proj,
shared_experts.down_proj weights.
"""
assert self.quant_method is not None
assert self.dp_size == 1, 'dp_size > 1 is not supported for now, please refer v0.11.0 moe codes'
final_hidden_states = self.quant_method.apply(
layer=self,
x=hidden_states,
router_logits=router_logits,
top_k=self.top_k,
renormalize=self.renormalize,
use_grouped_topk=self.use_grouped_topk,
global_num_experts=self.global_num_experts,
expert_map=self.expert_map,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
custom_routing_function=self.custom_routing_function,
scoring_func=self.scoring_func,
e_score_correction_bias=self.e_score_correction_bias,
activation=self.activation,
apply_router_weight_on_input=self.apply_router_weight_on_input,
)
# NOTE: if using supa-moe-ccl kernel, add property `all_reduced` to the final_hidden_states
support_types = ((16, 4), (16, 8), (32, 2), (32, 4))
tp_size = get_tensor_model_parallel_world_size()
if hidden_states.shape[
0] <= envs.VLLM_BR_STATIC_MOE_DECODER_MAX_LEN and envs.VLLM_BR_QUANT_METHOD != "INT4" and envs.VLLM_BR_USE_FUSED_ALLREDUCE and (
envs.VLLM_BR_DEVICE_SPC_NUM, tp_size) in support_types:
final_hidden_states.all_reduced = True
return final_hidden_states
@patch_to(FusedMoE)
def _load_w13(self, expert_data: torch.Tensor, shard_dim: int, shard_id: str,
loaded_weight: torch.Tensor, tp_rank: int):
# Index the loaded weight for tp sharding.
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
shard_size = expert_data.shape[shard_dim] // 2
loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank,
shard_size)
# Narrow parameter and load.
# w1, gate_proj: Load into first logical weight of w13.
if shard_id == "w1":
expert_data = expert_data.narrow(shard_dim, 0, shard_size)
# w3, up_proj: Load into second logical weight of w13.
else:
assert shard_id == "w3"
expert_data = expert_data.narrow(shard_dim, shard_size, shard_size)
expert_data.copy_(loaded_weight.cpu())
@patch_to(FusedMoE)
def _load_w2(self,
expert_data: torch.Tensor,
shard_dim: int,
loaded_weight: torch.Tensor,
tp_rank: int,
load_full: bool = False):
# Index the loaded weight for tp sharding.
# down_proj: "RowParallel" so tp sharding on input_dim
# Narrow parameter and load.
shard_size = expert_data.shape[shard_dim]
if not load_full:
loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank,
shard_size)
# w2, down_proj: Load into only logical weight of w2.
expert_data.copy_(loaded_weight.cpu())
def wrapper_FusedMoE_init(fn):
@wraps(fn)
def wrapper(self, *args, **kwargs):
fn(self, *args, **kwargs)
if self.e_score_correction_bias is not None:
self.e_score_correction_bias.data = self.e_score_correction_bias.float(
)
return wrapper
FusedMoE.__init__ = wrapper_FusedMoE_init(FusedMoE.__init__) # noqa: E501

View File

@@ -0,0 +1,518 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from typing import Callable, Optional
import torch
import torch_br
from torch_br.utils.tensor_methods import Sbp
from vllm_br import envs
# gpt-oss moe forward version
def fused_oss_moe_dyn(
hidden_states: torch.Tensor,
w13: torch.Tensor,
w13_bias: torch.Tensor,
w2: torch.Tensor,
w2_bias: torch.Tensor,
gating_weight: torch.Tensor,
topk: int,
intermediate_size: int,
renormalize: bool,
inplace: bool = False,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
ep_rank: Optional[int] = None,
ep_size: Optional[int] = None,
) -> torch.Tensor:
total_expert_num = gating_weight.shape[-2]
probs_supa, indices_supa, prob_per_expert, indice_per_expert = torch_br.supa_moe_router_v2_infer(
hidden_states,
gating_weight,
topk,
ep_size,
ep_rank,
gating_bias=e_score_correction_bias)
cur_device = hidden_states.device
probs_supa = probs_supa.cpu().permute(1, 0).contiguous().to(cur_device)
indices_supa = indices_supa.cpu().permute(1, 0).contiguous().to(cur_device)
indice_per_expert = indice_per_expert.cpu().permute(
1, 0).contiguous().to(cur_device)
prob_per_expert = prob_per_expert.cpu().permute(
1, 0).contiguous().to(cur_device)
is_dual_die = (envs.VLLM_BR_DEVICE_SPC_NUM > 16)
local_expert_num = total_expert_num // ep_size # type: ignore
b_seq = hidden_states.shape[0]
indices_trans_supa = torch_br._empty_ut_only(
size=(local_expert_num, (b_seq + 64 - 1) // 64 * 64),
dtype=torch.int32,
is_numa=False,
device=hidden_states.device,
tensor_type="colmajor",
sbp="BB" if is_dual_die else None)
tokens_per_expert_supa = torch.bincount(indices_supa.reshape(-1),
minlength=total_expert_num)
tokens_per_expert_list = tokens_per_expert_supa.cpu().data.numpy().tolist(
)[ep_rank * local_expert_num:(ep_rank + 1) * # type: ignore
local_expert_num]
topk_per_expert = sum(1 for x in tokens_per_expert_list if x != 0)
if topk_per_expert > 0:
expert_tokens = torch_br.supa_permutation_infer(
global_hidden_states=hidden_states,
indices=indice_per_expert,
tokens_per_expert=tokens_per_expert_list,
indices_trans=indices_trans_supa)
assert len(
expert_tokens) == local_expert_num, "Number of experts mismatch"
gate_up_outputs = []
down_outputs = []
cur_device = expert_tokens[0].device
hidden_size = expert_tokens[0].shape[-1]
for i in range(local_expert_num):
if tokens_per_expert_list[i] == 0:
gate_up_outputs.append(
torch.empty(size=(0, intermediate_size),
dtype=torch.bfloat16,
device=cur_device))
down_outputs.append(
torch.empty(size=(0, hidden_size),
dtype=torch.float32,
device=cur_device))
continue
gate_up_output = torch_br._empty_ut_only(
size=(tokens_per_expert_list[i], intermediate_size),
dtype=torch.bfloat16,
is_numa=False,
device=cur_device,
tensor_type="colmajor",
sbp="SB" if is_dual_die else None,
axis=1)
gate_up_outputs.append(gate_up_output)
down_output = torch_br._empty_ut_only(
size=(tokens_per_expert_list[i], hidden_size),
dtype=torch.float32,
is_numa=False,
device=cur_device,
tensor_type="colmajor",
sbp="BB" if is_dual_die else None)
down_outputs.append(down_output)
torch_br.supa_moe_fused_ffn_dyn_infer(gate_up_outputs,
expert_tokens,
w13,
tokens_per_expert_list,
max(tokens_per_expert_list),
bias=w13_bias,
act_mode="act_swiglu_oai")
torch_br.supa_moe_fused_ffn_dyn_infer(down_outputs,
gate_up_outputs,
w2,
tokens_per_expert_list,
max(tokens_per_expert_list),
bias=w2_bias,
act_mode="act_default")
output = torch_br.supa_unpermutation_infer(
input_list=down_outputs,
indices=indices_trans_supa,
probs=prob_per_expert,
tokens_per_expert=tokens_per_expert_list)
else:
output = torch.zeros_like(hidden_states)
return output.unsqueeze(0)
def fused_moe_quant_dyn(
hidden_states: torch.Tensor,
shared_gate_up_weight: torch.Tensor,
down_weight: torch.Tensor,
w13: torch.Tensor,
w2: torch.Tensor,
w13_scale: torch.Tensor,
w2_scale: torch.Tensor,
gating_weight: torch.Tensor,
topk: int,
intermediate_size: int,
renormalize: bool,
inplace: bool = False,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
tp_rank: Optional[int] = None,
global_rank: Optional[int] = None,
tp_size: Optional[int] = None,
ep_rank: Optional[int] = None,
ep_size: Optional[int] = None,
) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
weights, w1 and w2, and top-k gating mechanism.
Parameters:
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- gating_output (torch.Tensor): The output of the gating operation
(before softmax).
- topk (int): The number of top-k experts to select.
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- inplace (bool): If True, perform the operation in-place.
Defaults to False.
- override_config (Optional[Dict[str, Any]]): Optional override
for the kernel configuration.
- num_expert_group: Optional[int]: additional parameter for grouped_topk
- topk_group: Optional[int]: additional parameter for grouped_topk
- use_grouped_topk: If True, use grouped_topk instead of fused_topk
note: Deepseekv2 model uses grouped_topk
- use_fp8 (bool): If True, use fp8 arithmetic to compute the inner
products for w1 and w2. Defaults to False.
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
w1.
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
w2.
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""
is_dual_die = (envs.VLLM_BR_DEVICE_SPC_NUM > 16)
total_expert_num = gating_weight.shape[-1]
cur_device = hidden_states.device
if use_grouped_topk:
assert num_expert_group is not None and topk_group is not None
shared_output, _, indices_supa, indice_per_expert, prob_per_expert = torch_br.supa_fused_shared_router_prefill_v2_infer(
hidden_states,
shared_gate_up_weight,
down_weight,
gating_weight,
intermediate_size,
topk,
num_expert_group,
topk_group,
ep_size,
ep_rank,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
if is_dual_die:
shared_tmp = torch_br._empty_ut_only(size=shared_output.shape,
dtype=shared_output.dtype,
is_numa=False,
device=shared_output.device,
tensor_type="colmajor")
shared_tmp.copy_(shared_output)
shared_output = shared_tmp
else:
assert topk_group is None, "Only support non group topk router"
assert shared_gate_up_weight is None and down_weight is None
_, indices_supa, prob_per_expert, indice_per_expert = torch_br.supa_moe_router_v2_infer(
hidden_states,
gating_weight.permute(1, 0).contiguous(),
topk,
ep_size,
ep_rank,
gating_bias=e_score_correction_bias)
shared_output = None
indices_supa = indices_supa.permute(1, 0).contiguous()
indice_per_expert = indice_per_expert.permute(1, 0).contiguous()
prob_per_expert = prob_per_expert.permute(1, 0).contiguous()
local_expert_num = total_expert_num // ep_size # type: ignore
b_seq = hidden_states.shape[0]
indices_trans_supa = torch_br._empty_ut_only(
size=(local_expert_num, (b_seq + 64 - 1) // 64 * 64),
dtype=torch.int32,
is_numa=False,
device=hidden_states.device,
tensor_type="colmajor",
sbp="BB" if is_dual_die else None)
tokens_per_expert_supa = torch.bincount(indices_supa.reshape(-1),
minlength=total_expert_num)
tokens_per_expert_list = tokens_per_expert_supa.cpu().data.numpy().tolist(
)[ep_rank * local_expert_num:(ep_rank + 1) * # type: ignore
local_expert_num]
topk_per_expert = sum(1 for x in tokens_per_expert_list if x != 0)
if topk_per_expert > 0:
expert_tokens = torch_br.supa_permutation_infer(
global_hidden_states=hidden_states,
indices=indice_per_expert,
tokens_per_expert=tokens_per_expert_list,
indices_trans=indices_trans_supa)
assert len(
expert_tokens) == local_expert_num, "Number of experts mismatch"
supa_device = torch.supa.current_device()
spc_num = torch_br.supa.get_device_properties(
supa_device).max_compute_units
out_expert_tokens = []
use_moe_fused_ffn_dyn = True
if not use_moe_fused_ffn_dyn or total_expert_num == 128:
w13_hw = w13.shape[-2] * w13.shape[-1]
w2_hw = w2.shape[-2] * w2.shape[-1]
for i in range(local_expert_num):
expert_token = expert_tokens[i]
if tokens_per_expert_list[i] == 0:
out_expert_tokens.append(expert_token)
continue
expert_gate_up_weight = w13.view_as_usharp(
"COLMAJOR", (spc_num, w13.shape[-2], w13.shape[-1]),
Sbp.ss(0), i * w13_hw)
down_weight = w2.view_as_usharp(
"COLMAJOR", (spc_num, w2.shape[-2], w2.shape[-1]),
Sbp.ss(0), i * w2_hw)
expert_gate_up_scale = w13_scale[
i] if w13_scale is not None else None
down_scale = w2_scale[i] if w2_scale is not None else None
gate_up_output = torch_br._empty_ut_only(
size=(expert_token.shape[0], intermediate_size),
dtype=torch.bfloat16,
is_numa=False,
device=expert_token.device,
tensor_type="colmajor",
sbp="SB" if is_dual_die else None,
axis=1)
torch_br.supa_fused_linear_infer(gate_up_output,
expert_token,
expert_gate_up_weight,
expert_gate_up_scale,
act_mode="act_swiglu")
down_output = torch_br._empty_ut_only(
size=expert_token.shape,
dtype=torch.float32,
is_numa=False,
device=gate_up_output.device,
tensor_type="colmajor",
sbp="BB" if is_dual_die else None)
torch_br.supa_fused_linear_infer(down_output, gate_up_output,
down_weight, down_scale)
out_expert_tokens.append(down_output)
else:
gate_up_outputs = []
cur_device = expert_tokens[0].device
hidden_size = expert_tokens[0].shape[-1]
for i in range(local_expert_num):
if tokens_per_expert_list[i] == 0:
gate_up_outputs.append(
torch.empty(size=(0, intermediate_size),
dtype=torch.bfloat16,
device=cur_device))
out_expert_tokens.append(
torch.empty(size=(0, hidden_size),
dtype=torch.float32,
device=cur_device))
continue
gate_up_output = torch_br._empty_ut_only(
size=(tokens_per_expert_list[i], intermediate_size),
dtype=torch.bfloat16,
is_numa=False,
device=cur_device,
tensor_type="colmajor",
sbp="SB" if is_dual_die else None,
axis=1)
gate_up_outputs.append(gate_up_output)
down_output = torch_br._empty_ut_only(
size=(tokens_per_expert_list[i], hidden_size),
dtype=torch.float32,
is_numa=False,
device=cur_device,
tensor_type="colmajor",
sbp="BB" if is_dual_die else None)
out_expert_tokens.append(down_output)
torch_br.supa_moe_fused_ffn_dyn_infer(gate_up_outputs,
expert_tokens,
w13,
tokens_per_expert_list,
max(tokens_per_expert_list),
scales=w13_scale,
act_mode="act_swiglu")
torch_br.supa_moe_fused_ffn_dyn_infer(out_expert_tokens,
gate_up_outputs,
w2,
tokens_per_expert_list,
max(tokens_per_expert_list),
scales=w2_scale,
act_mode="act_default")
out_states = torch_br.supa_unpermutation_infer(
input_list=out_expert_tokens,
indices=indices_trans_supa,
probs=prob_per_expert,
tokens_per_expert=tokens_per_expert_list)
output = out_states if shared_output is None else out_states + shared_output
else:
output = torch.zeros_like(
hidden_states) if shared_output is None else shared_output
return output.unsqueeze(0)
def fused_moe_quant_device(
hidden_states: torch.Tensor,
shared_gate_up_weight: torch.Tensor,
down_weight: torch.Tensor,
w13: torch.Tensor,
w2: torch.Tensor,
w13_scale: torch.Tensor,
w2_scale: torch.Tensor,
gating_weight: torch.Tensor,
topk: int,
intermediate_size: int,
renormalize: bool,
inplace: bool = False,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
tp_rank: Optional[int] = None,
global_rank: Optional[int] = None,
tp_size: Optional[int] = None,
ep_rank: Optional[int] = None,
ep_size: Optional[int] = None,
) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
weights, w1 and w2, and top-k gating mechanism.
Parameters:
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- gating_output (torch.Tensor): The output of the gating operation
(before softmax).
- topk (int): The number of top-k experts to select.
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- inplace (bool): If True, perform the operation in-place.
Defaults to False.
- override_config (Optional[Dict[str, Any]]): Optional override
for the kernel configuration.
- num_expert_group: Optional[int]: additional parameter for grouped_topk
- topk_group: Optional[int]: additional parameter for grouped_topk
- use_grouped_topk: If True, use grouped_topk instead of fused_topk
note: Deepseekv2 model uses grouped_topk
- use_fp8 (bool): If True, use fp8 arithmetic to compute the inner
products for w1 and w2. Defaults to False.
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
w1.
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
w2.
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""
is_dual_die = (envs.VLLM_BR_DEVICE_SPC_NUM > 16)
expert_num = gating_weight.shape[-1]
b_seq = hidden_states.shape[-2]
if topk_group is None:
assert shared_gate_up_weight is None and down_weight is None
shared_output, masked_probs, hitted_experts = torch_br.supa_moe_router_decoder_infer(
hidden_states, gating_weight, topk, ep_size, ep_rank)
else:
assert use_grouped_topk is True, "Only support group topk router"
assert num_expert_group is not None and topk_group is not None
if ep_size > 1: # type: ignore
shared_output, masked_probs, hitted_experts = torch_br.supa_fused_shared_router_v2_infer(
hidden_states,
shared_gate_up_weight,
down_weight,
gating_weight,
intermediate_size,
topk,
num_expert_group,
topk_group,
ep_size,
ep_rank,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias
if e_score_correction_bias is not None else torch.empty(
(expert_num),
dtype=torch.float32,
device=hidden_states.device))
else:
shared_output, masked_probs, hitted_experts = torch_br.supa_fused_shared_router_infer(
hidden_states,
shared_gate_up_weight,
down_weight,
gating_weight,
intermediate_size,
topk,
num_expert_group,
topk_group,
scoring_func,
e_score_correction_bias=e_score_correction_bias
if e_score_correction_bias is not None else torch.empty(
(expert_num),
dtype=torch.float32,
device=hidden_states.device))
if is_dual_die:
shared_output = shared_output.view_as_usharp(
"COLMAJOR", shared_output.shape, Sbp.bb())
if w13.dtype == torch.int32:
torch_br.supa_moe_fused_ffn_s4_infer(shared_output, hidden_states, w13,
w2, hitted_experts, masked_probs,
w13_scale, w2_scale)
else:
support_types = ((16, 4), (16, 8), (32, 2), (32, 4))
if envs.VLLM_BR_USE_FUSED_ALLREDUCE and b_seq <= envs.VLLM_BR_STATIC_MOE_DECODER_MAX_LEN and (
envs.VLLM_BR_DEVICE_SPC_NUM, tp_size) in support_types:
# ffn+allreduce only support tp 4|8 and 16spc
torch_br.supa_moe_fused_ffn_allreduce(shared_output, hidden_states,
w13, w2, hitted_experts,
masked_probs, tp_rank,
tp_size, global_rank, 0,
w13_scale, w2_scale)
else:
torch_br.supa_moe_fused_ffn_infer(shared_output, hidden_states,
w13, w2, hitted_experts,
masked_probs, w13_scale,
w2_scale)
return shared_output.unsqueeze(0)

View File

@@ -0,0 +1,67 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
import os
from typing import Optional, Tuple, Union
import torch
import torch_br
from fastcore.basics import patch_to
from torch import Tensor, nn
from vllm.model_executor.layers.layernorm import RMSNorm
@patch_to(RMSNorm)
def forward_oot(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if self.weight.data.dtype == torch.bfloat16:
self.weight.data = self.weight.data.to(torch.float32)
if residual is not None:
y_supa, add_out_supa = torch_br.supa_add_rmsnorm_infer( # type: ignore
x, residual, self.weight.data, self.variance_epsilon)
return y_supa, add_out_supa
else:
if len(x.shape) == 2:
x = x.unsqueeze(0)
if len(x.shape) == 4:
x = x.squeeze(0)
x = torch_br.supa_rmsnorm_infer(
x,
self.weight.data,
self.variance_epsilon # type: ignore
)
return x
@patch_to(RMSNorm)
def enabled(cls) -> bool:
return True
@patch_to(nn.LayerNorm)
def forward(self, input: Tensor) -> Tensor:
if os.environ.get("USE_BR_FUSED_LAYERNORM",
'False').lower() not in {'false', '0', ''}:
return torch_br.fused_layernorm(input, self.weight, self.bias,
self.eps)
else:
return nn.functional.layer_norm(input, self.normalized_shape,
self.weight, self.bias, self.eps)

View File

@@ -0,0 +1,767 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from typing import Literal, Optional, Union
import torch
import torch.nn.functional as F
import torch_br
import torch_br.supa._debug as supa_debug
from fastcore.basics import patch_to
from torch.nn.parameter import Parameter
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
get_tp_group, split_tensor_along_last_dim,
tensor_model_parallel_all_reduce)
from vllm.distributed.parallel_state import get_pp_group
from vllm.logger import logger
from vllm.model_executor.layers.linear import (adjust_bitsandbytes_4bit_shard,
adjust_marlin_shard,
adjust_scalar_to_fused_array)
from vllm_br import envs
from vllm_br.utils import get_grandparent_pid
from .br_utils import (_convert_to_crossed_numa_tensor,
_convert_to_numa_tensor, _convert_to_numa_tensor_vit,
is_br166_device)
from vllm.model_executor.layers.linear import ( # isort:skip
LinearBase, MergedColumnParallelLinear, QuantizationConfig,
ReplicatedLinear, RowParallelLinear, UnquantizedLinearMethod,
QKVParallelLinear)
def _should_skip_linear_post_process(layer, use_ds_mla, use_ds_mla_sparse):
"""NOTE: SUPA: for MLA linears, we do process in MLA.process_weights_after_loading """
# TODO: Hard code for native dsa op
if use_ds_mla_sparse:
MLA_LINEAR_NAMES = [
"kv_b_proj",
]
else:
MLA_LINEAR_NAMES = [
"q_a_proj",
"q_b_proj",
# "q_proj",
"kv_a_proj_with_mqa",
"kv_b_proj",
# "o_proj",
]
if use_ds_mla and not use_ds_mla_sparse:
MLA_LINEAR_NAMES.append("o_proj")
skip = any(k in layer.prefix for k in MLA_LINEAR_NAMES)
if skip:
logger.debug(
f'[SUPA] skip {layer.prefix} UnquantizedLinearMethod.process_weights_after_loading' # noqa: G004
)
return skip
# NOTE: ReplicatedLinear, usually used in MoE as a gate module.
# In DeepseekV3, it needs to be transposed.
def process_weights_ReplicatedLinear(
layer: ReplicatedLinear) -> Literal[True, False]:
layer.weight.data = layer.weight.data.transpose(1, 0).contiguous()
return True
def process_share_expert_weight(layer: MergedColumnParallelLinear):
gate_up_weight = layer.weight.transpose(1, 0).contiguous()
gate_weight, up_weight = torch.chunk(gate_up_weight, 2, dim=-1)
die_spc_num = envs.VLLM_BR_DEVICE_SPC_NUM
is_br166 = die_spc_num > 16
spc_num = die_spc_num // 2 if is_br166 else die_spc_num
if is_br166:
# 2&2 for 4spc, 4&8 for 12spc, 8&8 for 16spc
spc_for_shared = 2 if spc_num == 4 else 8
spc_for_router = spc_num - spc_for_shared
align_size = 32
weight_dtype = gate_weight.dtype
hidden_size = gate_weight.shape[0]
gate_d0, gate_d1 = torch.chunk(gate_weight, 2, dim=-1)
up_d0, up_d1 = torch.chunk(up_weight, 2, dim=-1)
im_size = gate_d0.shape[-1]
n_align_size = (align_size * 2) * spc_for_shared
swiglu_w_aligned = ((
(im_size * 2) + n_align_size - 1) // n_align_size) * n_align_size
region_size = swiglu_w_aligned // spc_for_shared
block_nums = (region_size // (align_size * 2)) * spc_for_shared
gate_d0_align = torch.nn.functional.pad(
gate_d0, (0, (swiglu_w_aligned // 2) - im_size, 0, 0),
mode='constant',
value=0)
gate_d1_align = torch.nn.functional.pad(
gate_d1, (0, (swiglu_w_aligned // 2) - im_size, 0, 0),
mode='constant',
value=0)
up_d0_align = torch.nn.functional.pad(
up_d0, (0, (swiglu_w_aligned // 2) - im_size, 0, 0),
mode='constant',
value=0)
up_d1_align = torch.nn.functional.pad(
up_d1, (0, (swiglu_w_aligned // 2) - im_size, 0, 0),
mode='constant',
value=0)
gate_weight_d0_reshape = gate_d0_align.reshape(
hidden_size, block_nums, align_size).contiguous().to(weight_dtype)
gate_weight_d1_reshape = gate_d1_align.reshape(
hidden_size, block_nums, align_size).contiguous().to(weight_dtype)
up_weight_d0_reshape = up_d0_align.reshape(
hidden_size, block_nums, align_size).contiguous().to(weight_dtype)
up_weight_d1_reshape = up_d1_align.reshape(
hidden_size, block_nums, align_size).contiguous().to(weight_dtype)
gate_up_weight_d0 = torch.zeros(
[hidden_size, block_nums, align_size * 2],
dtype=weight_dtype,
device='supa')
gate_up_weight_d0[:, :, 0:0 +
align_size] = gate_weight_d0_reshape[:, :,
0:align_size]
gate_up_weight_d0[:, :, align_size:align_size *
2] = up_weight_d0_reshape[:, :, 0:align_size]
gate_up_weight_d0 = gate_up_weight_d0.reshape(
hidden_size, spc_for_shared,
region_size).permute(1, 0, 2).contiguous().to(weight_dtype)
gate_up_d0_invalid = torch.zeros(
[spc_for_router, hidden_size, region_size],
dtype=weight_dtype,
device='supa') # invalid regions
gate_up_weight_d0_whole = torch.cat(
[gate_up_weight_d0, gate_up_d0_invalid], dim=0)
gate_up_weight_d1 = torch.zeros(
[hidden_size, block_nums, align_size * 2],
dtype=weight_dtype,
device='supa')
gate_up_weight_d1[:, :, 0:0 +
align_size] = gate_weight_d1_reshape[:, :,
0:align_size]
gate_up_weight_d1[:, :, align_size:align_size *
2] = up_weight_d1_reshape[:, :, 0:align_size]
gate_up_weight_d1 = gate_up_weight_d1.reshape(
hidden_size, spc_for_shared,
region_size).permute(1, 0, 2).contiguous().to(weight_dtype)
gate_up_d1_invalid = torch.zeros(
[spc_for_router, hidden_size, region_size],
dtype=weight_dtype,
device='supa') # invalid regions
gate_up_weight_d1_whole = torch.cat(
[gate_up_weight_d1, gate_up_d1_invalid], dim=0)
gate_up_weight_whole = torch.cat(
[gate_up_weight_d0_whole, gate_up_weight_d1_whole], dim=0)
gate_up_weight_supa = torch_br._empty_ut_only(
size=gate_up_weight_whole.shape,
dtype=gate_weight.dtype,
is_numa=True,
device="supa",
tensor_type="colmajor")
gate_up_weight_supa.copy_(gate_up_weight_whole)
layer.weight.data = gate_up_weight_supa
else:
# 2&2 for 4spc, 4&8 for 12spc, 8&8 for 16spc
spc_for_shared = 2 if spc_num == 4 else 8
spc_for_router = spc_num - spc_for_shared
align_size = 32
weight_dtype = gate_weight.dtype
hidden_size = gate_weight.shape[0]
im_size = gate_weight.shape[-1]
n_align_size = (align_size * 2) * spc_for_shared
swiglu_w_aligned = ((
(im_size * 2) + n_align_size - 1) // n_align_size) * n_align_size
region_size = swiglu_w_aligned // spc_for_shared
block_nums = (region_size // (align_size * 2)) * spc_for_shared
gate_golden_align = torch.nn.functional.pad(
gate_weight, (0, (swiglu_w_aligned // 2) - im_size, 0, 0),
mode='constant',
value=0)
up_golden_align = torch.nn.functional.pad(
up_weight, (0, (swiglu_w_aligned // 2) - im_size, 0, 0),
mode='constant',
value=0)
gate_weight_golden_reshape = gate_golden_align.reshape(
hidden_size, block_nums, align_size).contiguous().to(weight_dtype)
up_weight_golden_reshape = up_golden_align.reshape(
hidden_size, block_nums, align_size).contiguous().to(weight_dtype)
gate_up_weight_golden = torch.zeros(
[hidden_size, block_nums, align_size * 2],
dtype=weight_dtype,
device='supa')
gate_up_weight_golden[:, :, 0:0 +
align_size] = gate_weight_golden_reshape[:, :, 0:
align_size]
gate_up_weight_golden[:, :, align_size:align_size *
2] = up_weight_golden_reshape[:, :, 0:align_size]
gate_up_weight_golden = gate_up_weight_golden.reshape(
hidden_size, spc_for_shared,
region_size).permute(1, 0, 2).contiguous().to(weight_dtype)
gate_up_invalid = torch.zeros(
[spc_for_router, hidden_size, region_size],
dtype=weight_dtype,
device='supa') # invalid regions
gate_up_weight_whole = torch.cat(
[gate_up_weight_golden, gate_up_invalid], dim=0)
gate_up_weight_supa = torch_br._empty_ut_only(
size=gate_up_weight_whole.shape,
dtype=gate_weight.dtype,
is_numa=True,
device="supa",
tensor_type="colmajor")
gate_up_weight_supa.copy_(gate_up_weight_whole)
layer.weight.data = gate_up_weight_supa
# NOTE: MergedColumnParallelLinear, usually used in MergedGateUpMLPSiluL2
def process_weights_QuantMergedColumnParallelLinear(
layer: MergedColumnParallelLinear):
if 'shared_experts' not in layer.prefix:
#NOTE: normal MLP gate_up, after load weight, convert to supa numa tensor
if hasattr(layer, "qweight"):
gate_weight, up_weight = torch.chunk(layer.qweight, 2, dim=-1)
gate_up_weight_numa = _convert_to_crossed_numa_tensor(
gate_weight,
up_weight,
envs.VLLM_BR_DEVICE_SPC_NUM,
dim=-1,
need_pad=True,
do_transpose=False)
layer.qweight.data = gate_up_weight_numa
else:
gate_up_weight = layer.weight.permute(1, 0).contiguous()
gate_up_weight_numa = _convert_to_numa_tensor(
gate_up_weight,
32,
"colmajor",
gate_up_weight.dtype,
False,
parallel_type="col_parallel")
layer.weight.data = gate_up_weight_numa
if hasattr(layer, "scales") and layer.scales is not None:
gate_scales, up_scales = torch.chunk(layer.scales, 2, dim=-1)
gate_up_scales_internleaved_numa = _convert_to_crossed_numa_tensor(
gate_scales,
up_scales,
envs.VLLM_BR_DEVICE_SPC_NUM,
dim=-1,
need_pad=False,
layout="linear_bias",
do_transpose=False)
layer.scales.data = gate_up_scales_internleaved_numa
if hasattr(layer, "bias") and layer.bias is not None:
gate_bias, up_bias = torch.chunk(layer.bias, 2, dim=-1)
gate_up_bias_internleaved_numa = _convert_to_crossed_numa_tensor(
gate_bias,
up_bias,
envs.VLLM_BR_DEVICE_SPC_NUM,
dim=-1,
need_pad=False,
layout="linear_bias",
do_transpose=False)
layer.bias.data = gate_up_bias_internleaved_numa
else:
process_share_expert_weight(layer)
def process_weights_MergedColumnParallelLinear(
layer: MergedColumnParallelLinear):
if 'shared_experts' not in layer.prefix:
gate_up_weight = layer.weight.permute(1, 0).contiguous()
if not (hasattr(layer, "no_need_cross") and layer.no_need_cross):
gate_weight, up_weight = torch.chunk(gate_up_weight, 2, dim=-1)
gate_up_weight_internleaved_numa = _convert_to_crossed_numa_tensor(
gate_weight,
up_weight,
envs.VLLM_BR_DEVICE_SPC_NUM,
dim=-1,
need_pad=True,
do_transpose=False)
layer.weight.data = gate_up_weight_internleaved_numa
else:
gate_up_weight_numa = _convert_to_numa_tensor(
gate_up_weight,
align_size=32,
layout="colmajor",
dtype=gate_up_weight.dtype,
do_transpose=False)
layer.weight.data = gate_up_weight_numa
if hasattr(layer, "bias") and layer.bias is not None:
gate_bias, up_bias = torch.chunk(layer.bias, 2, dim=-1)
gate_up_bias_internleaved_numa = _convert_to_crossed_numa_tensor(
gate_bias,
up_bias,
envs.VLLM_BR_DEVICE_SPC_NUM,
dim=-1,
need_pad=False,
layout="linear_bias",
do_transpose=False)
layer.bias.data = gate_up_bias_internleaved_numa
else:
#NOTE: by default, gate module and shared_expert(1) module will be involved into calculation in 1 kernel
process_share_expert_weight(layer)
@patch_to(UnquantizedLinearMethod)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if _should_skip_linear_post_process(
layer, self.use_ds_mla,
self.use_ds_mla_sparse) or self.weight_type != "NUMA":
return
still_need_process = True
do_transpose = True
parallel_type = "col_parallel"
# NOTE: all process_weights func should done before process_weights_after_loading
match layer:
case ReplicatedLinear():
process_weights_ReplicatedLinear(layer)
still_need_process = not ("indexer" not in layer.prefix and (
layer.output_size == 64 or layer.output_size == 160 # Glm4-Moe
or layer.output_size == 128 or layer.output_size == 256))
do_transpose = False
case MergedColumnParallelLinear():
process_weights_MergedColumnParallelLinear(layer)
still_need_process = False
do_transpose = False
case RowParallelLinear():
parallel_type = "row_parallel"
case _:
pass
if not still_need_process or self.weight_type != "NUMA":
return
# process numa weight and bias
if hasattr(layer, "weight") and len(layer.weight.shape) == 2:
if 'vision' in layer.prefix and is_br166_device():
layer.weight.data = _convert_to_numa_tensor_vit(
layer.weight,
envs.VLLM_BR_DEVICE_WARP_SIZE,
"colmajor",
torch.bfloat16,
do_transpose=do_transpose,
wk=(layer.weight.data.shape[1]
if do_transpose else layer.weight.data.shape[0]),
wn=(layer.weight.data.shape[0]
if do_transpose else layer.weight.data.shape[1]),
parallel_type=parallel_type) # noqa: SIM210
else:
layer.weight.data = _convert_to_numa_tensor(
layer.weight,
envs.VLLM_BR_DEVICE_WARP_SIZE,
"colmajor",
torch.bfloat16,
do_transpose=do_transpose,
wk=(layer.weight.data.shape[1]
if do_transpose else layer.weight.data.shape[0]),
wn=(layer.weight.data.shape[0]
if do_transpose else layer.weight.data.shape[1]),
parallel_type=parallel_type) # noqa: SIM210
if hasattr(layer, "bias") and layer.bias is not None:
pad_zeros = (parallel_type == "row_parallel")
if 'vision' in layer.prefix and is_br166_device():
if (pad_zeros and layer.reduce_results):
return
layer.bias.data = _convert_to_numa_tensor_vit(
layer.bias,
envs.VLLM_BR_DEVICE_WARP_SIZE,
"linear_bias",
torch.float32,
parallel_type=parallel_type,
pad_zeros=pad_zeros)
else:
layer.bias.data = _convert_to_numa_tensor(
layer.bias,
envs.VLLM_BR_DEVICE_WARP_SIZE,
"linear_bias",
torch.float32,
parallel_type=parallel_type,
pad_zeros=pad_zeros)
@patch_to(UnquantizedLinearMethod)
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
# numa weight is 3-dims
if 'vision' in layer.prefix and is_br166_device():
if len(layer.weight.shape) == 3:
is_row = isinstance(layer, RowParallelLinear)
output_size = (layer.output_size_per_partition if hasattr(
layer, "output_size_per_partition") else layer.output_size)
act_mode = "act_default"
if isinstance(layer, MergedColumnParallelLinear) and not (hasattr(
layer, "no_need_cross") and layer.no_need_cross):
act_mode = "act_swiglu"
output_size //= 2
if bias is None or (is_row and layer.reduce_results):
# return torch_br.br_matmul_infer(
# x,
# layer.weight,
# bias=None,
# output_w=output_size,
# )
return torch_br.br_fused_mlp_infer(x, [layer.weight],
output_w=output_size,
activation_mode=act_mode)
else:
return torch_br.br_matmul_infer(x, layer.weight, bias,
output_size)
supa_debug.set_enable_sublas_api(True)
output = F.linear(x, layer.weight, bias)
supa_debug.set_enable_sublas_api(False)
return output
if len(layer.weight.shape) == 3:
output_size = (layer.output_size_per_partition if hasattr(
layer, "output_size_per_partition") else layer.output_size)
act_mode = "act_default"
if isinstance(layer, MergedColumnParallelLinear) and not (hasattr(
layer, "no_need_cross") and layer.no_need_cross):
act_mode = "act_swiglu"
output_size //= 2
bias = [bias] if bias is not None else None
if isinstance(layer, RowParallelLinear):
seq_len = x.shape[-2]
tp_size = get_tensor_model_parallel_world_size()
# TODO(CaoJun): This is WA, delete (16, 8) so that the test_vllm_model_accu_qwen25_72b_instruct can run through
support_types = ((16, 4), (32, 2), (32, 4))
# bypass tp8 and tp4pp2 allreduce
pp_size = get_pp_group().world_size
all_rank = tp_size * pp_size
layer.reduce_results = not (
all_rank <= envs.VLLM_BR_USE_FUSED_ALLREDUCE
and seq_len <= envs.VLLM_BR_STATIC_MOE_DECODER_MAX_LEN and
(envs.VLLM_BR_DEVICE_SPC_NUM, tp_size) in support_types)
if layer.reduce_results:
return torch_br.br_fused_mlp_infer(x, [layer.weight],
output_w=output_size,
bias=bias,
activation_mode=act_mode)
else:
tp_rank = get_tp_group().rank_in_group
global_rank = get_tp_group().rank
rank_i = global_rank % tp_size
assert rank_i == tp_rank
return torch_br.supa_fused_linear_allreduce_opt(
x, layer.weight, output_size, tp_rank, tp_size,
global_rank, 0)
else:
return torch_br.br_fused_mlp_infer(x, [layer.weight],
output_w=output_size,
bias=bias,
activation_mode=act_mode)
supa_debug.set_enable_sublas_api(True)
output = F.linear(x, layer.weight, bias)
supa_debug.set_enable_sublas_api(False)
return output
@patch_to(LinearBase)
def __init__(
self,
input_size: int,
output_size: int,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
*,
return_bias: bool = True,
disable_tp: bool = False,
):
super(LinearBase, self).__init__()
# Keep input parameters
self.input_size = input_size
self.output_size = output_size
self.skip_bias_add = skip_bias_add
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype
if quant_config is None:
self.quant_method = UnquantizedLinearMethod()
else:
self.quant_method = quant_config.get_quant_method(self, prefix=prefix)
self.return_bias = return_bias
self.prefix = prefix
self.tp_rank = (get_tensor_model_parallel_rank() if not disable_tp else 0)
self.tp_size = (get_tensor_model_parallel_world_size()
if not disable_tp else 1)
@patch_to(RowParallelLinear)
def forward(
self, input_
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
if envs.VLLM_BR_USE_CPU_ALL_REDUCE != 0 and not hasattr(
self, "grandparent_pid"):
self.grandparent_pid = get_grandparent_pid()
if self.input_is_parallel:
input_parallel = input_
else:
tp_rank = get_tensor_model_parallel_rank()
splitted_input = split_tensor_along_last_dim(
input_, num_partitions=self.tp_size)
input_parallel = splitted_input[tp_rank].contiguous()
# Matrix multiply.
assert self.quant_method is not None
# Only fuse bias add into GEMM for rank 0 (this ensures that
# bias will not get added more than once in TP>1 case)
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_)
if self.reduce_results and self.tp_size > 1:
# CPU all reduce will be applied.
if envs.VLLM_BR_USE_CPU_ALL_REDUCE != 0 and self.tp_size >= 4 and output_parallel.shape[
1] <= 32:
tp_rank = get_tensor_model_parallel_rank()
output = torch_br.supa_allreduce_pcie_infer(
output_parallel, tp_rank, self.tp_size, self.grandparent_pid)
else:
output = tensor_model_parallel_all_reduce(output_parallel)
else:
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias
@patch_to(QKVParallelLinear)
def weight_loader(self,
param: Parameter,
loaded_weight: torch.Tensor,
loaded_shard_id: Optional[str] = None):
# Special case for GGUF
# initialize GGUF param after we know the quantize type
is_gguf_weight = getattr(param, "is_gguf_weight", False)
is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
if is_gguf_weight_type:
idx_map = {"q": 0, "k": 1, "v": 2}
if loaded_shard_id is not None:
param.data[idx_map[loaded_shard_id]].copy_(loaded_weight)
param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
else:
param.shard_weight_type = {
k: loaded_weight.item()
for k in idx_map
}
return
if is_gguf_weight:
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
output_dim = getattr(param, "output_dim", None)
shard_size = loaded_weight.size(output_dim) // tp_size
start_idx = tp_rank * shard_size
if loaded_shard_id is not None:
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size)
param.shard_id.append(loaded_shard_id)
param.shard_id_map[loaded_shard_id] = len(param.data_container)
param.data_container.append(loaded_weight)
return
param_data = param.data
output_dim = getattr(param, "output_dim", None)
# Special case for AQLM codebooks.
is_metadata = getattr(param, "is_metadata", False)
# Special case for per-tensor scales in fused case.
needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
if loaded_shard_id is None:
# Loaded weight is already fused on disk (qkv).
# (e.g., Phi-3's qkv_proj).
if output_dim is None:
if needs_scalar_to_array:
param_data, loaded_weight = adjust_scalar_to_fused_array(
param_data, loaded_weight, 0)
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
return
shard_offsets = [
# (shard_id, shard_offset, shard_size)
("q", 0, self.total_num_heads * self.head_size),
("k", self.total_num_heads * self.head_size,
self.total_num_kv_heads * self.head_size),
("v",
(self.total_num_heads + self.total_num_kv_heads) * self.head_size,
self.total_num_kv_heads * self.head_size),
]
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
packed_dim = getattr(param, "packed_dim", None)
for shard_id, shard_offset, shard_size in shard_offsets:
# Special case for Quantized Weights.
# If quantized, we need to adjust the offset and size to account
# for the packing.
if packed_dim == output_dim:
shard_size = shard_size // param.pack_factor
shard_offset = shard_offset // param.pack_factor
# Special case for Marlin.
shard_size, shard_offset = adjust_marlin_shard(
param, shard_size, shard_offset)
if use_bitsandbytes_4bit:
orig_qkv_offsets = {
"q": (0, self.total_num_heads * self.head_size),
"k": (self.total_num_heads * self.head_size,
self.total_num_kv_heads * self.head_size),
"v":
((self.total_num_heads + self.total_num_kv_heads) *
self.head_size, self.total_num_kv_heads * self.head_size),
"total":
((self.total_num_heads + 2 * self.total_num_kv_heads) *
self.head_size, 0)
}
shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
param, orig_qkv_offsets, shard_id)
loaded_weight_shard = loaded_weight.narrow(output_dim,
shard_offset,
shard_size)
self.weight_loader(param, loaded_weight_shard, shard_id)
return
tp_rank = get_tensor_model_parallel_rank()
assert loaded_shard_id in ["q", "k", "v"]
# If output dim is defined, use the default loading process.
if output_dim is not None:
if loaded_shard_id == "q":
shard_offset = 0
shard_size = self.num_heads * self.head_size
elif loaded_shard_id == "k":
shard_offset = self.num_heads * self.head_size
shard_size = self.num_kv_heads * self.head_size
elif loaded_shard_id == "v":
shard_offset = (self.num_heads +
self.num_kv_heads) * self.head_size
shard_size = self.num_kv_heads * self.head_size
# Special case for Quantized Weights.
# If quantized, we need to adjust the offset and size to account
# for the packing.
packed_dim = getattr(param, "packed_dim", None)
if packed_dim == output_dim:
shard_size = shard_size // param.pack_factor
shard_offset = shard_offset // param.pack_factor
# Special case for Marlin.
shard_size, shard_offset = adjust_marlin_shard(
param, shard_size, shard_offset)
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
is_sharded_weight = getattr(param, "is_sharded_weight", False)
# bitsandbytes loads the weights of the specific portion
# no need to narrow
is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
if use_bitsandbytes_4bit:
orig_qkv_offsets = {
"q": (0, self.num_heads * self.head_size),
"k": (self.num_heads * self.head_size,
self.num_kv_heads * self.head_size),
"v": ((self.num_heads + self.num_kv_heads) * self.head_size,
self.num_kv_heads * self.head_size),
"total":
((self.num_heads + 2 * self.num_kv_heads) * self.head_size, 0)
}
shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
param, orig_qkv_offsets, loaded_shard_id)
if envs.VLLM_BR_DEVICE_SPC_NUM > 16:
half_w = param_data.shape[output_dim] // 2
param_data = (param_data.narrow(output_dim, shard_offset // 2,
shard_size // 2),
param_data.narrow(output_dim,
shard_offset // 2 + half_w,
shard_size // 2))
else:
param_data = param_data.narrow(output_dim, shard_offset,
shard_size)
if loaded_shard_id == "q":
shard_id = tp_rank
else:
shard_id = tp_rank // self.num_kv_head_replicas
start_idx = shard_id * shard_size
if not is_sharded_weight:
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size)
# Special case for for AQLM codebooks.
elif is_metadata:
# metadata indicates fixed size concatenated along dim 0
shard_size = loaded_weight.shape[0]
shard_index = ["q", "k", "v"].index(loaded_shard_id)
param_data = param_data.narrow(0, shard_index * shard_size, shard_size)
# Special case for per-tensor scales in fused case.
elif needs_scalar_to_array:
param_data, loaded_weight = adjust_scalar_to_fused_array(
param_data, loaded_weight, loaded_shard_id)
else:
ignore_warning = getattr(param, "ignore_warning", False)
if not ignore_warning:
logger.warning(
"Loading a weight without `output_dim` attribute in "
"QKVParallelLinear, assume the weight is the same "
"for all partitions.")
if isinstance(param_data, tuple):
half_w = loaded_weight.shape[output_dim] // 2
param_data[0].copy_(loaded_weight.narrow(output_dim, 0, half_w))
param_data[1].copy_(loaded_weight.narrow(output_dim, half_w, half_w))
else:
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)

View File

@@ -0,0 +1,72 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from typing import Optional
import torch
import torch_br
from fastcore.basics import patch_to
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm_br import envs
# TODO(shouqing): need to open this patch when fix hang in mtp
@patch_to(LogitsProcessor)
def _get_logits(
self,
hidden_states: torch.Tensor,
lm_head: VocabParallelEmbedding,
embedding_bias: Optional[torch.Tensor],
) -> Optional[torch.Tensor]:
# Get the logits for the next tokens.
logits = lm_head.quant_method.apply(lm_head,
hidden_states,
bias=embedding_bias)
spc_num = envs.VLLM_BR_DEVICE_SPC_NUM
if spc_num > 16:
bb_input = torch_br._empty_ut_only(size=logits.shape,
dtype=logits.dtype,
is_numa=False,
device=logits.device,
tensor_type="colmajor")
# work around the hang in s1b copy to bb
bb_input.copy_(logits)
logits = bb_input
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
logits_ = torch.zeros((logits.shape[0], logits.shape[-1] * tp_size),
dtype=logits.dtype,
device=logits.device)
start = logits.shape[-1] * tp_rank
end = start + logits.shape[-1]
logits_[:, start:end].copy_(logits)
logits = tensor_model_parallel_all_reduce(logits_)
# Remove paddings in vocab (if any).
if logits is not None:
logits = logits[..., :self.org_vocab_size]
return logits

View File

@@ -0,0 +1,19 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from . import compressed_tensors, gptq
__all__ = ["gptq", 'compressed_tensors']

View File

@@ -0,0 +1,18 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from .compressed_tensors import *
from .compressed_tensors_moe import *
from .compressed_tensors_wNa16 import *

View File

@@ -0,0 +1,64 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from functools import wraps
from typing import Any, cast
from fastcore.basics import patch_to
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import (
CompressedTensorsConfig)
@patch_to(CompressedTensorsConfig, cls_method=True)
def from_config(cls, config: dict[str, Any]) -> "CompressedTensorsConfig":
"""
[PatchNote] add qkv_quantized param support
"""
ignore: list[str] = cast(list[str], config.get("ignore", []))
quant_format = cast(str, config.get("format"))
target_scheme_map = cls._quantization_scheme_map_from_config(config=config)
sparsity_scheme_map, sparsity_ignore_list = cls._parse_sparsity_config(
config=config)
transform_config = config.get("transform_config")
qkv_quantized = cls.get_from_keys_or(config, ["qkv_quantized"],
default=True)
return cls(target_scheme_map=target_scheme_map,
ignore=ignore,
quant_format=quant_format,
sparsity_scheme_map=sparsity_scheme_map,
sparsity_ignore_list=sparsity_ignore_list,
config=config,
transform_config=transform_config,
qkv_quantized=qkv_quantized)
def wrapper_CompressedTensorsConfig_init(fn):
@wraps(fn)
def wrapper(self, *args, **kwargs):
qkv_quantized = kwargs.pop("qkv_quantized", True)
fn(self, *args, **kwargs)
self.qkv_quantized = qkv_quantized
return wrapper
CompressedTensorsConfig.__init__ = wrapper_CompressedTensorsConfig_init(
CompressedTensorsConfig.__init__) # noqa: E501

View File

@@ -0,0 +1,594 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from typing import Callable, Optional
import torch
import torch_br
from compressed_tensors.compressors.quantized_compressors import (
unpack_from_int32)
from fastcore.basics import patch_to
from torch_br.utils.tensor_methods import Sbp
from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group
from vllm.logger import logger
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import (
WNA16_SUPPORTED_BITS, CompressedTensorsMoEMethod,
CompressedTensorsWNA16MoEMethod, CompressionFormat)
from vllm.model_executor.utils import set_weight_attrs
from vllm_br import envs
from ...br_utils import (_convert_to_crossed_numa_tensor,
_convert_to_numa_tensor, align_n, cross_weight_32)
from ...fused_moe.supa_moe import fused_moe_quant_device, fused_moe_quant_dyn
@patch_to(CompressedTensorsMoEMethod)
def get_moe_method(
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
layer: torch.nn.Module,
) -> "CompressedTensorsMoEMethod":
"""NOTE:
1. SUPA only supports CompressedTensorsWNA16MoEMethod without Marlin
2. Only Linear targets are supported for MoE layers
"""
# TODO: @dsikka: refactor this to use schemes as other kernels
# are supported + check if the layer is being ignored.
keys = list(quant_config.target_scheme_map.keys())
assert len(keys) > 0, ("No valid quant key!!!")
# assert "Linear" in quant_config.target_scheme_map
# [Patch]: Only Linear target is supported for MoE layers, for temporary compatibility, we change the key of target_scheme_map to the first one
quant_config.target_scheme_map[
"Linear"] = quant_config.target_scheme_map.pop(keys[0])
target_key = "Linear"
# target_key = keys[0] # normal only one key
weight_quant = quant_config.target_scheme_map[target_key].get("weights")
input_quant = quant_config.target_scheme_map[target_key].get(
"input_activations")
if quant_config._is_wNa16_group_channel(weight_quant, input_quant):
logger.info_once("Using CompressedTensorsWNA16MoEMethod")
return CompressedTensorsWNA16MoEMethod(quant_config, layer.moe_config)
else:
raise RuntimeError(
f"Unsupported FusedMoe scheme: {weight_quant}, {input_quant}")
@patch_to(CompressedTensorsWNA16MoEMethod)
def __init__(
self,
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
moe: FusedMoEConfig,
):
super(CompressedTensorsWNA16MoEMethod, self).__init__(moe)
self.quant_config = quant_config
# TODO: @dsikka: refactor this to use schemes as other kernels
# are supported + check if the layer is being ignored.
config = self.quant_config.target_scheme_map["Linear"].get("weights")
self.num_bits = config.num_bits
self.packed_factor = 32 // config.num_bits
self.strategy = config.strategy
# channelwise is not supported by this kernel
# [Patch]: SUPA use CompressedTensorsWNA16MoEMethod for both channel/group strategies
# assert config.strategy == "group"
self.group_size = config.group_size
# grouped actorder isn't supported by this kernel
# assert config.actorder != "group"
assert config.symmetric, (
"Only symmetric quantization is supported for MoE")
if not (self.quant_config.quant_format
== CompressionFormat.pack_quantized.value
and self.num_bits in WNA16_SUPPORTED_BITS):
raise ValueError("For Fused MoE layers, only ",
f"{CompressionFormat.pack_quantized.value} ",
"is supported for the following bits: ",
f"{WNA16_SUPPORTED_BITS}")
@patch_to(CompressedTensorsWNA16MoEMethod)
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
# Will transpose the loaded weight along the
# intermediate and hidden dim sizes. Will
# shard for TP along the transposed dims
extra_weight_attrs.update({
"is_transposed": True,
"quant_method": self.strategy
})
w13_weight = torch.nn.Parameter(torch.empty(
num_experts,
hidden_size // self.packed_factor,
2 * intermediate_size_per_partition,
dtype=torch.int32,
device="cpu",
),
requires_grad=False)
layer.register_parameter("w13_weight_packed", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter(torch.empty(
num_experts,
intermediate_size_per_partition // self.packed_factor,
hidden_size,
dtype=torch.int32,
device="cpu"),
requires_grad=False)
layer.register_parameter("w2_weight_packed", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
w2_scales_size = intermediate_size_per_partition
if self.strategy == "channel":
num_groups_w2 = num_groups_w13 = 1
self.group_size = -1
else:
num_groups_w2 = w2_scales_size // self.group_size
num_groups_w13 = hidden_size // self.group_size
w13_scale = torch.nn.Parameter(torch.ones(
num_experts,
num_groups_w13,
2 * intermediate_size_per_partition,
dtype=params_dtype,
device="cpu",
),
requires_grad=False)
layer.register_parameter("w13_weight_scale", w13_scale)
set_weight_attrs(w13_scale, extra_weight_attrs)
w2_scale = torch.nn.Parameter(torch.ones(num_experts,
num_groups_w2,
hidden_size,
dtype=params_dtype,
device="cpu"),
requires_grad=False)
layer.register_parameter("w2_weight_scale", w2_scale)
set_weight_attrs(w2_scale, extra_weight_attrs)
set_weight_attrs(w2_scale, {"load_full_w2": False})
w2_weight_shape = torch.nn.Parameter(torch.empty(num_experts,
2,
device="cpu"),
requires_grad=False)
layer.register_parameter("w2_weight_shape", w2_weight_shape)
set_weight_attrs(w2_weight_shape, extra_weight_attrs)
w13_weight_shape = torch.nn.Parameter(torch.empty(num_experts,
2,
device="cpu"),
requires_grad=False)
layer.register_parameter("w13_weight_shape", w13_weight_shape)
set_weight_attrs(w13_weight_shape, extra_weight_attrs)
w13_g_idx = torch.nn.Parameter(
torch.empty(
num_experts,
hidden_size,
device="cpu",
dtype=torch.int32,
),
requires_grad=False,
)
layer.register_parameter("w13_weight_g_idx", w13_g_idx)
set_weight_attrs(w13_g_idx, extra_weight_attrs)
w2_g_idx = torch.nn.Parameter(
torch.empty(
num_experts,
intermediate_size_per_partition,
device="cpu",
dtype=torch.int32,
),
requires_grad=False,
)
layer.register_parameter("w2_weight_g_idx", w2_g_idx)
set_weight_attrs(w2_g_idx, extra_weight_attrs)
w13_g_idx_sort_indices = torch.nn.Parameter(
torch.empty(
num_experts,
hidden_size,
device="cpu",
dtype=torch.int32,
),
requires_grad=False,
)
layer.register_parameter("w13_g_idx_sort_indices", w13_g_idx_sort_indices)
set_weight_attrs(w13_g_idx_sort_indices, extra_weight_attrs)
w2_g_idx_sort_indices = torch.nn.Parameter(
torch.empty(
num_experts,
intermediate_size_per_partition,
device="cpu",
dtype=torch.int32,
),
requires_grad=False,
)
layer.register_parameter("w2_g_idx_sort_indices", w2_g_idx_sort_indices)
set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs)
layer.a13_scale = None
layer.a2_scale = None
@patch_to(CompressedTensorsWNA16MoEMethod)
def process_weights_after_loading(self: CompressedTensorsWNA16MoEMethod,
layer: FusedMoE) -> None:
die_spc_num = envs.VLLM_BR_DEVICE_SPC_NUM
die_num = 1 if die_spc_num <= 16 else 2
spc_num = die_spc_num // die_num
cur_device = torch.supa.current_device()
is_dual_die = (die_spc_num > 16)
if self.num_bits == 8:
# NOTE: w13_weight
# after _load_w13, w13_weight is a colparallel weight, shape
# [num_experts, hidden_size // 4, 2 * intermediate_size_per_partition] INT32
# for SUPA, transform it to a NUMA colmajor weight, shape
# [spc_num * num_experts, wk, wn_block] INT8
wk = layer.hidden_size
wn = layer.intermediate_size_per_partition * 2
align_size = 64
wn_block = align_n(wn // die_num,
align_size=align_size,
spc_num=spc_num)
supa_w13_weight_packed = torch_br._empty_ut_only(
size=(die_spc_num * layer.local_num_experts, wk, wn_block),
dtype=torch.int8,
is_numa=True,
device=cur_device,
tensor_type="colmajor",
axis=0,
sbp="SS" if is_dual_die else None)
for expert_id in range(layer.local_num_experts):
expert_w13 = layer.w13_weight_packed[
expert_id] # hidden_size // 4, 2 * intermediate_size_per_partition
expert_1, expert_3 = expert_w13.chunk(
2, dim=1) # each is a packed int4 weight
unpacked_expert_1 = unpack_from_int32(
expert_1, self.num_bits,
torch.Size(
[layer.hidden_size,
layer.intermediate_size_per_partition]), 0)
unpacked_expert_3 = unpack_from_int32(
expert_3, self.num_bits,
torch.Size(
[layer.hidden_size,
layer.intermediate_size_per_partition]), 0)
pad_expert_w13 = _convert_to_crossed_numa_tensor(
unpacked_expert_1,
unpacked_expert_3,
die_spc_num,
dim=1,
need_pad=True,
layout='COLMAJOR',
do_transpose=False)
hw_size = pad_expert_w13.shape[-2] * pad_expert_w13.shape[-1]
narrow_data = supa_w13_weight_packed.view_as_usharp(
"COLMAJOR", pad_expert_w13.shape, Sbp.ss(0),
expert_id * hw_size)
narrow_data.copy_(pad_expert_w13)
layer.w13_weight_packed.data = supa_w13_weight_packed
# NOTE: w13_scale
# after _load_w13, w13_weight is a colparallel weight, shape
# S8: [num_experts, 1, 2 * intermediate_size_per_partition]
# for SUPA, transform it to a NUMA colmajor weight, shape
# [num_experts, wn]
supa_w13_scales = torch_br._empty_ut_only(
size=(layer.local_num_experts, wn),
dtype=torch.float32,
is_numa=False,
device=cur_device,
tensor_type="linear_bias",
sbp="BB" if is_dual_die else None)
for expert_id in range(layer.local_num_experts):
expert_w13_scales = layer.w13_weight_scale[expert_id]
expert_1_scale, expert_3_scale = expert_w13_scales.chunk(
2, dim=1) # each is a packed int4 weight
crossed_expert_w13_scales = cross_weight_32(
expert_1_scale.squeeze(),
expert_3_scale.squeeze(),
die_spc_num,
dim=0,
need_pad=False,
)
narrow_data = supa_w13_scales[expert_id]
narrow_data.copy_(crossed_expert_w13_scales)
layer.w13_weight_scale.data = supa_w13_scales
# NOTE: w2_weight
# after _load_w2, w2_weight is a colparallel weight, shape
# [num_experts, intermediate_size_per_partition // 4, hidden_size] INT32
# for SUPA, transform it to a NUMA colmajor weight, shape
# [spc_num * num_experts, wk, wn_block] INT8
wk = layer.intermediate_size_per_partition
wn = layer.hidden_size
align_size = 32
wn_block = align_n(wn, align_size=align_size, spc_num=spc_num)
supa_w2_weight_packed = torch_br._empty_ut_only(
size=(die_spc_num * layer.local_num_experts, wk // die_num,
wn_block),
dtype=torch.int8,
is_numa=True,
device=cur_device,
tensor_type="colmajor",
axis=0,
sbp="SS" if is_dual_die else None)
for expert_id in range(layer.local_num_experts):
expert_w2 = layer.w2_weight_packed[expert_id]
unpacked_expert_2 = unpack_from_int32(
expert_w2, self.num_bits,
torch.Size(
[layer.intermediate_size_per_partition,
layer.hidden_size]), 0)
pad_expert_w2 = _convert_to_numa_tensor(
unpacked_expert_2,
align_size,
'COLMAJOR',
expert_w2.dtype,
do_transpose=False,
parallel_type="row_parallel")
pad_expert_w2_shape = pad_expert_w2.shape
hw_size = pad_expert_w2_shape[-2] * pad_expert_w2_shape[-1]
narrow_data = supa_w2_weight_packed.view_as_usharp(
"COLMAJOR", pad_expert_w2_shape, Sbp.ss(0),
expert_id * hw_size)
narrow_data.copy_(pad_expert_w2)
layer.w2_weight_packed.data = supa_w2_weight_packed
# NOTE: w2_scale
# after _load_w2, w2_weight is a colparallel weight, shape
# S8: [num_experts, 1, hidden_size]
# for SUPA, transform it to a NUMA colmajor weight, shape
# [num_experts, wn]
supa_w2_scales = torch_br._empty_ut_only(size=(layer.local_num_experts,
wn),
dtype=torch.float32,
is_numa=False,
device=cur_device,
tensor_type="linear_bias")
for expert_id in range(layer.local_num_experts):
expert_w2 = layer.w2_weight_scale[expert_id]
narrow_data = supa_w2_scales[expert_id::layer.local_num_experts]
narrow_data.copy_(expert_w2)
layer.w2_weight_scale.data = supa_w2_scales
elif self.num_bits == 4:
# NOTE: w13_weight
# after _load_w13, w13_weight is a colparallel weight, shape
# [num_experts, hidden_size // 8, 2 * intermediate_size_per_partition] INT32
# for SUPA, transform it to a NUMA colmajor weight, shape
# [spc_num * num_experts, wk, wn_block] INT32
wk = layer.hidden_size // 8
wn = layer.intermediate_size_per_partition * 2
wn_block = align_n(wn, align_size=64, spc_num=spc_num)
supa_w13_weight_packed = torch_br._empty_ut_only(
size=(spc_num * layer.local_num_experts, wk, wn_block),
dtype=torch.int32,
is_numa=True,
device=cur_device,
tensor_type="colmajor")
for expert_id in range(layer.local_num_experts):
expert_w13 = layer.w13_weight_packed[
expert_id] # hidden_size // 4, 2 * intermediate_size_per_partition
expert_1, expert_3 = expert_w13.chunk(
2, dim=0) # each is a packed int4 weight
pad_expert_w13 = _convert_to_crossed_numa_tensor(expert_1,
expert_3,
spc_num,
dim=1,
need_pad=True,
layout='COLMAJOR',
do_transpose=True)
hw_size = pad_expert_w13.shape[-2] * pad_expert_w13.shape[-1]
narrow_data = supa_w13_weight_packed.view_as_usharp(
"COLMAJOR", pad_expert_w13.shape, Sbp.ss(0),
expert_id * hw_size)
narrow_data.copy_(pad_expert_w13)
layer.w13_weight_packed.data = supa_w13_weight_packed
# NOTE: w13_scale
# after _load_w13, w13_weight is a colparallel weight, shape
# S4: [num_experts, hidden_size // 128, 2 * intermediate_size_per_partition]
# for SUPA, transform it to a NUMA colmajor weight, shape
# [num_experts, group_nums, wn]
supa_w13_scales = torch_br._empty_ut_only(
size=(layer.local_num_experts,
layer.hidden_size // self.group_size, wn),
dtype=torch.float32,
is_numa=False,
device=cur_device,
tensor_type="colmajor")
for expert_id in range(layer.local_num_experts):
expert_w13_scales = layer.w13_weight_scale[expert_id]
expert_1_scale, expert_3_scale = expert_w13_scales.chunk(
2, dim=0) # each is a packed int4 weight
crossed_expert_w13_scales = cross_weight_32(
expert_1_scale,
expert_3_scale,
spc_num,
dim=1,
need_pad=False,
)
narrow_data = supa_w13_scales[expert_id]
narrow_data.copy_(crossed_expert_w13_scales)
layer.w13_weight_scale.data = supa_w13_scales
# NOTE: w2_weight
# after _load_w2, w2_weight is a colparallel weight, shape
# [num_experts, intermediate_size_per_partition // 8, hidden_size] INT32
# for SUPA, transform it to a NUMA colmajor weight, shape
# [spc_num * num_experts, wk, wn_block] INT32
wk = layer.intermediate_size_per_partition // 8
wn = layer.hidden_size
wn_block = align_n(wn, align_size=32, spc_num=spc_num)
supa_w2_weight_packed = torch_br._empty_ut_only(
size=(spc_num * layer.local_num_experts, wk, wn_block),
dtype=torch.int32,
is_numa=True,
device=cur_device,
tensor_type="colmajor")
for expert_id in range(layer.local_num_experts):
expert_w2 = layer.w2_weight_packed[expert_id]
pad_expert_w2 = _convert_to_numa_tensor(expert_w2,
spc_num,
'COLMAJOR',
expert_w2.dtype,
do_transpose=True)
pad_expert_w2_shape = pad_expert_w2.shape
hw_size = pad_expert_w2_shape[-2] * pad_expert_w2_shape[-1]
narrow_data = supa_w2_weight_packed.view_as_usharp(
"COLMAJOR", pad_expert_w2_shape, Sbp.ss(0),
expert_id * hw_size)
narrow_data.copy_(pad_expert_w2)
layer.w2_weight_packed.data = supa_w2_weight_packed
# NOTE: w2_scale
# after _load_w2, w2_weight is a colparallel weight, shape
# S4: [num_experts, intermediate_size_per_partition // 128, hidden_size]
# for SUPA, transform it to a NUMA colmajor weight, shape
# [num_experts, group_nums, wn]
supa_w2_scales = torch_br._empty_ut_only(
size=(layer.local_num_experts,
layer.intermediate_size_per_partition // self.group_size,
wn),
dtype=torch.float32,
is_numa=False,
device=cur_device,
tensor_type="colmajor")
for expert_id in range(layer.local_num_experts):
expert_w2 = layer.w2_weight_scale[expert_id]
narrow_data = supa_w2_scales[expert_id::layer.local_num_experts]
narrow_data.copy_(expert_w2)
layer.w2_weight_scale.data = supa_w2_scales
else:
raise ValueError(
f"Unsupported num_bits: {self.num_bits}. Only 4 and 8 are supported."
)
# remove other CompressedTensorsWNA16MoEMethod registied buffer to reduce memory usage
layer.w13_weight_shape = None
layer.w13_weight_g_idx = None
layer.w13_g_idx_sort_indices = None
layer.w2_weight_shape = None
layer.w2_weight_g_idx = None
layer.w2_g_idx_sort_indices = None
@patch_to(CompressedTensorsWNA16MoEMethod)
def apply(
self,
layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
) -> torch.Tensor:
b_seq = x.shape[0]
gating_weight, shared_gate_up_weight, shared_down_weight = router_logits
if b_seq > envs.VLLM_BR_STATIC_MOE_DECODER_MAX_LEN:
return fused_moe_quant_dyn(
x,
shared_gate_up_weight,
shared_down_weight,
layer.w13_weight_packed,
layer.w2_weight_packed,
layer.w13_weight_scale,
layer.w2_weight_scale,
gating_weight,
top_k,
layer.intermediate_size_per_partition,
renormalize=renormalize,
inplace=True,
use_grouped_topk=use_grouped_topk,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
ep_rank=layer.ep_rank,
ep_size=layer.ep_size,
)
else:
return fused_moe_quant_device(
x,
shared_gate_up_weight,
shared_down_weight,
layer.w13_weight_packed,
layer.w2_weight_packed,
layer.w13_weight_scale,
layer.w2_weight_scale,
gating_weight,
top_k,
layer.intermediate_size_per_partition,
renormalize=renormalize,
inplace=True,
use_grouped_topk=use_grouped_topk,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
tp_rank=get_tp_group().rank_in_group,
global_rank=get_tp_group().rank,
tp_size=get_tensor_model_parallel_world_size(),
ep_rank=layer.ep_rank,
ep_size=layer.ep_size,
)

View File

@@ -0,0 +1,267 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from typing import Callable, Optional
import torch
import torch_br
from compressed_tensors.compressors.quantized_compressors import (
unpack_from_int32)
from fastcore.basics import patch_to
from vllm.distributed import (get_pipeline_model_parallel_group,
get_tensor_model_parallel_world_size,
get_tp_group)
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsWNA16)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
marlin_repeat_scales_on_all_ranks)
# yapf conflicts with isort for this block
# yapf: disable
from vllm.model_executor.parameter import (BasevLLMParameter,
ChannelQuantScaleParameter,
GroupQuantScaleParameter,
PackedColumnParameter,
PackedvLLMParameter,
RowvLLMParameter)
# yapf: enable
from vllm_br import envs
from ...br_utils import _convert_to_numa_tensor
@patch_to(CompressedTensorsWNA16)
def create_weights(self, layer: torch.nn.Module, input_size: int,
input_size_per_partition: int, output_size: int,
output_partition_sizes: list[int],
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):
self.output_size_per_partition = sum(output_partition_sizes)
group_size = self.group_size if self.group_size != -1 else input_size
row_parallel = (input_size != input_size_per_partition)
partition_scales = not marlin_repeat_scales_on_all_ranks(
self.has_g_idx, self.group_size, row_parallel)
scales_and_zp_size = input_size // group_size
if partition_scales:
assert input_size_per_partition % group_size == 0
scales_and_zp_size = input_size_per_partition // group_size
weight = PackedvLLMParameter(
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
packed_factor=self.pack_factor,
packed_dim=1,
data=torch.empty(self.output_size_per_partition,
input_size_per_partition // self.pack_factor,
dtype=torch.int32,
device="cpu"))
weight_scale_args = {
"weight_loader":
weight_loader,
"data":
torch.empty(
self.output_size_per_partition,
scales_and_zp_size,
device="cpu",
dtype=params_dtype,
)
}
zeros_args = {
"weight_loader":
weight_loader,
"data":
torch.zeros(
self.output_size_per_partition // self.pack_factor,
scales_and_zp_size,
device="cpu",
dtype=torch.int32,
)
}
if not partition_scales:
weight_scale = ChannelQuantScaleParameter(output_dim=0,
**weight_scale_args)
if not self.symmetric:
qzeros = PackedColumnParameter(output_dim=0,
packed_dim=0,
packed_factor=self.pack_factor,
**zeros_args)
else:
weight_scale = GroupQuantScaleParameter(output_dim=0,
input_dim=1,
**weight_scale_args)
if not self.symmetric:
qzeros = PackedvLLMParameter(input_dim=1,
output_dim=0,
packed_dim=0,
packed_factor=self.pack_factor,
**zeros_args)
# A 2D array defining the original shape of the weights
# before packing
weight_shape = BasevLLMParameter(data=torch.empty(2,
dtype=torch.int64,
device="cpu"),
weight_loader=weight_loader)
layer.register_parameter("weight_packed", weight)
layer.register_parameter("weight_scale", weight_scale)
layer.register_parameter("weight_shape", weight_shape)
if not self.symmetric:
layer.register_parameter("weight_zero_point", qzeros)
# group index (for activation reordering)
if self.has_g_idx:
weight_g_idx = RowvLLMParameter(data=torch.empty(
input_size_per_partition,
device="cpu",
dtype=torch.int32,
),
input_dim=0,
weight_loader=weight_loader)
layer.register_parameter("weight_g_idx", weight_g_idx)
self.input_size_per_partition = input_size_per_partition
@patch_to(CompressedTensorsWNA16)
def process_weights_after_loading(self: CompressedTensorsWNA16,
layer: torch.nn.Module) -> None:
# spc_num = envs.VLLM_BR_DEVICE_SPC_NUM
# cur_device = torch.supa.current_device()
self.num_bits = 32 // self.pack_factor
layer.weight_packed.data = unpack_from_int32(
layer.weight_packed.data, self.num_bits,
torch.Size(
[self.output_size_per_partition, self.input_size_per_partition]),
1)
layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data,
requires_grad=False)
br_scales = layer.weight_scale.data.to(torch.float32)
layer.weight_scale.data = br_scales
do_transpose = True
parallel_type = "col_parallel"
match layer:
case RowParallelLinear():
parallel_type = "row_parallel"
case _:
pass
if hasattr(layer, 'weight_packed') and len(layer.weight_packed.shape) == 2:
weight_packed = layer.weight_packed.data
layer.weight_packed.data = _convert_to_numa_tensor(
weight_packed,
envs.VLLM_BR_DEVICE_WARP_SIZE,
"colmajor",
torch.int8,
do_transpose=do_transpose,
wk=(weight_packed.shape[1]
if do_transpose else weight_packed.shape[0]),
wn=(weight_packed.shape[0]
if do_transpose else weight_packed.shape[1]),
parallel_type=parallel_type) # noqa: SIM210
if hasattr(layer, 'weight_scale') and layer.weight_scale is not None:
pad_zeros = False
layer.weight_scale.data = _convert_to_numa_tensor(
layer.weight_scale.data.T,
envs.VLLM_BR_DEVICE_WARP_SIZE,
"linear_bias",
torch.float32,
parallel_type=parallel_type,
pad_zeros=pad_zeros)
if hasattr(layer, 'bias') and layer.bias is not None:
pad_zeros = (parallel_type == "row_parallel")
layer.bias.data = _convert_to_numa_tensor(
layer.bias.data,
envs.VLLM_BR_DEVICE_WARP_SIZE,
"linear_bias",
torch.float32,
parallel_type=parallel_type,
pad_zeros=pad_zeros)
@patch_to(CompressedTensorsWNA16)
def apply_weights(self: CompressedTensorsWNA16,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
# numa weight is 3-dims
if len(layer.weight_packed.shape) == 3:
output_size = (layer.output_size_per_partition if hasattr(
layer, "output_size_per_partition") else layer.output_size)
act_mode = "act_default"
if isinstance(layer, MergedColumnParallelLinear):
act_mode = "act_swiglu"
output_size //= 2
if isinstance(layer, RowParallelLinear):
seq_len = x.shape[-2]
tp_size = get_tensor_model_parallel_world_size()
# bypass tp8 and tp4pp2 allreduce
pp_size = get_pipeline_model_parallel_group().world_size
all_rank = tp_size * pp_size
support_types = ((16, 4), (16, 8), (32, 2), (32, 4))
layer.reduce_results = not (
all_rank <= envs.VLLM_BR_USE_FUSED_ALLREDUCE
and seq_len <= envs.VLLM_BR_STATIC_MOE_DECODER_MAX_LEN and
(envs.VLLM_BR_DEVICE_SPC_NUM, tp_size) in support_types)
if layer.reduce_results:
return torch_br.br_fused_mlp_infer(
x, [layer.weight_packed.data],
output_w=output_size,
scales=[layer.weight_scale.data]
if layer.weight_scale.data is not None else None,
bias=[bias] if bias is not None else None,
activaion_mode=act_mode)
else:
tp_rank = get_tp_group().rank_in_group
global_rank = get_tp_group().rank
rank_i = global_rank % tp_size
assert rank_i == tp_rank
return torch_br.supa_fused_linear_allreduce_opt(
x,
layer.weight_packed.data,
output_size,
tp_rank,
tp_size,
global_rank,
0,
scales=layer.weight_scale.data,
bias=bias,
act_mode=act_mode)
else:
return torch_br.br_fused_mlp_infer(
x, [layer.weight_packed.data],
output_w=output_size,
scales=[layer.weight_scale.data]
if layer.weight_scale.data is not None else None,
bias=[bias] if bias is not None else None,
activation_mode=act_mode)
xn = x.shape[0]
xh = x.shape[1]
ww = layer.weight_packed.shape[1]
# TODO, hard code to skip dry_run stage
if xh >= 4096:
return torch.ones((xn, xh, ww), dtype=x.dtype, device=x.device)
return torch_br.sudnn_qmatmul_infer(x,
layer.weight_packed,
layer.weight_scale,
bias=bias)

View File

@@ -0,0 +1,34 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from typing import Optional
def get_compressed_tensors_cache_scale(name: str) -> Optional[str]:
"""
Check whether the param name matches the format for k/v cache scales
in compressed-tensors. If this is the case, return its equivalent
param name expected by vLLM
:param name: param name
:return: matching param name for KV cache scale in vLLM
"""
if name.endswith(".output_scale") and ".k_proj" in name:
return name.replace(".k_proj.output_scale", ".attn.k_scale")
if name.endswith(".output_scale") and ".v_proj" in name:
return name.replace(".v_proj.output_scale", ".attn.v_scale")
# If no matches, return None
return None

View File

@@ -0,0 +1,244 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from functools import wraps
from typing import Any, Dict, Optional
import torch
import torch_br
from fastcore.basics import patch_to
from torch.nn.parameter import Parameter
from vllm.distributed import (get_pipeline_model_parallel_group,
get_tensor_model_parallel_world_size,
get_tp_group)
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization.gptq import (GPTQConfig,
GPTQLinearMethod)
from vllm.model_executor.layers.quantization.utils.gptq_utils import (
get_linear_quant_method)
from vllm_br import envs
from ..br_utils import _br_qweight_cvt, _convert_to_numa_tensor
from ..linear import (process_weights_MergedColumnParallelLinear,
process_weights_QuantMergedColumnParallelLinear,
process_weights_ReplicatedLinear)
@patch_to(GPTQConfig, cls_method=True)
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.bfloat16]
@patch_to(GPTQConfig)
def get_quant_method(self: GPTQConfig, layer: torch.nn.Module,
prefix: str) -> Optional["GPTQLinearMethod"]:
quant_method = get_linear_quant_method(self, layer, prefix,
GPTQLinearMethod)
return quant_method
@patch_to(GPTQConfig, cls_method=True)
def from_config(cls, config: Dict[str, Any]) -> GPTQConfig:
"""
[PatchNote] add qkv_quantized param support
"""
dynamic = cls.get_from_keys_or(config, ["dynamic"], default={})
dynamic = {} if dynamic is None else dynamic
weight_bits = cls.get_from_keys(config, ["bits"])
group_size = cls.get_from_keys(config, ["group_size"])
desc_act = cls.get_from_keys(config, ["desc_act"])
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
default=False)
autoround_version = cls.get_from_keys_or(config, ["autoround_version"],
default="")
modules_in_block_to_quantize = cls.get_from_keys_or(
config, ["modules_in_block_to_quantize"], default=None)
qkv_quantized = cls.get_from_keys_or(config, ["qkv_quantized"],
default=True)
return cls(weight_bits=weight_bits,
group_size=group_size,
desc_act=desc_act,
lm_head_quantized=lm_head_quantized,
dynamic=dynamic,
autoround_version=autoround_version,
modules_in_block_to_quantize=modules_in_block_to_quantize,
qkv_quantized=qkv_quantized)
def wrapper_GPTQConfig_init(fn):
@wraps(fn)
def wrapper(self, *args, **kwargs):
qkv_quantized = kwargs.pop("qkv_quantized", True)
fn(self, *args, **kwargs)
self.qkv_quantized = qkv_quantized
return wrapper
GPTQConfig.__init__ = wrapper_GPTQConfig_init(
GPTQConfig.__init__) # noqa: E501
@patch_to(GPTQLinearMethod)
def process_weights_after_loading(self: GPTQLinearMethod,
layer: torch.nn.Module) -> None:
still_need_process = True
merge_col_quant = False
# NOTE: all process_weights func should done before process_weights_after_loading
parallel_type = "col_parallel"
match layer:
case ReplicatedLinear():
process_weights_ReplicatedLinear(layer)
still_need_process = layer.output_size == 64 or layer.output_size == 256
case MergedColumnParallelLinear():
if hasattr(layer, "qweight"):
merge_col_quant = True
else:
process_weights_MergedColumnParallelLinear(layer)
still_need_process = False
case RowParallelLinear():
parallel_type = "row_parallel"
case _:
pass
# NOTE: if use exllama, br gptq needs similar treatment
# exllama needs to shuffle the weight after the weight is loaded
# here we do the shuffle on first forward pass
if layer.qweight.dtype == torch.int32:
input_size = layer.input_size_per_partition if hasattr(
layer, 'input_size_per_partition') else layer.input_size
output_size = layer.output_size_per_partition if hasattr(
layer, 'output_size_per_partition') else layer.output_size
br_qweight = _br_qweight_cvt(self, layer.qweight, layer.qzeros,
input_size, output_size)
layer.qweight.data = br_qweight
if merge_col_quant:
process_weights_QuantMergedColumnParallelLinear(layer)
still_need_process = False
br_scales = layer.scales.to(torch.float32)
layer.scales.data = br_scales
# for torch.compile
layer.qzeros = Parameter(layer.qzeros.data, requires_grad=False)
layer.qweight = Parameter(layer.qweight.data, requires_grad=False)
layer.g_idx = Parameter(layer.g_idx.data, requires_grad=False)
layer.scales = Parameter(layer.scales.data, requires_grad=False)
if not still_need_process or self.weight_type != "NUMA":
return
if hasattr(layer, 'qweight') and len(layer.qweight.shape) == 2:
layer.qweight.data = _convert_to_numa_tensor(
layer.qweight,
envs.VLLM_BR_DEVICE_WARP_SIZE,
"colmajor",
torch.int8,
parallel_type=parallel_type)
if hasattr(layer, 'scales') and layer.scales is not None:
pad_zeros = False
layer.scales.data = _convert_to_numa_tensor(
layer.scales,
envs.VLLM_BR_DEVICE_WARP_SIZE,
"linear_bias",
torch.float32,
parallel_type=parallel_type,
pad_zeros=pad_zeros)
if hasattr(layer, 'bias') and layer.bias is not None:
pad_zeros = (parallel_type == "row_parallel")
layer.bias.data = _convert_to_numa_tensor(
layer.bias,
envs.VLLM_BR_DEVICE_WARP_SIZE,
"linear_bias",
torch.float32,
parallel_type=parallel_type,
pad_zeros=pad_zeros)
@patch_to(GPTQLinearMethod)
def apply(self: GPTQLinearMethod,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
# numa weight is 3-dims
if len(layer.qweight.shape) == 3:
output_size = (layer.output_size_per_partition if hasattr(
layer, "output_size_per_partition") else layer.output_size)
act_mode = "act_default"
if isinstance(layer, MergedColumnParallelLinear):
act_mode = "act_swiglu"
output_size //= 2
if isinstance(layer, RowParallelLinear):
seq_len = x.shape[-2]
tp_size = get_tensor_model_parallel_world_size()
# bypass tp8 and tp4pp2 allreduce
pp_size = get_pipeline_model_parallel_group().world_size
all_rank = tp_size * pp_size
support_types = ((16, 4), (16, 8), (32, 2), (32, 4))
layer.reduce_results = not (
all_rank <= envs.VLLM_BR_USE_FUSED_ALLREDUCE
and seq_len <= envs.VLLM_BR_STATIC_MOE_DECODER_MAX_LEN and
(envs.VLLM_BR_DEVICE_SPC_NUM, tp_size) in support_types)
if layer.reduce_results:
return torch_br.br_fused_mlp_infer(
x, [layer.qweight],
output_w=output_size,
scales=[layer.scales]
if layer.scales is not None else None,
bias=[bias] if bias is not None else None,
activation_mode=act_mode)
else:
tp_rank = get_tp_group().rank_in_group
global_rank = get_tp_group().rank
rank_i = global_rank % tp_size
assert rank_i == tp_rank
return torch_br.supa_fused_linear_allreduce_opt(
x,
layer.qweight,
output_size,
tp_rank,
tp_size,
global_rank,
0,
scales=layer.scales,
bias=bias,
act_mode=act_mode)
else:
return torch_br.br_fused_mlp_infer(
x, [layer.qweight],
output_w=output_size,
scales=[layer.scales] if layer.scales is not None else None,
bias=[bias] if bias is not None else None,
activation_mode=act_mode)
xn = x.shape[0]
xh = x.shape[1]
ww = layer.qweight.shape[1]
# TODO, hard code to skip dry_run stage
if xh >= 4096:
return torch.ones((xn, xh, ww), dtype=x.dtype, device=x.device)
return torch_br.sudnn_qmatmul_infer(x,
layer.qweight,
layer.scales,
bias=bias)

View File

@@ -0,0 +1,924 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
import itertools
from typing import Any, Optional, Tuple, Union
import torch
import torch_br
from fastcore.basics import patch_to
from transformers import PretrainedConfig
import vllm.model_executor.layers.rotary_embedding
import vllm.model_executor.models.chatglm
import vllm.model_executor.models.deepseek_v2
import vllm_br.envs as br_envs
from vllm.logger import logger
from vllm.model_executor.layers.rotary_embedding import (
_ROPE_DICT, DeepseekScalingRotaryEmbedding, DualChunkRotaryEmbedding,
DynamicNTKScalingRotaryEmbedding, LinearScalingRotaryEmbedding,
Llama3RotaryEmbedding, Llama4VisionRotaryEmbedding, MRotaryEmbedding,
NTKScalingRotaryEmbedding, Phi3LongRoPEScaledRotaryEmbedding,
RotaryEmbedding, YaRNScalingRotaryEmbedding)
from vllm.model_executor.layers.rotary_embedding.common import (
rotate_gptj, rotate_neox, yarn_find_correction_range,
yarn_linear_ramp_mask)
from vllm.model_executor.layers.rotary_embedding.deepseek_scaling_rope import (
yarn_get_mscale)
from vllm.model_executor.layers.rotary_embedding.mrope import (
apply_interleaved_rope)
@patch_to(RotaryEmbedding)
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
dtype: torch.dtype,
op_type: str = "Half", # FIXME: other op type not supported yet
) -> None:
logger.info('[Patch] RotaryEmbedding use SUPA RoPE')
super(RotaryEmbedding, self).__init__() # type: ignore
self.head_size = head_size
self.rotary_dim = rotary_dim
self.max_position_embeddings = max_position_embeddings
self.base = base
self.is_neox_style = is_neox_style
self.dtype = dtype
self.op_type = op_type # FIXME: other op type not supported yet
if isinstance(self, MRotaryEmbedding):
cache = self._compute_cos_sin_cache()
cache = cache.to(dtype)
device = torch.cuda.current_device()
cache = cache.to(device)
self.cos_sin_cache: torch.Tensor # type: ignore
self.register_buffer("cos_sin_cache", cache, persistent=False)
elif isinstance(self, DeepseekScalingRotaryEmbedding):
self.head_size = head_size
self.rotary_dim = rotary_dim
self.max_position_embeddings = max_position_embeddings
self.base = base
self.is_neox_style = is_neox_style
self.dtype = dtype
cache = self._compute_cos_sin_cache()
cache = cache.to(dtype)
device = torch.supa.current_device()
cache = cache.to(device)
self.cos_sin_cache: torch.Tensor # type: ignore
self.register_buffer("cos_sin_cache", cache, persistent=False)
else:
sin_cache, cos_cache = self._compute_cos_sin_cache()
sin_cache = sin_cache.to(torch.float32)
cos_cache = cos_cache.to(torch.float32)
device = torch.cuda.current_device()
sin_cache = sin_cache.to(device)
cos_cache = cos_cache.to(device)
self.register_buffer("sin_cache", sin_cache, persistent=False)
self.register_buffer("cos_cache", cos_cache, persistent=False)
@patch_to(RotaryEmbedding)
def _compute_cos_sin_cache(self) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute the cos and sin cache."""
with torch.device('cpu'):
inv_freq = self._compute_inv_freq(self.base)
t = torch.arange(self.max_position_embeddings, dtype=torch.float)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
if isinstance(self, MRotaryEmbedding):
cos = freqs.cos()
sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1)
return cache
else:
if self.op_type == "Half" or self.op_type == "TeleChat":
freqs = freqs.repeat(1, 2)
cos = freqs.cos()
sin = freqs.sin()
else:
cos_freqs = freqs.repeat_interleave(2, dim=-1)
cos = cos_freqs.cos()
scales = torch.arange(cos_freqs.numel()) % 2 * 2 - 1
sin_freqs = cos_freqs * scales.reshape_as(cos_freqs)
sin = sin_freqs.sin()
return sin, cos
@patch_to(RotaryEmbedding)
def forward_oot(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
query_, key_ = torch_br.supa_rope_infer_v2(query,
key,
self.sin_cache,
self.cos_cache,
positions,
self.head_size,
rope_type=self.op_type,
rotary_size=self.rotary_dim)
return query_, key_
@patch_to(RotaryEmbedding)
def enabled(cls) -> bool:
return True
class SupaDeepseekScalingRotaryEmbedding(RotaryEmbedding):
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: float,
is_neox_style: bool,
scaling_factor: float,
dtype: torch.dtype,
*,
extrapolation_factor: float = 1,
attn_factor: float = 1,
beta_fast: int = 32,
beta_slow: int = 1,
mscale: float = 1,
mscale_all_dim: float = 0,
) -> None:
self.scaling_factor = scaling_factor
self.extrapolation_factor = extrapolation_factor
self.attn_factor = attn_factor
self.beta_fast = beta_fast
self.beta_slow = beta_slow
# Get n-d magnitude scaling corrected for interpolation.
self.mscale = float(
yarn_get_mscale(self.scaling_factor, float(mscale)) /
yarn_get_mscale(self.scaling_factor, float(mscale_all_dim)) *
attn_factor)
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
is_neox_style, dtype)
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
with torch.device('cpu'):
pos_freqs = self.base**(torch.arange(
0, self.rotary_dim, 2, dtype=torch.float, device="cpu") /
self.rotary_dim)
inv_freq_extrapolation = 1.0 / pos_freqs
inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)
low, high = yarn_find_correction_range(
self.beta_fast, self.beta_slow, self.rotary_dim, self.base,
self.max_position_embeddings)
# Get n-d rotational scaling corrected for extrapolation
inv_freq_mask = (1 - yarn_linear_ramp_mask(
low, high, self.rotary_dim // 2,
dtype=torch.float)) * self.extrapolation_factor
inv_freq = inv_freq_interpolation * (
1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
return inv_freq
def _compute_cos_sin_cache(self) -> torch.Tensor:
with torch.device('cpu'):
inv_freq = self._compute_inv_freq(self.scaling_factor)
t = torch.arange(self.max_position_embeddings *
self.scaling_factor,
dtype=torch.float)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos_freqs = freqs.repeat_interleave(2, dim=-1)
cos = (cos_freqs.cos() * self.mscale)
scales = torch.arange(cos_freqs.numel()) % 2 * 2 - 1
sin_freqs = cos_freqs * scales.reshape_as(cos_freqs)
sin = (sin_freqs.sin() * self.mscale)
return sin, cos
@patch_to(DeepseekScalingRotaryEmbedding)
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
with torch.device('cpu'):
pos_freqs = self.base**(torch.arange(
0, self.rotary_dim, 2, dtype=torch.float, device="cpu") /
self.rotary_dim)
inv_freq_extrapolation = 1.0 / pos_freqs
inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)
low, high = yarn_find_correction_range(self.beta_fast, self.beta_slow,
self.rotary_dim, self.base,
self.max_position_embeddings)
# Get n-d rotational scaling corrected for extrapolation
inv_freq_mask = (1 - yarn_linear_ramp_mask(
low, high, self.rotary_dim // 2,
dtype=torch.float)) * self.extrapolation_factor
inv_freq = inv_freq_interpolation * (
1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
return inv_freq
@patch_to(DeepseekScalingRotaryEmbedding)
def _compute_cos_sin_cache(self) -> torch.Tensor:
with torch.device('cpu'):
inv_freq = self._compute_inv_freq(self.scaling_factor)
t = torch.arange(self.max_position_embeddings * self.scaling_factor,
dtype=torch.float32)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = (freqs.cos() * self.mscale)
sin = (freqs.sin() * self.mscale)
cache = torch.cat((cos, sin), dim=-1)
return cache
@patch_to(DeepseekScalingRotaryEmbedding)
def forward_native(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: Optional[torch.Tensor] = None,
offsets: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
"""PyTorch-native implementation equivalent to forward()."""
assert key is not None
self._match_cos_sin_cache_dtype(query)
query_rot = query[..., :self.rotary_dim]
key_rot = key[..., :self.rotary_dim]
if self.rotary_dim < self.head_size:
query_pass = query[..., self.rotary_dim:]
key_pass = key[..., self.rotary_dim:]
cos_sin = self.cos_sin_cache[
torch.add(positions, offsets) if offsets is not None else positions]
cos, sin = cos_sin.chunk(2, dim=-1)
if self.is_neox_style:
# NOTE(woosuk): Here we assume that the positions tensor has the
# shape [batch_size, seq_len].
cos = cos.repeat(1, 1, 2).unsqueeze(-2)
sin = sin.repeat(1, 1, 2).unsqueeze(-2)
else:
device = torch.supa.current_device()
cos = cos.to('cpu')
sin = sin.to('cpu')
cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)
cos = cos.to(device)
sin = sin.to(device)
rotate_fn = rotate_neox if self.is_neox_style else rotate_gptj
device = query_rot.device
if query.shape[0] > 1024:
query_rot = query_rot.to('cpu')
key_rot = key_rot.to('cpu')
cos = cos.to('cpu')
sin = sin.to('cpu')
query_rot = query_rot * cos + rotate_fn(query_rot) * sin
key_rot = key_rot * cos + rotate_fn(key_rot) * sin
if query.shape[0] > 1024:
query_rot = query_rot.to(device)
key_rot = key_rot.to(device)
if self.rotary_dim < self.head_size:
query = torch.cat((query_rot, query_pass), dim=-1)
key = torch.cat((key_rot, key_pass), dim=-1)
else:
query = query_rot
key = key_rot
return query, key
@patch_to(DeepseekScalingRotaryEmbedding)
def forward_oot(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
query, key = self.forward_native(positions, query, key, offsets)
return query, key
@patch_to(YaRNScalingRotaryEmbedding)
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
with torch.device('cpu'):
pos_freqs = self.base**(
torch.arange(0, self.rotary_dim, 2, dtype=torch.float) /
self.rotary_dim)
inv_freq_extrapolation = 1.0 / pos_freqs
inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)
low, high = yarn_find_correction_range(self.beta_fast, self.beta_slow,
self.rotary_dim, self.base,
self.max_position_embeddings)
# Get n-d rotational scaling corrected for extrapolation
inv_freq_mask = (1 - yarn_linear_ramp_mask(
low, high, self.rotary_dim // 2,
dtype=torch.float)) * self.extrapolation_factor
inv_freq = inv_freq_interpolation * (
1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
return inv_freq
@patch_to(YaRNScalingRotaryEmbedding)
def _compute_cos_sin_cache(self) -> torch.Tensor:
with torch.device('cpu'):
inv_freq = self._compute_inv_freq(self.scaling_factor)
t = torch.arange(self.max_position_embeddings * self.scaling_factor,
dtype=torch.float32)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
freqs = freqs.repeat(1, 2)
cos = freqs.cos() * self.mscale
sin = freqs.sin() * self.mscale
return sin, cos
def dtnamicNTK_compute_cos_sin_cache(
self) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute the cos and sin cache."""
with torch.device('cpu'):
inv_freq = self._compute_inv_freq(self.base)
t = torch.arange(self.max_position_embeddings, dtype=torch.float)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
if self.op_type == "Half" or self.op_type == "TeleChat":
freqs = freqs.repeat(1, 2)
cos = freqs.cos()
sin = freqs.sin()
else:
cos_freqs = freqs.repeat_interleave(2, dim=-1)
cos = cos_freqs.cos()
scales = torch.arange(cos_freqs.numel()) % 2 * 2 - 1
sin_freqs = cos_freqs * scales.reshape_as(cos_freqs)
sin = sin_freqs.sin()
return sin, cos
def dynamicNTKScaling_rope_forward(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
if query.shape[-1] != key.shape[-1]:
query_, key_ = torch_br.supa_rope_infer_v2(query,
key,
self.sin_cache,
self.cos_cache,
positions,
self.head_size,
rope_type="MRope")
else:
query_, key_ = torch_br.supa_rope_infer_v2(query,
key,
self.sin_cache,
self.cos_cache,
positions,
self.head_size,
rope_type=self.op_type)
return query_, key_
DynamicNTKScalingRotaryEmbedding._compute_cos_sin_cache = dtnamicNTK_compute_cos_sin_cache
DynamicNTKScalingRotaryEmbedding.forward = dynamicNTKScaling_rope_forward
def _apply_rotary_emb_torch(
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
is_neox_style: bool,
) -> torch.Tensor:
cos = cos.unsqueeze(-2).to(x.dtype)
sin = sin.unsqueeze(-2).to(x.dtype)
if is_neox_style:
x1, x2 = torch.chunk(x, 2, dim=-1)
else:
x1 = x[..., ::2]
x2 = x[..., 1::2]
o1 = x1 * cos - x2 * sin
o2 = x2 * cos + x1 * sin
if is_neox_style:
return torch.cat((o1, o2), dim=-1)
else:
return torch.stack((o1, o2), dim=-1).flatten(-2)
def _apply_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor,
is_neox_style: bool) -> torch.Tensor:
"""
Args:
x: [num_tokens, num_heads, head_size]
cos: [num_tokens, head_size // 2]
sin: [num_tokens, head_size // 2]
is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
positional embeddings.
"""
return _apply_rotary_emb_torch(x, cos, sin, is_neox_style)
def forward_MRotaryEmbedding_0_9_2(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
"""PyTorch-native implementation equivalent to forward().
Args:
positions:
[num_tokens,] (text only) or
[3, num_tokens] (T/H/W positions with multimodal inputs)
query: [num_tokens, num_heads * head_size]
key: [num_tokens, num_kv_heads * head_size]
"""
assert positions.ndim == 1 or positions.ndim == 2
assert key is not None
num_tokens = positions.shape[-1]
cos_sin = self.cos_sin_cache[positions]
cos, sin = cos_sin.chunk(2, dim=-1)
if positions.ndim == 2:
assert self.mrope_section
cos = torch.cat([
m[i] for i, m in enumerate(cos.split(self.mrope_section, dim=-1))
],
dim=-1)
sin = torch.cat([
m[i] for i, m in enumerate(sin.split(self.mrope_section, dim=-1))
],
dim=-1)
query_shape = query.shape
query = query.view(num_tokens, -1, self.head_size)
query_rot = query[..., :self.rotary_dim]
query_pass = query[..., self.rotary_dim:]
query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style)
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
key_shape = key.shape
key = key.view(num_tokens, -1, self.head_size)
key_rot = key[..., :self.rotary_dim]
key_pass = key[..., self.rotary_dim:]
key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style)
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
return query, key
def forward_supa(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
positions:
[num_tokens,] (text only) or
[3, num_tokens] (T/H/W positions with multimodal inputs)
query: [num_tokens, num_heads * head_size]
key: [num_tokens, num_kv_heads * head_size]
"""
if br_envs.VLLM_BR_USE_MROPE_0_9_2:
return forward_MRotaryEmbedding_0_9_2(self, positions, query, key)
assert positions.ndim == 1 or positions.ndim == 2
data_in_supa = lambda t: str(t.device).startswith('supa')
data_in_cpu = lambda t: t.device == torch.device('cpu')
if positions.ndim == 2:
# use bypass for decode stage
if (positions.shape[1] == 1):
cos_sin = self.cos_sin_cache[positions]
cos, sin = cos_sin.chunk(2, dim=-1)
cos = cos[0]
sin = sin[0]
else:
cos_sin = self.cos_sin_cache[positions.to(torch.int64)]
cos, sin = cos_sin.chunk(2, dim=-1)
assert self.mrope_section
if self.mrope_interleaved:
cos = apply_interleaved_rope(cos, self.mrope_section)
sin = apply_interleaved_rope(sin, self.mrope_section)
else:
cos = torch.cat([
m[i] for i, m in enumerate(
cos.split(self.mrope_section, dim=-1))
],
dim=-1)
sin = torch.cat([
m[i] for i, m in enumerate(
sin.split(self.mrope_section, dim=-1))
],
dim=-1)
if data_in_supa(query) and data_in_supa(key):
sin = sin.supa() if data_in_cpu(sin) else sin
cos = cos.supa() if data_in_cpu(cos) else cos
positions = positions.supa() if data_in_cpu(positions) else positions
query, key = torch_br.supa_rope_infer_v2(query,
key,
sin.to(torch.float32),
cos.to(torch.float32),
positions.to(torch.int32),
self.head_size,
rope_type="MRope")
return query, key
MRotaryEmbedding.forward = forward_supa
def get_rope(
head_size: int,
rotary_dim: int,
max_position: int,
base: int,
is_neox_style: bool = True,
rope_scaling: Optional[dict[str, Any]] = None,
dtype: Optional[torch.dtype] = None,
partial_rotary_factor: float = 1.0,
dual_chunk_attention_config: Optional[dict[str, Any]] = None,
op_type: str = "Half",
) -> RotaryEmbedding:
if dtype is None:
dtype = torch.get_default_dtype()
if rope_scaling is not None:
# Transforms every value that is a list into a tuple for caching calls
rope_scaling_tuple = {
k: tuple(v) if isinstance(v, list) else v
for k, v in rope_scaling.items()
}
rope_scaling_args = tuple(rope_scaling_tuple.items())
else:
rope_scaling_args = None
if dual_chunk_attention_config is not None:
dual_chunk_attention_tuple = {
k: tuple(v) if isinstance(v, list) else v
for k, v in dual_chunk_attention_config.items()
if k != "sparse_attention_config"
}
dual_chunk_attention_args = tuple(dual_chunk_attention_tuple.items())
else:
dual_chunk_attention_args = None
if partial_rotary_factor < 1.0:
rotary_dim = int(rotary_dim * partial_rotary_factor)
key = (head_size, rotary_dim, max_position, base, is_neox_style,
rope_scaling_args, dual_chunk_attention_args, dtype)
if key in _ROPE_DICT:
return _ROPE_DICT[key]
if dual_chunk_attention_config is not None:
extra_kwargs = {
k: v
for k, v in dual_chunk_attention_config.items()
if k in ("chunk_size", "local_size")
}
rotary_emb = DualChunkRotaryEmbedding(head_size, rotary_dim,
max_position, base,
is_neox_style, dtype,
**extra_kwargs)
elif not rope_scaling:
rotary_emb = RotaryEmbedding(head_size,
rotary_dim,
max_position,
base,
is_neox_style,
dtype,
op_type=op_type)
else:
scaling_type = rope_scaling["rope_type"]
if scaling_type == "llama3":
scaling_factor = rope_scaling["factor"]
low_freq_factor = rope_scaling["low_freq_factor"]
high_freq_factor = rope_scaling["high_freq_factor"]
original_max_position = rope_scaling[
"original_max_position_embeddings"]
rotary_emb = Llama3RotaryEmbedding(head_size, rotary_dim,
max_position, base,
is_neox_style, dtype,
scaling_factor, low_freq_factor,
high_freq_factor,
original_max_position)
elif scaling_type == "mllama4":
rotary_emb = Llama4VisionRotaryEmbedding(head_size, rotary_dim,
max_position, base,
is_neox_style, dtype)
elif scaling_type == "default":
if "mrope_section" in rope_scaling:
rotary_emb = MRotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
dtype=torch.float32,
mrope_section=rope_scaling["mrope_section"],
mrope_interleaved=rope_scaling.get("mrope_interleaved",
False),
)
else:
rotary_emb = RotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
dtype,
)
elif scaling_type == "linear":
scaling_factor = rope_scaling["factor"]
rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim,
max_position, base,
is_neox_style,
scaling_factor, dtype)
elif scaling_type == "ntk":
scaling_factor = rope_scaling["factor"]
mixed_b = rope_scaling.get('mixed_b', None)
rotary_emb = NTKScalingRotaryEmbedding(head_size, rotary_dim,
max_position, base,
is_neox_style,
scaling_factor, dtype,
mixed_b)
elif scaling_type == "dynamic":
scaling_factor = rope_scaling["factor"]
rotary_emb = DynamicNTKScalingRotaryEmbedding(
head_size, rotary_dim, max_position, base, is_neox_style,
scaling_factor, dtype)
elif scaling_type == "yarn":
scaling_factor = rope_scaling["factor"]
original_max_position = rope_scaling[
"original_max_position_embeddings"]
extra_kwargs = {
k: v
for k, v in rope_scaling.items()
if k in ("extrapolation_factor", "attn_factor", "beta_fast",
"beta_slow")
}
rotary_emb = YaRNScalingRotaryEmbedding(head_size, rotary_dim,
original_max_position,
base, is_neox_style,
scaling_factor, dtype,
**extra_kwargs)
elif scaling_type == "deepseek_yarn":
scaling_factor = rope_scaling["factor"]
original_max_position = rope_scaling[
"original_max_position_embeddings"]
# assert max_position == original_max_position * scaling_factor
extra_kwargs = {
k: v
for k, v in rope_scaling.items()
if k in ("extrapolation_factor", "attn_factor", "beta_fast",
"beta_slow", "mscale", "mscale_all_dim")
}
rotary_emb = DeepseekScalingRotaryEmbedding(
head_size, rotary_dim, original_max_position, base,
is_neox_style, scaling_factor, dtype, **extra_kwargs)
elif scaling_type == "deepseek_yarn_supa":
scaling_factor = rope_scaling["factor"]
original_max_position = rope_scaling[
"original_max_position_embeddings"]
# assert max_position == original_max_position * scaling_factor
extra_kwargs = {
k: v
for k, v in rope_scaling.items()
if k in ("extrapolation_factor", "attn_factor", "beta_fast",
"beta_slow", "mscale", "mscale_all_dim")
}
rotary_emb = SupaDeepseekScalingRotaryEmbedding(
head_size, rotary_dim, original_max_position, base,
is_neox_style, scaling_factor, dtype, **extra_kwargs)
elif scaling_type == "longrope":
short_factor = rope_scaling["short_factor"]
long_factor = rope_scaling["long_factor"]
original_max_position = rope_scaling[
"original_max_position_embeddings"]
extra_kwargs = {
k: v
for k, v in rope_scaling.items()
if k in ("short_mscale", "long_mscale")
}
rotary_emb = Phi3LongRoPEScaledRotaryEmbedding(
head_size, rotary_dim, max_position, original_max_position,
base, is_neox_style, dtype, short_factor, long_factor,
**extra_kwargs)
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
_ROPE_DICT[key] = rotary_emb
return rotary_emb
def deepseek_get_rope(
head_size: int,
rotary_dim: int,
max_position: int,
base: int,
is_neox_style: bool = True,
rope_scaling: Optional[dict[str, Any]] = None,
dtype: Optional[torch.dtype] = None,
partial_rotary_factor: float = 1.0,
dual_chunk_attention_config: Optional[dict[str, Any]] = None,
) -> RotaryEmbedding:
return get_rope(head_size, rotary_dim, max_position, base, is_neox_style,
rope_scaling, dtype, partial_rotary_factor,
dual_chunk_attention_config, "DeepSeek")
def chatglm2_get_rope(
head_size: int,
rotary_dim: int,
max_position: int,
base: int,
is_neox_style: bool = True,
rope_scaling: Optional[dict[str, Any]] = None,
dtype: Optional[torch.dtype] = None,
partial_rotary_factor: float = 1.0,
dual_chunk_attention_config: Optional[dict[str, Any]] = None,
) -> RotaryEmbedding:
return get_rope(head_size, rotary_dim, max_position, base, is_neox_style,
rope_scaling, dtype, partial_rotary_factor,
dual_chunk_attention_config, "DeepSeek")
vllm.model_executor.layers.rotary_embedding.get_rope = get_rope
vllm.model_executor.models.deepseek_v2.get_rope = deepseek_get_rope
vllm.model_executor.models.chatglm.get_rope = chatglm2_get_rope
@patch_to(MRotaryEmbedding)
def _glm4v_get_input_positions_tensor(
cls,
input_tokens: list[int],
hf_config: PretrainedConfig,
image_grid_thw: Union[list[list[int]], torch.Tensor],
video_grid_thw: Union[list[list[int]], torch.Tensor],
context_len: int = 0,
seq_len: Optional[int] = None,
) -> tuple[torch.Tensor, int]:
"""Get mrope input positions and delta value for GLM4V."""
image_token_id = hf_config.image_token_id
video_start_token_id = hf_config.video_start_token_id
video_end_token_id = hf_config.video_end_token_id
spatial_merge_size = hf_config.vision_config.spatial_merge_size
llm_pos_ids_list: list = []
if not (image_grid_thw is None and video_grid_thw is None):
if isinstance(image_grid_thw, torch.Tensor):
image_grid_thw = image_grid_thw.tolist()
input_token_type: list[str] = []
video_check_flg = False
for token in input_tokens:
if token == video_start_token_id:
video_check_flg = True
elif token == video_end_token_id:
video_check_flg = False
if (token == image_token_id) and (video_check_flg is False):
input_token_type.append("image")
elif (token == image_token_id) and (video_check_flg is True):
input_token_type.append("video")
else:
input_token_type.append("text")
input_type_group: list[tuple[str, int, int]] = []
for key, group_iter in itertools.groupby(enumerate(input_token_type),
lambda x: x[1]):
group_list = list(group_iter)
start_index = group_list[0][0]
end_index = group_list[-1][0] + 1
input_type_group.append((key, start_index, end_index))
video_frame_num = 1
mm_data_idx = 0
for modality_type, start_idx, end_idx in input_type_group:
st_idx = llm_pos_ids_list[-1].max() + 1 if len(
llm_pos_ids_list) > 0 else 0
if modality_type == "image":
t, h, w = (
image_grid_thw[mm_data_idx][0],
image_grid_thw[mm_data_idx][1],
image_grid_thw[mm_data_idx][2],
)
llm_grid_t, llm_grid_h, llm_grid_w = \
t, h // spatial_merge_size, w // spatial_merge_size
t_index = torch.arange(llm_grid_t).view(-1, 1).expand(
-1, llm_grid_h * llm_grid_w).flatten()
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(
llm_grid_t, -1, llm_grid_w).flatten()
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(
llm_grid_t, llm_grid_h, -1).flatten()
llm_pos_ids_list.append(
torch.stack([t_index, h_index, w_index]) + st_idx)
mm_data_idx += 1
elif modality_type == "video":
t, h, w = (
video_frame_num,
image_grid_thw[mm_data_idx][1],
image_grid_thw[mm_data_idx][2],
)
llm_grid_t, llm_grid_h, llm_grid_w = \
t, h // spatial_merge_size, w // spatial_merge_size
for t_idx in range(llm_grid_t):
t_index = torch.tensor(t_idx).view(-1, 1).expand(
-1, llm_grid_h * llm_grid_w).flatten()
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(
1, -1, llm_grid_w).flatten()
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(
1, llm_grid_h, -1).flatten()
llm_pos_ids_list.append(
torch.stack([t_index, h_index, w_index]) + st_idx)
mm_data_idx += 1
video_frame_num += 1
else:
text_len = end_idx - start_idx
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
video_frame_num = 1
else:
text_len = len(input_tokens)
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1))
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
llm_positions = llm_positions[:, context_len:seq_len]
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
return llm_positions, mrope_position_delta
@patch_to(MRotaryEmbedding)
def get_input_positions_tensor_for_glm(
cls,
input_tokens: list[int],
hf_config: PretrainedConfig,
image_grid_thw: Union[list[list[int]], torch.Tensor],
video_grid_thw: Union[list[list[int]], torch.Tensor],
second_per_grid_ts: list[float],
context_len: int = 0,
seq_len: Optional[int] = None,
audio_feature_lengths: Optional[torch.Tensor] = None,
use_audio_in_video: bool = False,
) -> tuple[torch.Tensor, int]:
from vllm.transformers_utils.config import thinker_uses_mrope
if thinker_uses_mrope(hf_config):
return cls._omni_get_input_positions_tensor(
input_tokens=input_tokens,
hf_config=hf_config,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
second_per_grid_ts=second_per_grid_ts,
context_len=context_len,
seq_len=seq_len,
audio_feature_lengths=audio_feature_lengths,
use_audio_in_video=use_audio_in_video,
)
elif "glm4v" in hf_config.model_type:
return cls._glm4v_get_input_positions_tensor(
cls,
input_tokens=input_tokens,
hf_config=hf_config,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
context_len=context_len,
seq_len=seq_len,
)
else:
return cls._vl_get_input_positions_tensor(
input_tokens=input_tokens,
hf_config=hf_config,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
second_per_grid_ts=second_per_grid_ts,
context_len=context_len,
seq_len=seq_len,
)

View File

@@ -0,0 +1,65 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
import torch
import vllm
from vllm.model_executor.layers.utils import get_token_bin_counts_and_mask
def apply_penalties_fit(logits: torch.Tensor,
prompt_tokens_tensor: torch.Tensor,
output_tokens_tensor: torch.Tensor,
presence_penalties: torch.Tensor,
frequency_penalties: torch.Tensor,
repetition_penalties: torch.Tensor) -> torch.Tensor:
"""
Applies penalties in place to the logits tensor
logits : The input logits tensor of shape [num_seqs, vocab_size]
prompt_tokens_tensor: A tensor containing the prompt tokens. The prompts
are padded to the maximum prompt length within the batch using
`vocab_size` as the padding value. The value `vocab_size` is used
for padding because it does not correspond to any valid token ID
in the vocabulary.
output_tokens_tensor: The output tokens tensor.
presence_penalties: The presence penalties of shape (num_seqs, )
frequency_penalties: The frequency penalties of shape (num_seqs, )
repetition_penalties: The repetition penalties of shape (num_seqs, )
"""
num_seqs, vocab_size = logits.shape
_, prompt_mask = get_token_bin_counts_and_mask(prompt_tokens_tensor,
vocab_size, num_seqs)
output_bin_counts, output_mask = get_token_bin_counts_and_mask(
output_tokens_tensor, vocab_size, num_seqs)
repetition_penalties = repetition_penalties.unsqueeze(dim=1).repeat(
1, vocab_size)
# If token appears in prompt or output, apply, otherwise use 1.0 for no-op.
penalties = torch.where(prompt_mask | output_mask, repetition_penalties,
1.0)
# If logits are positive, divide by penalty, otherwise multiply by penalty.
scaling = torch.where(logits > 0, 1.0 / penalties, penalties)
logits *= scaling
# We follow the definition in OpenAI API.
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
logits -= frequency_penalties.unsqueeze(dim=1) * output_bin_counts
logits -= presence_penalties.unsqueeze(dim=1) * output_mask
return logits
vllm.model_executor.layers.utils.apply_penalties = apply_penalties_fit

View File

@@ -0,0 +1,139 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from typing import Optional, Tuple
import torch
import torch.nn.functional as F
import torch_br
import torch_br.supa._debug as supa_debug
from fastcore.basics import patch_to
import vllm
from vllm.distributed import tensor_model_parallel_all_reduce
from vllm.model_executor.layers.vocab_parallel_embedding import (
UnquantizedEmbeddingMethod, VocabParallelEmbedding)
from vllm_br import envs
@patch_to(UnquantizedEmbeddingMethod)
def process_weights_after_loading(self, module):
if envs.VLLM_BR_EMBEDDING_S0B:
ori_weight = module.weight.data.cpu()
module.weight.data = torch_br._empty_ut_only(module.weight.shape,
"colmajor",
False,
0,
dtype=module.weight.dtype,
sbp='SB')
module.weight.data.copy_(ori_weight)
@patch_to(UnquantizedEmbeddingMethod)
def embedding(self, layer: torch.nn.Module,
input_: torch.Tensor) -> torch.Tensor:
if envs.VLLM_BR_EMBEDDING_S0B:
y_supa = torch_br._empty_ut_only(
[1, input_.shape[0], layer.weight.shape[-1]],
is_numa=False,
dtype=layer.weight.dtype,
sbp='BB',
tensor_type="colmajor",
)
torch_br.out_embedding(y_supa, layer.weight.data, input_, -1, -1)
y_supa.squeeze_(0)
return y_supa
return F.embedding(input_, layer.weight)
@torch.jit.script
def get_masked_input_and_mask(
input_: torch.Tensor, org_vocab_start_index: int,
org_vocab_end_index: int, num_org_vocab_padding: int,
added_vocab_start_index: int,
added_vocab_end_index: int) -> Tuple[torch.Tensor, torch.Tensor]:
# torch.jit.script will fuse all of the pointwise ops below
# into a single kernel, making it very fast
spc_num = envs.VLLM_BR_DEVICE_SPC_NUM
if spc_num > 16:
org_vocab_mask = (input_ >= org_vocab_start_index) & (
input_ < org_vocab_end_index)
added_vocab_mask = (input_ >= added_vocab_start_index) & (
input_ < added_vocab_end_index)
added_offset = added_vocab_start_index - (
org_vocab_end_index -
org_vocab_start_index) - num_org_vocab_padding
valid_offset = (org_vocab_start_index *
org_vocab_mask) + (added_offset * added_vocab_mask)
vocab_mask = org_vocab_mask | added_vocab_mask
input_ = vocab_mask * (input_ - valid_offset)
return input_, ~vocab_mask
else:
input_, inv_vocab_mask = torch_br.supa_embedding_mask_infer(
input_, org_vocab_start_index, org_vocab_end_index,
num_org_vocab_padding, added_vocab_start_index,
added_vocab_end_index)
return input_, inv_vocab_mask
vllm.model_executor.layers.vocab_parallel_embedding.get_masked_input_and_mask = get_masked_input_and_mask
def vocab_parallel_embedding_forward(self, input_) -> torch.Tensor:
if self.tp_size > 1:
# Build the mask.
masked_input, input_mask = get_masked_input_and_mask(
input_,
self.shard_indices.org_vocab_start_index,
self.shard_indices.org_vocab_end_index,
self.shard_indices.num_org_vocab_padding,
self.shard_indices.added_vocab_start_index,
self.shard_indices.added_vocab_end_index,
)
else:
masked_input = input_
# Get the embeddings.
output_parallel = self.quant_method.embedding(self, masked_input.long())
# Mask the output embedding.
if self.tp_size > 1:
output_parallel.masked_fill_(input_mask.unsqueeze(-1),
0) # type: ignore
# Reduce across all the model parallel GPUs.
output = tensor_model_parallel_all_reduce(output_parallel)
return output
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
# numa weight is 3-dims
if len(layer.weight.shape) == 3:
output_size = (layer.output_size_per_partition if hasattr(
layer, "output_size_per_partition") else layer.output_size)
return torch_br.br_fused_mlp_infer(
x, [layer.weight],
output_w=output_size,
bias=[bias] if bias is not None else None)
supa_debug.set_enable_sublas_api(True)
output = F.linear(x, layer.weight, bias)
supa_debug.set_enable_sublas_api(False)
return output
UnquantizedEmbeddingMethod.apply = apply
VocabParallelEmbedding.forward = vocab_parallel_embedding_forward

View File

@@ -0,0 +1,17 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from . import default_loader # noqa: F401

View File

@@ -0,0 +1,83 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
import time
import torch
from torch import nn
from vllm.config import ModelConfig, VllmConfig
from vllm.logger import logger
from vllm.model_executor.model_loader import DefaultModelLoader
from vllm.model_executor.model_loader.utils import (initialize_model,
set_default_torch_dtype)
from .utils import process_weights_after_loading
def load_model(self, vllm_config: VllmConfig,
model_config: ModelConfig) -> nn.Module:
device_config = vllm_config.device_config
target_device = torch.device(device_config.device)
with set_default_torch_dtype(model_config.dtype):
model = initialize_model(vllm_config=vllm_config,
model_config=model_config)
# NOTE: on SUPA, with device context may not take effect, mamully to device
# model = model.to(target_device)
# NOTE: move moe weight to cpu, reduce device memory usage, more layers can be moved to cpu if necessary
moe_packed_weights = [
"mlp.experts.w13_weight_packed",
"mlp.experts.w2_weight_packed",
"mlp.gate_up_proj",
"mlp.down_proj",
"mlp.experts",
"self_attn.qkv_proj",
"self_attn.o_proj",
]
for name, module in model.named_parameters():
if any(s in name for s in moe_packed_weights):
module.data = module.to("cpu")
else:
module.data = module.to(target_device)
torch.supa.empty_cache()
weights_to_load = {name for name, _ in model.named_parameters()}
loaded_weights = model.load_weights(
self.get_all_weights(model_config, model))
torch.supa.empty_cache()
self.counter_after_loading_weights = time.perf_counter()
logger.info(
"Loading weights took %.2f seconds",
self.counter_after_loading_weights -
self.counter_before_loading_weights)
# We only enable strict check for non-quantized models
# that have loaded weights tracking currently.
if model_config.quantization is None and loaded_weights is not None:
weights_not_loaded = weights_to_load - loaded_weights
if weights_not_loaded:
raise ValueError("Following weights were not initialized from "
f"checkpoint: {weights_not_loaded}")
process_weights_after_loading(model, model_config, target_device)
torch.cuda.empty_cache()
return model.eval()
DefaultModelLoader.load_model = load_model

Some files were not shown because too many files have changed in this diff Show More