first commit
This commit is contained in:
49
vllm_br/__init__.py
Normal file
49
vllm_br/__init__.py
Normal file
@@ -0,0 +1,49 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
import os
|
||||
|
||||
import torch # noqa F401
|
||||
import torch_br # noqa F401
|
||||
from torch_br.contrib import transfer_to_supa # noqa F401
|
||||
from torch_br.supa import _debug as supa_debug
|
||||
|
||||
# patches
|
||||
from . import utils # noqa: F401
|
||||
|
||||
# bypass memset
|
||||
supa_debug.set_disable_zero_ws(True)
|
||||
supa_debug.set_disable_zero_output_uma(True)
|
||||
supa_debug.set_disable_zero_output_numa(True)
|
||||
supa_debug.set_disable_reorder_zero(True)
|
||||
|
||||
os.environ["BRTB_ENABLE_NUMA_SPLIT"] = "1"
|
||||
os.environ["BRTB_ENABLE_NUMA_ALIGN_4K"] = "1"
|
||||
|
||||
|
||||
def register():
|
||||
"""Register the SUPA platform."""
|
||||
|
||||
return "vllm_br.platform.SUPAPlatform"
|
||||
|
||||
|
||||
def register_model():
|
||||
from . import attention # noqa: F401
|
||||
from . import config # noqa: F401
|
||||
from . import distributed # noqa: F401
|
||||
from . import v1 # noqa: F401
|
||||
from .model_executor import register_model
|
||||
|
||||
register_model()
|
||||
BIN
vllm_br/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
vllm_br/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm_br/__pycache__/envs.cpython-310.pyc
Normal file
BIN
vllm_br/__pycache__/envs.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm_br/__pycache__/forward_context.cpython-310.pyc
Normal file
BIN
vllm_br/__pycache__/forward_context.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm_br/__pycache__/platform.cpython-310.pyc
Normal file
BIN
vllm_br/__pycache__/platform.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm_br/__pycache__/utils.cpython-310.pyc
Normal file
BIN
vllm_br/__pycache__/utils.cpython-310.pyc
Normal file
Binary file not shown.
16
vllm_br/attention/__init__.py
Normal file
16
vllm_br/attention/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
from . import layer # noqa: F401
|
||||
BIN
vllm_br/attention/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
vllm_br/attention/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm_br/attention/__pycache__/layer.cpython-310.pyc
Normal file
BIN
vllm_br/attention/__pycache__/layer.cpython-310.pyc
Normal file
Binary file not shown.
130
vllm_br/attention/layer.py
Normal file
130
vllm_br/attention/layer.py
Normal file
@@ -0,0 +1,130 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.attention.layer
|
||||
from vllm.attention.layer import (maybe_save_kv_layer_to_connector,
|
||||
wait_for_kv_layer_from_connector)
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
|
||||
#direct_register_custom_op(
|
||||
# op_name="unified_attention",
|
||||
# op_func=unified_attention,
|
||||
# mutates_args=[],
|
||||
# fake_impl=unified_attention_fake,
|
||||
# dispatch_key=current_platform.dispatch_key,
|
||||
#)
|
||||
|
||||
#direct_register_custom_op(
|
||||
# op_name="unified_attention_with_output",
|
||||
# op_func=unified_attention_with_output,
|
||||
# mutates_args=["output"],
|
||||
# fake_impl=unified_attention_with_output_fake,
|
||||
# dispatch_key=current_platform.dispatch_key,
|
||||
#)
|
||||
|
||||
|
||||
def forward_(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
# For some alternate attention backends like MLA the attention output
|
||||
# shape does not match the query shape, so we optionally let the model
|
||||
# definition specify the output tensor shape.
|
||||
output_shape: Optional[torch.Size] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
The KV cache is stored inside this class and is accessed via
|
||||
`self.kv_cache`.
|
||||
|
||||
Attention metadata (`attn_metadata`) is set using a context manager in
|
||||
the model runner's `execute_model` method. It is accessed via forward
|
||||
context using
|
||||
`vllm.forward_context.get_forward_context().attn_metadata`.
|
||||
"""
|
||||
if self.calculate_kv_scales:
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
if attn_metadata.enable_kv_scales_calculation:
|
||||
self.calc_kv_scales(query, key, value)
|
||||
if self.use_output:
|
||||
output_shape = (output_shape
|
||||
if output_shape is not None else query.shape)
|
||||
output = torch.empty(output_shape,
|
||||
dtype=query.dtype,
|
||||
device=query.device)
|
||||
hidden_size = output_shape[-1]
|
||||
# We skip reshaping query, key and value tensors for the MLA
|
||||
# backend since these tensors have different semantics and are
|
||||
# processed differently.
|
||||
if not self.use_mla:
|
||||
# Reshape the query, key, and value tensors.
|
||||
# NOTE(woosuk): We do this outside the custom op to minimize the
|
||||
# CPU overheads from the non-CUDA-graph regions.
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
output = output.view(-1, self.num_heads, self.head_size)
|
||||
if key is not None:
|
||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||
if value is not None:
|
||||
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||
if self.use_direct_call:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
if isinstance(attn_metadata, dict):
|
||||
attn_metadata = attn_metadata[self.layer_name]
|
||||
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
self.impl.forward(self,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
self_kv_cache,
|
||||
attn_metadata,
|
||||
output=output)
|
||||
else:
|
||||
torch.ops.vllm.unified_attention_with_output(
|
||||
query, key, value, output, self.layer_name)
|
||||
return output.view(-1, hidden_size)
|
||||
else:
|
||||
if self.use_direct_call:
|
||||
wait_for_kv_layer_from_connector(self.layer_name)
|
||||
forward_context = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
if isinstance(attn_metadata, dict):
|
||||
attn_metadata = attn_metadata[self.layer_name]
|
||||
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
output = self.impl.forward(self, query, key, value, self_kv_cache,
|
||||
attn_metadata)
|
||||
maybe_save_kv_layer_to_connector(self.layer_name, self_kv_cache)
|
||||
return output
|
||||
else:
|
||||
# return torch.ops.vllm.unified_attention(
|
||||
# query, key, value, self.layer_name)
|
||||
wait_for_kv_layer_from_connector(self.layer_name)
|
||||
forward_context = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
if isinstance(attn_metadata, dict):
|
||||
attn_metadata = attn_metadata[self.layer_name]
|
||||
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
output = self.impl.forward(self, query, key, value, self_kv_cache,
|
||||
attn_metadata)
|
||||
maybe_save_kv_layer_to_connector(self.layer_name, self_kv_cache)
|
||||
return output
|
||||
|
||||
|
||||
vllm.attention.layer.Attention.forward = forward_
|
||||
0
vllm_br/compilation/__init__.py
Normal file
0
vllm_br/compilation/__init__.py
Normal file
BIN
vllm_br/compilation/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
vllm_br/compilation/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm_br/compilation/__pycache__/monitor.cpython-310.pyc
Normal file
BIN
vllm_br/compilation/__pycache__/monitor.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm_br/compilation/__pycache__/supa_graph.cpython-310.pyc
Normal file
BIN
vllm_br/compilation/__pycache__/supa_graph.cpython-310.pyc
Normal file
Binary file not shown.
70
vllm_br/compilation/monitor.py
Normal file
70
vllm_br/compilation/monitor.py
Normal file
@@ -0,0 +1,70 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
import os
|
||||
import time
|
||||
|
||||
from vllm.config import CompilationConfig, CompilationLevel, VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
context_manager = None
|
||||
torch_compile_start_time: float = 0.0
|
||||
|
||||
|
||||
def start_monitoring_torch_compile(vllm_config: VllmConfig):
|
||||
global torch_compile_start_time
|
||||
torch_compile_start_time = time.time()
|
||||
|
||||
compilation_config: CompilationConfig = vllm_config.compilation_config
|
||||
if compilation_config.level == CompilationLevel.PIECEWISE and \
|
||||
compilation_config.debug_dump_path:
|
||||
import depyf
|
||||
path = os.path.join(compilation_config.debug_dump_path,
|
||||
f"rank_{vllm_config.parallel_config.rank}")
|
||||
global context_manager
|
||||
context_manager = depyf.prepare_debug(path)
|
||||
context_manager.__enter__()
|
||||
|
||||
|
||||
def end_monitoring_torch_compile(vllm_config: VllmConfig):
|
||||
compilation_config: CompilationConfig = vllm_config.compilation_config
|
||||
if compilation_config.level == CompilationLevel.PIECEWISE:
|
||||
logger.info("torch.compile takes %.2f s in total",
|
||||
compilation_config.compilation_time)
|
||||
global context_manager
|
||||
if context_manager is not None:
|
||||
context_manager.__exit__(None, None, None)
|
||||
context_manager = None
|
||||
|
||||
|
||||
supagraph_capturing_enabled: bool = True
|
||||
|
||||
|
||||
def validate_supagraph_capturing_enabled():
|
||||
# used to monitor whether a supagraph capturing is legal at runtime.
|
||||
# should be called before any supagraph capturing.
|
||||
# if an illegal supagraph capturing happens, raise an error.
|
||||
global supagraph_capturing_enabled
|
||||
if not supagraph_capturing_enabled:
|
||||
raise RuntimeError("CUDA graph capturing detected at an inappropriate "
|
||||
"time. This operation is currently disabled.")
|
||||
|
||||
|
||||
def set_supagraph_capturing_enabled(enabled: bool):
|
||||
global supagraph_capturing_enabled
|
||||
supagraph_capturing_enabled = enabled
|
||||
239
vllm_br/compilation/supa_graph.py
Normal file
239
vllm_br/compilation/supa_graph.py
Normal file
@@ -0,0 +1,239 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
import dataclasses
|
||||
from contextlib import ExitStack
|
||||
from typing import Any, Callable, Optional
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
import torch_br
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.device_communicators.pynccl_allocator import (
|
||||
set_graph_pool_id)
|
||||
from vllm.distributed.parallel_state import get_world_group
|
||||
from vllm.forward_context import BatchDescriptor, get_forward_context
|
||||
from vllm.logger import init_logger, logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm_br.compilation.monitor import validate_supagraph_capturing_enabled
|
||||
from vllm_br.config.compilation import SUPAGraphMode
|
||||
from vllm_br.forward_context import BatchDescriptor
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class SUPAGraphEntry:
|
||||
batch_descriptor: BatchDescriptor
|
||||
supagraph: Optional[torch.supa.SUPAGraph] = None
|
||||
output: Optional[Any] = None
|
||||
|
||||
# for supagraph debugging, track the input addresses
|
||||
# during capture, and check if they are the same during replay
|
||||
input_addresses: Optional[list[int]] = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class SUPAGraphOptions:
|
||||
debug_log_enable: bool = True
|
||||
gc_disable: bool = False
|
||||
weak_ref_output: bool = True
|
||||
|
||||
|
||||
class SUPAGraphWrapper:
|
||||
"""Wraps a runnable to add SUPA graph capturing and replaying ability. And
|
||||
provide attribute access to the underlying `runnable` via `__getattr__`.
|
||||
|
||||
The workflow of this wrapper in the supagraph dispatching is as follows:
|
||||
1. At initialization, a runtime mode is assigned to the wrapper (FULL or
|
||||
PIECEWISE).
|
||||
2. At runtime, the wrapper receives a runtime_mode and a
|
||||
batch_descriptor(key) from the forward context and blindly trust them
|
||||
for supagraph dispatching.
|
||||
3. If runtime_mode is NONE or runtime_mode does not match the mode of the
|
||||
wrapper, just call the runnable directly.
|
||||
4. Otherwise, i.e., the runtime_mode matches the mode of the wrapper,
|
||||
the wrapper will perform supagraph capture(if key does not exist, create
|
||||
a new entry and cache it) or replay (if key exists in the cache).
|
||||
|
||||
Note: SUPAGraphWrapper does not store persistent buffers or copy any
|
||||
runtime inputs into that buffers for replay. We assume implementing them
|
||||
is done outside of the wrapper. That is because we do not make any
|
||||
assumption on the dynamic shape (batch size) of the runtime inputs, as a
|
||||
trade-off for staying orthogonal to compilation logic. Nevertheless,
|
||||
tracing and checking the input addresses to be consistent during replay is
|
||||
guaranteed when VLLM_LOGGING_LEVEL == "DEBUG".
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
runnable: Callable,
|
||||
vllm_config: VllmConfig,
|
||||
runtime_mode: SUPAGraphMode,
|
||||
supagraph_options: Optional[SUPAGraphOptions] = None):
|
||||
self.runnable = runnable
|
||||
self.vllm_config = vllm_config
|
||||
self.runtime_mode = runtime_mode
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
|
||||
self.first_run_finished = False
|
||||
self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG"
|
||||
|
||||
# assert runtime_mode is not NONE(no supagraph), otherwise, we don't
|
||||
# need to initialize a SUPAGraphWrapper.
|
||||
assert self.runtime_mode != SUPAGraphMode.NONE
|
||||
# TODO: in the future, if we want to use multiple
|
||||
# streams, it might not be safe to share a global pool.
|
||||
# only investigate this when we use multiple streams
|
||||
self.graph_pool = current_platform.get_global_graph_pool()
|
||||
|
||||
if supagraph_options is None:
|
||||
supagraph_options = SUPAGraphOptions()
|
||||
self.supagraph_options = supagraph_options
|
||||
# the entries for different batch descriptors that we need to capture
|
||||
# supagraphs for.
|
||||
self.concrete_supagraph_entries: dict[BatchDescriptor, SUPAGraphEntry]\
|
||||
= {}
|
||||
|
||||
def __getattr__(self, key: str):
|
||||
# allow accessing the attributes of the runnable.
|
||||
if hasattr(self.runnable, key):
|
||||
return getattr(self.runnable, key)
|
||||
raise AttributeError(f"Attribute {key} not exists in the runnable of "
|
||||
f"supagraph wrapper: {self.runnable}")
|
||||
|
||||
def unwrap(self) -> Callable:
|
||||
# in case we need to access the original runnable.
|
||||
return self.runnable
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
forward_context = get_forward_context()
|
||||
batch_descriptor = forward_context.batch_descriptor
|
||||
supagraph_runtime_mode = forward_context.cudagraph_runtime_mode
|
||||
|
||||
#if supagraph_runtime_mode == SUPAGraphMode.NONE or \
|
||||
# supagraph_runtime_mode != self.runtime_mode:
|
||||
if supagraph_runtime_mode == SUPAGraphMode.NONE:
|
||||
|
||||
# SUPAGraphMode.NONE could mean the profile run, a warmup run, or
|
||||
# running without supagraphs.
|
||||
# We do not trigger capture/replay if the runtime mode is not
|
||||
# matches. This enables properly dispatching to the correct
|
||||
# SUPAGraphWrapper when nesting multiple instances with different
|
||||
# runtime modes.
|
||||
return self.runnable(*args, **kwargs)
|
||||
|
||||
if batch_descriptor not in self.concrete_supagraph_entries:
|
||||
# create a new entry for this batch descriptor
|
||||
self.concrete_supagraph_entries[batch_descriptor] = \
|
||||
SUPAGraphEntry(batch_descriptor=batch_descriptor)
|
||||
|
||||
entry = self.concrete_supagraph_entries[batch_descriptor]
|
||||
|
||||
if entry.supagraph is None:
|
||||
if self.supagraph_options.debug_log_enable:
|
||||
# Since we capture supagraph for many different shapes and
|
||||
# capturing is fast, we don't need to log it for every
|
||||
# shape. E.g. we only log it for the first subgraph in
|
||||
# piecewise mode.
|
||||
logger.debug("Capturing a supagraph on (%s,%s)",
|
||||
self.runtime_mode.name, entry.batch_descriptor)
|
||||
# validate that supagraph capturing is legal at this point.
|
||||
validate_supagraph_capturing_enabled()
|
||||
|
||||
input_addresses = [
|
||||
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
|
||||
] + [
|
||||
x.data_ptr()
|
||||
for x in kwargs.values() if isinstance(x, torch.Tensor)
|
||||
]
|
||||
entry.input_addresses = input_addresses
|
||||
supagraph = torch.supa.SUPAGraph()
|
||||
|
||||
with ExitStack() as stack:
|
||||
if self.supagraph_options.gc_disable:
|
||||
# during every model forward for piecewise supagraph
|
||||
# mode, we will capture many pieces of supagraphs
|
||||
# (roughly one per layer). running gc again and again
|
||||
# across layers will make the supagraph capture very slow.
|
||||
# therefore, we only run gc for the first graph,
|
||||
# and disable gc for the rest of the graphs.
|
||||
stack.enter_context(patch("gc.collect", lambda: None))
|
||||
stack.enter_context(
|
||||
patch("torch.supa.empty_cache", lambda: None))
|
||||
|
||||
if self.graph_pool is not None:
|
||||
set_graph_pool_id(self.graph_pool)
|
||||
else:
|
||||
set_graph_pool_id(current_platform.graph_pool_handle())
|
||||
# mind-exploding: carefully manage the reference and memory.
|
||||
with torch.supa.graph(supagraph, pool=self.graph_pool):
|
||||
# `output` is managed by pytorch's supagraph pool
|
||||
output = self.runnable(*args, **kwargs)
|
||||
|
||||
# (FIXME): torch.ops._C.weak_ref_tensor is not supported
|
||||
# if self.supagraph_options.weak_ref_output:
|
||||
# # by converting it to weak ref,
|
||||
# # the original `output` will immediately be released
|
||||
# # to save memory. It is only safe to do this for
|
||||
# # the last graph in piecewise cuadgraph mode, because
|
||||
# # the output of the last graph will not be used by
|
||||
# # any other supa graph.
|
||||
# output = weak_ref_tensors(output)
|
||||
|
||||
# here we always use weak ref for the output
|
||||
# entry.output = weak_ref_tensors(output)
|
||||
entry.output = output
|
||||
entry.supagraph = supagraph
|
||||
|
||||
compilation_counter.num_cudagraph_captured += 1
|
||||
|
||||
# important: we need to return the output, rather than
|
||||
# the weak ref of the output, so that pytorch can correctly
|
||||
# manage the memory during supa graph capture
|
||||
return output
|
||||
|
||||
if self.is_debugging_mode:
|
||||
# check if the input addresses are the same
|
||||
new_input_addresses = [
|
||||
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
|
||||
] + [
|
||||
x.data_ptr()
|
||||
for x in kwargs.values() if isinstance(x, torch.Tensor)
|
||||
]
|
||||
assert new_input_addresses == entry.input_addresses, (
|
||||
f"Input addresses for supagraphs are different "
|
||||
f"during replay. Expected {entry.input_addresses}, "
|
||||
f"got {new_input_addresses}")
|
||||
|
||||
if self.vllm_config.parallel_config.world_size != 1:
|
||||
# prevent SCCL capturing by using the same stream with SCCL
|
||||
stream = torch.distributed.get_group_stream(
|
||||
get_world_group().device_group)
|
||||
else:
|
||||
stream = torch_br.supa.Stream()
|
||||
current_stream = torch.supa.current_stream()
|
||||
with torch_br.supa.stream(stream):
|
||||
entry.supagraph.replay()
|
||||
event = torch.supa.Event()
|
||||
stream.record_event(event)
|
||||
current_stream.wait_event(event)
|
||||
logger.debug(" ========Supa graph reply======== ")
|
||||
logger.debug(" padded num_tokens size = %s",
|
||||
batch_descriptor.num_tokens)
|
||||
return entry.output
|
||||
256
vllm_br/config/__init__.py
Normal file
256
vllm_br/config/__init__.py
Normal file
@@ -0,0 +1,256 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
||||
import vllm
|
||||
import vllm.envs as envs
|
||||
from vllm.config import VllmConfig, logger
|
||||
from vllm.config.compilation import CompilationLevel
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.utils import random_uuid
|
||||
from .compilation import SUPAGraphMode
|
||||
|
||||
|
||||
def supa_post_init(self):
|
||||
"""Verify configs are valid & consistent with each other.
|
||||
"""
|
||||
|
||||
self.try_verify_and_update_config()
|
||||
|
||||
if self.model_config is not None:
|
||||
self.model_config.verify_with_parallel_config(self.parallel_config)
|
||||
self.model_config.verify_dual_chunk_attention_config(self.load_config)
|
||||
|
||||
self.cache_config.verify_with_parallel_config(self.parallel_config)
|
||||
|
||||
if self.lora_config is not None:
|
||||
self.lora_config.verify_with_cache_config(self.cache_config)
|
||||
self.lora_config.verify_with_model_config(self.model_config)
|
||||
|
||||
if self.quant_config is None and self.model_config is not None:
|
||||
self.quant_config = VllmConfig._get_quantization_config(
|
||||
self.model_config, self.load_config)
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
if self.model_config is not None and \
|
||||
self.scheduler_config.chunked_prefill_enabled and \
|
||||
self.model_config.dtype == torch.float32 and \
|
||||
current_platform.get_device_capability() == (7, 5):
|
||||
logger.warning_once(
|
||||
"Turing devices tensor cores do not support float32 matmul. "
|
||||
"To workaround this limitation, vLLM will set 'ieee' input "
|
||||
"precision for chunked prefill triton kernels.")
|
||||
|
||||
# If the user does not explicitly set a compilation level, then
|
||||
# we use the default level. The default level depends on other
|
||||
# settings (see the below code).
|
||||
if self.compilation_config.level is None:
|
||||
if envs.VLLM_USE_V1:
|
||||
if (self.model_config is not None
|
||||
and not self.model_config.enforce_eager):
|
||||
self.compilation_config.level = CompilationLevel.PIECEWISE
|
||||
else:
|
||||
self.compilation_config.level = \
|
||||
CompilationLevel.NO_COMPILATION
|
||||
|
||||
else:
|
||||
# NB: Passing both --enforce-eager and a compilation level
|
||||
# in V0 means the compilation level wins out.
|
||||
self.compilation_config.level = CompilationLevel.NO_COMPILATION
|
||||
|
||||
# async tp is built on top of sequence parallelism
|
||||
# and requires it to be enabled.
|
||||
if self.compilation_config.pass_config.enable_async_tp:
|
||||
self.compilation_config.pass_config.enable_sequence_parallelism = \
|
||||
True
|
||||
if self.compilation_config.pass_config.enable_sequence_parallelism:
|
||||
self.compilation_config.custom_ops.append("+rms_norm")
|
||||
|
||||
if current_platform.support_static_graph_mode():
|
||||
# if cudagraph_mode is not explicitly set by users, set default
|
||||
# value
|
||||
if self.compilation_config.cudagraph_mode is None:
|
||||
if envs.VLLM_USE_V1 and self.compilation_config.level \
|
||||
== CompilationLevel.PIECEWISE:
|
||||
# default to full and piecewise for most models
|
||||
self.compilation_config.cudagraph_mode = \
|
||||
SUPAGraphMode.FULL_AND_PIECEWISE
|
||||
|
||||
# pooling models and encoder-decoder models
|
||||
# do not support full cudagraphs
|
||||
if self.model_config is not None and \
|
||||
(self.model_config.pooler_config is not None
|
||||
or self.model_config.is_encoder_decoder):
|
||||
self.compilation_config.cudagraph_mode = \
|
||||
SUPAGraphMode.PIECEWISE
|
||||
else:
|
||||
self.compilation_config.cudagraph_mode = SUPAGraphMode.NONE
|
||||
|
||||
# disable cudagraph when enforce eager execution
|
||||
if self.model_config is not None and \
|
||||
self.model_config.enforce_eager:
|
||||
logger.info("Cudagraph is disabled under eager mode")
|
||||
self.compilation_config.cudagraph_mode = SUPAGraphMode.NONE
|
||||
elif envs.VLLM_USE_V1:
|
||||
self.compilation_config.cudagraph_num_of_warmups = 1
|
||||
|
||||
self._set_cudagraph_sizes()
|
||||
else:
|
||||
self.compilation_config.cudagraph_mode = SUPAGraphMode.NONE
|
||||
|
||||
if self.cache_config.kv_sharing_fast_prefill:
|
||||
|
||||
if self.speculative_config is not None and \
|
||||
self.speculative_config.use_eagle():
|
||||
raise NotImplementedError(
|
||||
"Fast prefill optimization for KV sharing is not "
|
||||
"compatible with EAGLE as EAGLE requires correct logits "
|
||||
"for all tokens while fast prefill gives incorrect logits "
|
||||
"for prompt tokens.")
|
||||
|
||||
logger.warning_once(
|
||||
"--kv-sharing-fast-prefill requires changes on model side for "
|
||||
"correctness and to realize prefill savings. ")
|
||||
|
||||
disable_chunked_prefill_reasons: list[str] = []
|
||||
|
||||
if self.model_config:
|
||||
if self.model_config.pooler_config:
|
||||
pooling_type = self.model_config.pooler_config.pooling_type
|
||||
if pooling_type is None or pooling_type.lower() != "last":
|
||||
disable_chunked_prefill_reasons.append(
|
||||
"Only \"last\" pooling supports chunked "
|
||||
"prefill and prefix caching; disabling both.")
|
||||
if not getattr(self.model_config.hf_config, "is_causal", True):
|
||||
disable_chunked_prefill_reasons.append(
|
||||
"Only models using causal attention supports chunked "
|
||||
"prefill and prefix caching; disabling both.")
|
||||
elif self.model_config.is_encoder_decoder:
|
||||
self.scheduler_config.max_num_encoder_input_tokens = \
|
||||
MULTIMODAL_REGISTRY.get_encdec_max_encoder_len(self.model_config)
|
||||
logger.debug(
|
||||
"Encoder-decoder model detected: setting "
|
||||
"`max_num_encoder_input_tokens` to encoder length (%s)",
|
||||
self.scheduler_config.max_num_encoder_input_tokens)
|
||||
self.scheduler_config.disable_chunked_mm_input = True
|
||||
disable_chunked_prefill_reasons.append(
|
||||
"Encoder-decoder models do not support chunked prefill nor"
|
||||
" prefix caching; disabling both.")
|
||||
if (self.model_config.architecture
|
||||
== "WhisperForConditionalGeneration" and
|
||||
os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn"):
|
||||
logger.warning("Whisper is known to have issues with "
|
||||
"forked workers. If startup is hanging, "
|
||||
"try setting 'VLLM_WORKER_MULTIPROC_METHOD' "
|
||||
"to 'spawn'.")
|
||||
|
||||
if disable_chunked_prefill_reasons:
|
||||
for reason in disable_chunked_prefill_reasons:
|
||||
logger.info(reason)
|
||||
self.scheduler_config.chunked_prefill_enabled = False
|
||||
self.scheduler_config.long_prefill_token_threshold = 0
|
||||
|
||||
if self.cache_config is not None:
|
||||
self.cache_config.enable_prefix_caching = False
|
||||
|
||||
if (self.kv_events_config is not None
|
||||
and self.kv_events_config.enable_kv_cache_events
|
||||
and not self.cache_config.enable_prefix_caching):
|
||||
logger.warning(
|
||||
"KV cache events are on, but prefix caching is not enabled."
|
||||
"Use --enable-prefix-caching to enable.")
|
||||
if (self.kv_events_config is not None
|
||||
and self.kv_events_config.publisher != "null"
|
||||
and not self.kv_events_config.enable_kv_cache_events):
|
||||
logger.warning("KV cache events are disabled,"
|
||||
"but the scheduler is configured to publish them."
|
||||
"Modify KVEventsConfig.enable_kv_cache_events"
|
||||
"to True to enable.")
|
||||
current_platform.check_and_update_config(self)
|
||||
|
||||
# final check of cudagraph mode after platform-specific update
|
||||
if envs.VLLM_USE_V1 and current_platform.is_cuda_alike():
|
||||
if self.compilation_config.cudagraph_mode == SUPAGraphMode.FULL \
|
||||
and self.model_config is not None and \
|
||||
not self.model_config.disable_cascade_attn:
|
||||
logger.info("SUPAGraphMode.FULL is not supported with "
|
||||
"cascade attention currently. Disabling cascade"
|
||||
"attention.")
|
||||
self.model_config.disable_cascade_attn = True
|
||||
|
||||
if self.compilation_config.cudagraph_mode\
|
||||
.requires_piecewise_compilation():
|
||||
assert self.compilation_config.level == \
|
||||
CompilationLevel.PIECEWISE, \
|
||||
"Compilation level should be CompilationLevel.PIECEWISE "\
|
||||
"when cudagraph_mode piecewise cudagraphs is used, "\
|
||||
f"cudagraph_mode={self.compilation_config.cudagraph_mode}"
|
||||
|
||||
if self.parallel_config.enable_dbo:
|
||||
a2a_backend = envs.VLLM_ALL2ALL_BACKEND
|
||||
assert a2a_backend in \
|
||||
["deepep_low_latency", "deepep_high_throughput"], \
|
||||
"Microbatching currently only supports the deepep_low_latency and "\
|
||||
f"deepep_high_throughput all2all backend. {a2a_backend} is not "\
|
||||
"supported. To fix set the VLLM_ALL2ALL_BACKEND environment "\
|
||||
"variable to deepep_low_latency or deepep_high_throughput and "\
|
||||
"install the DeepEP kernels."
|
||||
|
||||
if not self.instance_id:
|
||||
self.instance_id = random_uuid()[:5]
|
||||
|
||||
# Do this after all the updates to compilation_config.level
|
||||
if envs.VLLM_USE_V1 and \
|
||||
self.compilation_config.level == CompilationLevel.PIECEWISE:
|
||||
self.compilation_config.set_splitting_ops_for_v1()
|
||||
|
||||
if (envs.VLLM_USE_V1
|
||||
and not self.scheduler_config.disable_hybrid_kv_cache_manager):
|
||||
# logger should only print warning message for hybrid models. As we
|
||||
# can't know whether the model is hybrid or not now, so we don't log
|
||||
# warning message here and will log it later.
|
||||
if not current_platform.support_hybrid_kv_cache():
|
||||
# Hybrid KV cache manager is not supported on non-GPU platforms.
|
||||
self.scheduler_config.disable_hybrid_kv_cache_manager = True
|
||||
if self.kv_transfer_config is not None:
|
||||
# Hybrid KV cache manager is not compatible with KV transfer.
|
||||
self.scheduler_config.disable_hybrid_kv_cache_manager = True
|
||||
if self.kv_events_config is not None:
|
||||
# Hybrid KV cache manager is not compatible with KV events.
|
||||
self.scheduler_config.disable_hybrid_kv_cache_manager = True
|
||||
if self.model_config is not None and \
|
||||
self.model_config.attention_chunk_size is not None:
|
||||
if self.speculative_config is not None and \
|
||||
self.speculative_config.use_eagle():
|
||||
# Hybrid KV cache manager is not yet supported with chunked
|
||||
# local attention + eagle.
|
||||
self.scheduler_config.disable_hybrid_kv_cache_manager = True
|
||||
elif \
|
||||
not envs.VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE:
|
||||
logger.warning(
|
||||
"There is a latency regression when using chunked local"
|
||||
" attention with the hybrid KV cache manager. Disabling"
|
||||
" it, by default. To enable it, set the environment "
|
||||
"VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE=1.")
|
||||
# Hybrid KV cache manager is not yet supported with chunked
|
||||
# local attention.
|
||||
self.scheduler_config.disable_hybrid_kv_cache_manager = True
|
||||
|
||||
|
||||
vllm.config.VllmConfig.__post_init__ = supa_post_init
|
||||
BIN
vllm_br/config/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
vllm_br/config/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm_br/config/__pycache__/compilation.cpython-310.pyc
Normal file
BIN
vllm_br/config/__pycache__/compilation.cpython-310.pyc
Normal file
Binary file not shown.
67
vllm_br/config/compilation.py
Normal file
67
vllm_br/config/compilation.py
Normal file
@@ -0,0 +1,67 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
import enum
|
||||
|
||||
from vllm.config.compilation import CUDAGraphMode
|
||||
|
||||
|
||||
class CompilationLevel:
|
||||
# constants for the levels of the compilation process
|
||||
NO_COMPILATION = 0
|
||||
DYNAMO_AS_IS = 1
|
||||
DYNAMO_ONCE = 2
|
||||
PIECEWISE = 3
|
||||
|
||||
|
||||
class SUPAGraphMode(enum.Enum):
|
||||
""" Constants for the supagraph mode in CompilationConfig.
|
||||
Meanwhile, the subset enum `NONE`, `PIECEWISE` and `FULL` are also
|
||||
treated as concrete runtime mode for supagraph runtime dispatching.
|
||||
"""
|
||||
NONE = 0
|
||||
PIECEWISE = 1
|
||||
FULL = 2
|
||||
FULL_DECODE_ONLY = (FULL, NONE)
|
||||
FULL_AND_PIECEWISE = (FULL, PIECEWISE)
|
||||
|
||||
def decode_mode(self) -> 'SUPAGraphMode':
|
||||
return SUPAGraphMode(self.value[0]) if \
|
||||
self.separate_routine() else self
|
||||
|
||||
def mixed_mode(self) -> 'SUPAGraphMode':
|
||||
return SUPAGraphMode(self.value[1]) if \
|
||||
self.separate_routine() else self
|
||||
|
||||
def requires_piecewise_compilation(self) -> bool:
|
||||
return (self.decode_mode() == SUPAGraphMode.PIECEWISE
|
||||
or self.mixed_mode() == SUPAGraphMode.PIECEWISE)
|
||||
|
||||
def max_supagraph_mode(self) -> 'SUPAGraphMode':
|
||||
return SUPAGraphMode(max(
|
||||
self.value)) if self.separate_routine() else self
|
||||
|
||||
def has_full_supagraphs(self) -> bool:
|
||||
return self.max_supagraph_mode() == SUPAGraphMode.FULL
|
||||
|
||||
# ychun, trick for CUDAGraphMode
|
||||
def has_full_cudagraphs(self) -> bool:
|
||||
cuda_graph_mode = CUDAGraphMode(max(
|
||||
self.value)) if self.separate_routine() else self
|
||||
return cuda_graph_mode == CUDAGraphMode.FULL
|
||||
|
||||
def separate_routine(self) -> bool:
|
||||
return isinstance(self.value, tuple)
|
||||
17
vllm_br/distributed/__init__.py
Normal file
17
vllm_br/distributed/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
from . import device_communicators # noqa: F401
|
||||
from . import kv_transfer # noqa: F401
|
||||
BIN
vllm_br/distributed/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
vllm_br/distributed/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm_br/distributed/__pycache__/communicator.cpython-310.pyc
Normal file
BIN
vllm_br/distributed/__pycache__/communicator.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm_br/distributed/__pycache__/parallel_state.cpython-310.pyc
Normal file
BIN
vllm_br/distributed/__pycache__/parallel_state.cpython-310.pyc
Normal file
Binary file not shown.
60
vllm_br/distributed/communicator.py
Normal file
60
vllm_br/distributed/communicator.py
Normal file
@@ -0,0 +1,60 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from vllm.distributed.device_communicators.base_device_communicator import (
|
||||
DeviceCommunicatorBase)
|
||||
from vllm.logger import logger
|
||||
from vllm_br import envs
|
||||
|
||||
|
||||
class SUPACommunicator(DeviceCommunicatorBase):
|
||||
|
||||
def __init__(self,
|
||||
cpu_group: dist.ProcessGroup,
|
||||
device: Optional[torch.device] = None,
|
||||
device_group: Optional[dist.ProcessGroup] = None,
|
||||
unique_name: str = ""):
|
||||
super().__init__(cpu_group, device, device_group, unique_name)
|
||||
self.device = torch.supa.current_device()
|
||||
|
||||
# TODO: Deprecate this method in the future if torch_br support gather
|
||||
def gather(self,
|
||||
input_: torch.Tensor,
|
||||
dst: int = 0,
|
||||
dim: int = -1) -> torch.Tensor:
|
||||
"""All gather as gather"""
|
||||
|
||||
output_tensor = self.all_gather(input_, dim)
|
||||
if self.rank_in_group == dst:
|
||||
return output_tensor
|
||||
return None
|
||||
|
||||
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
|
||||
if envs.VLLM_BR_USE_FP32_ALL_REDUCE and input_ is not None and input_.dtype == torch.bfloat16:
|
||||
logger.debug(
|
||||
'[Patch] patch all_reduce: use fp32 all_reduce when env VLLM_BR_USE_FP32_ALL_REDUCE is set'
|
||||
)
|
||||
input_ = input_.to(torch.float32)
|
||||
dist.all_reduce(input_, group=self.device_group)
|
||||
input_ = input_.to(torch.bfloat16)
|
||||
else:
|
||||
dist.all_reduce(input_, group=self.device_group)
|
||||
return input_
|
||||
18
vllm_br/distributed/device_communicators/__init__.py
Normal file
18
vllm_br/distributed/device_communicators/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
from . import base_device_communicator # noqa: F401
|
||||
from . import pysccl_wrapper # noqa: F401
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,44 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
import torch
|
||||
|
||||
import vllm
|
||||
|
||||
|
||||
def supa_prepare_communication_buffer_for_model(
|
||||
self, model: torch.nn.Module) -> None:
|
||||
"""
|
||||
Prepare the communication buffer for the model.
|
||||
"""
|
||||
if not self.use_all2all:
|
||||
return
|
||||
|
||||
if not self.is_ep_communicator:
|
||||
return
|
||||
|
||||
moe_modules = [
|
||||
module for module in model.modules()
|
||||
# TODO(bnell): Should use isinstance but can't. Maybe search for
|
||||
# presence of quant_method.init_prepare_finalize?
|
||||
if (module.__class__.__name__ == "FusedMoE"
|
||||
or module.__class__.__name__ == "SharedFusedMoE")
|
||||
]
|
||||
for module in moe_modules:
|
||||
module.quant_method.init_prepare_finalize(module)
|
||||
|
||||
|
||||
vllm.distributed.device_communicators.base_device_communicator.DeviceCommunicatorBase.prepare_communication_buffer_for_model = supa_prepare_communication_buffer_for_model
|
||||
420
vllm_br/distributed/device_communicators/pysccl_wrapper.py
Normal file
420
vllm_br/distributed/device_communicators/pysccl_wrapper.py
Normal file
@@ -0,0 +1,420 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# This file is a pure Python wrapper for the SCCL library.
|
||||
# The main purpose is to use SCCL combined with CUDA graph.
|
||||
# Before writing this script, we tried the following approach:
|
||||
# 1. We tried to use `cupy`, it calls SCCL correctly, but `cupy` itself
|
||||
# often gets stuck when initializing the SCCL communicator.
|
||||
# 2. We tried to use `torch.distributed`, but `torch.distributed.all_reduce`
|
||||
# contains many other potential cuda APIs, that are not allowed during
|
||||
# capturing the CUDA graph. For further details, please check
|
||||
# https://discuss.pytorch.org/t/pytorch-cudagraph-with-nccl-operation-failed/ .
|
||||
#
|
||||
# Another rejected idea is to write a C/C++ binding for SCCL. It is usually
|
||||
# doable, but we often encounter issues related with succl versions, and need
|
||||
# to switch between different versions of SCCL. See
|
||||
# https://github.com/NVIDIA/nccl/issues/1234 for more details.
|
||||
# A C/C++ binding is not flexible enough to handle this. It requires
|
||||
# recompilation of the code every time we want to switch between different
|
||||
# versions. This current implementation, with a **pure** Python wrapper, is
|
||||
# more flexible. We can easily switch between different versions of SCCL by
|
||||
# changing the environment variable `VLLM_SCCL_SO_PATH`, or the `so_file`
|
||||
# variable in the code.
|
||||
|
||||
import ctypes
|
||||
import platform
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from torch.distributed import ReduceOp
|
||||
|
||||
from vllm.logger import logger
|
||||
from vllm_br import envs
|
||||
|
||||
# === export types and functions from nccl to Python ===
|
||||
# for the original nccl definition, please check
|
||||
# https://github.com/NVIDIA/nccl/blob/master/src/nccl.h.in
|
||||
|
||||
succlResult_t = ctypes.c_int
|
||||
succlComm_t = ctypes.c_void_p
|
||||
|
||||
|
||||
class succlUniqueId(ctypes.Structure):
|
||||
_fields_ = [("internal", ctypes.c_byte * 128)]
|
||||
|
||||
|
||||
suStream_t = ctypes.c_void_p
|
||||
buffer_type = ctypes.c_void_p
|
||||
|
||||
succlDataType_t = ctypes.c_int
|
||||
|
||||
|
||||
class succlDataTypeEnum:
|
||||
succlInt8 = 0
|
||||
succlChar = 0
|
||||
succlUint8 = 1
|
||||
succlInt16 = 2
|
||||
succlUint16 = 3
|
||||
succlInt32 = 4
|
||||
succlInt = 4
|
||||
succlUint32 = 5
|
||||
succlInt64 = 6
|
||||
succlUint64 = 7
|
||||
succlBfloat16 = 8
|
||||
succlFloat32 = 9
|
||||
succlFloat = 9
|
||||
succlFloat64 = 10
|
||||
succlDouble = 10
|
||||
succlNumTypes = 11
|
||||
|
||||
@classmethod
|
||||
def from_torch(cls, dtype: torch.dtype) -> int:
|
||||
if dtype == torch.int8:
|
||||
return cls.succlInt8
|
||||
if dtype == torch.uint8:
|
||||
return cls.succlUint8
|
||||
if dtype == torch.int32:
|
||||
return cls.succlInt32
|
||||
if dtype == torch.int64:
|
||||
return cls.succlInt64
|
||||
if dtype == torch.float16:
|
||||
return cls.succlBfloat16
|
||||
if dtype == torch.float32:
|
||||
return cls.succlFloat32
|
||||
if dtype == torch.float64:
|
||||
return cls.succlFloat64
|
||||
if dtype == torch.bfloat16:
|
||||
return cls.succlBfloat16
|
||||
raise ValueError(f"Unsupported dtype: {dtype}")
|
||||
|
||||
|
||||
succlRedOp_t = ctypes.c_int
|
||||
|
||||
|
||||
class succlRedOpTypeEnum:
|
||||
succlSum = 0
|
||||
succlProd = 1
|
||||
succlMax = 2
|
||||
succlMin = 3
|
||||
succlAvg = 4
|
||||
succlNumOps = 5
|
||||
|
||||
@classmethod
|
||||
def from_torch(cls, op: ReduceOp) -> int:
|
||||
if op == ReduceOp.SUM:
|
||||
return cls.succlSum
|
||||
if op == ReduceOp.PRODUCT:
|
||||
return cls.succlProd
|
||||
if op == ReduceOp.MAX:
|
||||
return cls.succlMax
|
||||
if op == ReduceOp.MIN:
|
||||
return cls.succlMin
|
||||
if op == ReduceOp.AVG:
|
||||
return cls.succlAvg
|
||||
raise ValueError(f"Unsupported op: {op}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class Function:
|
||||
name: str
|
||||
restype: Any
|
||||
argtypes: list[Any]
|
||||
|
||||
|
||||
class SCCLLibrary:
|
||||
exported_functions = [
|
||||
# const char* succlGetErrorString(succlResult_t result)
|
||||
Function("succlGetErrorString", ctypes.c_char_p, [succlResult_t]),
|
||||
# succlResult_t succlGetVersion(int *version);
|
||||
Function("succlGetVersion", succlResult_t,
|
||||
[ctypes.POINTER(ctypes.c_int)]),
|
||||
# succlResult_t succlGetUniqueId(succlUniqueId* uniqueId);
|
||||
Function("succlGetUniqueId", succlResult_t,
|
||||
[ctypes.POINTER(succlUniqueId)]),
|
||||
# succlResult_t succlCommInitRank(
|
||||
# succlComm_t* comm, int nranks, succlUniqueId commId, int rank);
|
||||
# note that succlComm_t is a pointer type, so the first argument
|
||||
# is a pointer to a pointer
|
||||
Function("succlCommInitRank", succlResult_t, [
|
||||
ctypes.POINTER(succlComm_t), ctypes.c_int, succlUniqueId,
|
||||
ctypes.c_int, ctypes.c_void_p
|
||||
]),
|
||||
# succlResult_t succlAllReduce(
|
||||
# const void* sendbuff, void* recvbuff, size_t count,
|
||||
# succlDataType_t datatype, succlRedOp_t op, succlComm_t comm,
|
||||
# suStream_t stream);
|
||||
# note that suStream_t is a pointer type, so the last argument
|
||||
# is a pointer
|
||||
Function("succlAllReduce", succlResult_t, [
|
||||
buffer_type, buffer_type, ctypes.c_size_t, succlDataType_t,
|
||||
succlRedOp_t, succlComm_t, suStream_t, ctypes.c_void_p
|
||||
]),
|
||||
|
||||
# succlResult_t succlReduce(
|
||||
# const void* sendbuff, void* recvbuff, size_t count,
|
||||
# succlDataType_t datatype, succlRedOp_t op, int root,
|
||||
# succlComm_t comm, suStream_t stream);
|
||||
# note that suStream_t is a pointer type, so the last argument
|
||||
# is a pointer
|
||||
Function("succlReduce", succlResult_t, [
|
||||
buffer_type, buffer_type, ctypes.c_size_t, succlDataType_t,
|
||||
succlRedOp_t, ctypes.c_int, succlComm_t, suStream_t,
|
||||
ctypes.c_void_p
|
||||
]),
|
||||
|
||||
# succlResult_t succlAllGather(
|
||||
# const void* sendbuff, void* recvbuff, size_t count,
|
||||
# succlDataType_t datatype, succlComm_t comm,
|
||||
# suStream_t stream);
|
||||
# note that suStream_t is a pointer type, so the last argument
|
||||
# is a pointer
|
||||
Function("succlAllGather", succlResult_t, [
|
||||
buffer_type, buffer_type, ctypes.c_size_t, succlDataType_t,
|
||||
succlComm_t, suStream_t, ctypes.c_void_p
|
||||
]),
|
||||
|
||||
# succlResult_t succlReduceScatter(
|
||||
# const void* sendbuff, void* recvbuff, size_t count,
|
||||
# succlDataType_t datatype, succlRedOp_t op, succlComm_t comm,
|
||||
# suStream_t stream);
|
||||
# note that suStream_t is a pointer type, so the last argument
|
||||
# is a pointer
|
||||
Function("succlReduceScatter", succlResult_t, [
|
||||
buffer_type, buffer_type, ctypes.c_size_t, succlDataType_t,
|
||||
succlRedOp_t, succlComm_t, suStream_t, ctypes.c_void_p
|
||||
]),
|
||||
|
||||
# succlResult_t succlSend(
|
||||
# const void* sendbuff, size_t count, succlDataType_t datatype,
|
||||
# int dest, succlComm_t comm, suStream_t stream);
|
||||
Function("succlSend", succlResult_t, [
|
||||
buffer_type, ctypes.c_size_t, succlDataType_t, ctypes.c_int,
|
||||
succlComm_t, suStream_t, ctypes.c_void_p
|
||||
]),
|
||||
|
||||
# succlResult_t succlRecv(
|
||||
# void* recvbuff, size_t count, succlDataType_t datatype,
|
||||
# int src, succlComm_t comm, suStream_t stream);
|
||||
Function("succlRecv", succlResult_t, [
|
||||
buffer_type, ctypes.c_size_t, succlDataType_t, ctypes.c_int,
|
||||
succlComm_t, suStream_t, ctypes.c_void_p
|
||||
]),
|
||||
|
||||
# succlResult_t succlBroadcast(
|
||||
# const void* sendbuff, void* recvbuff, size_t count,
|
||||
# succlDataType_t datatype, int root, succlComm_t comm,
|
||||
# suStream_t stream);
|
||||
Function("succlBroadcast", succlResult_t, [
|
||||
buffer_type, buffer_type, ctypes.c_size_t, succlDataType_t,
|
||||
ctypes.c_int, succlComm_t, suStream_t, ctypes.c_void_p
|
||||
]),
|
||||
|
||||
# be cautious! this is a collective call, it will block until all
|
||||
# processes in the communicator have called this function.
|
||||
# because Python object destruction can happen in random order,
|
||||
# it is better not to call it at all.
|
||||
# succlResult_t succlCommDestroy(succlComm_t comm);
|
||||
Function("succlCommDestroy", succlResult_t, [succlComm_t]),
|
||||
# succlResult_t succlGroupStart();
|
||||
Function("succlGroupStart", succlResult_t, []),
|
||||
# succlResult_t succlGroupEnd();
|
||||
Function("succlGroupEnd", succlResult_t, []),
|
||||
# Function("succldemoSetdevice", succlResult_t, [ctypes.c_int]),
|
||||
]
|
||||
|
||||
# class attribute to store the mapping from the path to the library
|
||||
# to avoid loading the same library multiple times
|
||||
path_to_library_cache: dict[str, Any] = {}
|
||||
|
||||
# class attribute to store the mapping from library path
|
||||
# to the corresponding dictionary
|
||||
path_to_dict_mapping: dict[str, dict[str, Any]] = {}
|
||||
|
||||
def __init__(self, so_file: Optional[str] = None):
|
||||
|
||||
so_file = so_file or find_sccl_library()
|
||||
try:
|
||||
if so_file not in SCCLLibrary.path_to_dict_mapping:
|
||||
lib = ctypes.CDLL(so_file)
|
||||
SCCLLibrary.path_to_library_cache[so_file] = lib
|
||||
self.lib = SCCLLibrary.path_to_library_cache[so_file]
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to load SCCL library from %s. "
|
||||
"It is expected if you are not running on NVIDIA/AMD GPUs."
|
||||
"Otherwise, the sccl library might not exist, be corrupted "
|
||||
"or it does not support the current platform %s. "
|
||||
"If you already have the library, please set the "
|
||||
"environment variable VLLM_SCCL_SO_PATH"
|
||||
" to point to the correct sccl library path.", so_file,
|
||||
platform.platform())
|
||||
raise e
|
||||
|
||||
if so_file not in SCCLLibrary.path_to_dict_mapping:
|
||||
_funcs: dict[str, Any] = {}
|
||||
for func in SCCLLibrary.exported_functions:
|
||||
f = getattr(self.lib, func.name)
|
||||
f.restype = func.restype
|
||||
f.argtypes = func.argtypes
|
||||
_funcs[func.name] = f
|
||||
SCCLLibrary.path_to_dict_mapping[so_file] = _funcs
|
||||
self._funcs = SCCLLibrary.path_to_dict_mapping[so_file]
|
||||
|
||||
def succlGetErrorString(self, result: succlResult_t) -> str:
|
||||
return self._funcs["succlGetErrorString"](result).decode("utf-8")
|
||||
|
||||
def SUCCL_CHECK(self, result: succlResult_t) -> None:
|
||||
if result != 0:
|
||||
error_str = self.succlGetErrorString(result)
|
||||
raise RuntimeError(f"SCCL error: {error_str}")
|
||||
|
||||
def succlGetVersion(self) -> str:
|
||||
version = ctypes.c_int()
|
||||
self.SUCCL_CHECK(self._funcs["succlGetVersion"](ctypes.byref(version)))
|
||||
version_str = str(version.value)
|
||||
# something like 21903 --> "2.19.3"
|
||||
major = version_str[0].lstrip("0")
|
||||
minor = version_str[1:3].lstrip("0")
|
||||
patch = version_str[3:].lstrip("0")
|
||||
return f"{major}.{minor}.{patch}"
|
||||
|
||||
def succlGetUniqueId(self) -> succlUniqueId:
|
||||
unique_id = succlUniqueId()
|
||||
self.SUCCL_CHECK(self._funcs["succlGetUniqueId"](
|
||||
ctypes.byref(unique_id)))
|
||||
return unique_id
|
||||
|
||||
def unique_id_from_bytes(self, data: bytes) -> succlUniqueId:
|
||||
if len(data) != 128:
|
||||
raise ValueError(
|
||||
f"Expected 128 bytes for succlUniqueId, got {len(data)} bytes")
|
||||
unique_id = succlUniqueId()
|
||||
ctypes.memmove(ctypes.addressof(unique_id.internal), data, 128)
|
||||
return unique_id
|
||||
|
||||
def succlCommInitRank(self, world_size: int, unique_id: succlUniqueId,
|
||||
rank: int) -> succlComm_t:
|
||||
comm = succlComm_t()
|
||||
result = self._funcs["succlCommInitRank"](ctypes.byref(comm),
|
||||
world_size, unique_id, rank,
|
||||
None)
|
||||
self.SUCCL_CHECK(result)
|
||||
return comm
|
||||
|
||||
# def succldemoSetdevice(self, deviceid:int):
|
||||
# self._funcs["succldemoSetdevice"](deviceid)
|
||||
|
||||
def succlAllReduce(self, sendbuff: buffer_type, recvbuff: buffer_type,
|
||||
count: int, datatype: int, op: int, comm: succlComm_t,
|
||||
stream: suStream_t) -> None:
|
||||
# `datatype` actually should be `succlDataType_t`
|
||||
# and `op` should be `succlRedOp_t`
|
||||
# both are aliases of `ctypes.c_int`
|
||||
# when we pass int to a function, it will be converted to `ctypes.c_int`
|
||||
# by ctypes automatically
|
||||
self.SUCCL_CHECK(self._funcs["succlAllReduce"](sendbuff, recvbuff,
|
||||
count, datatype, op,
|
||||
comm, stream, None))
|
||||
|
||||
def succlReduce(self, sendbuff: buffer_type, recvbuff: buffer_type,
|
||||
count: int, datatype: int, op: int, root: int,
|
||||
comm: succlComm_t, stream: suStream_t) -> None:
|
||||
# `datatype` actually should be `succlDataType_t`
|
||||
# and `op` should be `succlRedOp_t`
|
||||
# both are aliases of `ctypes.c_int`
|
||||
# when we pass int to a function, it will be converted to `ctypes.c_int`
|
||||
# by ctypes automatically
|
||||
self.SUCCL_CHECK(self._funcs["succlReduce"](sendbuff, recvbuff, count,
|
||||
datatype, op, root, comm,
|
||||
stream, None))
|
||||
|
||||
def succlReduceScatter(self, sendbuff: buffer_type, recvbuff: buffer_type,
|
||||
count: int, datatype: int, op: int,
|
||||
comm: succlComm_t, stream: suStream_t) -> None:
|
||||
# `datatype` actually should be `succlDataType_t`
|
||||
# and `op` should be `succlRedOp_t`
|
||||
# both are aliases of `ctypes.c_int`
|
||||
# when we pass int to a function, it will be converted to `ctypes.c_int`
|
||||
# by ctypes automatically
|
||||
self.SUCCL_CHECK(self._funcs["succlReduceScatter"](sendbuff, recvbuff,
|
||||
count, datatype, op,
|
||||
comm, stream, None))
|
||||
|
||||
def succlAllGather(self, sendbuff: buffer_type, recvbuff: buffer_type,
|
||||
count: int, datatype: int, comm: succlComm_t,
|
||||
stream: suStream_t) -> None:
|
||||
# `datatype` actually should be `succlDataType_t`
|
||||
# which is an aliases of `ctypes.c_int`
|
||||
# when we pass int to a function, it will be converted to `ctypes.c_int`
|
||||
# by ctypes automatically
|
||||
self.SUCCL_CHECK(self._funcs["succlAllGather"](sendbuff, recvbuff,
|
||||
count, datatype, comm,
|
||||
stream, None))
|
||||
|
||||
def succlSend(self, sendbuff: buffer_type, count: int, datatype: int,
|
||||
dest: int, comm: succlComm_t, stream: suStream_t) -> None:
|
||||
self.SUCCL_CHECK(self._funcs["succlSend"](sendbuff, count, datatype,
|
||||
dest, comm, stream, None))
|
||||
|
||||
def succlRecv(self, recvbuff: buffer_type, count: int, datatype: int,
|
||||
src: int, comm: succlComm_t, stream: suStream_t) -> None:
|
||||
self.SUCCL_CHECK(self._funcs["succlRecv"](recvbuff, count, datatype,
|
||||
src, comm, stream, None))
|
||||
|
||||
def succlBroadcast(self, sendbuff: buffer_type, recvbuff: buffer_type,
|
||||
count: int, datatype: int, root: int, comm: succlComm_t,
|
||||
stream: suStream_t) -> None:
|
||||
self.SUCCL_CHECK(self._funcs["succlBroadcast"](sendbuff, recvbuff,
|
||||
count, datatype, root,
|
||||
comm, stream, None))
|
||||
|
||||
def succlCommDestroy(self, comm: succlComm_t) -> None:
|
||||
self.SUCCL_CHECK(self._funcs["succlCommDestroy"](comm))
|
||||
|
||||
def succlGroupStart(self) -> None:
|
||||
self.SUCCL_CHECK(self._funcs["succlGroupStart"]())
|
||||
|
||||
def succlGroupEnd(self) -> None:
|
||||
self.SUCCL_CHECK(self._funcs["succlGroupEnd"]())
|
||||
|
||||
|
||||
def find_sccl_library() -> str:
|
||||
"""
|
||||
We either use the library file specified by the `VLLM_SCCL_SO_PATH`
|
||||
environment variable, or we find the library file brought by PyTorch.
|
||||
After importing `torch`, `libsuccl.so.2` or `librccl.so.1` can be
|
||||
found by `ctypes` automatically.
|
||||
"""
|
||||
so_file = envs.VLLM_SCCL_SO_PATH
|
||||
# manually load the sccl library
|
||||
if so_file:
|
||||
logger.info(
|
||||
"Found sccl from environment variable VLLM_SCCL_SO_PATH=%s",
|
||||
so_file)
|
||||
else:
|
||||
raise ValueError("SCCL lib file not found.")
|
||||
return so_file
|
||||
|
||||
|
||||
__all__ = [
|
||||
"SCCLLibrary", "succlDataTypeEnum", "succlRedOpTypeEnum", "succlUniqueId",
|
||||
"succlComm_t", "suStream_t", "buffer_type"
|
||||
]
|
||||
17
vllm_br/distributed/kv_transfer/__init__.py
Normal file
17
vllm_br/distributed/kv_transfer/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
from . import kv_connector # noqa: F401
|
||||
Binary file not shown.
17
vllm_br/distributed/kv_transfer/kv_connector/__init__.py
Normal file
17
vllm_br/distributed/kv_transfer/kv_connector/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
from . import v1 # noqa: F401
|
||||
Binary file not shown.
17
vllm_br/distributed/kv_transfer/kv_connector/v1/__init__.py
Normal file
17
vllm_br/distributed/kv_transfer/kv_connector/v1/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
from . import base, p2p # noqa: F401
|
||||
Binary file not shown.
Binary file not shown.
28
vllm_br/distributed/kv_transfer/kv_connector/v1/base.py
Normal file
28
vllm_br/distributed/kv_transfer/kv_connector/v1/base.py
Normal file
@@ -0,0 +1,28 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
# from vllm.logger import logger
|
||||
# from vllm.v1.core.sched.output import SchedulerOutput
|
||||
# from vllm.v1.outputs import KVConnectorOutput
|
||||
|
||||
# class KVConnectorRole(enum.Enum):
|
||||
# # Connector running in the scheduler process
|
||||
# SCHEDULER = 0
|
||||
|
||||
# # Connector running in the worker process
|
||||
# WORKER = 1
|
||||
|
||||
# vllm.distributed.kv_transfer.kv_connector.v1.base.KVConnectorRole=KVConnectorRole
|
||||
@@ -0,0 +1,17 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
from . import p2p_succl_engine # noqa: F401
|
||||
from . import p2p_succl_connector, tensor_memory_pool # noqa: F401
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,535 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import regex as re
|
||||
import torch
|
||||
import torch_br
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.factory import (
|
||||
KVConnectorFactory)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
|
||||
from vllm.distributed.parallel_state import get_world_group
|
||||
from vllm.logger import logger
|
||||
from vllm.v1.attention.backends.mla.common import MLACommonMetadata
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm_br.distributed.kv_transfer.kv_connector.v1.p2p.p2p_succl_engine import (
|
||||
P2pSucclEngine)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.request import Request
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReqMeta:
|
||||
# Request Id
|
||||
request_id: str
|
||||
# Request block ids
|
||||
block_ids: torch.Tensor
|
||||
# Request num tokens
|
||||
num_tokens: int
|
||||
|
||||
@staticmethod
|
||||
def make_meta(request_id: str, token_ids: list[int], block_ids: list[int],
|
||||
block_size: int) -> "ReqMeta":
|
||||
block_ids_tensor = torch.tensor(block_ids)
|
||||
return ReqMeta(
|
||||
request_id=request_id,
|
||||
block_ids=block_ids_tensor,
|
||||
num_tokens=len(token_ids),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class P2pSucclConnectorMetadata(KVConnectorMetadata):
|
||||
requests: list[ReqMeta]
|
||||
|
||||
def __init__(self):
|
||||
self.requests = []
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
request_id: str,
|
||||
token_ids: list[int],
|
||||
block_ids: list[int],
|
||||
block_size: int,
|
||||
) -> None:
|
||||
self.requests.append(
|
||||
ReqMeta.make_meta(request_id, token_ids, block_ids, block_size))
|
||||
|
||||
|
||||
class P2pSucclConnector(KVConnectorBase_V1):
|
||||
|
||||
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
|
||||
super().__init__(vllm_config=vllm_config, role=role)
|
||||
self._block_size = vllm_config.cache_config.block_size
|
||||
self._requests_need_load: dict[str, Any] = {}
|
||||
self.config = vllm_config.kv_transfer_config
|
||||
self.is_producer = self.config.is_kv_producer
|
||||
self.chunked_prefill: dict[str, Any] = {}
|
||||
|
||||
self._rank = get_world_group().rank \
|
||||
if role == KVConnectorRole.WORKER else 0
|
||||
self._local_rank = get_world_group().local_rank \
|
||||
if role == KVConnectorRole.WORKER else 0
|
||||
self.p2p_nccl_engine = P2pSucclEngine(
|
||||
local_rank=self._local_rank,
|
||||
config=self.config,
|
||||
hostname="",
|
||||
port_offset=self._rank,
|
||||
) if role == KVConnectorRole.WORKER else None
|
||||
|
||||
# ==============================
|
||||
# Worker-side methods
|
||||
# ==============================
|
||||
|
||||
def start_load_kv(self, forward_context: "ForwardContext",
|
||||
**kwargs) -> None:
|
||||
"""Start loading the KV cache from the connector buffer to vLLM's
|
||||
paged KV buffer.
|
||||
|
||||
Args:
|
||||
forward_context (ForwardContext): the forward context.
|
||||
**kwargs: additional arguments for the load operation
|
||||
|
||||
Note:
|
||||
The number of elements in kv_caches and layer_names should be
|
||||
the same.
|
||||
"""
|
||||
|
||||
# Only consumer/decode loads KV Cache
|
||||
if self.is_producer:
|
||||
return
|
||||
assert self.p2p_nccl_engine is not None
|
||||
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
if attn_metadata is None:
|
||||
return
|
||||
|
||||
def inject_kv_into_layer(
|
||||
layer: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
block_ids: torch.Tensor,
|
||||
request_id: str,
|
||||
) -> None:
|
||||
"""
|
||||
Inject KV cache data into a given attention layer tensor.
|
||||
|
||||
This function updates `layer` in-place with values from `kv_cache`,
|
||||
handling different backend layouts:
|
||||
- MLA (Multi-Linear Attention) or FlashInfer: KV tensors are
|
||||
indexed along the first dimension.
|
||||
- FlashAttention: KV tensors are indexed along the second
|
||||
dimension.
|
||||
|
||||
If the number of provided block IDs does not match the number of KV
|
||||
blocks, only the overlapping portion is updated, and a warning is
|
||||
logged.
|
||||
|
||||
Args:
|
||||
layer (torch.Tensor): The attention layer KV tensor to update.
|
||||
kv_cache (torch.Tensor): The KV cache tensor to inject.
|
||||
block_ids (torch.Tensor): Indices of the blocks to update.
|
||||
request_id (str): Request identifier used for logging.
|
||||
|
||||
Returns:
|
||||
None. The function modifies `layer` in-place.
|
||||
"""
|
||||
if (isinstance(attn_metadata, MLACommonMetadata)
|
||||
or layer.shape[1] == 2): # MLA or FlashInfer
|
||||
num_block = kv_cache.shape[1]
|
||||
block_len = min(len(block_ids), num_block)
|
||||
block_ids = block_ids[:block_len]
|
||||
th_gran = layer.shape[2] // self._block_size
|
||||
for i, block_index in enumerate(block_ids.tolist()):
|
||||
dst_0 = block_index // th_gran
|
||||
dst_1 = (block_index % th_gran) * self._block_size
|
||||
layer[0][dst_0][dst_1:dst_1 +
|
||||
self._block_size] = kv_cache[0][i]
|
||||
elif layer.shape[0] == 2: # FlashAttention
|
||||
num_block = kv_cache.shape[1]
|
||||
block_len = min(len(block_ids), num_block)
|
||||
block_ids = block_ids[:block_len]
|
||||
th_gran = layer.shape[2] // self._block_size
|
||||
for i, block_index in enumerate(block_ids.tolist()):
|
||||
dst_0 = block_index // th_gran
|
||||
dst_1 = (block_index % th_gran) * self._block_size
|
||||
layer[0][dst_0][dst_1:dst_1 +
|
||||
self._block_size] = kv_cache[0][i]
|
||||
layer[1][dst_0][dst_1:dst_1 +
|
||||
self._block_size] = kv_cache[1][i]
|
||||
|
||||
# Get the metadata
|
||||
metadata: KVConnectorMetadata = \
|
||||
self._get_connector_metadata()
|
||||
assert isinstance(metadata, P2pSucclConnectorMetadata)
|
||||
|
||||
if metadata is None:
|
||||
return
|
||||
|
||||
# Load the KV for each request each layer
|
||||
for request in metadata.requests:
|
||||
request_id = request.request_id
|
||||
ip, port = self.parse_request_id(request_id, False)
|
||||
remote_address = ip + ":" + str(port + self._rank)
|
||||
for layer_name in forward_context.no_compile_layers:
|
||||
layer = forward_context.no_compile_layers[layer_name]
|
||||
|
||||
# Only process layers that have kv_cache
|
||||
# attribute (attention layers) Skip non-attention
|
||||
# layers like FusedMoE
|
||||
kv_cache = getattr(layer, 'kv_cache', None)
|
||||
if kv_cache is None:
|
||||
continue
|
||||
|
||||
layer = kv_cache[forward_context.virtual_engine]
|
||||
kv_cache = self.p2p_nccl_engine.recv_tensor(
|
||||
request.request_id + "#" + layer_name, remote_address)
|
||||
if kv_cache is None:
|
||||
logger.warning("🚧kv_cache is None, %s", request.request_id)
|
||||
continue
|
||||
|
||||
inject_kv_into_layer(layer, kv_cache, request.block_ids,
|
||||
request.request_id)
|
||||
|
||||
def wait_for_layer_load(self, layer_name: str) -> None:
|
||||
"""Blocking until the KV for a specific layer is loaded into vLLM's
|
||||
paged buffer.
|
||||
|
||||
This interface will be useful for layer-by-layer pipelining.
|
||||
|
||||
Args:
|
||||
layer_name: the name of that layer
|
||||
"""
|
||||
return
|
||||
|
||||
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
|
||||
attn_metadata: "AttentionMetadata", **kwargs) -> None:
|
||||
"""Start saving the KV cache of the layer from vLLM's paged buffer
|
||||
to the connector.
|
||||
|
||||
Args:
|
||||
layer_name (str): the name of the layer.
|
||||
kv_layer (torch.Tensor): the paged KV buffer of the current
|
||||
layer in vLLM.
|
||||
attn_metadata (AttentionMetadata): the attention metadata.
|
||||
**kwargs: additional arguments for the save operation.
|
||||
"""
|
||||
|
||||
# Only producer/prefill saves KV Cache
|
||||
if not self.is_producer:
|
||||
return
|
||||
assert self.p2p_nccl_engine is not None
|
||||
|
||||
def extract_kv_from_layer(
|
||||
layer: torch.Tensor,
|
||||
block_ids: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Extract KV cache slices from a given attention layer tensor.
|
||||
|
||||
This function handles multiple backend layouts:
|
||||
- MLA (Multi-Linear Attention) or FlashInfer: KV tensors are
|
||||
indexed along the first dimension.
|
||||
- FlashAttention: KV tensors are indexed along the second
|
||||
dimension.
|
||||
|
||||
Args:
|
||||
layer (torch.Tensor): The KV cache from the attention layer.
|
||||
block_ids (torch.Tensor): Indices of blocks to extract.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: A tensor containing the extracted KV slices.
|
||||
Returns None if the layout is unsupported.
|
||||
"""
|
||||
if (isinstance(attn_metadata, MLACommonMetadata)
|
||||
or layer.shape[1] == 2): # MLA or FlashInfer
|
||||
origin_shape = layer.shape
|
||||
shape = [
|
||||
origin_shape[0],
|
||||
len(block_ids), self._block_size, origin_shape[3]
|
||||
]
|
||||
layer_send = torch_br._empty_ut_only(shape,
|
||||
dtype=layer.dtype,
|
||||
tensor_type='BUFFER_ANY',
|
||||
device=layer.device)
|
||||
th_gran = origin_shape[2] // self._block_size
|
||||
for i, block_index in enumerate(block_ids.tolist()):
|
||||
dst_0 = block_index // th_gran
|
||||
dst_1 = (block_index % th_gran) * self._block_size
|
||||
layer_send[0][i] = layer[0][dst_0][dst_1:dst_1 +
|
||||
self._block_size]
|
||||
return layer_send
|
||||
|
||||
if layer.shape[0] == 2: # FlashAttention
|
||||
origin_shape = layer.shape
|
||||
shape = [
|
||||
origin_shape[0],
|
||||
len(block_ids), self._block_size, origin_shape[3]
|
||||
]
|
||||
layer_send = torch_br._empty_ut_only(shape,
|
||||
dtype=layer.dtype,
|
||||
tensor_type='BUFFER_ANY',
|
||||
device=layer.device)
|
||||
th_gran = origin_shape[2] // self._block_size
|
||||
for i, block_index in enumerate(block_ids.tolist()):
|
||||
dst_0 = block_index // th_gran
|
||||
dst_1 = (block_index % th_gran) * self._block_size
|
||||
layer_send[0][i] = layer[0][dst_0][dst_1:dst_1 +
|
||||
self._block_size]
|
||||
layer_send[1][i] = layer[1][dst_0][dst_1:dst_1 +
|
||||
self._block_size]
|
||||
return layer_send
|
||||
|
||||
return None
|
||||
|
||||
connector_metadata = self._get_connector_metadata()
|
||||
assert isinstance(connector_metadata, P2pSucclConnectorMetadata)
|
||||
for request in connector_metadata.requests:
|
||||
request_id = request.request_id
|
||||
ip, port = self.parse_request_id(request_id, True)
|
||||
remote_address = ip + ":" + str(port + self._rank)
|
||||
|
||||
kv_cache = extract_kv_from_layer(kv_layer, request.block_ids)
|
||||
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
|
||||
kv_cache, remote_address)
|
||||
|
||||
def wait_for_save(self):
|
||||
if self.is_producer:
|
||||
assert self.p2p_nccl_engine is not None
|
||||
self.p2p_nccl_engine.wait_for_sent()
|
||||
|
||||
def get_finished(
|
||||
self, finished_req_ids: set[str],
|
||||
**kwargs) -> tuple[Optional[set[str]], Optional[set[str]]]:
|
||||
"""
|
||||
Notifies worker-side connector ids of requests that have
|
||||
finished generating tokens.
|
||||
|
||||
Returns:
|
||||
ids of requests that have finished asynchronous transfer,
|
||||
tuple of (sending/saving ids, recving/loading ids).
|
||||
The finished saves/sends req ids must belong to a set provided in a
|
||||
call to this method (this call or a prior one).
|
||||
"""
|
||||
|
||||
assert self.p2p_nccl_engine is not None
|
||||
|
||||
no_compile_layers = (
|
||||
self._vllm_config.compilation_config.static_forward_context)
|
||||
return self.p2p_nccl_engine.get_finished(finished_req_ids,
|
||||
no_compile_layers)
|
||||
|
||||
# ==============================
|
||||
# Scheduler-side methods
|
||||
# ==============================
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self,
|
||||
request: "Request",
|
||||
num_computed_tokens: int,
|
||||
) -> tuple[int, bool]:
|
||||
"""
|
||||
Get number of new tokens that can be loaded from the
|
||||
external KV cache beyond the num_computed_tokens.
|
||||
|
||||
Args:
|
||||
request (Request): the request object.
|
||||
num_computed_tokens (int): the number of locally
|
||||
computed tokens for this request
|
||||
|
||||
Returns:
|
||||
the number of tokens that can be loaded from the
|
||||
external KV cache beyond what is already computed.
|
||||
"""
|
||||
if self.is_producer:
|
||||
return 0, False
|
||||
|
||||
num_external_tokens = (len(request.prompt_token_ids) - 1 -
|
||||
num_computed_tokens)
|
||||
|
||||
if num_external_tokens < 0:
|
||||
num_external_tokens = 0
|
||||
|
||||
return num_external_tokens, False
|
||||
|
||||
def update_state_after_alloc(self, request: "Request",
|
||||
blocks: "KVCacheBlocks",
|
||||
num_external_tokens: int):
|
||||
"""
|
||||
Update KVConnector state after block allocation.
|
||||
"""
|
||||
if not self.is_producer and num_external_tokens > 0:
|
||||
self._requests_need_load[request.request_id] = (
|
||||
request, blocks.get_block_ids()[0])
|
||||
|
||||
def build_connector_meta(
|
||||
self,
|
||||
scheduler_output: SchedulerOutput,
|
||||
) -> KVConnectorMetadata:
|
||||
"""Build the connector metadata for this step.
|
||||
|
||||
This function should NOT modify any fields in the scheduler_output.
|
||||
Also, calling this function will reset the state of the connector.
|
||||
|
||||
Args:
|
||||
scheduler_output (SchedulerOutput): the scheduler output object.
|
||||
"""
|
||||
|
||||
meta = P2pSucclConnectorMetadata()
|
||||
|
||||
for new_req in scheduler_output.scheduled_new_reqs:
|
||||
if self.is_producer:
|
||||
num_scheduled_tokens = (
|
||||
scheduler_output.num_scheduled_tokens)[new_req.req_id]
|
||||
num_tokens = num_scheduled_tokens + new_req.num_computed_tokens
|
||||
# the request's prompt is chunked prefill
|
||||
if num_tokens < len(new_req.prompt_token_ids):
|
||||
# 'CachedRequestData' has no attribute 'prompt_token_ids'
|
||||
self.chunked_prefill[new_req.req_id] = (
|
||||
new_req.block_ids[0], new_req.prompt_token_ids)
|
||||
continue
|
||||
# the request's prompt is not chunked prefill
|
||||
meta.add_request(request_id=new_req.req_id,
|
||||
token_ids=new_req.prompt_token_ids,
|
||||
block_ids=new_req.block_ids[0],
|
||||
block_size=self._block_size)
|
||||
continue
|
||||
if new_req.req_id in self._requests_need_load:
|
||||
meta.add_request(request_id=new_req.req_id,
|
||||
token_ids=new_req.prompt_token_ids,
|
||||
block_ids=new_req.block_ids[0],
|
||||
block_size=self._block_size)
|
||||
self._requests_need_load.pop(new_req.req_id)
|
||||
|
||||
cached_reqs = scheduler_output.scheduled_cached_reqs
|
||||
for i, req_id in enumerate(cached_reqs.req_ids):
|
||||
num_computed_tokens = cached_reqs.num_computed_tokens[i]
|
||||
new_block_ids = cached_reqs.new_block_ids[i]
|
||||
resumed_from_preemption = cached_reqs.resumed_from_preemption[i]
|
||||
if self.is_producer:
|
||||
num_scheduled_tokens = (
|
||||
scheduler_output.num_scheduled_tokens)[req_id]
|
||||
num_tokens = (num_scheduled_tokens + num_computed_tokens)
|
||||
assert req_id in self.chunked_prefill
|
||||
block_ids = new_block_ids[0]
|
||||
if not resumed_from_preemption:
|
||||
block_ids = (self.chunked_prefill[req_id][0] + block_ids)
|
||||
prompt_token_ids = self.chunked_prefill[req_id][1]
|
||||
# the request's prompt is chunked prefill again
|
||||
if num_tokens < len(prompt_token_ids):
|
||||
self.chunked_prefill[req_id] = (block_ids,
|
||||
prompt_token_ids)
|
||||
continue
|
||||
# the request's prompt is all prefilled finally
|
||||
meta.add_request(request_id=req_id,
|
||||
token_ids=prompt_token_ids,
|
||||
block_ids=block_ids,
|
||||
block_size=self._block_size)
|
||||
self.chunked_prefill.pop(req_id, None)
|
||||
continue
|
||||
|
||||
# NOTE(rob): here we rely on the resumed requests being
|
||||
# the first N requests in the list scheduled_cache_reqs.
|
||||
if not resumed_from_preemption:
|
||||
break
|
||||
if req_id in self._requests_need_load:
|
||||
request, _ = self._requests_need_load.pop(req_id)
|
||||
total_tokens = num_computed_tokens + 1
|
||||
token_ids = request.all_token_ids[:total_tokens]
|
||||
|
||||
# NOTE(rob): For resumed req, new_block_ids is all
|
||||
# of the block_ids for the request.
|
||||
block_ids = new_block_ids[0]
|
||||
|
||||
meta.add_request(request_id=req_id,
|
||||
token_ids=token_ids,
|
||||
block_ids=block_ids,
|
||||
block_size=self._block_size)
|
||||
|
||||
self._requests_need_load.clear()
|
||||
return meta
|
||||
|
||||
def request_finished(
|
||||
self,
|
||||
request: "Request",
|
||||
block_ids: list[int],
|
||||
) -> tuple[bool, Optional[dict[str, Any]]]:
|
||||
"""
|
||||
Called when a request has finished, before its blocks are freed.
|
||||
|
||||
Returns:
|
||||
True if the request is being saved/sent asynchronously and blocks
|
||||
should not be freed until the request_id is returned from
|
||||
get_finished().
|
||||
Optional KVTransferParams to be included in the request outputs
|
||||
returned by the engine.
|
||||
"""
|
||||
|
||||
self.chunked_prefill.pop(request.request_id, None)
|
||||
|
||||
return False, None
|
||||
|
||||
# ==============================
|
||||
# Static methods
|
||||
# ==============================
|
||||
|
||||
@staticmethod
|
||||
def parse_request_id(request_id: str, is_prefill=True) -> tuple[str, int]:
|
||||
# Regular expression to match the string hostname and integer port
|
||||
if is_prefill:
|
||||
pattern = r"___decode_addr_(.*):(\d+)"
|
||||
else:
|
||||
pattern = r"___prefill_addr_(.*):(\d+)___"
|
||||
|
||||
# Use re.search to find the pattern in the request_id
|
||||
match = re.search(pattern, request_id)
|
||||
if match:
|
||||
# Extract the ranks
|
||||
ip = match.group(1)
|
||||
port = int(match.group(2))
|
||||
|
||||
return ip, port
|
||||
raise ValueError(
|
||||
f"Request id {request_id} does not contain hostname and port")
|
||||
|
||||
@staticmethod
|
||||
def check_tensors_except_dim(tensor1, tensor2, dim):
|
||||
shape1 = tensor1.size()
|
||||
shape2 = tensor2.size()
|
||||
|
||||
if len(shape1) != len(shape2) or not all(
|
||||
s1 == s2
|
||||
for i, (s1, s2) in enumerate(zip(shape1, shape2)) if i != dim):
|
||||
raise NotImplementedError(
|
||||
"Currently, only symmetric TP is supported. Asymmetric TP, PP,"
|
||||
"and others will be supported in future PRs.")
|
||||
|
||||
|
||||
KVConnectorFactory.register_connector(
|
||||
"P2pSucclConnector",
|
||||
"vllm_br.distributed.kv_transfer.kv_connector.v1.p2p.p2p_succl_connector",
|
||||
"P2pSucclConnector")
|
||||
@@ -0,0 +1,572 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import typing
|
||||
from collections import deque
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
|
||||
import msgpack
|
||||
import torch
|
||||
import torch_br
|
||||
import zmq
|
||||
from torch_br.supa._internal import get_tensor_info
|
||||
|
||||
from vllm.config import KVTransferConfig
|
||||
# import vllm.distributed.kv_transfer.kv_connector.v1.p2p.p2p_nccl_engine
|
||||
from vllm.utils import get_ip
|
||||
from vllm_br.distributed.device_communicators.pysccl_wrapper import (
|
||||
SCCLLibrary, buffer_type, succlComm_t, succlDataTypeEnum, suStream_t)
|
||||
from vllm_br.distributed.kv_transfer.kv_connector.v1.p2p.tensor_memory_pool import ( # noqa: E501
|
||||
TensorMemoryPool)
|
||||
from vllm_br.platform import SUPAPlatform
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_MEM_POOL_SIZE_GB = 1
|
||||
|
||||
|
||||
@contextmanager
|
||||
def set_p2p_succl_context(num_channels: str):
|
||||
original_values: dict[str, Any] = {}
|
||||
env_vars = [
|
||||
'SUCCL_MAX_NCHANNELS',
|
||||
'SUCCL_MIN_NCHANNELS',
|
||||
'SUCCL_CUMEM_ENABLE',
|
||||
'SUCCL_BUFFSIZE',
|
||||
'SUCCL_PROTO', # LL,LL128,SIMPLE
|
||||
'SUCCL_ALGO', # RING,TREE
|
||||
]
|
||||
|
||||
for var in env_vars:
|
||||
original_values[var] = os.environ.get(var)
|
||||
|
||||
logger.info("set_p2p_succl_context, original_values: %s", original_values)
|
||||
|
||||
try:
|
||||
os.environ['SUCCL_MAX_NCHANNELS'] = num_channels
|
||||
os.environ['SUCCL_MIN_NCHANNELS'] = num_channels
|
||||
os.environ['SUCCL_CUMEM_ENABLE'] = '1'
|
||||
yield
|
||||
finally:
|
||||
for var in env_vars:
|
||||
if original_values[var] is not None:
|
||||
os.environ[var] = original_values[var]
|
||||
else:
|
||||
os.environ.pop(var, None)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SendQueueItem:
|
||||
tensor_id: str
|
||||
remote_address: str
|
||||
tensor: torch.Tensor
|
||||
|
||||
|
||||
class P2pSucclEngine:
|
||||
|
||||
def __init__(self,
|
||||
local_rank: int,
|
||||
config: KVTransferConfig,
|
||||
hostname: str = "",
|
||||
port_offset: int = 0,
|
||||
library_path: Optional[str] = None) -> None:
|
||||
self.config = config
|
||||
self.rank = port_offset
|
||||
self.local_rank = local_rank
|
||||
self.device = torch.device(f"supa:{self.local_rank}")
|
||||
if config is not None:
|
||||
device_cursor = self.config.get_from_extra_config(
|
||||
"device_cursor", 0)
|
||||
self.device = torch.device(
|
||||
f"supa:{self.local_rank + int(device_cursor)}")
|
||||
SUPAPlatform.set_device(self.device)
|
||||
self.succl = SCCLLibrary(library_path)
|
||||
|
||||
if not hostname:
|
||||
hostname = get_ip()
|
||||
port = int(self.config.kv_port) + port_offset
|
||||
if port == 0:
|
||||
raise ValueError("Port cannot be 0")
|
||||
self._hostname = hostname
|
||||
self._port = port
|
||||
|
||||
# Each card corresponds to a ZMQ address.
|
||||
self.zmq_address = f"{self._hostname}:{self._port}"
|
||||
|
||||
# The `http_port` must be consistent with the port of OpenAI.
|
||||
self.http_address = (
|
||||
f"{self._hostname}:"
|
||||
f"{self.config.kv_connector_extra_config['http_port']}")
|
||||
|
||||
# If `proxy_ip` or `proxy_port` is `""`,
|
||||
# then the ping thread will not be enabled.
|
||||
proxy_ip = self.config.get_from_extra_config("proxy_ip", "")
|
||||
proxy_port = self.config.get_from_extra_config("proxy_port", "")
|
||||
if proxy_ip == "" or proxy_port == "":
|
||||
self.proxy_address = ""
|
||||
else:
|
||||
self.proxy_address = proxy_ip + ":" + proxy_port
|
||||
|
||||
self.context = zmq.Context()
|
||||
self.router_socket = self.context.socket(zmq.ROUTER)
|
||||
self.router_socket.bind(f"tcp://{self.zmq_address}")
|
||||
|
||||
self.poller = zmq.Poller()
|
||||
self.poller.register(self.router_socket, zmq.POLLIN)
|
||||
|
||||
self.send_store_cv = threading.Condition()
|
||||
self.send_queue_cv = threading.Condition()
|
||||
self.recv_store_cv = threading.Condition()
|
||||
|
||||
self.send_stream = torch_br.supa.Stream()
|
||||
self.recv_stream = self.send_stream
|
||||
|
||||
mem_pool_size_gb = float(
|
||||
self.config.get_from_extra_config("mem_pool_size_gb",
|
||||
DEFAULT_MEM_POOL_SIZE_GB))
|
||||
self.pool = TensorMemoryPool(max_block_size=int(mem_pool_size_gb *
|
||||
1024**3)) # GB
|
||||
|
||||
# The sending type includes tree mutually exclusive options:
|
||||
# PUT, GET, PUT_ASYNC.
|
||||
self.send_type = self.config.get_from_extra_config(
|
||||
"send_type", "PUT_ASYNC")
|
||||
if self.send_type == "GET":
|
||||
# tensor_id: torch.Tensor
|
||||
self.send_store: dict[str, torch.Tensor] = {}
|
||||
else:
|
||||
# PUT or PUT_ASYNC
|
||||
# tensor_id: torch.Tensor
|
||||
self.send_queue: deque[SendQueueItem] = deque()
|
||||
if self.send_type == "PUT_ASYNC":
|
||||
self._send_thread = threading.Thread(target=self.send_async,
|
||||
daemon=True)
|
||||
self._send_thread.start()
|
||||
|
||||
# tensor_id: torch.Tensor/(addr, dtype, shape)
|
||||
self.recv_store: dict[str, Any] = {}
|
||||
self.recv_request_id_to_tensor_ids: dict[str, set[str]] = {}
|
||||
self.send_request_id_to_tensor_ids: dict[str, set[str]] = {}
|
||||
self.socks: dict[str, Any] = {} # remote_address: client socket
|
||||
self.comms: dict[str, Any] = {} # remote_address: (succlComm_t, rank)
|
||||
|
||||
self.buffer_size = 0
|
||||
self.buffer_size_threshold = float(self.config.kv_buffer_size)
|
||||
|
||||
self.succl_num_channels = self.config.get_from_extra_config(
|
||||
"nccl_num_channels", "8")
|
||||
|
||||
self._listener_thread = threading.Thread(
|
||||
target=self.listen_for_requests, daemon=True)
|
||||
self._listener_thread.start()
|
||||
|
||||
self._ping_thread = None
|
||||
if port_offset == 0 and self.proxy_address != "":
|
||||
self._ping_thread = threading.Thread(target=self.ping, daemon=True)
|
||||
self._ping_thread.start()
|
||||
|
||||
logger.info(
|
||||
"💯P2pSucclEngine init, rank:%d, local_rank:%d, http_address:%s, "
|
||||
"zmq_address:%s, proxy_address:%s, send_type:%s, buffer_size_"
|
||||
"threshold:%.2f, nccl_num_channels:%s", self.rank, self.local_rank,
|
||||
self.http_address, self.zmq_address, self.proxy_address,
|
||||
self.send_type, self.buffer_size_threshold,
|
||||
self.succl_num_channels)
|
||||
|
||||
def create_connect(self, remote_address: typing.Optional[str] = None):
|
||||
assert remote_address is not None
|
||||
if remote_address not in self.socks:
|
||||
sock = self.context.socket(zmq.DEALER)
|
||||
sock.setsockopt_string(zmq.IDENTITY, self.zmq_address)
|
||||
sock.connect(f"tcp://{remote_address}")
|
||||
self.socks[remote_address] = sock
|
||||
if remote_address in self.comms:
|
||||
logger.info("👋comm exists, remote_address:%s, comms:%s",
|
||||
remote_address, self.comms)
|
||||
return sock, self.comms[remote_address]
|
||||
unique_id = self.succl.succlGetUniqueId()
|
||||
data = {"cmd": "NEW", "unique_id": bytes(unique_id.internal)}
|
||||
sock.send(msgpack.dumps(data))
|
||||
rank = 0
|
||||
SUPAPlatform.set_device(self.device)
|
||||
comm: succlComm_t = self.succl.succlCommInitRank(
|
||||
2, unique_id, rank)
|
||||
self.comms[remote_address] = (comm, rank)
|
||||
logger.info("🤝succlCommInitRank Success, %s👉%s, MyRank:%s",
|
||||
self.zmq_address, remote_address, rank)
|
||||
return self.socks[remote_address], self.comms[remote_address]
|
||||
|
||||
def send_tensor(
|
||||
self,
|
||||
tensor_id: str,
|
||||
tensor: torch.Tensor,
|
||||
remote_address: typing.Optional[str] = None,
|
||||
) -> bool:
|
||||
if remote_address is None:
|
||||
with self.recv_store_cv:
|
||||
self.recv_store[tensor_id] = tensor
|
||||
self.recv_store_cv.notify()
|
||||
return True
|
||||
item = SendQueueItem(tensor_id=tensor_id,
|
||||
remote_address=remote_address,
|
||||
tensor=tensor)
|
||||
|
||||
if self.send_type == "PUT":
|
||||
return self.send_sync(item)
|
||||
|
||||
if self.send_type == "PUT_ASYNC":
|
||||
with self.send_queue_cv:
|
||||
self.send_queue.append(item)
|
||||
self.send_queue_cv.notify()
|
||||
return True
|
||||
|
||||
# GET
|
||||
with self.send_store_cv:
|
||||
tensor_size = tensor.element_size() * tensor.numel()
|
||||
if tensor_size > self.buffer_size_threshold:
|
||||
logger.warning(
|
||||
"❗[GET]tensor_id:%s, tensor_size:%d, is greater than"
|
||||
"buffer size threshold :%d, skip send to %s, rank:%d",
|
||||
tensor_id, tensor_size, self.buffer_size_threshold,
|
||||
remote_address, self.rank)
|
||||
return False
|
||||
while (self.buffer_size + tensor_size
|
||||
> self.buffer_size_threshold):
|
||||
assert len(self.send_store) > 0
|
||||
oldest_tensor_id = next(iter(self.send_store))
|
||||
oldest_tensor = self.send_store.pop(oldest_tensor_id)
|
||||
oldest_tensor_size = oldest_tensor.element_size(
|
||||
) * oldest_tensor.numel()
|
||||
self.buffer_size -= oldest_tensor_size
|
||||
logger.debug(
|
||||
"⛔[GET]Send to %s, tensor_id:%s, tensor_size:%d,"
|
||||
" buffer_size:%d, oldest_tensor_size:%d, rank:%d",
|
||||
remote_address, tensor_id, tensor_size, self.buffer_size,
|
||||
oldest_tensor_size, self.rank)
|
||||
|
||||
self.send_store[tensor_id] = tensor
|
||||
self.buffer_size += tensor_size
|
||||
logger.debug(
|
||||
"🔵[GET]Send to %s, tensor_id:%s, tensor_size:%d, "
|
||||
"shape:%s, rank:%d, buffer_size:%d(%.2f%%)", remote_address,
|
||||
tensor_id, tensor_size, tensor.shape, self.rank,
|
||||
self.buffer_size,
|
||||
self.buffer_size / self.buffer_size_threshold * 100)
|
||||
return True
|
||||
|
||||
def recv_tensor(
|
||||
self,
|
||||
tensor_id: str,
|
||||
remote_address: typing.Optional[str] = None,
|
||||
) -> torch.Tensor:
|
||||
if self.send_type == "PUT" or self.send_type == "PUT_ASYNC":
|
||||
start_time = time.time()
|
||||
with self.recv_store_cv:
|
||||
while tensor_id not in self.recv_store:
|
||||
self.recv_store_cv.wait()
|
||||
tensor = self.recv_store[tensor_id]
|
||||
|
||||
if tensor is not None:
|
||||
if isinstance(tensor, tuple):
|
||||
addr, dtype, shape = tensor
|
||||
tensor = self.pool.load_tensor(addr, dtype, shape,
|
||||
self.device)
|
||||
else:
|
||||
self.buffer_size -= (tensor.element_size() *
|
||||
tensor.numel())
|
||||
else:
|
||||
duration = time.time() - start_time
|
||||
logger.warning(
|
||||
"🔴[PUT]Recv From %s, tensor_id:%s, duration:%.3fms, "
|
||||
"rank:%d", remote_address, tensor_id, duration * 1000,
|
||||
self.rank)
|
||||
return tensor
|
||||
|
||||
# GET
|
||||
if remote_address is None:
|
||||
return None
|
||||
|
||||
if remote_address not in self.socks:
|
||||
self.create_connect(remote_address)
|
||||
|
||||
sock = self.socks[remote_address]
|
||||
comm, rank = self.comms[remote_address]
|
||||
|
||||
data = {"cmd": "GET", "tensor_id": tensor_id}
|
||||
sock.send(msgpack.dumps(data))
|
||||
|
||||
message = sock.recv()
|
||||
data = msgpack.loads(message)
|
||||
if data["ret"] != 0:
|
||||
logger.warning("🔴[GET]Recv From %s, tensor_id: %s, ret: %d",
|
||||
remote_address, tensor_id, data["ret"])
|
||||
return None
|
||||
|
||||
with torch_br.supa.stream(self.recv_stream):
|
||||
tensor = torch_br._empty_ut_only(data["shape"],
|
||||
dtype=getattr(
|
||||
torch, data["dtype"]),
|
||||
tensor_type='BUFFER_ANY',
|
||||
device=self.device)
|
||||
self.recv(comm, tensor, rank ^ 1, self.recv_stream)
|
||||
|
||||
return tensor
|
||||
|
||||
def listen_for_requests(self):
|
||||
while True:
|
||||
socks = dict(self.poller.poll())
|
||||
if self.router_socket not in socks:
|
||||
continue
|
||||
|
||||
remote_address, message = self.router_socket.recv_multipart()
|
||||
data = msgpack.loads(message)
|
||||
if data["cmd"] == "NEW":
|
||||
unique_id = self.succl.unique_id_from_bytes(
|
||||
bytes(data["unique_id"]))
|
||||
|
||||
rank = 1
|
||||
SUPAPlatform.set_device(self.device)
|
||||
comm: succlComm_t = self.succl.succlCommInitRank(
|
||||
2, unique_id, rank)
|
||||
self.comms[remote_address.decode()] = (comm, rank)
|
||||
logger.info("🤝suclCommInitRank Success, %s👈%s, MyRank:%s",
|
||||
self.zmq_address, remote_address.decode(), rank)
|
||||
elif data["cmd"] == "PUT":
|
||||
tensor_id = data["tensor_id"]
|
||||
try:
|
||||
with torch_br.supa.stream(self.recv_stream):
|
||||
tensor = torch_br._empty_ut_only(
|
||||
data["shape"],
|
||||
dtype=getattr(torch, data["dtype"]),
|
||||
tensor_type='BUFFER_ANY',
|
||||
device=self.device)
|
||||
self.router_socket.send_multipart([remote_address, b"0"])
|
||||
comm, rank = self.comms[remote_address.decode()]
|
||||
self.recv(comm, tensor, rank ^ 1, self.recv_stream)
|
||||
tensor_size = tensor.element_size() * tensor.numel()
|
||||
if (self.buffer_size + tensor_size
|
||||
> self.buffer_size_threshold):
|
||||
# Store Tensor in memory pool
|
||||
addr = self.pool.store_tensor(tensor)
|
||||
tensor = (addr, tensor.dtype, tensor.shape)
|
||||
logger.warning(
|
||||
"🔴[PUT]Recv Tensor, Out Of Threshold, "
|
||||
"%s👈%s, data:%s, addr:%d", self.zmq_address,
|
||||
remote_address.decode(), data, addr)
|
||||
else:
|
||||
self.buffer_size += tensor_size
|
||||
|
||||
except torch.cuda.OutOfMemoryError:
|
||||
self.router_socket.send_multipart([remote_address, b"1"])
|
||||
tensor = None
|
||||
logger.warning(
|
||||
"🔴[PUT]Recv Tensor, Out Of Memory, %s👈%s, "
|
||||
"data:%s", self.zmq_address, remote_address.decode(),
|
||||
data)
|
||||
|
||||
with self.recv_store_cv:
|
||||
self.recv_store[tensor_id] = tensor
|
||||
self.have_received_tensor_id(tensor_id)
|
||||
self.recv_store_cv.notify()
|
||||
|
||||
elif data["cmd"] == "GET":
|
||||
tensor_id = data["tensor_id"]
|
||||
with self.send_store_cv:
|
||||
tensor = self.send_store.pop(tensor_id, None)
|
||||
if tensor is not None:
|
||||
data = {
|
||||
"ret": 0,
|
||||
"shape": tensor.shape,
|
||||
"dtype": str(tensor.dtype).replace("torch.", ""),
|
||||
"tensor_type": get_tensor_info(tensor)[0]['layout']
|
||||
}
|
||||
# LRU
|
||||
self.send_store[tensor_id] = tensor
|
||||
self.have_sent_tensor_id(tensor_id)
|
||||
else:
|
||||
data = {"ret": 1}
|
||||
|
||||
self.router_socket.send_multipart(
|
||||
[remote_address, msgpack.dumps(data)])
|
||||
|
||||
if data["ret"] == 0:
|
||||
comm, rank = self.comms[remote_address.decode()]
|
||||
self.send(comm, tensor.to(self.device), rank ^ 1,
|
||||
self.send_stream)
|
||||
else:
|
||||
logger.warning(
|
||||
"🚧Unexpected, Received message from %s, data:%s",
|
||||
remote_address, data)
|
||||
|
||||
def have_sent_tensor_id(self, tensor_id: str):
|
||||
request_id = tensor_id.split('#')[0]
|
||||
if request_id not in self.send_request_id_to_tensor_ids:
|
||||
self.send_request_id_to_tensor_ids[request_id] = set()
|
||||
self.send_request_id_to_tensor_ids[request_id].add(tensor_id)
|
||||
|
||||
def have_received_tensor_id(self, tensor_id: str):
|
||||
request_id = tensor_id.split('#')[0]
|
||||
if request_id not in self.recv_request_id_to_tensor_ids:
|
||||
self.recv_request_id_to_tensor_ids[request_id] = set()
|
||||
self.recv_request_id_to_tensor_ids[request_id].add(tensor_id)
|
||||
|
||||
def send_async(self):
|
||||
while True:
|
||||
with self.send_queue_cv:
|
||||
while not self.send_queue:
|
||||
self.send_queue_cv.wait()
|
||||
item = self.send_queue.popleft()
|
||||
if not self.send_queue:
|
||||
self.send_queue_cv.notify()
|
||||
self.send_sync(item)
|
||||
|
||||
def wait_for_sent(self):
|
||||
if self.send_type == "PUT_ASYNC":
|
||||
start_time = time.time()
|
||||
with self.send_queue_cv:
|
||||
while self.send_queue:
|
||||
self.send_queue_cv.wait()
|
||||
duration = time.time() - start_time
|
||||
logger.debug(
|
||||
"🚧[PUT_ASYNC]It took %.3fms to wait for the send_queue"
|
||||
" to be empty, rank:%d", duration * 1000, self.rank)
|
||||
|
||||
def send_sync(self, item: SendQueueItem) -> bool:
|
||||
if item.remote_address is None:
|
||||
return False
|
||||
if item.remote_address not in self.socks:
|
||||
self.create_connect(item.remote_address)
|
||||
tensor = item.tensor
|
||||
|
||||
sock = self.socks[item.remote_address]
|
||||
comm, rank = self.comms[item.remote_address]
|
||||
data = {
|
||||
"cmd": "PUT",
|
||||
"tensor_id": item.tensor_id,
|
||||
"shape": tensor.shape,
|
||||
"dtype": str(tensor.dtype).replace("torch.", ""),
|
||||
"tensor_type": get_tensor_info(tensor)[0]['layout']
|
||||
}
|
||||
sock.send(msgpack.dumps(data))
|
||||
|
||||
response = sock.recv()
|
||||
if response != b"0":
|
||||
logger.error(
|
||||
"🔴Send Tensor, Peer Out Of Memory/Threshold, %s 👉 %s, "
|
||||
"MyRank:%s, data:%s, tensor:%s, size:%fGB, response:%s",
|
||||
self.zmq_address, item.remote_address, rank, data,
|
||||
tensor.shape,
|
||||
tensor.element_size() * tensor.numel() / 1024**3,
|
||||
response.decode())
|
||||
return False
|
||||
self.send(comm, tensor.to(self.device), rank ^ 1, self.send_stream)
|
||||
if self.send_type == "PUT_ASYNC":
|
||||
self.have_sent_tensor_id(item.tensor_id)
|
||||
|
||||
return True
|
||||
|
||||
def get_finished(
|
||||
self, finished_req_ids: set[str], no_compile_layers
|
||||
) -> tuple[Optional[set[str]], Optional[set[str]]]:
|
||||
"""
|
||||
Notifies worker-side connector ids of requests that have
|
||||
finished generating tokens.
|
||||
|
||||
Returns:
|
||||
ids of requests that have finished asynchronous transfer,
|
||||
tuple of (sending/saving ids, recving/loading ids).
|
||||
The finished saves/sends req ids must belong to a set provided in a
|
||||
call to this method (this call or a prior one).
|
||||
"""
|
||||
|
||||
# Clear the buffer upon request completion.
|
||||
for request_id in finished_req_ids:
|
||||
for layer_name in no_compile_layers:
|
||||
tensor_id = request_id + "#" + layer_name
|
||||
if tensor_id in self.recv_store:
|
||||
with self.recv_store_cv:
|
||||
tensor = self.recv_store.pop(tensor_id, None)
|
||||
self.send_request_id_to_tensor_ids.pop(
|
||||
request_id, None)
|
||||
self.recv_request_id_to_tensor_ids.pop(
|
||||
request_id, None)
|
||||
if isinstance(tensor, tuple):
|
||||
addr, _, _ = tensor
|
||||
self.pool.free(addr)
|
||||
|
||||
# TODO:Retrieve requests that have already sent the KV cache.
|
||||
finished_sending: set[str] = set()
|
||||
|
||||
# TODO:Retrieve requests that have already received the KV cache.
|
||||
finished_recving: set[str] = set()
|
||||
|
||||
return finished_sending or None, finished_recving or None
|
||||
|
||||
def ping(self):
|
||||
sock = self.context.socket(zmq.DEALER)
|
||||
sock.setsockopt_string(zmq.IDENTITY, self.zmq_address)
|
||||
logger.debug("ping start, zmq_address:%s", self.zmq_address)
|
||||
sock.connect(f"tcp://{self.proxy_address}")
|
||||
data = {
|
||||
"type": "P" if self.config.is_kv_producer else "D",
|
||||
"http_address": self.http_address,
|
||||
"zmq_address": self.zmq_address
|
||||
}
|
||||
while True:
|
||||
sock.send(msgpack.dumps(data))
|
||||
time.sleep(3)
|
||||
|
||||
def send(self, comm, tensor: torch.Tensor, dst: int, stream=None):
|
||||
assert tensor.device == self.device, (
|
||||
f"this succl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {tensor.device}")
|
||||
if stream is None:
|
||||
stream = torch_br.supa.Stream()
|
||||
with torch_br.supa.stream(stream):
|
||||
self.succl.succlSend(buffer_type(tensor.data_ptr()),
|
||||
tensor.numel(),
|
||||
succlDataTypeEnum.from_torch(tensor.dtype),
|
||||
dst, comm, suStream_t(stream.supa_stream))
|
||||
stream.synchronize()
|
||||
|
||||
def recv(self, comm, tensor: torch.Tensor, src: int, stream=None):
|
||||
assert tensor.device == self.device, (
|
||||
f"this succl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {tensor.device}")
|
||||
if stream is None:
|
||||
stream = torch_br.supa.Stream()
|
||||
with torch_br.supa.stream(stream):
|
||||
self.succl.succlRecv(buffer_type(tensor.data_ptr()),
|
||||
tensor.numel(),
|
||||
succlDataTypeEnum.from_torch(tensor.dtype),
|
||||
src, comm, suStream_t(stream.supa_stream))
|
||||
stream.synchronize()
|
||||
|
||||
def close(self) -> None:
|
||||
self._listener_thread.join()
|
||||
if self.send_type == "PUT_ASYNC":
|
||||
self._send_thread.join()
|
||||
if self._ping_thread is not None:
|
||||
self._ping_thread.join()
|
||||
@@ -0,0 +1,280 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import atexit
|
||||
import ctypes
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import logger
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemoryBlock:
|
||||
size: int
|
||||
addr: int
|
||||
|
||||
|
||||
"""A memory pool for managing pinned host memory allocations for tensors.
|
||||
|
||||
This class implements a buddy allocation system to efficiently manage pinned
|
||||
host memory for tensor storage. It supports allocation, deallocation, and
|
||||
tensor storage/retrieval operations.
|
||||
|
||||
Key Features:
|
||||
- Uses power-of-two block sizes for efficient buddy allocation
|
||||
- Supports splitting and merging of memory blocks
|
||||
- Provides methods to store CUDA tensors in pinned host memory
|
||||
- Allows loading tensors from pinned memory back to device
|
||||
- Automatically cleans up memory on destruction
|
||||
|
||||
Attributes:
|
||||
max_block_size (int): Maximum block size (rounded to nearest power of two)
|
||||
min_block_size (int): Minimum block size (rounded to nearest power of two)
|
||||
free_lists (dict): Dictionary of free memory blocks by size
|
||||
allocated_blocks (dict): Dictionary of currently allocated blocks
|
||||
base_tensor (torch.Tensor): Base pinned memory tensor
|
||||
base_address (int): Base memory address of the pinned memory region
|
||||
|
||||
Example:
|
||||
>>> pool = TensorMemoryPool(max_block_size=1024*1024)
|
||||
>>> tensor = torch.randn(100, device='cuda')
|
||||
>>> addr = pool.store_tensor(tensor)
|
||||
>>> loaded_tensor = pool.load_tensor(addr, tensor.dtype,
|
||||
... tensor.shape, 'cuda')
|
||||
>>> pool.free(addr)
|
||||
"""
|
||||
|
||||
|
||||
class TensorMemoryPool:
|
||||
"""Initializes the memory pool with given size constraints.
|
||||
|
||||
Args:
|
||||
max_block_size (int): Maximum size of memory blocks to manage
|
||||
min_block_size (int, optional): Minimum size of memory blocks
|
||||
to manage. Defaults to 512.
|
||||
|
||||
Raises:
|
||||
ValueError: If block sizes are invalid or max_block_size is less
|
||||
than min_block_size
|
||||
"""
|
||||
|
||||
def __init__(self, max_block_size: int, min_block_size: int = 512):
|
||||
if max_block_size <= 0 or min_block_size <= 0:
|
||||
raise ValueError("Block sizes must be positive")
|
||||
if max_block_size < min_block_size:
|
||||
raise ValueError(
|
||||
"Max block size must be greater than min block size")
|
||||
|
||||
self.max_block_size = self._round_to_power_of_two(max_block_size)
|
||||
self.min_block_size = self._round_to_power_of_two(min_block_size)
|
||||
|
||||
self.free_lists: dict[int, dict[int, MemoryBlock]] = {}
|
||||
self.allocated_blocks: dict[int, MemoryBlock] = {}
|
||||
|
||||
self._initialize_free_lists()
|
||||
self._allocate_pinned_memory()
|
||||
|
||||
atexit.register(self.cleanup)
|
||||
|
||||
def _round_to_power_of_two(self, size: int) -> int:
|
||||
return 1 << (size - 1).bit_length()
|
||||
|
||||
def _initialize_free_lists(self):
|
||||
size = self.max_block_size
|
||||
while size >= self.min_block_size:
|
||||
self.free_lists[size] = {}
|
||||
size //= 2
|
||||
|
||||
def _allocate_pinned_memory(self):
|
||||
self.base_tensor = torch.empty(self.max_block_size // 4,
|
||||
dtype=torch.float32,
|
||||
pin_memory=True)
|
||||
self.base_address = self.base_tensor.data_ptr()
|
||||
initial_block = MemoryBlock(size=self.max_block_size,
|
||||
addr=self.base_address)
|
||||
self.free_lists[self.max_block_size][
|
||||
initial_block.addr] = initial_block
|
||||
|
||||
logger.debug("TensorMemoryPool, base_address:%d, max_block_size:%d",
|
||||
self.base_address, self.max_block_size)
|
||||
|
||||
def allocate(self, size: int) -> int:
|
||||
"""Allocates a memory block of at least the requested size.
|
||||
|
||||
Args:
|
||||
size (int): Minimum size of memory to allocate
|
||||
|
||||
Returns:
|
||||
int: Address of the allocated memory block
|
||||
|
||||
Raises:
|
||||
ValueError: If size is invalid or insufficient memory is available
|
||||
"""
|
||||
if size <= 0:
|
||||
raise ValueError("Allocation size must be positive")
|
||||
|
||||
required_size = self._round_to_power_of_two(
|
||||
max(size, self.min_block_size))
|
||||
if required_size > self.max_block_size:
|
||||
raise ValueError("Requested size exceeds maximum block size")
|
||||
|
||||
current_size = required_size
|
||||
while current_size <= self.max_block_size:
|
||||
if self.free_lists[current_size]:
|
||||
_, block = self.free_lists[current_size].popitem()
|
||||
self._split_block(block, required_size)
|
||||
self.allocated_blocks[block.addr] = block
|
||||
return block.addr
|
||||
current_size *= 2
|
||||
|
||||
raise ValueError("Insufficient memory")
|
||||
|
||||
def _split_block(self, block: MemoryBlock, required_size: int):
|
||||
while (block.size > required_size
|
||||
and block.size // 2 >= self.min_block_size):
|
||||
buddy_size = block.size // 2
|
||||
buddy_addr = block.addr + buddy_size
|
||||
|
||||
buddy = MemoryBlock(size=buddy_size, addr=buddy_addr)
|
||||
block.size = buddy_size
|
||||
|
||||
self.free_lists[buddy_size][buddy.addr] = buddy
|
||||
|
||||
def free(self, addr: int):
|
||||
"""Frees an allocated memory block.
|
||||
|
||||
Args:
|
||||
addr (int): Address of the block to free
|
||||
|
||||
Raises:
|
||||
ValueError: If address is invalid or not allocated
|
||||
"""
|
||||
if addr not in self.allocated_blocks:
|
||||
raise ValueError("Invalid address to free")
|
||||
|
||||
block = self.allocated_blocks.pop(addr)
|
||||
self._merge_buddies(block)
|
||||
|
||||
def _merge_buddies(self, block: MemoryBlock):
|
||||
MAX_MERGE_DEPTH = 30
|
||||
depth = 0
|
||||
|
||||
while depth < MAX_MERGE_DEPTH:
|
||||
buddy_offset = block.size if (block.addr - self.base_address) % (
|
||||
2 * block.size) == 0 else -block.size
|
||||
buddy_addr = block.addr + buddy_offset
|
||||
buddy = self.free_lists[block.size].get(buddy_addr)
|
||||
if buddy:
|
||||
del self.free_lists[buddy.size][buddy.addr]
|
||||
merged_addr = min(block.addr, buddy.addr)
|
||||
merged_size = block.size * 2
|
||||
block = MemoryBlock(size=merged_size, addr=merged_addr)
|
||||
depth += 1
|
||||
else:
|
||||
break
|
||||
self.free_lists[block.size][block.addr] = block
|
||||
|
||||
def store_tensor(self, tensor: torch.Tensor) -> int:
|
||||
"""Stores a CUDA tensor in pinned host memory.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): CUDA tensor to store
|
||||
|
||||
Returns:
|
||||
int: Address where the tensor is stored
|
||||
|
||||
Raises:
|
||||
ValueError: If tensor is not on CUDA or allocation fails
|
||||
"""
|
||||
if not tensor.is_cuda:
|
||||
raise ValueError("Only CUDA tensors can be stored")
|
||||
|
||||
size = tensor.element_size() * tensor.numel()
|
||||
addr = self.allocate(size)
|
||||
block = self.allocated_blocks[addr]
|
||||
|
||||
if block.size < size:
|
||||
self.free(addr)
|
||||
raise ValueError(
|
||||
f"Allocated block size {block.size} is smaller than "
|
||||
f"required size {size}")
|
||||
|
||||
try:
|
||||
buffer = (ctypes.c_byte * block.size).from_address(block.addr)
|
||||
cpu_tensor = torch.frombuffer(buffer,
|
||||
dtype=tensor.dtype,
|
||||
count=tensor.numel()).reshape(
|
||||
tensor.shape)
|
||||
except ValueError as err:
|
||||
self.free(addr)
|
||||
raise ValueError(f"Failed to create tensor view: {err}") from err
|
||||
|
||||
cpu_tensor.copy_(tensor)
|
||||
|
||||
return addr
|
||||
|
||||
def load_tensor(self, addr: int, dtype: torch.dtype,
|
||||
shape: tuple[int, ...], device) -> torch.Tensor:
|
||||
"""Loads a tensor from pinned host memory to the specified device.
|
||||
|
||||
Args:
|
||||
addr (int): Address where tensor is stored
|
||||
dtype (torch.dtype): Data type of the tensor
|
||||
shape (tuple[int, ...]): Shape of the tensor
|
||||
device: Target device for the loaded tensor
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The loaded tensor on the specified device
|
||||
|
||||
Raises:
|
||||
ValueError: If address is invalid or sizes don't match
|
||||
"""
|
||||
if addr not in self.allocated_blocks:
|
||||
raise ValueError("Invalid address to load")
|
||||
|
||||
block = self.allocated_blocks[addr]
|
||||
num_elements = math.prod(shape)
|
||||
dtype_size = torch.tensor([], dtype=dtype).element_size()
|
||||
required_size = num_elements * dtype_size
|
||||
|
||||
if required_size > block.size:
|
||||
raise ValueError("Requested tensor size exceeds block size")
|
||||
|
||||
buffer = (ctypes.c_byte * block.size).from_address(block.addr)
|
||||
cpu_tensor = torch.frombuffer(buffer, dtype=dtype,
|
||||
count=num_elements).reshape(shape)
|
||||
|
||||
cuda_tensor = torch.empty(shape, dtype=dtype, device=device)
|
||||
|
||||
cuda_tensor.copy_(cpu_tensor)
|
||||
|
||||
return cuda_tensor
|
||||
|
||||
def cleanup(self):
|
||||
"""Cleans up all memory resources and resets the pool state."""
|
||||
self.free_lists.clear()
|
||||
self.allocated_blocks.clear()
|
||||
if hasattr(self, 'base_tensor'):
|
||||
del self.base_tensor
|
||||
|
||||
def __del__(self):
|
||||
self.cleanup()
|
||||
473
vllm_br/distributed/parallel_state.py
Normal file
473
vllm_br/distributed/parallel_state.py
Normal file
@@ -0,0 +1,473 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
import torch_br
|
||||
|
||||
import vllm
|
||||
import vllm.distributed.parallel_state
|
||||
from vllm.distributed import GroupCoordinator
|
||||
from vllm.distributed.parallel_state import (_WORLD, TensorMetadata,
|
||||
_split_tensor_dict, get_pp_group,
|
||||
get_tp_group, get_world_group,
|
||||
init_model_parallel_group, logger)
|
||||
from vllm_br import envs
|
||||
|
||||
|
||||
@dataclass
|
||||
class GraphCaptureContext:
|
||||
stream: torch_br.supa.Stream
|
||||
|
||||
|
||||
@contextmanager
|
||||
#@patch_to(GroupCoordinator.graph_capture)
|
||||
def graph_capture_(self,
|
||||
graph_capture_context: Optional[GraphCaptureContext] = None
|
||||
):
|
||||
if graph_capture_context is None:
|
||||
stream = torch_br.supa.Stream()
|
||||
graph_capture_context = GraphCaptureContext(stream)
|
||||
else:
|
||||
stream = graph_capture_context.stream
|
||||
|
||||
# only supa uses this function,
|
||||
# so we don't abstract it into the base class
|
||||
#maybe_ca_context = nullcontext()
|
||||
#from vllm_br.distributed.communicator import SUPACommunicator
|
||||
#if self.device_communicator is not None:
|
||||
# assert isinstance(self.device_communicator, SUPACommunicator)
|
||||
# ca_comm = self.device_communicator.ca_comm
|
||||
# if ca_comm is not None:
|
||||
# maybe_ca_context = ca_comm.capture() # type: ignore
|
||||
|
||||
# ensure all initialization operations complete before attempting to
|
||||
# capture the graph on another stream
|
||||
curr_stream = torch_br.supa.current_stream()
|
||||
if curr_stream != stream:
|
||||
stream.wait_stream(curr_stream)
|
||||
|
||||
with torch_br.supa.stream(stream):
|
||||
yield graph_capture_context
|
||||
|
||||
|
||||
vllm.distributed.parallel_state.GroupCoordinator.graph_capture = graph_capture_
|
||||
|
||||
|
||||
@contextmanager
|
||||
#@patch_to(graph_capture)
|
||||
def graph_capture_supa(device: torch.device):
|
||||
"""
|
||||
`graph_capture` is a context manager which should surround the code that
|
||||
is capturing the SUPA graph. Its main purpose is to ensure that the
|
||||
some operations will be run after the graph is captured, before the graph
|
||||
is replayed. It returns a `GraphCaptureContext` object which contains the
|
||||
necessary data for the graph capture. Currently, it only contains the
|
||||
stream that the graph capture is running on. This stream is set to the
|
||||
current SUPA stream when the context manager is entered and reset to the
|
||||
default stream when the context manager is exited. This is to ensure that
|
||||
the graph capture is running on a separate stream from the default stream,
|
||||
in order to explicitly distinguish the kernels to capture
|
||||
from other kernels possibly launched on background in the default stream.
|
||||
"""
|
||||
context = GraphCaptureContext(torch_br.supa.Stream(device=device))
|
||||
with get_tp_group().graph_capture(context), get_pp_group().graph_capture(
|
||||
context):
|
||||
yield context
|
||||
|
||||
|
||||
vllm.distributed.parallel_state.graph_capture = graph_capture_supa
|
||||
|
||||
|
||||
def is_global_first_rank() -> bool:
|
||||
"""
|
||||
Check if the current process is the first rank globally across all
|
||||
parallelism strategies (PP, TP, DP, EP, etc.).
|
||||
|
||||
Unlike group-specific checks like `get_tensor_model_parallel_rank() == 0`
|
||||
or `get_pp_group().is_first_rank`, this function checks the global rank
|
||||
across all parallelism dimensions.
|
||||
|
||||
Returns:
|
||||
bool: True if this is the global first rank (rank 0), False otherwise.
|
||||
Returns True if distributed is not initialized (single process).
|
||||
"""
|
||||
try:
|
||||
# If world group is available, use it for the most accurate check
|
||||
if _WORLD is not None:
|
||||
return _WORLD.is_first_rank
|
||||
|
||||
# If torch distributed is not initialized, assume single process
|
||||
if not torch.distributed.is_initialized():
|
||||
return True
|
||||
|
||||
# Fallback to torch's global rank
|
||||
return torch.distributed.get_rank() == 0
|
||||
|
||||
except Exception:
|
||||
# If anything goes wrong, assume this is the first rank
|
||||
return True
|
||||
|
||||
|
||||
def generate_multi_node_parallel_groups(
|
||||
total_procs: int,
|
||||
tp_size: int,
|
||||
pp_size: int,
|
||||
dp_size: int,
|
||||
) -> dict:
|
||||
if total_procs == 16 and tp_size == 8 and pp_size == 2 and dp_size == 1:
|
||||
tp_groups = [[0, 1, 2, 3, 8, 9, 10, 11], [4, 5, 6, 7, 12, 13, 14, 15]]
|
||||
pp_groups = [[0, 4], [1, 5], [2, 6], [3, 7], [8, 12], [9, 13],
|
||||
[10, 14], [11, 15]]
|
||||
dp_groups = [[0], [1], [2], [3], [4], [5], [6], [7], [8], [9], [10],
|
||||
[11], [12], [13], [14], [15]]
|
||||
ep_groups = [[0, 1, 2, 3, 8, 9, 10, 11], [4, 5, 6, 7, 12, 13, 14, 15]]
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unsupported VLLM_BR_ENABLE_TP_GROUPS_IN_SUPERNODE parallel config of"
|
||||
" tp_size: {tp_size} pp_size: {pp_size} dp_size: {dp_size}"
|
||||
"Currently only 'tp8pp2dp1' is allowed.")
|
||||
return {
|
||||
"tp_groups": tp_groups,
|
||||
"pp_groups": pp_groups,
|
||||
"dp_groups": dp_groups,
|
||||
"ep_groups": ep_groups,
|
||||
}
|
||||
|
||||
|
||||
# sync v0.11 api update, while code logic possibly need sync with vllm original code implementation
|
||||
def initialize_model_parallel_cross_tp(
|
||||
tensor_model_parallel_size: int = 1,
|
||||
pipeline_model_parallel_size: int = 1,
|
||||
decode_context_model_parallel_size: Optional[int] = 1,
|
||||
backend: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize model parallel groups.
|
||||
|
||||
Arguments:
|
||||
tensor_model_parallel_size: number of GPUs used for tensor model
|
||||
parallelism.
|
||||
pipeline_model_parallel_size: number of GPUs used for pipeline model
|
||||
parallelism.
|
||||
backend: name of torch distributed communication backend.
|
||||
|
||||
Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we
|
||||
use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
|
||||
the model pipeline. The present function will
|
||||
create 4 tensor model-parallel groups and 2 pipeline model-parallel groups:
|
||||
4 tensor model-parallel groups:
|
||||
[g0, g1], [g2, g3], [g4, g5], [g6, g7]
|
||||
2 pipeline model-parallel groups:
|
||||
[g0, g2, g4, g6], [g1, g3, g5, g7]
|
||||
Note that for efficiency, the caller should make sure adjacent ranks
|
||||
are on the same DGX box. For example if we are using 2 DGX-1 boxes
|
||||
with a total of 16 GPUs, rank 0 to 7 belong to the first box and
|
||||
ranks 8 to 15 belong to the second box.
|
||||
"""
|
||||
# Get world size and rank. Ensure some consistencies.
|
||||
assert torch.distributed.is_initialized()
|
||||
world_size: int = torch.distributed.get_world_size()
|
||||
rank = torch.distributed.get_rank()
|
||||
backend = backend or torch.distributed.get_backend(
|
||||
get_world_group().device_group)
|
||||
|
||||
data_parallel_size = 1
|
||||
from vllm.config import get_current_vllm_config
|
||||
config = get_current_vllm_config()
|
||||
if config is not None:
|
||||
data_parallel_size = config.parallel_config.data_parallel_size
|
||||
|
||||
# the layout order is: ExternalDP x DP x PP x TP
|
||||
# ExternalDP is the data parallel group that is not part of the model,
|
||||
# every dp rank can generate independently (in verl integration).
|
||||
# DP is the data parallel group that is part of the model,
|
||||
# all the ranks in the same DP group should generate simultaneously,
|
||||
# i.e. the `generate` call in the same DP group should be called together,
|
||||
# otherwise it will cause deadlock.
|
||||
# to get group_ranks for each dimension, transpose that dimension to the
|
||||
# last dimension, then reshape to 2D, then unbind the last dimension
|
||||
all_ranks = torch.arange(world_size).reshape(
|
||||
-1, data_parallel_size, pipeline_model_parallel_size,
|
||||
tensor_model_parallel_size) # noqa
|
||||
if envs.VLLM_BR_ENABLE_TP_GROUPS_IN_SUPERNODE:
|
||||
groups = generate_multi_node_parallel_groups(
|
||||
world_size, tensor_model_parallel_size,
|
||||
pipeline_model_parallel_size, data_parallel_size)
|
||||
logger.info("supernode reorganized groups: %s", groups)
|
||||
# Build the tensor model-parallel groups.
|
||||
assert vllm.distributed.parallel_state._TP is None, (
|
||||
"tensor model parallel group is already initialized")
|
||||
if envs.VLLM_BR_ENABLE_TP_GROUPS_IN_SUPERNODE:
|
||||
group_ranks = groups['tp_groups']
|
||||
else:
|
||||
group_ranks = all_ranks.view(-1, tensor_model_parallel_size).unbind(0)
|
||||
group_ranks = [x.tolist() for x in group_ranks]
|
||||
# message queue broadcaster is only used in tensor model parallel group
|
||||
vllm.distributed.parallel_state._TP = init_model_parallel_group(
|
||||
group_ranks,
|
||||
get_world_group().local_rank,
|
||||
backend,
|
||||
use_message_queue_broadcaster=True,
|
||||
group_name="tp")
|
||||
|
||||
# Build the DCP model-parallel groups.
|
||||
# global _DCP
|
||||
assert vllm.distributed.parallel_state._DCP is None, (
|
||||
"decode context model parallel group is already initialized")
|
||||
# Note(hc): In the current implementation of decode context parallel,
|
||||
# dcp_size must not exceed tp_size, because the world size does not
|
||||
# change by DCP, it simply reuses the GPUs of TP group, and split one
|
||||
# TP group into tp_size//dcp_size DCP groups.
|
||||
group_ranks = all_ranks.reshape(
|
||||
-1, decode_context_model_parallel_size).unbind(0)
|
||||
group_ranks = [x.tolist() for x in group_ranks]
|
||||
vllm.distributed.parallel_state._DCP = init_model_parallel_group(
|
||||
group_ranks,
|
||||
get_world_group().local_rank,
|
||||
backend,
|
||||
use_message_queue_broadcaster=True,
|
||||
group_name="dcp")
|
||||
|
||||
# Build the pipeline model-parallel groups.
|
||||
assert vllm.distributed.parallel_state._PP is None, (
|
||||
"pipeline model parallel group is already initialized")
|
||||
if envs.VLLM_BR_ENABLE_TP_GROUPS_IN_SUPERNODE:
|
||||
group_ranks = groups['pp_groups']
|
||||
else:
|
||||
group_ranks = all_ranks.transpose(2, 3).reshape(
|
||||
-1, pipeline_model_parallel_size).unbind(0)
|
||||
group_ranks = [x.tolist() for x in group_ranks]
|
||||
vllm.distributed.parallel_state._PP = init_model_parallel_group(
|
||||
group_ranks, get_world_group().local_rank, backend, group_name="pp")
|
||||
|
||||
assert vllm.distributed.parallel_state._DP is None, (
|
||||
"data parallel group is already initialized")
|
||||
if envs.VLLM_BR_ENABLE_TP_GROUPS_IN_SUPERNODE:
|
||||
group_ranks = groups['dp_groups']
|
||||
else:
|
||||
group_ranks = all_ranks.transpose(1, 3).reshape(
|
||||
-1, data_parallel_size).unbind(0)
|
||||
group_ranks = [x.tolist() for x in group_ranks]
|
||||
vllm.distributed.parallel_state._DP = init_model_parallel_group(
|
||||
group_ranks, get_world_group().local_rank, backend, group_name="dp")
|
||||
|
||||
assert vllm.distributed.parallel_state._EP is None, (
|
||||
"expert parallel group is already initialized")
|
||||
if envs.VLLM_BR_ENABLE_TP_GROUPS_IN_SUPERNODE:
|
||||
group_ranks = groups['ep_groups']
|
||||
else:
|
||||
group_ranks = all_ranks.transpose(1, 2).reshape(
|
||||
-1, data_parallel_size * tensor_model_parallel_size).unbind(0)
|
||||
group_ranks = [x.tolist() for x in group_ranks]
|
||||
vllm.distributed.parallel_state._EP = init_model_parallel_group(
|
||||
group_ranks, get_world_group().local_rank, backend, group_name="ep")
|
||||
logger.info(
|
||||
"rank %s in world size %s is assigned as (br) "
|
||||
"DP rank %s, PP rank %s, TP rank %s, EP rank %s", rank, world_size,
|
||||
vllm.distributed.parallel_state._DP.rank_in_group,
|
||||
vllm.distributed.parallel_state._PP.rank_in_group,
|
||||
vllm.distributed.parallel_state._TP.rank_in_group,
|
||||
vllm.distributed.parallel_state._EP.rank_in_group)
|
||||
|
||||
|
||||
vllm.distributed.parallel_state.initialize_model_parallel = initialize_model_parallel_cross_tp
|
||||
|
||||
|
||||
def send_tensor_dict(
|
||||
self,
|
||||
tensor_dict: dict[str, Union[torch.Tensor, Any]],
|
||||
dst: Optional[int] = None,
|
||||
all_gather_group: Optional["GroupCoordinator"] = None,
|
||||
all_gather_tensors: Optional[dict[str, bool]] = None,
|
||||
) -> Optional[dict[str, Union[torch.Tensor, Any]]]:
|
||||
"""Send the input tensor dictionary.
|
||||
NOTE: `dst` is the local rank of the source rank.
|
||||
|
||||
all_gather_group: The group for the all-gather operation. If provided,
|
||||
an optimization is enabled where each rank in the group sends a
|
||||
slice of a tensor and the receiver reconstructs it using an
|
||||
all-gather, which can improve performance. This is typically the
|
||||
tensor-parallel group.
|
||||
all_gather_tensors: A dictionary to specify which tensors should use
|
||||
the all-gather optimization, which is only effective when
|
||||
`all_gather_group` is provided. By default, this optimization is
|
||||
on for any tensor whose size is divisible by the
|
||||
`all_gather_group`'s world size. However, it should be disabled
|
||||
for tensors that are not fully replicated across the group (e.g.,
|
||||
the residual tensor when sequence parallelism is enabled). This
|
||||
dictionary allows overriding the default behavior on a per-tensor
|
||||
basis.
|
||||
"""
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if not torch.distributed.is_initialized() or self.world_size == 1:
|
||||
return tensor_dict
|
||||
all_gather_size = (1 if all_gather_group is None else
|
||||
all_gather_group.world_size)
|
||||
all_gather_rank = (0 if all_gather_group is None else
|
||||
all_gather_group.rank_in_group)
|
||||
|
||||
group = self.device_group
|
||||
metadata_group = self.cpu_group
|
||||
|
||||
if dst is None:
|
||||
dst = (self.rank_in_group + 1) % self.world_size
|
||||
assert dst < self.world_size, f"Invalid dst rank ({dst})"
|
||||
|
||||
if self.use_cpu_custom_send_recv:
|
||||
if self.device_communicator is None:
|
||||
raise ValueError("No device communicator found")
|
||||
self.device_communicator.send_tensor_dict( # type: ignore
|
||||
tensor_dict, dst)
|
||||
return None
|
||||
|
||||
metadata_list: list[tuple[Any, Any]] = []
|
||||
assert isinstance(tensor_dict,
|
||||
dict), f"Expecting a dictionary, got {type(tensor_dict)}"
|
||||
metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
|
||||
# `metadata_list` lives in CPU memory.
|
||||
# `send_object_list` has serialization & deserialization,
|
||||
# all happening on CPU. Therefore, we can use the CPU group.
|
||||
self.send_object(metadata_list, dst=dst)
|
||||
|
||||
tensor_keys = [
|
||||
k for k, v in tensor_dict.items() if isinstance(v, torch.Tensor)
|
||||
]
|
||||
assert len(tensor_keys) == len(tensor_list)
|
||||
|
||||
for key, tensor in zip(tensor_keys, tensor_list):
|
||||
if tensor.numel() == 0:
|
||||
# Skip sending empty tensors.
|
||||
continue
|
||||
|
||||
# send-allgather: send only a slice, then do allgather.
|
||||
use_all_gather = (all_gather_group is not None
|
||||
and tensor.numel() % all_gather_size == 0)
|
||||
use_all_gather = all_gather_tensors.get(key, use_all_gather) \
|
||||
if all_gather_tensors else use_all_gather
|
||||
if use_all_gather:
|
||||
tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank]
|
||||
|
||||
if tensor.is_cpu:
|
||||
# use metadata_group for CPU tensors
|
||||
torch.distributed.send(tensor,
|
||||
dst=self.ranks[dst],
|
||||
group=metadata_group)
|
||||
else:
|
||||
# ensure tensor is ready
|
||||
torch.supa.synchronize()
|
||||
# use group for GPU tensors
|
||||
torch.distributed.send(tensor, dst=self.ranks[dst], group=group)
|
||||
return None
|
||||
|
||||
|
||||
def recv_tensor_dict(
|
||||
self,
|
||||
src: Optional[int] = None,
|
||||
all_gather_group: Optional["GroupCoordinator"] = None,
|
||||
all_gather_tensors: Optional[dict[str, bool]] = None,
|
||||
) -> Optional[dict[str, Union[torch.Tensor, Any]]]:
|
||||
"""Recv the input tensor dictionary.
|
||||
NOTE: `src` is the local rank of the source rank.
|
||||
|
||||
all_gather_group: The group for the all-gather operation. If provided,
|
||||
an optimization is enabled where each rank in the group sends a
|
||||
slice of a tensor and the receiver reconstructs it using an
|
||||
all-gather, which can improve performance. This is typically the
|
||||
tensor-parallel group.
|
||||
all_gather_tensors: A dictionary to specify which tensors should use
|
||||
the all-gather optimization, which is only effective when
|
||||
`all_gather_group` is provided. By default, this optimization is
|
||||
on for any tensor whose size is divisible by the
|
||||
`all_gather_group`'s world size. However, it should be disabled
|
||||
for tensors that are not fully replicated across the group (e.g.,
|
||||
the residual tensor when sequence parallelism is enabled). This
|
||||
dictionary allows overriding the default behavior on a per-tensor
|
||||
basis.
|
||||
"""
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if not torch.distributed.is_initialized() or self.world_size == 1:
|
||||
return None
|
||||
all_gather_size = (1 if all_gather_group is None else
|
||||
all_gather_group.world_size)
|
||||
all_gather_rank = (0 if all_gather_group is None else
|
||||
all_gather_group.rank_in_group)
|
||||
|
||||
group = self.device_group
|
||||
metadata_group = self.cpu_group
|
||||
|
||||
if src is None:
|
||||
src = (self.rank_in_group - 1) % self.world_size
|
||||
assert src < self.world_size, f"Invalid src rank ({src})"
|
||||
|
||||
if self.use_cpu_custom_send_recv:
|
||||
if self.device_communicator is None:
|
||||
raise ValueError("No device communicator found")
|
||||
return self.device_communicator.recv_tensor_dict( # type: ignore
|
||||
src)
|
||||
|
||||
recv_metadata_list = self.recv_object(src=src)
|
||||
tensor_dict: dict[str, Any] = {}
|
||||
for key, value in recv_metadata_list:
|
||||
if isinstance(value, TensorMetadata):
|
||||
tensor = torch.empty(value.size,
|
||||
dtype=value.dtype,
|
||||
device=value.device)
|
||||
if tensor.numel() == 0:
|
||||
# Skip broadcasting empty tensors.
|
||||
tensor_dict[key] = tensor
|
||||
continue
|
||||
|
||||
# send-allgather: send only a slice, then do allgather.
|
||||
use_all_gather = (all_gather_group is not None
|
||||
and tensor.numel() % all_gather_size == 0)
|
||||
use_all_gather = all_gather_tensors.get(key, use_all_gather) \
|
||||
if all_gather_tensors else use_all_gather
|
||||
|
||||
if use_all_gather:
|
||||
orig_shape = tensor.shape
|
||||
tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank]
|
||||
|
||||
if tensor.is_cpu:
|
||||
# use metadata_group for CPU tensors
|
||||
torch.distributed.recv(tensor,
|
||||
src=self.ranks[src],
|
||||
group=metadata_group)
|
||||
else:
|
||||
# use group for GPU tensors
|
||||
torch.distributed.recv(tensor,
|
||||
src=self.ranks[src],
|
||||
group=group)
|
||||
# ensure recv is done
|
||||
torch.supa.synchronize()
|
||||
if use_all_gather:
|
||||
# do the allgather
|
||||
tensor = all_gather_group.all_gather( # type: ignore
|
||||
tensor, dim=0)
|
||||
tensor = tensor.reshape(orig_shape)
|
||||
|
||||
tensor_dict[key] = tensor
|
||||
else:
|
||||
tensor_dict[key] = value
|
||||
return tensor_dict
|
||||
|
||||
|
||||
vllm.distributed.GroupCoordinator.send_tensor_dict = send_tensor_dict
|
||||
vllm.distributed.GroupCoordinator.recv_tensor_dict = recv_tensor_dict
|
||||
120
vllm_br/envs.py
Normal file
120
vllm_br/envs.py
Normal file
@@ -0,0 +1,120 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
import os
|
||||
from typing import Any, Callable, Dict
|
||||
|
||||
import pybrml
|
||||
import torch
|
||||
import torch_br
|
||||
|
||||
# The begin-* and end* here are used by the documentation generator
|
||||
# to extract the used env vars.
|
||||
|
||||
|
||||
# begin-env-vars-definition
|
||||
def check_allreduce_available():
|
||||
P2P_DIRECT_LINK_TYPE = 2
|
||||
pybrml.brmlInit()
|
||||
device_count = pybrml.brmlDeviceGetCount()
|
||||
|
||||
def is_p2p_direct_link(dev0, dev1):
|
||||
return pybrml.brmlDeviceGetP2PStatus_v3(
|
||||
dev0, dev1).type == P2P_DIRECT_LINK_TYPE
|
||||
|
||||
def get_p2p_link_info(device_count):
|
||||
p2p_link_info = []
|
||||
for i in range(device_count):
|
||||
current_link_info = []
|
||||
current_dev = pybrml.brmlDeviceGetHandleByIndex(i)
|
||||
for j in range(device_count):
|
||||
other_dev = pybrml.brmlDeviceGetHandleByIndex(j)
|
||||
current_link_info.append(
|
||||
is_p2p_direct_link(current_dev, other_dev))
|
||||
p2p_link_info.append(current_link_info)
|
||||
return p2p_link_info
|
||||
|
||||
p2p_link_info = get_p2p_link_info(device_count)
|
||||
all_reduce_count = sum(p2p_link_info[0])
|
||||
all_reduce = 1
|
||||
if all_reduce_count == 3:
|
||||
all_reduce = 4
|
||||
elif all_reduce_count == 4:
|
||||
all_reduce = 8
|
||||
pybrml.brmlShutdown()
|
||||
return all_reduce
|
||||
|
||||
|
||||
_VLLM_BR_USE_FUSED_ALLREDUCE_CACHE = check_allreduce_available()
|
||||
|
||||
env_variables: Dict[str, Callable[[], Any]] = {
|
||||
"VLLM_VERSION":
|
||||
lambda: os.getenv("VLLM_VERSION", None),
|
||||
"VLLM_BR_USE_PAGED_ATTN":
|
||||
lambda: os.getenv("VLLM_BR_USE_PAGED_ATTN", False),
|
||||
"VLLM_BR_WEIGHT_TYPE":
|
||||
lambda: os.getenv("VLLM_BR_WEIGHT_TYPE", "NUMA"),
|
||||
"VLLM_BR_QUANT_METHOD":
|
||||
lambda: os.getenv("VLLM_BR_QUANT_METHOD", "INT8"),
|
||||
"VLLM_BR_USE_FUSED_ALLREDUCE":
|
||||
lambda: int(
|
||||
os.getenv("VLLM_BR_USE_FUSED_ALLREDUCE",
|
||||
_VLLM_BR_USE_FUSED_ALLREDUCE_CACHE)),
|
||||
"VLLM_BR_EMBEDDING_S0B":
|
||||
lambda: bool(int(os.getenv("VLLM_BR_EMBEDDING_S0B", False))),
|
||||
# MoE (DeepSeek)
|
||||
"VLLM_BR_STATIC_MOE_DECODER_MAX_LEN":
|
||||
lambda: int(os.getenv("VLLM_BR_STATIC_MOE_DECODER_MAX_LEN", "256")),
|
||||
# NOTE: following are device properties
|
||||
"VLLM_BR_DEVICE_SPC_NUM":
|
||||
lambda: int(
|
||||
os.getenv(
|
||||
"VLLM_BR_DEVICE_SPC_NUM",
|
||||
torch_br.supa.get_device_properties(torch.device("supa")).
|
||||
max_compute_units)),
|
||||
"VLLM_BR_DEVICE_WARP_SIZE":
|
||||
lambda: int(os.getenv("VLLM_BR_DEVICE_WARP_SIZE", 32)),
|
||||
"VLLM_BR_USE_CPU_ALL_REDUCE":
|
||||
lambda: int(os.getenv("VLLM_BR_USE_CPU_ALL_REDUCE", 0)),
|
||||
"VLLM_SCCL_SO_PATH":
|
||||
lambda: os.getenv(
|
||||
"VLLM_SCCL_SO_PATH",
|
||||
"/usr/local/birensupa/base/latest/succl/lib/x86_64-linux-gnu/libsuccl.so"
|
||||
),
|
||||
"VLLM_RANDOMIZE_DP_DUMMY_INPUTS":
|
||||
lambda: bool(int(os.getenv("VLLM_RANDOMIZE_DP_DUMMY_INPUTS", False))),
|
||||
"VLLM_PP_CPU_SEND_RECV":
|
||||
lambda: bool(int(os.getenv("VLLM_PP_CPU_SEND_RECV", False))),
|
||||
"VLLM_BR_USE_FP32_ALL_REDUCE":
|
||||
lambda: int(os.getenv("VLLM_BR_USE_FP32_ALL_REDUCE", 0)),
|
||||
"VLLM_BR_USE_MROPE_0_9_2":
|
||||
lambda: bool(os.getenv("VLLM_BR_USE_MROPE_0_9_2", False)),
|
||||
"VLLM_BR_ENABLE_TP_GROUPS_IN_SUPERNODE":
|
||||
lambda: bool(int(os.getenv("VLLM_BR_ENABLE_TP_GROUPS_IN_SUPERNODE", "0"))),
|
||||
}
|
||||
|
||||
# end-env-vars-definition
|
||||
|
||||
|
||||
def __getattr__(name: str):
|
||||
# lazy evaluation of environment variables
|
||||
if name in env_variables:
|
||||
return env_variables[name]()
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
|
||||
|
||||
def __dir__():
|
||||
return list(env_variables.keys())
|
||||
15
vllm_br/executor/__init__.py
Normal file
15
vllm_br/executor/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
BIN
vllm_br/executor/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
vllm_br/executor/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
356
vllm_br/executor/ray_distributed_executor.py
Normal file
356
vllm_br/executor/ray_distributed_executor.py
Normal file
@@ -0,0 +1,356 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Dict, List
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.executor.ray_distributed_executor import RayDistributedExecutor
|
||||
from vllm.executor.ray_utils import RayWorkerWrapper, ray
|
||||
# from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.ray.ray_env import get_env_vars_to_copy
|
||||
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
|
||||
from vllm_br import envs as envs_br
|
||||
|
||||
if ray is not None:
|
||||
from ray.actor import ActorHandle
|
||||
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
|
||||
else:
|
||||
ActorHandle = None
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.util.placement_group import PlacementGroup
|
||||
|
||||
from vllm.executor.ray_distributed_executor import RayWorkerMetaData, logger
|
||||
|
||||
|
||||
def get_supernode_pp_tp_global_rank_map(tp_size, pp_size):
|
||||
rank_map = {}
|
||||
tp_driver_rank = []
|
||||
for pp_rank in range(pp_size):
|
||||
for tp_rank in range(tp_size):
|
||||
# PP=2, TP=8
|
||||
# pp_tp_workers = [[0, 1, 2, 3, 8, 9, 10, 11], [4, 5, 6, 7, 12, 13, 14, 15]]
|
||||
if tp_rank < 4 and pp_rank < 1:
|
||||
rank = (pp_rank * pp_size) + tp_rank
|
||||
elif tp_rank >= 4 and pp_rank < 1:
|
||||
rank = (pp_rank * pp_size) + tp_rank + 4
|
||||
elif tp_rank < 4 and pp_rank >= 1:
|
||||
rank = (pp_rank * pp_size) + tp_rank + 2
|
||||
elif tp_rank >= 4 and pp_rank >= 1:
|
||||
rank = (pp_rank * pp_size) + tp_rank + 6
|
||||
rank_map[(pp_rank, tp_rank)] = rank
|
||||
if tp_rank == 0:
|
||||
tp_driver_rank.append(rank)
|
||||
return rank_map, tp_driver_rank
|
||||
|
||||
|
||||
def _init_workers_ray_br(self, placement_group: "PlacementGroup",
|
||||
**ray_remote_kwargs):
|
||||
num_gpus = envs.VLLM_RAY_PER_WORKER_GPUS
|
||||
if envs_br.VLLM_BR_ENABLE_TP_GROUPS_IN_SUPERNODE:
|
||||
rank_map, tp_driver_rank = get_supernode_pp_tp_global_rank_map(
|
||||
self.parallel_config.tensor_parallel_size,
|
||||
self.parallel_config.pipeline_parallel_size)
|
||||
|
||||
# The driver dummy worker does not actually use any resources.
|
||||
# It holds the resource for the driver worker.
|
||||
self.driver_dummy_worker = None
|
||||
# The remaining workers are the actual ray actors.
|
||||
self.workers = []
|
||||
|
||||
# Used in ray compiled DAG: indexed first by PP rank,
|
||||
# and then TP rank. In other words, the inner list is
|
||||
# the TP group of workers for a PP rank.
|
||||
self.pp_tp_workers = []
|
||||
|
||||
if self.parallel_config.ray_workers_use_nsight:
|
||||
ray_remote_kwargs = self._configure_ray_workers_use_nsight(
|
||||
ray_remote_kwargs)
|
||||
|
||||
logger.info("use_ray_spmd_worker: %s", self.use_ray_spmd_worker)
|
||||
|
||||
# Create the workers.
|
||||
bundle_indices: List[int]
|
||||
if envs.VLLM_RAY_BUNDLE_INDICES:
|
||||
# Use the bundle indices specified by the user.
|
||||
bundle_indices = list(map(int,
|
||||
envs.VLLM_RAY_BUNDLE_INDICES.split(",")))
|
||||
assert len(bundle_indices) == self.parallel_config.world_size, \
|
||||
("VLLM_RAY_BUNDLE_INDICES must have the same size"
|
||||
f" as the world size, but got {bundle_indices=} "
|
||||
f"and {self.parallel_config.world_size=}")
|
||||
assert len(set(bundle_indices)) == len(bundle_indices), \
|
||||
("VLLM_RAY_BUNDLE_INDICES cannot have duplicate values,"
|
||||
f" but got {bundle_indices=}")
|
||||
else:
|
||||
# use the first N bundles that have GPU resources.
|
||||
bundle_indices = []
|
||||
for bundle_id, bundle in enumerate(placement_group.bundle_specs):
|
||||
if bundle.get(current_platform.ray_device_key, 0):
|
||||
bundle_indices.append(bundle_id)
|
||||
bundle_indices = bundle_indices[:self.parallel_config.world_size]
|
||||
|
||||
worker_metadata: List[RayWorkerMetaData] = []
|
||||
driver_ip = get_ip()
|
||||
for rank, bundle_id in enumerate(bundle_indices):
|
||||
scheduling_strategy = PlacementGroupSchedulingStrategy(
|
||||
placement_group=placement_group,
|
||||
placement_group_capture_child_tasks=True,
|
||||
placement_group_bundle_index=bundle_id,
|
||||
)
|
||||
|
||||
if current_platform.ray_device_key == "GPU":
|
||||
# NV+AMD GPUs, and Intel XPUs
|
||||
worker = ray.remote(
|
||||
num_cpus=0,
|
||||
num_gpus=num_gpus,
|
||||
scheduling_strategy=scheduling_strategy,
|
||||
**ray_remote_kwargs,
|
||||
)(RayWorkerWrapper).remote(vllm_config=self.vllm_config,
|
||||
rpc_rank=rank)
|
||||
else:
|
||||
worker = ray.remote(
|
||||
num_cpus=0,
|
||||
num_gpus=0,
|
||||
resources={current_platform.ray_device_key: num_gpus},
|
||||
scheduling_strategy=scheduling_strategy,
|
||||
**ray_remote_kwargs,
|
||||
)(RayWorkerWrapper).remote(vllm_config=self.vllm_config,
|
||||
rpc_rank=rank)
|
||||
worker_metadata.append(
|
||||
RayWorkerMetaData(worker=worker, created_rank=rank))
|
||||
|
||||
worker_ips = ray.get([
|
||||
each.worker.get_node_ip.remote() # type: ignore[attr-defined]
|
||||
for each in worker_metadata
|
||||
])
|
||||
|
||||
for each, ip in zip(worker_metadata, worker_ips):
|
||||
each.ip = ip
|
||||
|
||||
if not self.use_ray_spmd_worker:
|
||||
for i, each in enumerate(worker_metadata):
|
||||
# find and remove the dummy worker from the list
|
||||
worker = each.worker
|
||||
worker_ip = each.ip
|
||||
if self.driver_dummy_worker is None and worker_ip == driver_ip:
|
||||
# If the worker is on the same node as the driver, we use it
|
||||
# as the resource holder for the driver process.
|
||||
self.driver_dummy_worker = worker
|
||||
self.driver_worker = RayWorkerWrapper(
|
||||
vllm_config=self.vllm_config, rpc_rank=0)
|
||||
worker_metadata.pop(i)
|
||||
break
|
||||
|
||||
logger.debug("workers: %s", worker_metadata)
|
||||
logger.debug("driver_dummy_worker: %s", self.driver_dummy_worker)
|
||||
if not self.use_ray_spmd_worker and self.driver_dummy_worker is None:
|
||||
raise ValueError(
|
||||
"Ray does not allocate any GPUs on the driver node."
|
||||
f"Driver IP: {driver_ip}, worker IPs: {worker_ips}."
|
||||
"Consider adjusting the Ray placement group or running "
|
||||
"the driver on a GPU node.")
|
||||
|
||||
ip_counts: Dict[str, int] = {}
|
||||
for ip in worker_ips:
|
||||
ip_counts[ip] = ip_counts.get(ip, 0) + 1
|
||||
|
||||
def sort_by_driver_then_worker_ip(item: RayWorkerMetaData):
|
||||
"""
|
||||
Sort the workers based on 3 properties:
|
||||
1. If the worker is on the same node as the driver (vllm engine),
|
||||
it should be placed first.
|
||||
2. Then, if the worker is on a node with fewer workers, it should
|
||||
be placed first.
|
||||
3. Finally, if the work is on a node with smaller IP address, it
|
||||
should be placed first.
|
||||
"""
|
||||
ip = item.ip
|
||||
return (0 if ip == driver_ip else 1, ip_counts[ip], ip)
|
||||
|
||||
# After sorting, the workers on the same node will be
|
||||
# close to each other, and the workers on the driver
|
||||
# node will be placed first.
|
||||
sorted_worker_metadata = sorted(worker_metadata,
|
||||
key=sort_by_driver_then_worker_ip)
|
||||
start_rank = 0 if self.use_ray_spmd_worker else 1
|
||||
for i, item in enumerate(sorted_worker_metadata):
|
||||
item.adjusted_rank = i + start_rank
|
||||
self.workers = [item.worker for item in sorted_worker_metadata]
|
||||
rerank_mapping = {
|
||||
item.created_rank: item.adjusted_rank
|
||||
for item in sorted_worker_metadata
|
||||
}
|
||||
self._run_workers("adjust_rank", rerank_mapping)
|
||||
|
||||
# Get the set of GPU IDs used on each node.
|
||||
worker_node_and_gpu_ids = []
|
||||
for worker in [self.driver_dummy_worker] + self.workers:
|
||||
if worker is None:
|
||||
# driver_dummy_worker can be None when using ray spmd worker.
|
||||
continue
|
||||
worker_node_and_gpu_ids.append(
|
||||
ray.get(worker.get_node_and_gpu_ids.remote()) \
|
||||
) # type: ignore
|
||||
|
||||
node_workers = defaultdict(list) # node id -> list of worker ranks
|
||||
node_gpus = defaultdict(list) # node id -> list of gpu ids
|
||||
|
||||
for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids):
|
||||
node_workers[node_id].append(i)
|
||||
# `gpu_ids` can be a list of strings or integers.
|
||||
# convert them to integers for consistency.
|
||||
# NOTE: gpu_ids can be larger than 9 (e.g. 16 GPUs),
|
||||
# string sorting is not sufficient.
|
||||
# see https://github.com/vllm-project/vllm/issues/5590
|
||||
gpu_ids = [int(x) for x in gpu_ids]
|
||||
node_gpus[node_id].extend(gpu_ids)
|
||||
for node_id, gpu_ids in node_gpus.items():
|
||||
node_gpus[node_id] = sorted(gpu_ids)
|
||||
|
||||
all_ips = set(worker_ips + [driver_ip])
|
||||
n_ips = len(all_ips)
|
||||
n_nodes = len(node_workers)
|
||||
|
||||
if n_nodes != n_ips:
|
||||
raise RuntimeError(
|
||||
f"Every node should have a unique IP address. Got {n_nodes}"
|
||||
f" nodes with node ids {list(node_workers.keys())} and "
|
||||
f"{n_ips} unique IP addresses {all_ips}. Please check your"
|
||||
" network configuration. If you set `VLLM_HOST_IP`"
|
||||
" environment variable, make sure it is unique for"
|
||||
" each node.")
|
||||
|
||||
# Set environment variables for the driver and workers.
|
||||
all_args_to_update_environment_variables = [{
|
||||
current_platform.device_control_env_var:
|
||||
",".join(map(str, node_gpus[node_id])),
|
||||
} for (node_id, _) in worker_node_and_gpu_ids]
|
||||
|
||||
# Environment variables to copy from driver to workers
|
||||
env_vars_to_copy = get_env_vars_to_copy(
|
||||
exclude_vars=self.WORKER_SPECIFIC_ENV_VARS,
|
||||
additional_vars=set(current_platform.additional_env_vars).union(
|
||||
self.ADDITIONAL_ENV_VARS),
|
||||
destination="workers")
|
||||
|
||||
# Copy existing env vars to each worker's args
|
||||
for args in all_args_to_update_environment_variables:
|
||||
# TODO: refactor platform-specific env vars
|
||||
for name in env_vars_to_copy:
|
||||
if name in os.environ:
|
||||
args[name] = os.environ[name]
|
||||
|
||||
self._env_vars_for_all_workers = (all_args_to_update_environment_variables)
|
||||
|
||||
self._run_workers("update_environment_variables",
|
||||
self._get_env_vars_to_be_updated())
|
||||
|
||||
if len(node_gpus) == 1:
|
||||
# in single node case, we don't need to get the IP address.
|
||||
# the loopback address is sufficient
|
||||
# NOTE: a node may have several IP addresses, one for each
|
||||
# network interface. `get_ip()` might return any of them,
|
||||
# while they might not work for communication inside the node
|
||||
# if the network setup is complicated. Using the loopback address
|
||||
# solves this issue, as it always works for communication inside
|
||||
# the node.
|
||||
driver_ip = "127.0.0.1"
|
||||
distributed_init_method = get_distributed_init_method(
|
||||
driver_ip, get_open_port())
|
||||
|
||||
# Initialize the actual workers inside worker wrapper.
|
||||
all_kwargs = []
|
||||
for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids):
|
||||
local_rank = node_workers[node_id].index(rank)
|
||||
if envs_br.VLLM_BR_ENABLE_TP_GROUPS_IN_SUPERNODE:
|
||||
kwargs = dict(
|
||||
vllm_config=self.vllm_config,
|
||||
local_rank=local_rank,
|
||||
rank=rank,
|
||||
distributed_init_method=distributed_init_method,
|
||||
is_driver_worker=(not self.parallel_config)
|
||||
or (rank in tp_driver_rank),
|
||||
)
|
||||
else:
|
||||
kwargs = dict(
|
||||
vllm_config=self.vllm_config,
|
||||
local_rank=local_rank,
|
||||
rank=rank,
|
||||
distributed_init_method=distributed_init_method,
|
||||
is_driver_worker=(not self.parallel_config)
|
||||
or (rank % self.parallel_config.tensor_parallel_size == 0),
|
||||
)
|
||||
all_kwargs.append(kwargs)
|
||||
self._run_workers("init_worker", all_kwargs)
|
||||
|
||||
self._run_workers("init_device")
|
||||
self._run_workers("load_model",
|
||||
max_concurrent_workers=self.parallel_config.
|
||||
max_parallel_loading_workers)
|
||||
|
||||
if self.use_ray_spmd_worker:
|
||||
if envs_br.VLLM_BR_ENABLE_TP_GROUPS_IN_SUPERNODE:
|
||||
for pp_rank in range(self.parallel_config.pipeline_parallel_size):
|
||||
self.pp_tp_workers.append([])
|
||||
for tp_rank in range(
|
||||
self.parallel_config.tensor_parallel_size):
|
||||
# PP=8, TP=2
|
||||
# pp_tp_workers = [[0, 1, 2, 3, 8, 9, 10, 11], [4, 5, 6, 7, 12, 13, 14, 15]]
|
||||
rank = rank_map[(pp_rank, tp_rank)]
|
||||
self.pp_tp_workers[pp_rank].append(self.workers[rank])
|
||||
else:
|
||||
for pp_rank in range(self.parallel_config.pipeline_parallel_size):
|
||||
self.pp_tp_workers.append([])
|
||||
for tp_rank in range(
|
||||
self.parallel_config.tensor_parallel_size):
|
||||
# PP=2, TP=4
|
||||
# pp_tp_workers = [[0, 1, 2, 3], [4, 5, 6, 7]]
|
||||
rank = (pp_rank * self.parallel_config.tensor_parallel_size
|
||||
) + tp_rank
|
||||
assert len(self.pp_tp_workers[pp_rank]) == tp_rank
|
||||
assert pp_rank < len(self.pp_tp_workers)
|
||||
self.pp_tp_workers[pp_rank].append(self.workers[rank])
|
||||
|
||||
# This is the list of workers that are rank 0 of each TP group EXCEPT
|
||||
# global rank 0. These are the workers that will broadcast to the
|
||||
# rest of the workers.
|
||||
self.tp_driver_workers = []
|
||||
# This is the list of workers that are not drivers and not the first
|
||||
# worker in a TP group. These are the workers that will be
|
||||
# broadcasted to.
|
||||
self.non_driver_workers = []
|
||||
|
||||
# Enforce rank order for correct rank to return final output.
|
||||
for index, worker in enumerate(self.workers):
|
||||
# The driver worker is rank 0 and not in self.workers.
|
||||
rank = index + 1
|
||||
if envs_br.VLLM_BR_ENABLE_TP_GROUPS_IN_SUPERNODE:
|
||||
if rank in tp_driver_rank:
|
||||
self.tp_driver_workers.append(worker)
|
||||
else:
|
||||
self.non_driver_workers.append(worker)
|
||||
else:
|
||||
if rank % self.parallel_config.tensor_parallel_size == 0:
|
||||
self.tp_driver_workers.append(worker)
|
||||
else:
|
||||
self.non_driver_workers.append(worker)
|
||||
|
||||
|
||||
RayDistributedExecutor._init_workers_ray = _init_workers_ray_br # noqa: E501
|
||||
418
vllm_br/forward_context.py
Normal file
418
vllm_br/forward_context.py
Normal file
@@ -0,0 +1,418 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
import vllm
|
||||
import vllm.envs as envs
|
||||
from vllm.config import ParallelConfig, VllmConfig
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.worker.ubatch_utils import UBatchSlices, is_second_ubatch_empty
|
||||
from vllm_br.config.compilation import SUPAGraphMode
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
track_batchsize: bool = envs.VLLM_LOG_BATCHSIZE_INTERVAL >= 0
|
||||
last_logging_time: float = 0
|
||||
forward_start_time: float = 0
|
||||
batchsize_logging_interval: float = envs.VLLM_LOG_BATCHSIZE_INTERVAL
|
||||
batchsize_forward_time: defaultdict = defaultdict(list)
|
||||
|
||||
|
||||
class BatchDescriptor(NamedTuple):
|
||||
"""
|
||||
Batch descriptor for supagraph dispatching. We should keep the num of
|
||||
items as minimal as possible to properly and uniquely describe the padded
|
||||
batch for supagraph.
|
||||
"""
|
||||
num_tokens: int
|
||||
uniform_decode: bool
|
||||
"""
|
||||
False can also be used for an uniform decode batch to dispatch to the
|
||||
supagraph supporting non-uniform batches.
|
||||
"""
|
||||
|
||||
@property
|
||||
def non_uniform(self) -> "BatchDescriptor":
|
||||
"""
|
||||
Return a non-uniform version of current batch descriptor.
|
||||
"""
|
||||
return BatchDescriptor(self.num_tokens, self.uniform_decode)
|
||||
|
||||
|
||||
def _compute_sp_num_tokens(num_tokens_across_dp_cpu: torch.Tensor,
|
||||
sequence_parallel_size: int) -> list[int]:
|
||||
sp_tokens = ((num_tokens_across_dp_cpu + sequence_parallel_size - 1) //
|
||||
sequence_parallel_size)
|
||||
|
||||
sp_tokens = sp_tokens.repeat_interleave(sequence_parallel_size)
|
||||
return sp_tokens.tolist()
|
||||
|
||||
|
||||
def _compute_chunked_local_num_tokens(num_tokens_across_dp_cpu: torch.Tensor,
|
||||
sequence_parallel_size: int,
|
||||
max_num_tokens: int,
|
||||
chunk_idx: int) -> list[int]:
|
||||
|
||||
sp_tokens = _compute_sp_num_tokens(num_tokens_across_dp_cpu,
|
||||
sequence_parallel_size)
|
||||
sp_size = len(sp_tokens)
|
||||
|
||||
local_size = [-1] * sp_size
|
||||
for i in range(sp_size):
|
||||
# Take into account sharding if MoE activation is sequence parallel.
|
||||
local_size[i] = min(max_num_tokens,
|
||||
sp_tokens[i] - (max_num_tokens * chunk_idx))
|
||||
if local_size[i] <= 0:
|
||||
local_size[i] = 1 # ensure lockstep even if done
|
||||
return local_size
|
||||
|
||||
|
||||
@dataclass
|
||||
class DPMetadata:
|
||||
max_tokens_across_dp_cpu: torch.Tensor
|
||||
num_tokens_across_dp_cpu: torch.Tensor
|
||||
|
||||
# NOTE: local_sizes should only be set by the chunked_sizes context manager
|
||||
local_sizes: Optional[list[int]] = None
|
||||
|
||||
@staticmethod
|
||||
def num_tokens_across_dp(num_tokens: int, dp_size: int,
|
||||
dp_rank: int) -> torch.Tensor:
|
||||
"""
|
||||
Gather the num_tokens across all DP ranks and return results in a
|
||||
CPU tensor of size dp_size.
|
||||
"""
|
||||
from vllm.distributed.parallel_state import get_dp_group
|
||||
device = current_platform.device_type
|
||||
group = get_dp_group().device_group
|
||||
|
||||
# Transferring this tensor from GPU to CPU will introduce a GPU sync
|
||||
# point that could adversely affect performance of vllm with asynch
|
||||
# scheduling. This environment variable exists to quickly disable
|
||||
# this optimization if we run into this case.
|
||||
if envs.VLLM_DISABLE_NCCL_FOR_DP_SYNCHRONIZATION:
|
||||
logger.info_once(
|
||||
"Using CPU all reduce to synchronize DP padding between ranks."
|
||||
)
|
||||
device = "cpu"
|
||||
group = get_dp_group().cpu_group
|
||||
num_tokens_across_dp = [0] * dp_size
|
||||
num_tokens_across_dp[dp_rank] = num_tokens
|
||||
num_tokens_tensor = torch.tensor(num_tokens_across_dp,
|
||||
device=device,
|
||||
dtype=torch.int32)
|
||||
dist.all_reduce(num_tokens_tensor, group=group)
|
||||
return num_tokens_tensor.cpu()
|
||||
|
||||
# Get the cumulative tokens across sequence parallel ranks.
|
||||
# In this case the input to the MoEs will be distributed w.r.t both
|
||||
# DP and TP rank.
|
||||
# When sp_size==1, this is just the cumulative num tokens across DP.
|
||||
def cu_tokens_across_sp(self, sp_size: int) -> torch.Tensor:
|
||||
num_tokens_across_sp_cpu = (
|
||||
(self.num_tokens_across_dp_cpu - 1 + sp_size) // sp_size)
|
||||
num_tokens_across_sp_cpu = (
|
||||
num_tokens_across_sp_cpu.repeat_interleave(sp_size))
|
||||
return torch.cumsum(num_tokens_across_sp_cpu, dim=0)
|
||||
|
||||
@staticmethod
|
||||
def should_ubatch_across_dp(
|
||||
should_ubatch: bool, orig_num_tokens_per_ubatch: int,
|
||||
padded_num_tokens_per_ubatch: int, dp_size: int,
|
||||
dp_rank: int) -> tuple[bool, Optional[torch.Tensor]]:
|
||||
"""
|
||||
1. Decides if each DP rank is going to microbatch. Either all ranks
|
||||
run with microbatching or none of them do. If this function decides
|
||||
not to run with microbatching. It will "abort" meaning that no padding
|
||||
information will be returned to the caller. It will return (False, None)
|
||||
|
||||
2. Determines the total number of tokens that each rank will run.
|
||||
All ranks will be padded out so that the run with the same number
|
||||
of tokens
|
||||
|
||||
Returns: tuple[
|
||||
should_ubatch: Are all DP ranks going to microbatch
|
||||
num_tokens_after_padding: A tensor containing the total number of
|
||||
tokens per-microbatch for each DP rank including padding. Will be
|
||||
None if should_ubatch if False
|
||||
]
|
||||
"""
|
||||
|
||||
device = current_platform.device_type
|
||||
tensor = torch.zeros(3, dp_size, device=device, dtype=torch.int32)
|
||||
tensor[0][dp_rank] = orig_num_tokens_per_ubatch
|
||||
tensor[1][dp_rank] = padded_num_tokens_per_ubatch
|
||||
tensor[2][dp_rank] = 1 if should_ubatch else 0
|
||||
|
||||
from vllm.distributed.parallel_state import get_dp_group
|
||||
dist.all_reduce(tensor, group=get_dp_group().device_group)
|
||||
|
||||
result: bool = bool(torch.all(tensor[2] == 1).item())
|
||||
if not result:
|
||||
return result, None
|
||||
|
||||
orig_num_tokens_tensor = tensor[0, :]
|
||||
padded_num_tokens_tensor = tensor[1, :]
|
||||
|
||||
orig_min_num_tokens = int(orig_num_tokens_tensor.min().item())
|
||||
padded_max_num_tokens = int(padded_num_tokens_tensor.max().item())
|
||||
if is_second_ubatch_empty(orig_min_num_tokens, padded_max_num_tokens):
|
||||
logger.debug("Aborting ubatching %s %s", orig_min_num_tokens,
|
||||
padded_max_num_tokens)
|
||||
return False, None
|
||||
return result, padded_num_tokens_tensor.cpu()
|
||||
|
||||
@staticmethod
|
||||
def make(
|
||||
parallel_config: ParallelConfig,
|
||||
attn_metadata: Any,
|
||||
num_tokens: int,
|
||||
num_tokens_across_dp_cpu: Optional[torch.Tensor] = None
|
||||
) -> "DPMetadata":
|
||||
|
||||
assert parallel_config.data_parallel_size > 1
|
||||
dp_size = parallel_config.data_parallel_size
|
||||
dp_rank = parallel_config.data_parallel_rank
|
||||
if attn_metadata is not None and hasattr(attn_metadata,
|
||||
"num_prefill_tokens"):
|
||||
# for v0 attention backends
|
||||
batchsize = attn_metadata.num_prefill_tokens + \
|
||||
attn_metadata.num_decode_tokens
|
||||
else:
|
||||
# for v1 attention backends or no attn_metadata
|
||||
batchsize = num_tokens
|
||||
|
||||
# If num_tokens_across_dp is None, it will be computed by all_reduce
|
||||
# Otherwise, num_tokens_across_dp[dp_rank] should be equal to batchsize
|
||||
assert (num_tokens_across_dp_cpu is None
|
||||
or num_tokens_across_dp_cpu[dp_rank] == batchsize
|
||||
), f"{num_tokens_across_dp_cpu[dp_rank]} {batchsize}"
|
||||
if num_tokens_across_dp_cpu is None:
|
||||
num_tokens_across_dp_cpu = DPMetadata.num_tokens_across_dp(
|
||||
batchsize, dp_size, dp_rank)
|
||||
max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp_cpu)
|
||||
return DPMetadata(max_tokens_across_dp_cpu, num_tokens_across_dp_cpu)
|
||||
|
||||
@contextmanager
|
||||
def chunked_sizes(self, sequence_parallel_size: int,
|
||||
max_chunk_size_per_rank: int, chunk_idx: int):
|
||||
"""
|
||||
Context manager to compute and temporarily set the per-rank local token
|
||||
sizes for a specific chunk during chunked forward execution.
|
||||
|
||||
This is necessary to ensure each DP (data parallel) rank processes its
|
||||
designated portion of tokens in lockstep with others, even when the
|
||||
token counts are uneven or some ranks have completed their input early.
|
||||
|
||||
For chunked execution, we break up the total tokens on each rank into
|
||||
multiple chunks (of at most `max_chunk_size_per_rank`), and for a given
|
||||
`chunk_idx`, this context manager sets `self.local_sizes` to the number
|
||||
of tokens to process in that chunk on each rank.
|
||||
|
||||
`self.local_sizes` is only valid inside the context.
|
||||
|
||||
Args:
|
||||
sequence_parallel_size: When Attn is TP and MoE layers are EP,
|
||||
we use SP between the layers to avoid
|
||||
redundant ops. We need this value to
|
||||
compute the chunked sizes.
|
||||
max_chunk_size_per_rank: The max number of tokens each rank is
|
||||
allowed to process in this chunk.
|
||||
chunk_idx: The index of the chunk to compute sizes for.
|
||||
"""
|
||||
self.local_sizes = _compute_chunked_local_num_tokens(
|
||||
self.num_tokens_across_dp_cpu, sequence_parallel_size,
|
||||
max_chunk_size_per_rank, chunk_idx)
|
||||
try:
|
||||
yield self.local_sizes
|
||||
finally:
|
||||
self.local_sizes = None
|
||||
|
||||
@contextmanager
|
||||
def sp_local_sizes(self, sequence_parallel_size: int):
|
||||
"""
|
||||
Context manager for setting self.local_sizes. Same as self.chunked_sizes
|
||||
but without any chunking.
|
||||
"""
|
||||
self.local_sizes = _compute_sp_num_tokens(
|
||||
self.num_tokens_across_dp_cpu, sequence_parallel_size)
|
||||
try:
|
||||
yield self.local_sizes
|
||||
finally:
|
||||
self.local_sizes = None
|
||||
|
||||
def get_chunk_sizes_across_dp_rank(self) -> Optional[list[int]]:
|
||||
assert self.local_sizes is not None
|
||||
return self.local_sizes
|
||||
|
||||
|
||||
@dataclass
|
||||
class ForwardContext:
|
||||
# copy from vllm_config.compilation_config.static_forward_context
|
||||
no_compile_layers: dict[str, Any]
|
||||
"""
|
||||
Type AttentionMetadata for v0,
|
||||
Type Dict[str, AttentionMetadata] for v1, map from layer_name of each
|
||||
attention layer to its attention metadata
|
||||
Type List[Dict[str, AttentionMetadata]] for DBO. List of size two, one
|
||||
for each microbatch.
|
||||
Set dynamically for each forward pass
|
||||
"""
|
||||
attn_metadata: Union["AttentionMetadata", dict[str, "AttentionMetadata"],
|
||||
list[dict[str, "AttentionMetadata"]]]
|
||||
# TODO: remove after making all virtual_engines share the same kv cache
|
||||
virtual_engine: int # set dynamically for each forward pass
|
||||
# set dynamically for each forward pass
|
||||
dp_metadata: Optional[DPMetadata] = None
|
||||
# determine the cudagraph style at runtime to be FULL, PIECEWISE, or NONE.
|
||||
# by default NONE, no cudagraph is used.
|
||||
cudagraph_runtime_mode: SUPAGraphMode = SUPAGraphMode.NONE
|
||||
batch_descriptor: Optional[BatchDescriptor] = None
|
||||
|
||||
ubatch_slices: Optional[UBatchSlices] = None
|
||||
|
||||
def __post_init__(self):
|
||||
assert self.cudagraph_runtime_mode in [
|
||||
SUPAGraphMode.NONE, SUPAGraphMode.PIECEWISE, SUPAGraphMode.FULL, SUPAGraphMode.FULL_DECODE_ONLY], \
|
||||
f"Invalid cudagraph runtime mode: {self.cudagraph_runtime_mode}"
|
||||
|
||||
|
||||
# _forward_context: Optional[ForwardContext] = None
|
||||
|
||||
|
||||
def get_forward_context() -> ForwardContext:
|
||||
"""Get the current forward context."""
|
||||
assert vllm.forward_context._forward_context is not None, (
|
||||
"Forward context is not set. "
|
||||
"Please use `set_forward_context` to set the forward context.")
|
||||
return vllm.forward_context._forward_context
|
||||
|
||||
|
||||
def create_forward_context(
|
||||
attn_metadata: Any,
|
||||
vllm_config: VllmConfig,
|
||||
virtual_engine: int = 0,
|
||||
dp_metadata: Optional[DPMetadata] = None,
|
||||
cudagraph_runtime_mode: SUPAGraphMode = SUPAGraphMode.NONE,
|
||||
batch_descriptor: Optional[BatchDescriptor] = None,
|
||||
ubatch_slices: Optional[UBatchSlices] = None):
|
||||
return ForwardContext(no_compile_layers=vllm_config.compilation_config.
|
||||
static_forward_context,
|
||||
virtual_engine=virtual_engine,
|
||||
attn_metadata=attn_metadata,
|
||||
dp_metadata=dp_metadata,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
batch_descriptor=batch_descriptor,
|
||||
ubatch_slices=ubatch_slices)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def override_forward_context(forward_context: Optional[ForwardContext]):
|
||||
"""A context manager that overrides the current forward context.
|
||||
This is used to override the forward context for a specific
|
||||
forward pass.
|
||||
"""
|
||||
prev_context = vllm.forward_context._forward_context
|
||||
vllm.forward_context._forward_context = forward_context
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
vllm.forward_context._forward_context = prev_context
|
||||
|
||||
|
||||
@contextmanager
|
||||
def set_forward_context(
|
||||
attn_metadata: Any,
|
||||
vllm_config: VllmConfig,
|
||||
virtual_engine: int = 0,
|
||||
num_tokens: Optional[int] = None,
|
||||
num_tokens_across_dp: Optional[torch.Tensor] = None,
|
||||
cudagraph_runtime_mode: SUPAGraphMode = SUPAGraphMode.NONE,
|
||||
batch_descriptor: Optional[BatchDescriptor] = None,
|
||||
ubatch_slices: Optional[UBatchSlices] = None):
|
||||
"""A context manager that stores the current forward context,
|
||||
can be attention metadata, etc.
|
||||
Here we can inject common logic for every model forward pass.
|
||||
"""
|
||||
global forward_start_time
|
||||
need_to_track_batchsize = track_batchsize and attn_metadata is not None
|
||||
if need_to_track_batchsize:
|
||||
forward_start_time = time.perf_counter()
|
||||
|
||||
dp_metadata: Optional[DPMetadata] = None
|
||||
if vllm_config.parallel_config.data_parallel_size > 1 and (
|
||||
attn_metadata is not None or num_tokens is not None):
|
||||
dp_metadata = DPMetadata.make(vllm_config.parallel_config,
|
||||
attn_metadata, num_tokens or 0,
|
||||
num_tokens_across_dp)
|
||||
|
||||
forward_context = create_forward_context(attn_metadata, vllm_config,
|
||||
virtual_engine, dp_metadata,
|
||||
cudagraph_runtime_mode,
|
||||
batch_descriptor, ubatch_slices)
|
||||
|
||||
try:
|
||||
with override_forward_context(forward_context):
|
||||
yield
|
||||
finally:
|
||||
global last_logging_time, batchsize_logging_interval
|
||||
if need_to_track_batchsize:
|
||||
if hasattr(attn_metadata, "num_prefill_tokens"):
|
||||
# for v0 attention backends
|
||||
batchsize = attn_metadata.num_prefill_tokens + \
|
||||
attn_metadata.num_decode_tokens
|
||||
else:
|
||||
# for v1 attention backends
|
||||
batchsize = num_tokens
|
||||
# we use synchronous scheduling right now,
|
||||
# adding a sync point here should not affect
|
||||
# scheduling of the next batch
|
||||
from vllm.platforms import current_platform
|
||||
synchronize = current_platform.synchronize
|
||||
if synchronize is not None:
|
||||
synchronize()
|
||||
now = time.perf_counter()
|
||||
# time measurement is in milliseconds
|
||||
batchsize_forward_time[batchsize].append(
|
||||
(now - forward_start_time) * 1000)
|
||||
if now - last_logging_time > batchsize_logging_interval:
|
||||
last_logging_time = now
|
||||
forward_stats = []
|
||||
for bs, times in batchsize_forward_time.items():
|
||||
if len(times) <= 1:
|
||||
# can be cudagraph / profiling run
|
||||
continue
|
||||
medium = torch.quantile(torch.tensor(times), q=0.5).item()
|
||||
medium = round(medium, 2)
|
||||
forward_stats.append((bs, len(times), medium))
|
||||
forward_stats.sort(key=lambda x: x[1], reverse=True)
|
||||
if forward_stats:
|
||||
logger.info(("Batchsize forward time stats "
|
||||
"(batchsize, count, median_time(ms)): %s"),
|
||||
forward_stats)
|
||||
39
vllm_br/model_executor/__init__.py
Normal file
39
vllm_br/model_executor/__init__.py
Normal file
@@ -0,0 +1,39 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
from vllm import ModelRegistry # noqa: F401
|
||||
from . import parameter
|
||||
from .layers import *
|
||||
from .model_loader import *
|
||||
from .models import *
|
||||
|
||||
__all__ = [
|
||||
"parameter",
|
||||
]
|
||||
|
||||
|
||||
def register_model():
|
||||
"""Register Biren modified models"""
|
||||
'''
|
||||
ModelRegistry.register_model(
|
||||
"GptOssForCausalLM",
|
||||
"vllm_br.model_executor.models.gpt_oss:GptOssForCausalLM")
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"Glm4MoeForCausalLM",
|
||||
"vllm_br.model_executor.models.glm4_moe:Glm4MoeForCausalLM")
|
||||
'''
|
||||
pass
|
||||
BIN
vllm_br/model_executor/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
vllm_br/model_executor/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm_br/model_executor/__pycache__/parameter.cpython-310.pyc
Normal file
BIN
vllm_br/model_executor/__pycache__/parameter.cpython-310.pyc
Normal file
Binary file not shown.
25
vllm_br/model_executor/layers/__init__.py
Normal file
25
vllm_br/model_executor/layers/__init__.py
Normal file
@@ -0,0 +1,25 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
import vllm_br.model_executor.layers.activation
|
||||
import vllm_br.model_executor.layers.fused_moe
|
||||
import vllm_br.model_executor.layers.layernorm
|
||||
import vllm_br.model_executor.layers.linear
|
||||
import vllm_br.model_executor.layers.logits_processor
|
||||
import vllm_br.model_executor.layers.quantization
|
||||
import vllm_br.model_executor.layers.rotary_embedding
|
||||
import vllm_br.model_executor.layers.utils
|
||||
import vllm_br.model_executor.layers.vocab_parallel_embedding # noqa: F401
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
vllm_br/model_executor/layers/__pycache__/linear.cpython-310.pyc
Normal file
BIN
vllm_br/model_executor/layers/__pycache__/linear.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
vllm_br/model_executor/layers/__pycache__/utils.cpython-310.pyc
Normal file
BIN
vllm_br/model_executor/layers/__pycache__/utils.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
31
vllm_br/model_executor/layers/activation.py
Normal file
31
vllm_br/model_executor/layers/activation.py
Normal file
@@ -0,0 +1,31 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
import torch
|
||||
import torch_br
|
||||
from fastcore.basics import patch_to
|
||||
|
||||
from vllm.model_executor.layers.activation import QuickGELU, SiluAndMul
|
||||
|
||||
|
||||
@patch_to(SiluAndMul)
|
||||
def silu_and_mul_forward_oot(self, x: torch.Tensor) -> torch.Tensor:
|
||||
d = x.shape[-1] // 2
|
||||
return torch_br.supa_silumul(x[..., :d], x[..., d:]) # type: ignore
|
||||
|
||||
|
||||
@patch_to(QuickGELU)
|
||||
def quick_gelu_forward_oot(self, x: torch.Tensor) -> torch.Tensor: # noqa:F811
|
||||
return self.forward_native(x)
|
||||
619
vllm_br/model_executor/layers/br_utils.py
Normal file
619
vllm_br/model_executor/layers/br_utils.py
Normal file
@@ -0,0 +1,619 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
import torch
|
||||
import torch_br
|
||||
import torch_br.supa._debug as supa_debug
|
||||
|
||||
from vllm_br import envs
|
||||
|
||||
|
||||
def align_n(n, align_size, spc_num=envs.VLLM_BR_DEVICE_SPC_NUM) -> int:
|
||||
n_block = (n + spc_num - 1) // spc_num
|
||||
n_block = (n_block + align_size - 1) // align_size * align_size
|
||||
return n_block
|
||||
|
||||
|
||||
def _br_qweight_cvt(quant_method,
|
||||
qweight,
|
||||
qzeros,
|
||||
size_k,
|
||||
size_n,
|
||||
override_group_size=None):
|
||||
group_size = override_group_size or quant_method.quant_config.group_size
|
||||
curr_dev = qweight.device
|
||||
group_num = size_k // group_size if group_size > 0 else 1
|
||||
qweight = qweight.cpu().view(torch.int8).reshape(
|
||||
size_k // 4, size_n,
|
||||
4).permute(0, 2, 1).contiguous().reshape(group_num,
|
||||
size_k // group_num, size_n)
|
||||
if qzeros is not None and not torch.all(qzeros == 0):
|
||||
qzeros = qzeros.cpu().view(torch.int8).to(torch.int32) + 1
|
||||
qweight = (qweight.to(torch.int32) - qzeros.unsqueeze(1)).to(
|
||||
torch.int8)
|
||||
qwei_int8 = qweight.reshape(size_k, size_n).to(curr_dev)
|
||||
return qwei_int8
|
||||
|
||||
|
||||
def _numa_scales_cvt(scales, wn, spc_num):
|
||||
align_size = 32
|
||||
wn_block = (wn + spc_num - 1) // spc_num
|
||||
wn_block = (wn_block + align_size - 1) // align_size * align_size
|
||||
cvt_scales = torch.nn.functional.pad(scales, (0, spc_num * wn_block - wn),
|
||||
mode='constant',
|
||||
value=0)
|
||||
cvt_scales = cvt_scales.reshape(spc_num, wn_block).contiguous()
|
||||
return cvt_scales
|
||||
|
||||
|
||||
def cross_weight_32(t1, t2, spc_num, dim=1, need_pad=True):
|
||||
width = t1.shape[dim]
|
||||
# NOTE: br166 must ensure dual-dies width are 32-aligned
|
||||
if spc_num > 16:
|
||||
assert width % 2 == 0
|
||||
half_width = width // 2
|
||||
half_width_ = (half_width + 32 - 1) // 32 * 32
|
||||
half_pad = half_width_ - half_width
|
||||
if half_pad > 0:
|
||||
t10, t11 = torch.chunk(t1, 2, dim=-1)
|
||||
t10 = torch.nn.functional.pad(t10, (0, half_pad), "constant", 0)
|
||||
t11 = torch.nn.functional.pad(t11, (0, half_pad), "constant", 0)
|
||||
t1 = torch.cat([t10, t11], dim=-1)
|
||||
t20, t21 = torch.chunk(t2, 2, dim=-1)
|
||||
t20 = torch.nn.functional.pad(t20, (0, half_pad), "constant", 0)
|
||||
t21 = torch.nn.functional.pad(t21, (0, half_pad), "constant", 0)
|
||||
t2 = torch.cat([t20, t21], dim=-1)
|
||||
width = half_width_ * 2
|
||||
else:
|
||||
width_ = (width + 32 - 1) // 32 * 32
|
||||
t1 = torch.nn.functional.pad(t1, (0, width_ - width), "constant", 0)
|
||||
t2 = torch.nn.functional.pad(t2, (0, width_ - width), "constant", 0)
|
||||
width = width_
|
||||
|
||||
cnt = width // 32
|
||||
t1_list = torch.chunk(t1, cnt, dim)
|
||||
t2_list = torch.chunk(t2, cnt, dim)
|
||||
tt = []
|
||||
for i in range(cnt):
|
||||
tt.append(t1_list[i])
|
||||
tt.append(t2_list[i])
|
||||
no_pad = torch.cat(tt, dim=dim)
|
||||
if not need_pad:
|
||||
return no_pad
|
||||
|
||||
if spc_num > 16:
|
||||
align = (spc_num // 2) * 32 * 2
|
||||
width_align = (width + align - 1) // align * align
|
||||
pad_size = width_align - width
|
||||
out0, out1 = torch.chunk(no_pad, 2, dim=-1)
|
||||
out0 = torch.nn.functional.pad(out0, (0, pad_size), "constant", 0)
|
||||
out1 = torch.nn.functional.pad(out1, (0, pad_size), "constant", 0)
|
||||
out = torch.cat([out0, out1], dim=-1)
|
||||
else:
|
||||
align = spc_num * 32 * 2 # 768
|
||||
width_align = (width * 2 + align - 1) // align * align
|
||||
pad_size = width_align - width * 2
|
||||
out = torch.nn.functional.pad(no_pad, (0, pad_size), "constant", 0)
|
||||
return out
|
||||
|
||||
|
||||
# # NOTE: bias and scales will not be converted to numa arch, still use uma weight for support fused_linear
|
||||
def _convert_to_uma_tensor(tensor,
|
||||
align_size,
|
||||
layout,
|
||||
dtype,
|
||||
do_transpose=False,
|
||||
wk=None,
|
||||
wn=None,
|
||||
parallel_type="col_parallel"):
|
||||
|
||||
assert parallel_type in ("col_parallel", "row_parallel")
|
||||
|
||||
layout = layout.lower()
|
||||
if layout == "colmajor":
|
||||
wk = wk or tensor.shape[1]
|
||||
wn = wn or tensor.shape[0]
|
||||
d_shape = (wn, wk)
|
||||
if do_transpose:
|
||||
data = tensor.cpu().permute(1, 0).contiguous()
|
||||
d_shape = (wk, wn)
|
||||
else:
|
||||
data = tensor.cpu().contiguous()
|
||||
if parallel_type == "col_parallel":
|
||||
uma_tensor = torch_br._empty_ut_only(
|
||||
size=d_shape,
|
||||
dtype=dtype,
|
||||
is_numa=False,
|
||||
device=torch.supa.current_device(),
|
||||
tensor_type=layout,
|
||||
sbp="SB",
|
||||
axis=0)
|
||||
else:
|
||||
uma_tensor = torch_br._empty_ut_only(
|
||||
size=d_shape,
|
||||
dtype=dtype,
|
||||
is_numa=False,
|
||||
device=torch.supa.current_device(),
|
||||
tensor_type=layout,
|
||||
axis=1,
|
||||
sbp="SB")
|
||||
torch.supa.synchronize()
|
||||
uma_tensor.copy_(data.to(torch.supa.current_device()))
|
||||
elif layout == "linear_bias":
|
||||
axis = 0
|
||||
wn = wn or tensor.shape[-1]
|
||||
wk = 1
|
||||
data = tensor
|
||||
if len(data.shape) == 2 and data.shape[0] == 1:
|
||||
data = tensor.cpu().reshape(-1).contiguous()
|
||||
elif len(data.shape) == 2:
|
||||
axis = 1
|
||||
wk = data.shape[0]
|
||||
elif len(data.shape) == 3 and data.shape[1] == 1:
|
||||
data = tensor.cpu().reshape(
|
||||
(data.shape[0], data.shape[2])).contiguous()
|
||||
axis = 1
|
||||
wk = data.shape[0]
|
||||
d_shape = (wn, ) if axis == 0 else (wk, wn)
|
||||
if parallel_type == "row_parallel":
|
||||
uma_tensor = torch_br._empty_ut_only(
|
||||
size=d_shape,
|
||||
dtype=dtype,
|
||||
device=torch.supa.current_device(),
|
||||
tensor_type=layout)
|
||||
elif parallel_type == "col_parallel":
|
||||
uma_tensor = torch_br._empty_ut_only(
|
||||
size=d_shape,
|
||||
dtype=dtype,
|
||||
device=torch.supa.current_device(),
|
||||
tensor_type=layout,
|
||||
axis=axis,
|
||||
sbp="SB")
|
||||
|
||||
torch.supa.synchronize()
|
||||
uma_tensor.copy_(data.to(torch.supa.current_device()))
|
||||
else:
|
||||
raise ValueError("uma tensor only support colmajor and linear_bias")
|
||||
return uma_tensor
|
||||
|
||||
|
||||
def _convert_to_numa_tensor_vit(tensor,
|
||||
align_size,
|
||||
layout,
|
||||
dtype,
|
||||
do_transpose=False,
|
||||
wk=None,
|
||||
wn=None,
|
||||
parallel_type="col_parallel",
|
||||
pad_zeros=False):
|
||||
assert parallel_type in ("col_parallel", "row_parallel")
|
||||
|
||||
enable_force_uma = supa_debug.is_enable_force_uma()
|
||||
supa_debug.set_enable_force_uma(False)
|
||||
spc_num = envs.VLLM_BR_DEVICE_SPC_NUM
|
||||
layout = layout.lower()
|
||||
die_num = 1
|
||||
if spc_num > 16:
|
||||
spc_num = spc_num // 2
|
||||
die_num = 2
|
||||
die_spc_num = die_num * spc_num
|
||||
|
||||
if layout == "colmajor":
|
||||
wk = wk or tensor.shape[0]
|
||||
wn = wn or tensor.shape[1]
|
||||
if die_num == 1:
|
||||
wn_block = (wn + spc_num - 1) // spc_num
|
||||
wn_block = (wn_block + align_size - 1) // align_size * align_size
|
||||
numa_tensor = torch_br._empty_ut_only(
|
||||
size=(spc_num, wk, wn_block),
|
||||
dtype=dtype,
|
||||
is_numa=True,
|
||||
device=torch.supa.current_device(),
|
||||
tensor_type=layout)
|
||||
if do_transpose:
|
||||
data = tensor.cpu().permute(1, 0).contiguous()
|
||||
else:
|
||||
data = tensor.cpu().contiguous()
|
||||
data = torch.nn.functional.pad(data,
|
||||
(0, spc_num * wn_block - wn, 0, 0),
|
||||
mode='constant',
|
||||
value=0)
|
||||
data = data.reshape(wk, spc_num, wn_block).permute(1, 0,
|
||||
2).contiguous()
|
||||
torch.supa.synchronize()
|
||||
numa_tensor.copy_(data.to(torch.supa.current_device()))
|
||||
else:
|
||||
if parallel_type == "col_parallel":
|
||||
wn_block = (wn // die_num + spc_num - 1) // spc_num
|
||||
wn_block = (wn_block + align_size -
|
||||
1) // align_size * align_size
|
||||
numa_tensor = torch_br._empty_ut_only(
|
||||
size=(die_spc_num, wk, wn_block),
|
||||
dtype=dtype,
|
||||
is_numa=True,
|
||||
device=torch.supa.current_device(),
|
||||
tensor_type=layout)
|
||||
if do_transpose:
|
||||
weight = tensor.cpu().permute(1, 0).contiguous().reshape(
|
||||
wk, die_num, wn // die_num)
|
||||
else:
|
||||
weight = tensor.cpu().contiguous().reshape(
|
||||
wk, die_num, wn // die_num)
|
||||
weight = torch.nn.functional.pad(
|
||||
weight,
|
||||
(0, spc_num * wn_block - wn // die_num, 0, 0, 0, 0),
|
||||
mode='constant',
|
||||
value=0)
|
||||
weight = weight.reshape(wk, die_spc_num,
|
||||
wn_block).permute(1, 0,
|
||||
2).contiguous()
|
||||
numa_tensor.copy_(weight)
|
||||
else:
|
||||
wn_block = (wn + spc_num - 1) // spc_num
|
||||
# w_block must align with 32 (warp_size)
|
||||
wn_block = (wn_block + align_size -
|
||||
1) // align_size * align_size
|
||||
numa_tensor = torch_br._empty_ut_only(
|
||||
size=(die_spc_num, wk // die_num, wn_block),
|
||||
dtype=dtype,
|
||||
is_numa=True,
|
||||
device=torch.supa.current_device(),
|
||||
tensor_type=layout)
|
||||
if do_transpose:
|
||||
weight = tensor.cpu().permute(1, 0).contiguous()
|
||||
else:
|
||||
weight = tensor.cpu().contiguous()
|
||||
weight = torch.nn.functional.pad(
|
||||
weight, (0, spc_num * wn_block - wn, 0, 0),
|
||||
mode='constant',
|
||||
value=0)
|
||||
weight = weight.reshape(
|
||||
die_num, wk // die_num, spc_num,
|
||||
wn_block).permute(0, 2, 1, 3).contiguous().reshape(
|
||||
die_spc_num, wk // die_num, wn_block)
|
||||
numa_tensor.copy_(weight)
|
||||
|
||||
elif layout == "linear_bias":
|
||||
# NOTE: bias and scales will not be converted to numa arch, still use uma weight for support fused_linear
|
||||
# NOTE: index -1 for both scales and bias
|
||||
wn = tensor.shape[-1] if wn is None else wn
|
||||
group_num = tensor.shape[-2] if len(tensor.shape) > 1 else 1
|
||||
if die_num == 1:
|
||||
wn_block = (wn + spc_num - 1) // spc_num
|
||||
wn_block = (wn_block + align_size - 1) // align_size * align_size
|
||||
numa_tensor = torch_br._empty_ut_only(
|
||||
size=(spc_num * group_num, wn_block),
|
||||
dtype=dtype,
|
||||
is_numa=True,
|
||||
device=torch.supa.current_device(),
|
||||
tensor_type=layout)
|
||||
data = torch.nn.functional.pad(tensor.cpu(),
|
||||
(0, spc_num * wn_block - wn),
|
||||
mode='constant',
|
||||
value=0)
|
||||
if group_num > 1:
|
||||
data = data.type(dtype).reshape(
|
||||
group_num, spc_num,
|
||||
wn_block).permute(1, 0, 2).contiguous().reshape(
|
||||
spc_num * group_num, wn_block)
|
||||
else:
|
||||
data = data.type(dtype).reshape(spc_num, wn_block).contiguous()
|
||||
torch.supa.synchronize()
|
||||
numa_tensor.copy_(data.to(torch.supa.current_device()))
|
||||
else:
|
||||
if parallel_type == "col_parallel":
|
||||
wn_block = (wn // die_num + spc_num - 1) // spc_num
|
||||
wn_block = (wn_block + align_size -
|
||||
1) // align_size * align_size
|
||||
numa_tensor = torch_br._empty_ut_only(
|
||||
size=(die_spc_num, wn_block),
|
||||
dtype=dtype,
|
||||
is_numa=True,
|
||||
device=torch.supa.current_device(),
|
||||
tensor_type="linear_bias")
|
||||
bias = tensor.cpu().reshape(die_num, wn // die_num)
|
||||
bias = torch.nn.functional.pad(
|
||||
bias, (0, spc_num * wn_block - wn // die_num, 0, 0),
|
||||
mode='constant',
|
||||
value=0)
|
||||
bias = bias.type(torch.float32).reshape(die_spc_num,
|
||||
wn_block).contiguous()
|
||||
numa_tensor.copy_(bias)
|
||||
else:
|
||||
wn_block = (wn + spc_num - 1) // spc_num
|
||||
# w_block must align with 32 (warp_size)
|
||||
wn_block = (wn_block + align_size -
|
||||
1) // align_size * align_size
|
||||
numa_tensor = torch_br._empty_ut_only(
|
||||
size=(die_spc_num, wn_block),
|
||||
dtype=dtype,
|
||||
is_numa=True,
|
||||
device=torch.supa.current_device(),
|
||||
tensor_type="linear_bias")
|
||||
bias = torch.nn.functional.pad(tensor.cpu(),
|
||||
(0, spc_num * wn_block - wn),
|
||||
mode='constant',
|
||||
value=0)
|
||||
bias = bias.type(torch.float32).reshape(spc_num,
|
||||
wn_block).contiguous()
|
||||
if pad_zeros:
|
||||
bias_zeros_die2 = torch.zeros((spc_num, wn_block),
|
||||
dtype=bias.dtype)
|
||||
bias = torch.concat([bias, bias_zeros_die2], dim=0)
|
||||
else:
|
||||
bias = torch.concat([bias, bias], dim=0)
|
||||
numa_tensor.copy_(bias)
|
||||
else:
|
||||
raise ValueError(f"Unsupported tensor_type: {layout}")
|
||||
torch.supa.synchronize()
|
||||
supa_debug.set_enable_force_uma(enable_force_uma)
|
||||
return numa_tensor
|
||||
|
||||
|
||||
def _convert_to_numa_tensor(tensor,
|
||||
align_size,
|
||||
layout,
|
||||
dtype,
|
||||
do_transpose=False,
|
||||
wk=None,
|
||||
wn=None,
|
||||
parallel_type="col_parallel",
|
||||
pad_zeros=False):
|
||||
assert parallel_type in ("col_parallel", "row_parallel")
|
||||
|
||||
enable_force_uma = supa_debug.is_enable_force_uma()
|
||||
supa_debug.set_enable_force_uma(False)
|
||||
spc_num = envs.VLLM_BR_DEVICE_SPC_NUM
|
||||
layout = layout.lower()
|
||||
die_num = 1
|
||||
if spc_num > 16:
|
||||
spc_num = spc_num // 2
|
||||
die_num = 2
|
||||
die_spc_num = die_num * spc_num
|
||||
|
||||
if layout == "colmajor":
|
||||
wk = wk or tensor.shape[0]
|
||||
wn = wn or tensor.shape[1]
|
||||
if die_num == 1:
|
||||
wn_block = (wn + spc_num - 1) // spc_num
|
||||
wn_block = (wn_block + align_size - 1) // align_size * align_size
|
||||
numa_tensor = torch_br._empty_ut_only(
|
||||
size=(spc_num, wk, wn_block),
|
||||
dtype=dtype,
|
||||
is_numa=True,
|
||||
device=torch.supa.current_device(),
|
||||
tensor_type=layout)
|
||||
if do_transpose:
|
||||
data = tensor.cpu().permute(1, 0).contiguous()
|
||||
else:
|
||||
data = tensor.cpu().contiguous()
|
||||
data = torch.nn.functional.pad(data,
|
||||
(0, spc_num * wn_block - wn, 0, 0),
|
||||
mode='constant',
|
||||
value=0)
|
||||
data = data.reshape(wk, spc_num, wn_block).permute(1, 0,
|
||||
2).contiguous()
|
||||
torch.supa.synchronize()
|
||||
numa_tensor.copy_(data.to(torch.supa.current_device()))
|
||||
else:
|
||||
if parallel_type == "col_parallel":
|
||||
wn_block = (wn // die_num + spc_num - 1) // spc_num
|
||||
wn_block = (wn_block + align_size -
|
||||
1) // align_size * align_size
|
||||
numa_tensor = torch_br._empty_ut_only(
|
||||
size=(die_spc_num, wk, wn_block),
|
||||
dtype=dtype,
|
||||
is_numa=True,
|
||||
device=torch.supa.current_device(),
|
||||
tensor_type=layout,
|
||||
axis=0,
|
||||
sbp="SS")
|
||||
if do_transpose:
|
||||
weight = tensor.cpu().permute(1, 0).contiguous().reshape(
|
||||
wk, die_num, wn // die_num)
|
||||
else:
|
||||
weight = tensor.cpu().contiguous().reshape(
|
||||
wk, die_num, wn // die_num)
|
||||
weight = torch.nn.functional.pad(
|
||||
weight,
|
||||
(0, spc_num * wn_block - wn // die_num, 0, 0, 0, 0),
|
||||
mode='constant',
|
||||
value=0)
|
||||
weight = weight.reshape(wk, die_spc_num,
|
||||
wn_block).permute(1, 0,
|
||||
2).contiguous()
|
||||
numa_tensor.copy_(weight)
|
||||
else:
|
||||
wn_block = (wn + spc_num - 1) // spc_num
|
||||
# w_block must align with 32 (warp_size)
|
||||
wn_block = (wn_block + align_size -
|
||||
1) // align_size * align_size
|
||||
numa_tensor = torch_br._empty_ut_only(
|
||||
size=(die_spc_num, wk // die_num, wn_block),
|
||||
dtype=dtype,
|
||||
is_numa=True,
|
||||
device=torch.supa.current_device(),
|
||||
tensor_type=layout,
|
||||
axis=0,
|
||||
sbp="SS")
|
||||
if do_transpose:
|
||||
weight = tensor.cpu().permute(1, 0).contiguous()
|
||||
else:
|
||||
weight = tensor.cpu().contiguous()
|
||||
weight = torch.nn.functional.pad(
|
||||
weight, (0, spc_num * wn_block - wn, 0, 0),
|
||||
mode='constant',
|
||||
value=0)
|
||||
weight = weight.reshape(
|
||||
die_num, wk // die_num, spc_num,
|
||||
wn_block).permute(0, 2, 1, 3).contiguous().reshape(
|
||||
die_spc_num, wk // die_num, wn_block)
|
||||
numa_tensor.copy_(weight)
|
||||
|
||||
elif layout == "linear_bias":
|
||||
# NOTE: bias and scales will not be converted to numa arch, still use uma weight for support fused_linear
|
||||
# NOTE: index -1 for both scales and bias
|
||||
wn = tensor.shape[-1] if wn is None else wn
|
||||
expert_num = tensor.shape[-2] if len(tensor.shape) > 1 else 1
|
||||
bias_shape = (expert_num, wn) if expert_num > 1 else (wn, )
|
||||
if die_num == 1:
|
||||
numa_tensor = torch_br._empty_ut_only(size=bias_shape,
|
||||
dtype=dtype,
|
||||
is_numa=False,
|
||||
device=tensor.device,
|
||||
tensor_type=layout)
|
||||
data = tensor.cpu().type(dtype)
|
||||
if expert_num > 1:
|
||||
data = data.reshape(expert_num, wn)
|
||||
else:
|
||||
data = data.reshape(wn).type(dtype)
|
||||
torch.supa.synchronize()
|
||||
numa_tensor.copy_(data.to(tensor.device))
|
||||
|
||||
else:
|
||||
if parallel_type == "col_parallel":
|
||||
axis = 1 if expert_num > 1 else 0
|
||||
numa_tensor = torch_br._empty_ut_only(size=bias_shape,
|
||||
dtype=dtype,
|
||||
is_numa=False,
|
||||
device=tensor.device,
|
||||
tensor_type="buffer_any",
|
||||
axis=axis,
|
||||
sbp="SB")
|
||||
if expert_num == 1:
|
||||
tensor = tensor.reshape(-1)
|
||||
numa_tensor.copy_(tensor.to(torch.supa.current_device()))
|
||||
else:
|
||||
numa_tensor = torch_br._empty_ut_only(
|
||||
size=bias_shape,
|
||||
dtype=dtype,
|
||||
is_numa=False,
|
||||
device=tensor.device,
|
||||
tensor_type="linear_bias",
|
||||
sbp="BB")
|
||||
bias = tensor.reshape(expert_num, wn).cpu().type(dtype)
|
||||
if expert_num == 1:
|
||||
bias = bias.reshape(-1)
|
||||
numa_tensor.copy_(bias.to(torch.supa.current_device()))
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported tensor_type: {layout}")
|
||||
torch.supa.synchronize()
|
||||
supa_debug.set_enable_force_uma(enable_force_uma)
|
||||
return numa_tensor
|
||||
|
||||
|
||||
def _convert_to_crossed_numa_tensor(t1,
|
||||
t2,
|
||||
spc_num,
|
||||
dim=1,
|
||||
need_pad=True,
|
||||
layout="colmajor",
|
||||
do_transpose=False):
|
||||
"""Equals to V0: cross_weight_32 + numa_weight_convert/_numa_weight_cvt
|
||||
"""
|
||||
uma_weight = cross_weight_32(t1, t2, spc_num, dim, need_pad)
|
||||
numa_weight = _convert_to_numa_tensor(uma_weight, 32, layout,
|
||||
uma_weight.dtype, do_transpose)
|
||||
return numa_weight
|
||||
|
||||
|
||||
def _convert_to_numa_tensor_moe(tensor,
|
||||
align_size,
|
||||
layout,
|
||||
dtype,
|
||||
do_transpose=False,
|
||||
wb=None,
|
||||
wk=None,
|
||||
wn=None,
|
||||
parallel_type="col_parallel",
|
||||
pad_zeros=False):
|
||||
assert parallel_type in ("col_parallel", "row_parallel")
|
||||
|
||||
enable_force_uma = supa_debug.is_enable_force_uma()
|
||||
supa_debug.set_enable_force_uma(False)
|
||||
spc_num = envs.VLLM_BR_DEVICE_SPC_NUM
|
||||
layout = layout.lower()
|
||||
die_num = 1
|
||||
if spc_num > 16:
|
||||
spc_num = spc_num // 2
|
||||
die_num = 2
|
||||
die_spc_num = die_num * spc_num
|
||||
assert die_num == 2
|
||||
if layout == "colmajor":
|
||||
wb = wb or tensor.shape[0]
|
||||
wk = wk or tensor.shape[1]
|
||||
wn = wn or tensor.shape[2]
|
||||
if parallel_type == "col_parallel":
|
||||
wn_block = (wn // die_num + spc_num - 1) // spc_num
|
||||
wn_block = (wn_block + align_size - 1) // align_size * align_size
|
||||
numa_tensor = torch_br._empty_ut_only(
|
||||
size=(die_spc_num * wb, wk, wn_block),
|
||||
dtype=dtype,
|
||||
is_numa=True,
|
||||
device=torch.supa.current_device(),
|
||||
tensor_type=layout)
|
||||
if do_transpose:
|
||||
weight = tensor.cpu().permute(0, 2, 1).contiguous().reshape(
|
||||
wb, wk, die_num, wn // die_num)
|
||||
else:
|
||||
weight = tensor.cpu().contiguous().reshape(
|
||||
wb, wk, die_num, wn // die_num)
|
||||
weight = torch.nn.functional.pad(
|
||||
weight,
|
||||
(0, spc_num * wn_block - wn // die_num, 0, 0, 0, 0, 0, 0),
|
||||
mode='constant',
|
||||
value=0)
|
||||
weight = weight.reshape(wb, wk, die_spc_num, wn_block).permute(
|
||||
2, 0, 1, 3).reshape(wb * die_spc_num, wk,
|
||||
wn_block).contiguous()
|
||||
numa_tensor.copy_(weight)
|
||||
elif parallel_type == "row_parallel":
|
||||
wn_block = (wn + spc_num - 1) // spc_num
|
||||
wn_block = (wn_block + align_size - 1) // align_size * align_size
|
||||
numa_tensor = torch_br._empty_ut_only(
|
||||
size=(die_spc_num * wb, wk // die_num, wn_block),
|
||||
dtype=dtype,
|
||||
is_numa=True,
|
||||
device=torch.supa.current_device(),
|
||||
tensor_type=layout)
|
||||
if do_transpose:
|
||||
weight = tensor.cpu().permute(0, 2, 1).contiguous()
|
||||
else:
|
||||
weight = tensor.cpu().contiguous()
|
||||
weight = torch.nn.functional.pad(
|
||||
weight, (0, spc_num * wn_block - wn, 0, 0, 0, 0),
|
||||
mode='constant',
|
||||
value=0)
|
||||
weight = weight.reshape(wb, die_num, wk // die_num, spc_num,
|
||||
wn_block).permute(1, 3, 0, 2,
|
||||
4).contiguous().reshape(
|
||||
die_spc_num * wb,
|
||||
wk // die_num,
|
||||
wn_block)
|
||||
numa_tensor.copy_(weight)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported tensor_type: {layout}")
|
||||
torch.supa.synchronize()
|
||||
supa_debug.set_enable_force_uma(enable_force_uma)
|
||||
return numa_tensor, (die_spc_num, wk, wn_block)
|
||||
|
||||
|
||||
def is_br166_device():
|
||||
spc_num = torch_br.supa.get_device_properties(
|
||||
torch.device("supa")).max_compute_units
|
||||
return bool(spc_num > 16 and spc_num <= 32)
|
||||
23
vllm_br/model_executor/layers/fused_moe/__init__.py
Normal file
23
vllm_br/model_executor/layers/fused_moe/__init__.py
Normal file
@@ -0,0 +1,23 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
from . import layer, supa_moe # noqa: E402
|
||||
from .layer import * # noqa: E402
|
||||
|
||||
__all__ = [
|
||||
"layer",
|
||||
"supa_moe",
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
413
vllm_br/model_executor/layers/fused_moe/layer.py
Normal file
413
vllm_br/model_executor/layers/fused_moe/layer.py
Normal file
@@ -0,0 +1,413 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
from functools import wraps
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch_br
|
||||
from fastcore.basics import patch_to
|
||||
from torch_br.utils.tensor_methods import Sbp
|
||||
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoE, UnquantizedFusedMoEMethod)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm_br import envs
|
||||
from ..br_utils import (_convert_to_crossed_numa_tensor,
|
||||
_convert_to_numa_tensor, align_n, cross_weight_32)
|
||||
from .supa_moe import (fused_moe_quant_device, fused_moe_quant_dyn,
|
||||
fused_oss_moe_dyn)
|
||||
|
||||
|
||||
@patch_to(UnquantizedFusedMoEMethod)
|
||||
def forward_oot(
|
||||
self,
|
||||
layer: FusedMoE,
|
||||
x: torch.Tensor,
|
||||
use_grouped_topk: bool,
|
||||
top_k: int,
|
||||
router_logits: torch.Tensor,
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
):
|
||||
"""Forward for UnquantizedFusedMoEMethod with SUPA out-of-tree support.
|
||||
"""
|
||||
if activation == "swigluoai":
|
||||
return fused_oss_moe_dyn(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w13_bias,
|
||||
layer.w2_weight,
|
||||
layer.w2_bias,
|
||||
router_logits,
|
||||
top_k,
|
||||
layer.intermediate_size_per_partition,
|
||||
renormalize=renormalize,
|
||||
inplace=True,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
ep_rank=layer.ep_rank,
|
||||
ep_size=layer.ep_size)
|
||||
|
||||
b_seq = x.shape[0]
|
||||
gating_weight, shared_gate_up_weight, shared_down_weight = router_logits
|
||||
if b_seq > envs.VLLM_BR_STATIC_MOE_DECODER_MAX_LEN:
|
||||
# prefill
|
||||
return fused_moe_quant_dyn(
|
||||
x,
|
||||
shared_gate_up_weight,
|
||||
shared_down_weight,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
None,
|
||||
None,
|
||||
gating_weight,
|
||||
top_k,
|
||||
layer.intermediate_size_per_partition,
|
||||
renormalize=renormalize,
|
||||
inplace=True,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
tp_rank=get_tp_group().rank_in_group,
|
||||
global_rank=get_tp_group().rank,
|
||||
tp_size=get_tensor_model_parallel_world_size(),
|
||||
ep_rank=layer.ep_rank,
|
||||
ep_size=layer.ep_size)
|
||||
else:
|
||||
# decoder
|
||||
return fused_moe_quant_device(
|
||||
x,
|
||||
shared_gate_up_weight,
|
||||
shared_down_weight,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
None,
|
||||
None,
|
||||
gating_weight,
|
||||
top_k,
|
||||
layer.intermediate_size_per_partition,
|
||||
renormalize=renormalize,
|
||||
inplace=True,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
tp_rank=get_tp_group().rank_in_group,
|
||||
global_rank=get_tp_group().rank,
|
||||
tp_size=get_tensor_model_parallel_world_size(),
|
||||
ep_rank=layer.ep_rank,
|
||||
ep_size=layer.ep_size)
|
||||
|
||||
|
||||
@patch_to(UnquantizedFusedMoEMethod)
|
||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||
hidden_size: int, intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||
# Fused gate_up_proj (column parallel)
|
||||
w13_weight = torch.nn.Parameter(torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_size,
|
||||
device="cpu",
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_weight", w13_weight)
|
||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||
|
||||
if self.moe.has_bias:
|
||||
w13_bias = torch.nn.Parameter(torch.zeros(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
device="cpu",
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_bias", w13_bias)
|
||||
set_weight_attrs(w13_bias, extra_weight_attrs)
|
||||
|
||||
# down_proj (row parallel)
|
||||
w2_weight = torch.nn.Parameter(torch.empty(num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition,
|
||||
device="cpu",
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
|
||||
if self.moe.has_bias:
|
||||
w2_bias = torch.nn.Parameter(torch.zeros(num_experts,
|
||||
hidden_size,
|
||||
device="cpu",
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_bias", w2_bias)
|
||||
set_weight_attrs(w2_bias, extra_weight_attrs)
|
||||
|
||||
|
||||
@patch_to(UnquantizedFusedMoEMethod)
|
||||
def process_weights_after_loading(self: UnquantizedFusedMoEMethod,
|
||||
layer: FusedMoE) -> None:
|
||||
cur_device = torch.supa.current_device()
|
||||
die_spc_num = envs.VLLM_BR_DEVICE_SPC_NUM
|
||||
die_num = 1 if die_spc_num <= 16 else 2
|
||||
spc_num = die_spc_num // die_num
|
||||
align_size = 32 if layer.activation == "swigluoai" else 64
|
||||
is_dual_die = (die_spc_num > 16)
|
||||
|
||||
# NOTE: w13_weight
|
||||
# after _load_w13, w13_weight is a colparallel weight, shape
|
||||
# [num_experts, 2 * intermediate_size_per_partition, hidden_size]
|
||||
# for SUPA, transform it to a NUMA colmajor weight, shape
|
||||
# [spc_num * num_experts, wk, wn_block] (wn = aligned(2 * intermediate_size_per_partition, align_size=64))
|
||||
wk = layer.hidden_size
|
||||
wn_block = align_n((layer.intermediate_size_per_partition * 2) // die_num,
|
||||
align_size=align_size,
|
||||
spc_num=spc_num)
|
||||
|
||||
supa_w13_weight = torch_br._empty_ut_only(
|
||||
size=(die_spc_num * layer.local_num_experts, wk, wn_block),
|
||||
dtype=torch.bfloat16,
|
||||
is_numa=True,
|
||||
device=cur_device,
|
||||
tensor_type="colmajor",
|
||||
axis=0,
|
||||
sbp="SS" if is_dual_die else None)
|
||||
|
||||
for expert_id in range(layer.local_num_experts):
|
||||
expert_w13 = layer.w13_weight[expert_id].transpose(0, 1).contiguous()
|
||||
# swigluoai activation, no need do interweave
|
||||
if layer.activation and layer.activation == "swigluoai":
|
||||
pad_expert_w13 = _convert_to_numa_tensor(expert_w13, align_size,
|
||||
'COLMAJOR',
|
||||
expert_w13.dtype)
|
||||
pad_expert_w13_shape = pad_expert_w13.shape
|
||||
hw_size = pad_expert_w13_shape[-2] * pad_expert_w13_shape[-1]
|
||||
narrow_data = supa_w13_weight.view_as_usharp(
|
||||
"COLMAJOR", pad_expert_w13_shape, Sbp.ss(0),
|
||||
expert_id * hw_size)
|
||||
narrow_data.copy_(pad_expert_w13)
|
||||
else:
|
||||
expert_1, expert_3 = expert_w13.chunk(2, dim=1)
|
||||
pad_expert_w13 = _convert_to_crossed_numa_tensor(expert_1,
|
||||
expert_3,
|
||||
die_spc_num,
|
||||
dim=1,
|
||||
need_pad=True,
|
||||
layout='COLMAJOR')
|
||||
hw_size = pad_expert_w13.shape[-2] * pad_expert_w13.shape[-1]
|
||||
narrow_data = supa_w13_weight.view_as_usharp(
|
||||
"COLMAJOR", pad_expert_w13.shape, Sbp.ss(0),
|
||||
expert_id * hw_size)
|
||||
narrow_data.copy_(pad_expert_w13)
|
||||
|
||||
layer.w13_weight.data = supa_w13_weight
|
||||
|
||||
# NOTE: w13_bias
|
||||
if hasattr(layer, "w13_bias") and layer.w13_bias is not None:
|
||||
wn = layer.intermediate_size_per_partition * 2
|
||||
supa_w13_bias = torch_br._empty_ut_only(
|
||||
size=(layer.local_num_experts, wn),
|
||||
dtype=torch.float32,
|
||||
is_numa=False,
|
||||
device=cur_device,
|
||||
tensor_type="linear_bias",
|
||||
sbp="BB" if is_dual_die else None)
|
||||
for expert_id in range(layer.local_num_experts):
|
||||
expert_w13_bias = layer.w13_bias[expert_id]
|
||||
# swigluoai activation, no need do interweave
|
||||
if layer.activation and layer.activation == "swigluoai":
|
||||
narrow_data = supa_w13_bias[expert_id]
|
||||
narrow_data.copy_(expert_w13_bias)
|
||||
else:
|
||||
expert_1_bias, expert_3_bias = expert_w13_bias.chunk(2, dim=-1)
|
||||
crossed_expert_w13_bias = cross_weight_32(
|
||||
expert_1_bias,
|
||||
expert_3_bias,
|
||||
die_spc_num,
|
||||
dim=0,
|
||||
need_pad=False,
|
||||
)
|
||||
narrow_data = supa_w13_bias[expert_id]
|
||||
narrow_data.copy_(crossed_expert_w13_bias)
|
||||
layer.w13_bias.data = supa_w13_bias
|
||||
|
||||
# NOTE: w2_weight
|
||||
# after _load_w2, w2_weight is a rowparallel weight, shape
|
||||
# [num_experts, hidden_size, intermediate_size_per_partition]
|
||||
# for SUPA, transform it to a NUMA colmajor weight, shape
|
||||
# [spc_num * num_experts, wk, wn_block]
|
||||
align_size = 32
|
||||
wk = layer.intermediate_size_per_partition
|
||||
wn_block = align_n(layer.hidden_size,
|
||||
align_size=align_size,
|
||||
spc_num=spc_num)
|
||||
|
||||
supa_w2_weight = torch_br._empty_ut_only(
|
||||
size=(die_spc_num * layer.local_num_experts, wk // die_num, wn_block),
|
||||
dtype=torch.bfloat16,
|
||||
is_numa=True,
|
||||
device=cur_device,
|
||||
tensor_type="colmajor",
|
||||
axis=0,
|
||||
sbp="SS" if is_dual_die else None)
|
||||
|
||||
for expert_id in range(layer.local_num_experts):
|
||||
expert_w2 = layer.w2_weight[expert_id].transpose(0, 1).contiguous()
|
||||
pad_expert_w2 = _convert_to_numa_tensor(expert_w2,
|
||||
align_size,
|
||||
'COLMAJOR',
|
||||
expert_w2.dtype,
|
||||
parallel_type="row_parallel")
|
||||
pad_expert_w2_shape = pad_expert_w2.shape
|
||||
hw_size = pad_expert_w2_shape[-2] * pad_expert_w2_shape[-1]
|
||||
narrow_data = supa_w2_weight.view_as_usharp("COLMAJOR",
|
||||
pad_expert_w2_shape,
|
||||
Sbp.ss(0),
|
||||
expert_id * hw_size)
|
||||
narrow_data.copy_(pad_expert_w2)
|
||||
|
||||
layer.w2_weight.data = supa_w2_weight
|
||||
|
||||
# NOTE: w2_bias
|
||||
if hasattr(layer, "w2_bias") and layer.w2_bias is not None:
|
||||
wn = layer.hidden_size
|
||||
supa_w2_bias = torch.zeros((layer.local_num_experts, wn),
|
||||
dtype=torch.float32,
|
||||
device=cur_device)
|
||||
for expert_id in range(layer.local_num_experts):
|
||||
expert_w2 = layer.w2_bias[expert_id]
|
||||
narrow_data = supa_w2_bias[expert_id]
|
||||
narrow_data.copy_(expert_w2)
|
||||
|
||||
layer.w2_bias.data = supa_w2_bias
|
||||
|
||||
|
||||
@patch_to(FusedMoE)
|
||||
def forward(self: FusedMoE, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor):
|
||||
"""
|
||||
! router_logits is a tuple of gate, shared_experts.gate_up_proj,
|
||||
shared_experts.down_proj weights.
|
||||
"""
|
||||
assert self.quant_method is not None
|
||||
assert self.dp_size == 1, 'dp_size > 1 is not supported for now, please refer v0.11.0 moe codes'
|
||||
|
||||
final_hidden_states = self.quant_method.apply(
|
||||
layer=self,
|
||||
x=hidden_states,
|
||||
router_logits=router_logits,
|
||||
top_k=self.top_k,
|
||||
renormalize=self.renormalize,
|
||||
use_grouped_topk=self.use_grouped_topk,
|
||||
global_num_experts=self.global_num_experts,
|
||||
expert_map=self.expert_map,
|
||||
topk_group=self.topk_group,
|
||||
num_expert_group=self.num_expert_group,
|
||||
custom_routing_function=self.custom_routing_function,
|
||||
scoring_func=self.scoring_func,
|
||||
e_score_correction_bias=self.e_score_correction_bias,
|
||||
activation=self.activation,
|
||||
apply_router_weight_on_input=self.apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
# NOTE: if using supa-moe-ccl kernel, add property `all_reduced` to the final_hidden_states
|
||||
support_types = ((16, 4), (16, 8), (32, 2), (32, 4))
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
if hidden_states.shape[
|
||||
0] <= envs.VLLM_BR_STATIC_MOE_DECODER_MAX_LEN and envs.VLLM_BR_QUANT_METHOD != "INT4" and envs.VLLM_BR_USE_FUSED_ALLREDUCE and (
|
||||
envs.VLLM_BR_DEVICE_SPC_NUM, tp_size) in support_types:
|
||||
final_hidden_states.all_reduced = True
|
||||
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
@patch_to(FusedMoE)
|
||||
def _load_w13(self, expert_data: torch.Tensor, shard_dim: int, shard_id: str,
|
||||
loaded_weight: torch.Tensor, tp_rank: int):
|
||||
|
||||
# Index the loaded weight for tp sharding.
|
||||
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
|
||||
shard_size = expert_data.shape[shard_dim] // 2
|
||||
loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank,
|
||||
shard_size)
|
||||
# Narrow parameter and load.
|
||||
# w1, gate_proj: Load into first logical weight of w13.
|
||||
if shard_id == "w1":
|
||||
expert_data = expert_data.narrow(shard_dim, 0, shard_size)
|
||||
# w3, up_proj: Load into second logical weight of w13.
|
||||
else:
|
||||
assert shard_id == "w3"
|
||||
expert_data = expert_data.narrow(shard_dim, shard_size, shard_size)
|
||||
expert_data.copy_(loaded_weight.cpu())
|
||||
|
||||
|
||||
@patch_to(FusedMoE)
|
||||
def _load_w2(self,
|
||||
expert_data: torch.Tensor,
|
||||
shard_dim: int,
|
||||
loaded_weight: torch.Tensor,
|
||||
tp_rank: int,
|
||||
load_full: bool = False):
|
||||
|
||||
# Index the loaded weight for tp sharding.
|
||||
# down_proj: "RowParallel" so tp sharding on input_dim
|
||||
# Narrow parameter and load.
|
||||
shard_size = expert_data.shape[shard_dim]
|
||||
if not load_full:
|
||||
loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank,
|
||||
shard_size)
|
||||
# w2, down_proj: Load into only logical weight of w2.
|
||||
expert_data.copy_(loaded_weight.cpu())
|
||||
|
||||
|
||||
def wrapper_FusedMoE_init(fn):
|
||||
|
||||
@wraps(fn)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
fn(self, *args, **kwargs)
|
||||
|
||||
if self.e_score_correction_bias is not None:
|
||||
self.e_score_correction_bias.data = self.e_score_correction_bias.float(
|
||||
)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
FusedMoE.__init__ = wrapper_FusedMoE_init(FusedMoE.__init__) # noqa: E501
|
||||
518
vllm_br/model_executor/layers/fused_moe/supa_moe.py
Normal file
518
vllm_br/model_executor/layers/fused_moe/supa_moe.py
Normal file
@@ -0,0 +1,518 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch_br
|
||||
from torch_br.utils.tensor_methods import Sbp
|
||||
|
||||
from vllm_br import envs
|
||||
|
||||
|
||||
# gpt-oss moe forward version
|
||||
def fused_oss_moe_dyn(
|
||||
hidden_states: torch.Tensor,
|
||||
w13: torch.Tensor,
|
||||
w13_bias: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w2_bias: torch.Tensor,
|
||||
gating_weight: torch.Tensor,
|
||||
topk: int,
|
||||
intermediate_size: int,
|
||||
renormalize: bool,
|
||||
inplace: bool = False,
|
||||
use_grouped_topk: bool = False,
|
||||
num_expert_group: Optional[int] = None,
|
||||
topk_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
ep_rank: Optional[int] = None,
|
||||
ep_size: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
total_expert_num = gating_weight.shape[-2]
|
||||
probs_supa, indices_supa, prob_per_expert, indice_per_expert = torch_br.supa_moe_router_v2_infer(
|
||||
hidden_states,
|
||||
gating_weight,
|
||||
topk,
|
||||
ep_size,
|
||||
ep_rank,
|
||||
gating_bias=e_score_correction_bias)
|
||||
|
||||
cur_device = hidden_states.device
|
||||
probs_supa = probs_supa.cpu().permute(1, 0).contiguous().to(cur_device)
|
||||
indices_supa = indices_supa.cpu().permute(1, 0).contiguous().to(cur_device)
|
||||
indice_per_expert = indice_per_expert.cpu().permute(
|
||||
1, 0).contiguous().to(cur_device)
|
||||
prob_per_expert = prob_per_expert.cpu().permute(
|
||||
1, 0).contiguous().to(cur_device)
|
||||
|
||||
is_dual_die = (envs.VLLM_BR_DEVICE_SPC_NUM > 16)
|
||||
local_expert_num = total_expert_num // ep_size # type: ignore
|
||||
b_seq = hidden_states.shape[0]
|
||||
indices_trans_supa = torch_br._empty_ut_only(
|
||||
size=(local_expert_num, (b_seq + 64 - 1) // 64 * 64),
|
||||
dtype=torch.int32,
|
||||
is_numa=False,
|
||||
device=hidden_states.device,
|
||||
tensor_type="colmajor",
|
||||
sbp="BB" if is_dual_die else None)
|
||||
tokens_per_expert_supa = torch.bincount(indices_supa.reshape(-1),
|
||||
minlength=total_expert_num)
|
||||
|
||||
tokens_per_expert_list = tokens_per_expert_supa.cpu().data.numpy().tolist(
|
||||
)[ep_rank * local_expert_num:(ep_rank + 1) * # type: ignore
|
||||
local_expert_num]
|
||||
topk_per_expert = sum(1 for x in tokens_per_expert_list if x != 0)
|
||||
|
||||
if topk_per_expert > 0:
|
||||
expert_tokens = torch_br.supa_permutation_infer(
|
||||
global_hidden_states=hidden_states,
|
||||
indices=indice_per_expert,
|
||||
tokens_per_expert=tokens_per_expert_list,
|
||||
indices_trans=indices_trans_supa)
|
||||
|
||||
assert len(
|
||||
expert_tokens) == local_expert_num, "Number of experts mismatch"
|
||||
|
||||
gate_up_outputs = []
|
||||
down_outputs = []
|
||||
cur_device = expert_tokens[0].device
|
||||
hidden_size = expert_tokens[0].shape[-1]
|
||||
for i in range(local_expert_num):
|
||||
if tokens_per_expert_list[i] == 0:
|
||||
gate_up_outputs.append(
|
||||
torch.empty(size=(0, intermediate_size),
|
||||
dtype=torch.bfloat16,
|
||||
device=cur_device))
|
||||
down_outputs.append(
|
||||
torch.empty(size=(0, hidden_size),
|
||||
dtype=torch.float32,
|
||||
device=cur_device))
|
||||
continue
|
||||
|
||||
gate_up_output = torch_br._empty_ut_only(
|
||||
size=(tokens_per_expert_list[i], intermediate_size),
|
||||
dtype=torch.bfloat16,
|
||||
is_numa=False,
|
||||
device=cur_device,
|
||||
tensor_type="colmajor",
|
||||
sbp="SB" if is_dual_die else None,
|
||||
axis=1)
|
||||
gate_up_outputs.append(gate_up_output)
|
||||
|
||||
down_output = torch_br._empty_ut_only(
|
||||
size=(tokens_per_expert_list[i], hidden_size),
|
||||
dtype=torch.float32,
|
||||
is_numa=False,
|
||||
device=cur_device,
|
||||
tensor_type="colmajor",
|
||||
sbp="BB" if is_dual_die else None)
|
||||
down_outputs.append(down_output)
|
||||
|
||||
torch_br.supa_moe_fused_ffn_dyn_infer(gate_up_outputs,
|
||||
expert_tokens,
|
||||
w13,
|
||||
tokens_per_expert_list,
|
||||
max(tokens_per_expert_list),
|
||||
bias=w13_bias,
|
||||
act_mode="act_swiglu_oai")
|
||||
|
||||
torch_br.supa_moe_fused_ffn_dyn_infer(down_outputs,
|
||||
gate_up_outputs,
|
||||
w2,
|
||||
tokens_per_expert_list,
|
||||
max(tokens_per_expert_list),
|
||||
bias=w2_bias,
|
||||
act_mode="act_default")
|
||||
|
||||
output = torch_br.supa_unpermutation_infer(
|
||||
input_list=down_outputs,
|
||||
indices=indices_trans_supa,
|
||||
probs=prob_per_expert,
|
||||
tokens_per_expert=tokens_per_expert_list)
|
||||
else:
|
||||
output = torch.zeros_like(hidden_states)
|
||||
|
||||
return output.unsqueeze(0)
|
||||
|
||||
|
||||
def fused_moe_quant_dyn(
|
||||
hidden_states: torch.Tensor,
|
||||
shared_gate_up_weight: torch.Tensor,
|
||||
down_weight: torch.Tensor,
|
||||
w13: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w13_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
gating_weight: torch.Tensor,
|
||||
topk: int,
|
||||
intermediate_size: int,
|
||||
renormalize: bool,
|
||||
inplace: bool = False,
|
||||
use_grouped_topk: bool = False,
|
||||
num_expert_group: Optional[int] = None,
|
||||
topk_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
tp_rank: Optional[int] = None,
|
||||
global_rank: Optional[int] = None,
|
||||
tp_size: Optional[int] = None,
|
||||
ep_rank: Optional[int] = None,
|
||||
ep_size: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
This function computes a Mixture of Experts (MoE) layer using two sets of
|
||||
weights, w1 and w2, and top-k gating mechanism.
|
||||
|
||||
Parameters:
|
||||
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
|
||||
- w1 (torch.Tensor): The first set of expert weights.
|
||||
- w2 (torch.Tensor): The second set of expert weights.
|
||||
- gating_output (torch.Tensor): The output of the gating operation
|
||||
(before softmax).
|
||||
- topk (int): The number of top-k experts to select.
|
||||
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
|
||||
- inplace (bool): If True, perform the operation in-place.
|
||||
Defaults to False.
|
||||
- override_config (Optional[Dict[str, Any]]): Optional override
|
||||
for the kernel configuration.
|
||||
- num_expert_group: Optional[int]: additional parameter for grouped_topk
|
||||
- topk_group: Optional[int]: additional parameter for grouped_topk
|
||||
- use_grouped_topk: If True, use grouped_topk instead of fused_topk
|
||||
note: Deepseekv2 model uses grouped_topk
|
||||
- use_fp8 (bool): If True, use fp8 arithmetic to compute the inner
|
||||
products for w1 and w2. Defaults to False.
|
||||
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
|
||||
w1.
|
||||
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
|
||||
w2.
|
||||
|
||||
Returns:
|
||||
- torch.Tensor: The output tensor after applying the MoE layer.
|
||||
"""
|
||||
is_dual_die = (envs.VLLM_BR_DEVICE_SPC_NUM > 16)
|
||||
total_expert_num = gating_weight.shape[-1]
|
||||
cur_device = hidden_states.device
|
||||
if use_grouped_topk:
|
||||
assert num_expert_group is not None and topk_group is not None
|
||||
shared_output, _, indices_supa, indice_per_expert, prob_per_expert = torch_br.supa_fused_shared_router_prefill_v2_infer(
|
||||
hidden_states,
|
||||
shared_gate_up_weight,
|
||||
down_weight,
|
||||
gating_weight,
|
||||
intermediate_size,
|
||||
topk,
|
||||
num_expert_group,
|
||||
topk_group,
|
||||
ep_size,
|
||||
ep_rank,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias)
|
||||
if is_dual_die:
|
||||
shared_tmp = torch_br._empty_ut_only(size=shared_output.shape,
|
||||
dtype=shared_output.dtype,
|
||||
is_numa=False,
|
||||
device=shared_output.device,
|
||||
tensor_type="colmajor")
|
||||
shared_tmp.copy_(shared_output)
|
||||
shared_output = shared_tmp
|
||||
else:
|
||||
assert topk_group is None, "Only support non group topk router"
|
||||
assert shared_gate_up_weight is None and down_weight is None
|
||||
_, indices_supa, prob_per_expert, indice_per_expert = torch_br.supa_moe_router_v2_infer(
|
||||
hidden_states,
|
||||
gating_weight.permute(1, 0).contiguous(),
|
||||
topk,
|
||||
ep_size,
|
||||
ep_rank,
|
||||
gating_bias=e_score_correction_bias)
|
||||
shared_output = None
|
||||
|
||||
indices_supa = indices_supa.permute(1, 0).contiguous()
|
||||
indice_per_expert = indice_per_expert.permute(1, 0).contiguous()
|
||||
prob_per_expert = prob_per_expert.permute(1, 0).contiguous()
|
||||
|
||||
local_expert_num = total_expert_num // ep_size # type: ignore
|
||||
b_seq = hidden_states.shape[0]
|
||||
indices_trans_supa = torch_br._empty_ut_only(
|
||||
size=(local_expert_num, (b_seq + 64 - 1) // 64 * 64),
|
||||
dtype=torch.int32,
|
||||
is_numa=False,
|
||||
device=hidden_states.device,
|
||||
tensor_type="colmajor",
|
||||
sbp="BB" if is_dual_die else None)
|
||||
tokens_per_expert_supa = torch.bincount(indices_supa.reshape(-1),
|
||||
minlength=total_expert_num)
|
||||
|
||||
tokens_per_expert_list = tokens_per_expert_supa.cpu().data.numpy().tolist(
|
||||
)[ep_rank * local_expert_num:(ep_rank + 1) * # type: ignore
|
||||
local_expert_num]
|
||||
topk_per_expert = sum(1 for x in tokens_per_expert_list if x != 0)
|
||||
|
||||
if topk_per_expert > 0:
|
||||
expert_tokens = torch_br.supa_permutation_infer(
|
||||
global_hidden_states=hidden_states,
|
||||
indices=indice_per_expert,
|
||||
tokens_per_expert=tokens_per_expert_list,
|
||||
indices_trans=indices_trans_supa)
|
||||
|
||||
assert len(
|
||||
expert_tokens) == local_expert_num, "Number of experts mismatch"
|
||||
|
||||
supa_device = torch.supa.current_device()
|
||||
spc_num = torch_br.supa.get_device_properties(
|
||||
supa_device).max_compute_units
|
||||
|
||||
out_expert_tokens = []
|
||||
use_moe_fused_ffn_dyn = True
|
||||
if not use_moe_fused_ffn_dyn or total_expert_num == 128:
|
||||
w13_hw = w13.shape[-2] * w13.shape[-1]
|
||||
w2_hw = w2.shape[-2] * w2.shape[-1]
|
||||
|
||||
for i in range(local_expert_num):
|
||||
expert_token = expert_tokens[i]
|
||||
if tokens_per_expert_list[i] == 0:
|
||||
out_expert_tokens.append(expert_token)
|
||||
continue
|
||||
|
||||
expert_gate_up_weight = w13.view_as_usharp(
|
||||
"COLMAJOR", (spc_num, w13.shape[-2], w13.shape[-1]),
|
||||
Sbp.ss(0), i * w13_hw)
|
||||
|
||||
down_weight = w2.view_as_usharp(
|
||||
"COLMAJOR", (spc_num, w2.shape[-2], w2.shape[-1]),
|
||||
Sbp.ss(0), i * w2_hw)
|
||||
|
||||
expert_gate_up_scale = w13_scale[
|
||||
i] if w13_scale is not None else None
|
||||
down_scale = w2_scale[i] if w2_scale is not None else None
|
||||
|
||||
gate_up_output = torch_br._empty_ut_only(
|
||||
size=(expert_token.shape[0], intermediate_size),
|
||||
dtype=torch.bfloat16,
|
||||
is_numa=False,
|
||||
device=expert_token.device,
|
||||
tensor_type="colmajor",
|
||||
sbp="SB" if is_dual_die else None,
|
||||
axis=1)
|
||||
torch_br.supa_fused_linear_infer(gate_up_output,
|
||||
expert_token,
|
||||
expert_gate_up_weight,
|
||||
expert_gate_up_scale,
|
||||
act_mode="act_swiglu")
|
||||
|
||||
down_output = torch_br._empty_ut_only(
|
||||
size=expert_token.shape,
|
||||
dtype=torch.float32,
|
||||
is_numa=False,
|
||||
device=gate_up_output.device,
|
||||
tensor_type="colmajor",
|
||||
sbp="BB" if is_dual_die else None)
|
||||
|
||||
torch_br.supa_fused_linear_infer(down_output, gate_up_output,
|
||||
down_weight, down_scale)
|
||||
|
||||
out_expert_tokens.append(down_output)
|
||||
else:
|
||||
gate_up_outputs = []
|
||||
cur_device = expert_tokens[0].device
|
||||
hidden_size = expert_tokens[0].shape[-1]
|
||||
for i in range(local_expert_num):
|
||||
if tokens_per_expert_list[i] == 0:
|
||||
gate_up_outputs.append(
|
||||
torch.empty(size=(0, intermediate_size),
|
||||
dtype=torch.bfloat16,
|
||||
device=cur_device))
|
||||
out_expert_tokens.append(
|
||||
torch.empty(size=(0, hidden_size),
|
||||
dtype=torch.float32,
|
||||
device=cur_device))
|
||||
continue
|
||||
|
||||
gate_up_output = torch_br._empty_ut_only(
|
||||
size=(tokens_per_expert_list[i], intermediate_size),
|
||||
dtype=torch.bfloat16,
|
||||
is_numa=False,
|
||||
device=cur_device,
|
||||
tensor_type="colmajor",
|
||||
sbp="SB" if is_dual_die else None,
|
||||
axis=1)
|
||||
gate_up_outputs.append(gate_up_output)
|
||||
|
||||
down_output = torch_br._empty_ut_only(
|
||||
size=(tokens_per_expert_list[i], hidden_size),
|
||||
dtype=torch.float32,
|
||||
is_numa=False,
|
||||
device=cur_device,
|
||||
tensor_type="colmajor",
|
||||
sbp="BB" if is_dual_die else None)
|
||||
out_expert_tokens.append(down_output)
|
||||
torch_br.supa_moe_fused_ffn_dyn_infer(gate_up_outputs,
|
||||
expert_tokens,
|
||||
w13,
|
||||
tokens_per_expert_list,
|
||||
max(tokens_per_expert_list),
|
||||
scales=w13_scale,
|
||||
act_mode="act_swiglu")
|
||||
torch_br.supa_moe_fused_ffn_dyn_infer(out_expert_tokens,
|
||||
gate_up_outputs,
|
||||
w2,
|
||||
tokens_per_expert_list,
|
||||
max(tokens_per_expert_list),
|
||||
scales=w2_scale,
|
||||
act_mode="act_default")
|
||||
|
||||
out_states = torch_br.supa_unpermutation_infer(
|
||||
input_list=out_expert_tokens,
|
||||
indices=indices_trans_supa,
|
||||
probs=prob_per_expert,
|
||||
tokens_per_expert=tokens_per_expert_list)
|
||||
|
||||
output = out_states if shared_output is None else out_states + shared_output
|
||||
else:
|
||||
output = torch.zeros_like(
|
||||
hidden_states) if shared_output is None else shared_output
|
||||
|
||||
return output.unsqueeze(0)
|
||||
|
||||
|
||||
def fused_moe_quant_device(
|
||||
hidden_states: torch.Tensor,
|
||||
shared_gate_up_weight: torch.Tensor,
|
||||
down_weight: torch.Tensor,
|
||||
w13: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w13_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
gating_weight: torch.Tensor,
|
||||
topk: int,
|
||||
intermediate_size: int,
|
||||
renormalize: bool,
|
||||
inplace: bool = False,
|
||||
use_grouped_topk: bool = False,
|
||||
num_expert_group: Optional[int] = None,
|
||||
topk_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
tp_rank: Optional[int] = None,
|
||||
global_rank: Optional[int] = None,
|
||||
tp_size: Optional[int] = None,
|
||||
ep_rank: Optional[int] = None,
|
||||
ep_size: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
This function computes a Mixture of Experts (MoE) layer using two sets of
|
||||
weights, w1 and w2, and top-k gating mechanism.
|
||||
|
||||
Parameters:
|
||||
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
|
||||
- w1 (torch.Tensor): The first set of expert weights.
|
||||
- w2 (torch.Tensor): The second set of expert weights.
|
||||
- gating_output (torch.Tensor): The output of the gating operation
|
||||
(before softmax).
|
||||
- topk (int): The number of top-k experts to select.
|
||||
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
|
||||
- inplace (bool): If True, perform the operation in-place.
|
||||
Defaults to False.
|
||||
- override_config (Optional[Dict[str, Any]]): Optional override
|
||||
for the kernel configuration.
|
||||
- num_expert_group: Optional[int]: additional parameter for grouped_topk
|
||||
- topk_group: Optional[int]: additional parameter for grouped_topk
|
||||
- use_grouped_topk: If True, use grouped_topk instead of fused_topk
|
||||
note: Deepseekv2 model uses grouped_topk
|
||||
- use_fp8 (bool): If True, use fp8 arithmetic to compute the inner
|
||||
products for w1 and w2. Defaults to False.
|
||||
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
|
||||
w1.
|
||||
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
|
||||
w2.
|
||||
|
||||
Returns:
|
||||
- torch.Tensor: The output tensor after applying the MoE layer.
|
||||
"""
|
||||
is_dual_die = (envs.VLLM_BR_DEVICE_SPC_NUM > 16)
|
||||
expert_num = gating_weight.shape[-1]
|
||||
b_seq = hidden_states.shape[-2]
|
||||
if topk_group is None:
|
||||
assert shared_gate_up_weight is None and down_weight is None
|
||||
shared_output, masked_probs, hitted_experts = torch_br.supa_moe_router_decoder_infer(
|
||||
hidden_states, gating_weight, topk, ep_size, ep_rank)
|
||||
else:
|
||||
assert use_grouped_topk is True, "Only support group topk router"
|
||||
assert num_expert_group is not None and topk_group is not None
|
||||
if ep_size > 1: # type: ignore
|
||||
shared_output, masked_probs, hitted_experts = torch_br.supa_fused_shared_router_v2_infer(
|
||||
hidden_states,
|
||||
shared_gate_up_weight,
|
||||
down_weight,
|
||||
gating_weight,
|
||||
intermediate_size,
|
||||
topk,
|
||||
num_expert_group,
|
||||
topk_group,
|
||||
ep_size,
|
||||
ep_rank,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias
|
||||
if e_score_correction_bias is not None else torch.empty(
|
||||
(expert_num),
|
||||
dtype=torch.float32,
|
||||
device=hidden_states.device))
|
||||
else:
|
||||
shared_output, masked_probs, hitted_experts = torch_br.supa_fused_shared_router_infer(
|
||||
hidden_states,
|
||||
shared_gate_up_weight,
|
||||
down_weight,
|
||||
gating_weight,
|
||||
intermediate_size,
|
||||
topk,
|
||||
num_expert_group,
|
||||
topk_group,
|
||||
scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias
|
||||
if e_score_correction_bias is not None else torch.empty(
|
||||
(expert_num),
|
||||
dtype=torch.float32,
|
||||
device=hidden_states.device))
|
||||
if is_dual_die:
|
||||
shared_output = shared_output.view_as_usharp(
|
||||
"COLMAJOR", shared_output.shape, Sbp.bb())
|
||||
|
||||
if w13.dtype == torch.int32:
|
||||
torch_br.supa_moe_fused_ffn_s4_infer(shared_output, hidden_states, w13,
|
||||
w2, hitted_experts, masked_probs,
|
||||
w13_scale, w2_scale)
|
||||
else:
|
||||
support_types = ((16, 4), (16, 8), (32, 2), (32, 4))
|
||||
if envs.VLLM_BR_USE_FUSED_ALLREDUCE and b_seq <= envs.VLLM_BR_STATIC_MOE_DECODER_MAX_LEN and (
|
||||
envs.VLLM_BR_DEVICE_SPC_NUM, tp_size) in support_types:
|
||||
# ffn+allreduce only support tp 4|8 and 16spc
|
||||
torch_br.supa_moe_fused_ffn_allreduce(shared_output, hidden_states,
|
||||
w13, w2, hitted_experts,
|
||||
masked_probs, tp_rank,
|
||||
tp_size, global_rank, 0,
|
||||
w13_scale, w2_scale)
|
||||
else:
|
||||
torch_br.supa_moe_fused_ffn_infer(shared_output, hidden_states,
|
||||
w13, w2, hitted_experts,
|
||||
masked_probs, w13_scale,
|
||||
w2_scale)
|
||||
|
||||
return shared_output.unsqueeze(0)
|
||||
67
vllm_br/model_executor/layers/layernorm.py
Normal file
67
vllm_br/model_executor/layers/layernorm.py
Normal file
@@ -0,0 +1,67 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
import os
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch_br
|
||||
from fastcore.basics import patch_to
|
||||
from torch import Tensor, nn
|
||||
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
|
||||
|
||||
@patch_to(RMSNorm)
|
||||
def forward_oot(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
if self.weight.data.dtype == torch.bfloat16:
|
||||
self.weight.data = self.weight.data.to(torch.float32)
|
||||
|
||||
if residual is not None:
|
||||
y_supa, add_out_supa = torch_br.supa_add_rmsnorm_infer( # type: ignore
|
||||
x, residual, self.weight.data, self.variance_epsilon)
|
||||
return y_supa, add_out_supa
|
||||
else:
|
||||
if len(x.shape) == 2:
|
||||
x = x.unsqueeze(0)
|
||||
if len(x.shape) == 4:
|
||||
x = x.squeeze(0)
|
||||
|
||||
x = torch_br.supa_rmsnorm_infer(
|
||||
x,
|
||||
self.weight.data,
|
||||
self.variance_epsilon # type: ignore
|
||||
)
|
||||
return x
|
||||
|
||||
|
||||
@patch_to(RMSNorm)
|
||||
def enabled(cls) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
@patch_to(nn.LayerNorm)
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
if os.environ.get("USE_BR_FUSED_LAYERNORM",
|
||||
'False').lower() not in {'false', '0', ''}:
|
||||
return torch_br.fused_layernorm(input, self.weight, self.bias,
|
||||
self.eps)
|
||||
else:
|
||||
return nn.functional.layer_norm(input, self.normalized_shape,
|
||||
self.weight, self.bias, self.eps)
|
||||
767
vllm_br/model_executor/layers/linear.py
Normal file
767
vllm_br/model_executor/layers/linear.py
Normal file
@@ -0,0 +1,767 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
from typing import Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch_br
|
||||
import torch_br.supa._debug as supa_debug
|
||||
from fastcore.basics import patch_to
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
get_tp_group, split_tensor_along_last_dim,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.distributed.parallel_state import get_pp_group
|
||||
from vllm.logger import logger
|
||||
from vllm.model_executor.layers.linear import (adjust_bitsandbytes_4bit_shard,
|
||||
adjust_marlin_shard,
|
||||
adjust_scalar_to_fused_array)
|
||||
from vllm_br import envs
|
||||
from vllm_br.utils import get_grandparent_pid
|
||||
from .br_utils import (_convert_to_crossed_numa_tensor,
|
||||
_convert_to_numa_tensor, _convert_to_numa_tensor_vit,
|
||||
is_br166_device)
|
||||
|
||||
from vllm.model_executor.layers.linear import ( # isort:skip
|
||||
LinearBase, MergedColumnParallelLinear, QuantizationConfig,
|
||||
ReplicatedLinear, RowParallelLinear, UnquantizedLinearMethod,
|
||||
QKVParallelLinear)
|
||||
|
||||
|
||||
def _should_skip_linear_post_process(layer, use_ds_mla, use_ds_mla_sparse):
|
||||
"""NOTE: SUPA: for MLA linears, we do process in MLA.process_weights_after_loading """
|
||||
# TODO: Hard code for native dsa op
|
||||
if use_ds_mla_sparse:
|
||||
MLA_LINEAR_NAMES = [
|
||||
"kv_b_proj",
|
||||
]
|
||||
else:
|
||||
MLA_LINEAR_NAMES = [
|
||||
"q_a_proj",
|
||||
"q_b_proj",
|
||||
# "q_proj",
|
||||
"kv_a_proj_with_mqa",
|
||||
"kv_b_proj",
|
||||
# "o_proj",
|
||||
]
|
||||
if use_ds_mla and not use_ds_mla_sparse:
|
||||
MLA_LINEAR_NAMES.append("o_proj")
|
||||
|
||||
skip = any(k in layer.prefix for k in MLA_LINEAR_NAMES)
|
||||
if skip:
|
||||
logger.debug(
|
||||
f'[SUPA] skip {layer.prefix} UnquantizedLinearMethod.process_weights_after_loading' # noqa: G004
|
||||
)
|
||||
return skip
|
||||
|
||||
|
||||
# NOTE: ReplicatedLinear, usually used in MoE as a gate module.
|
||||
# In DeepseekV3, it needs to be transposed.
|
||||
def process_weights_ReplicatedLinear(
|
||||
layer: ReplicatedLinear) -> Literal[True, False]:
|
||||
layer.weight.data = layer.weight.data.transpose(1, 0).contiguous()
|
||||
return True
|
||||
|
||||
|
||||
def process_share_expert_weight(layer: MergedColumnParallelLinear):
|
||||
gate_up_weight = layer.weight.transpose(1, 0).contiguous()
|
||||
|
||||
gate_weight, up_weight = torch.chunk(gate_up_weight, 2, dim=-1)
|
||||
|
||||
die_spc_num = envs.VLLM_BR_DEVICE_SPC_NUM
|
||||
is_br166 = die_spc_num > 16
|
||||
spc_num = die_spc_num // 2 if is_br166 else die_spc_num
|
||||
|
||||
if is_br166:
|
||||
# 2&2 for 4spc, 4&8 for 12spc, 8&8 for 16spc
|
||||
spc_for_shared = 2 if spc_num == 4 else 8
|
||||
spc_for_router = spc_num - spc_for_shared
|
||||
|
||||
align_size = 32
|
||||
weight_dtype = gate_weight.dtype
|
||||
hidden_size = gate_weight.shape[0]
|
||||
|
||||
gate_d0, gate_d1 = torch.chunk(gate_weight, 2, dim=-1)
|
||||
up_d0, up_d1 = torch.chunk(up_weight, 2, dim=-1)
|
||||
im_size = gate_d0.shape[-1]
|
||||
n_align_size = (align_size * 2) * spc_for_shared
|
||||
swiglu_w_aligned = ((
|
||||
(im_size * 2) + n_align_size - 1) // n_align_size) * n_align_size
|
||||
region_size = swiglu_w_aligned // spc_for_shared
|
||||
block_nums = (region_size // (align_size * 2)) * spc_for_shared
|
||||
|
||||
gate_d0_align = torch.nn.functional.pad(
|
||||
gate_d0, (0, (swiglu_w_aligned // 2) - im_size, 0, 0),
|
||||
mode='constant',
|
||||
value=0)
|
||||
gate_d1_align = torch.nn.functional.pad(
|
||||
gate_d1, (0, (swiglu_w_aligned // 2) - im_size, 0, 0),
|
||||
mode='constant',
|
||||
value=0)
|
||||
up_d0_align = torch.nn.functional.pad(
|
||||
up_d0, (0, (swiglu_w_aligned // 2) - im_size, 0, 0),
|
||||
mode='constant',
|
||||
value=0)
|
||||
up_d1_align = torch.nn.functional.pad(
|
||||
up_d1, (0, (swiglu_w_aligned // 2) - im_size, 0, 0),
|
||||
mode='constant',
|
||||
value=0)
|
||||
gate_weight_d0_reshape = gate_d0_align.reshape(
|
||||
hidden_size, block_nums, align_size).contiguous().to(weight_dtype)
|
||||
gate_weight_d1_reshape = gate_d1_align.reshape(
|
||||
hidden_size, block_nums, align_size).contiguous().to(weight_dtype)
|
||||
up_weight_d0_reshape = up_d0_align.reshape(
|
||||
hidden_size, block_nums, align_size).contiguous().to(weight_dtype)
|
||||
up_weight_d1_reshape = up_d1_align.reshape(
|
||||
hidden_size, block_nums, align_size).contiguous().to(weight_dtype)
|
||||
|
||||
gate_up_weight_d0 = torch.zeros(
|
||||
[hidden_size, block_nums, align_size * 2],
|
||||
dtype=weight_dtype,
|
||||
device='supa')
|
||||
|
||||
gate_up_weight_d0[:, :, 0:0 +
|
||||
align_size] = gate_weight_d0_reshape[:, :,
|
||||
0:align_size]
|
||||
|
||||
gate_up_weight_d0[:, :, align_size:align_size *
|
||||
2] = up_weight_d0_reshape[:, :, 0:align_size]
|
||||
gate_up_weight_d0 = gate_up_weight_d0.reshape(
|
||||
hidden_size, spc_for_shared,
|
||||
region_size).permute(1, 0, 2).contiguous().to(weight_dtype)
|
||||
|
||||
gate_up_d0_invalid = torch.zeros(
|
||||
[spc_for_router, hidden_size, region_size],
|
||||
dtype=weight_dtype,
|
||||
device='supa') # invalid regions
|
||||
gate_up_weight_d0_whole = torch.cat(
|
||||
[gate_up_weight_d0, gate_up_d0_invalid], dim=0)
|
||||
|
||||
gate_up_weight_d1 = torch.zeros(
|
||||
[hidden_size, block_nums, align_size * 2],
|
||||
dtype=weight_dtype,
|
||||
device='supa')
|
||||
|
||||
gate_up_weight_d1[:, :, 0:0 +
|
||||
align_size] = gate_weight_d1_reshape[:, :,
|
||||
0:align_size]
|
||||
|
||||
gate_up_weight_d1[:, :, align_size:align_size *
|
||||
2] = up_weight_d1_reshape[:, :, 0:align_size]
|
||||
gate_up_weight_d1 = gate_up_weight_d1.reshape(
|
||||
hidden_size, spc_for_shared,
|
||||
region_size).permute(1, 0, 2).contiguous().to(weight_dtype)
|
||||
|
||||
gate_up_d1_invalid = torch.zeros(
|
||||
[spc_for_router, hidden_size, region_size],
|
||||
dtype=weight_dtype,
|
||||
device='supa') # invalid regions
|
||||
gate_up_weight_d1_whole = torch.cat(
|
||||
[gate_up_weight_d1, gate_up_d1_invalid], dim=0)
|
||||
|
||||
gate_up_weight_whole = torch.cat(
|
||||
[gate_up_weight_d0_whole, gate_up_weight_d1_whole], dim=0)
|
||||
gate_up_weight_supa = torch_br._empty_ut_only(
|
||||
size=gate_up_weight_whole.shape,
|
||||
dtype=gate_weight.dtype,
|
||||
is_numa=True,
|
||||
device="supa",
|
||||
tensor_type="colmajor")
|
||||
gate_up_weight_supa.copy_(gate_up_weight_whole)
|
||||
|
||||
layer.weight.data = gate_up_weight_supa
|
||||
else:
|
||||
# 2&2 for 4spc, 4&8 for 12spc, 8&8 for 16spc
|
||||
spc_for_shared = 2 if spc_num == 4 else 8
|
||||
spc_for_router = spc_num - spc_for_shared
|
||||
|
||||
align_size = 32
|
||||
weight_dtype = gate_weight.dtype
|
||||
hidden_size = gate_weight.shape[0]
|
||||
im_size = gate_weight.shape[-1]
|
||||
n_align_size = (align_size * 2) * spc_for_shared
|
||||
swiglu_w_aligned = ((
|
||||
(im_size * 2) + n_align_size - 1) // n_align_size) * n_align_size
|
||||
region_size = swiglu_w_aligned // spc_for_shared
|
||||
block_nums = (region_size // (align_size * 2)) * spc_for_shared
|
||||
|
||||
gate_golden_align = torch.nn.functional.pad(
|
||||
gate_weight, (0, (swiglu_w_aligned // 2) - im_size, 0, 0),
|
||||
mode='constant',
|
||||
value=0)
|
||||
up_golden_align = torch.nn.functional.pad(
|
||||
up_weight, (0, (swiglu_w_aligned // 2) - im_size, 0, 0),
|
||||
mode='constant',
|
||||
value=0)
|
||||
gate_weight_golden_reshape = gate_golden_align.reshape(
|
||||
hidden_size, block_nums, align_size).contiguous().to(weight_dtype)
|
||||
up_weight_golden_reshape = up_golden_align.reshape(
|
||||
hidden_size, block_nums, align_size).contiguous().to(weight_dtype)
|
||||
|
||||
gate_up_weight_golden = torch.zeros(
|
||||
[hidden_size, block_nums, align_size * 2],
|
||||
dtype=weight_dtype,
|
||||
device='supa')
|
||||
|
||||
gate_up_weight_golden[:, :, 0:0 +
|
||||
align_size] = gate_weight_golden_reshape[:, :, 0:
|
||||
align_size]
|
||||
|
||||
gate_up_weight_golden[:, :, align_size:align_size *
|
||||
2] = up_weight_golden_reshape[:, :, 0:align_size]
|
||||
gate_up_weight_golden = gate_up_weight_golden.reshape(
|
||||
hidden_size, spc_for_shared,
|
||||
region_size).permute(1, 0, 2).contiguous().to(weight_dtype)
|
||||
|
||||
gate_up_invalid = torch.zeros(
|
||||
[spc_for_router, hidden_size, region_size],
|
||||
dtype=weight_dtype,
|
||||
device='supa') # invalid regions
|
||||
gate_up_weight_whole = torch.cat(
|
||||
[gate_up_weight_golden, gate_up_invalid], dim=0)
|
||||
|
||||
gate_up_weight_supa = torch_br._empty_ut_only(
|
||||
size=gate_up_weight_whole.shape,
|
||||
dtype=gate_weight.dtype,
|
||||
is_numa=True,
|
||||
device="supa",
|
||||
tensor_type="colmajor")
|
||||
gate_up_weight_supa.copy_(gate_up_weight_whole)
|
||||
|
||||
layer.weight.data = gate_up_weight_supa
|
||||
|
||||
|
||||
# NOTE: MergedColumnParallelLinear, usually used in MergedGateUpMLPSiluL2
|
||||
def process_weights_QuantMergedColumnParallelLinear(
|
||||
layer: MergedColumnParallelLinear):
|
||||
if 'shared_experts' not in layer.prefix:
|
||||
#NOTE: normal MLP gate_up, after load weight, convert to supa numa tensor
|
||||
if hasattr(layer, "qweight"):
|
||||
gate_weight, up_weight = torch.chunk(layer.qweight, 2, dim=-1)
|
||||
gate_up_weight_numa = _convert_to_crossed_numa_tensor(
|
||||
gate_weight,
|
||||
up_weight,
|
||||
envs.VLLM_BR_DEVICE_SPC_NUM,
|
||||
dim=-1,
|
||||
need_pad=True,
|
||||
do_transpose=False)
|
||||
layer.qweight.data = gate_up_weight_numa
|
||||
else:
|
||||
gate_up_weight = layer.weight.permute(1, 0).contiguous()
|
||||
gate_up_weight_numa = _convert_to_numa_tensor(
|
||||
gate_up_weight,
|
||||
32,
|
||||
"colmajor",
|
||||
gate_up_weight.dtype,
|
||||
False,
|
||||
parallel_type="col_parallel")
|
||||
layer.weight.data = gate_up_weight_numa
|
||||
|
||||
if hasattr(layer, "scales") and layer.scales is not None:
|
||||
gate_scales, up_scales = torch.chunk(layer.scales, 2, dim=-1)
|
||||
gate_up_scales_internleaved_numa = _convert_to_crossed_numa_tensor(
|
||||
gate_scales,
|
||||
up_scales,
|
||||
envs.VLLM_BR_DEVICE_SPC_NUM,
|
||||
dim=-1,
|
||||
need_pad=False,
|
||||
layout="linear_bias",
|
||||
do_transpose=False)
|
||||
layer.scales.data = gate_up_scales_internleaved_numa
|
||||
|
||||
if hasattr(layer, "bias") and layer.bias is not None:
|
||||
gate_bias, up_bias = torch.chunk(layer.bias, 2, dim=-1)
|
||||
gate_up_bias_internleaved_numa = _convert_to_crossed_numa_tensor(
|
||||
gate_bias,
|
||||
up_bias,
|
||||
envs.VLLM_BR_DEVICE_SPC_NUM,
|
||||
dim=-1,
|
||||
need_pad=False,
|
||||
layout="linear_bias",
|
||||
do_transpose=False)
|
||||
layer.bias.data = gate_up_bias_internleaved_numa
|
||||
else:
|
||||
process_share_expert_weight(layer)
|
||||
|
||||
|
||||
def process_weights_MergedColumnParallelLinear(
|
||||
layer: MergedColumnParallelLinear):
|
||||
if 'shared_experts' not in layer.prefix:
|
||||
gate_up_weight = layer.weight.permute(1, 0).contiguous()
|
||||
if not (hasattr(layer, "no_need_cross") and layer.no_need_cross):
|
||||
gate_weight, up_weight = torch.chunk(gate_up_weight, 2, dim=-1)
|
||||
gate_up_weight_internleaved_numa = _convert_to_crossed_numa_tensor(
|
||||
gate_weight,
|
||||
up_weight,
|
||||
envs.VLLM_BR_DEVICE_SPC_NUM,
|
||||
dim=-1,
|
||||
need_pad=True,
|
||||
do_transpose=False)
|
||||
layer.weight.data = gate_up_weight_internleaved_numa
|
||||
else:
|
||||
gate_up_weight_numa = _convert_to_numa_tensor(
|
||||
gate_up_weight,
|
||||
align_size=32,
|
||||
layout="colmajor",
|
||||
dtype=gate_up_weight.dtype,
|
||||
do_transpose=False)
|
||||
layer.weight.data = gate_up_weight_numa
|
||||
|
||||
if hasattr(layer, "bias") and layer.bias is not None:
|
||||
gate_bias, up_bias = torch.chunk(layer.bias, 2, dim=-1)
|
||||
gate_up_bias_internleaved_numa = _convert_to_crossed_numa_tensor(
|
||||
gate_bias,
|
||||
up_bias,
|
||||
envs.VLLM_BR_DEVICE_SPC_NUM,
|
||||
dim=-1,
|
||||
need_pad=False,
|
||||
layout="linear_bias",
|
||||
do_transpose=False)
|
||||
layer.bias.data = gate_up_bias_internleaved_numa
|
||||
|
||||
else:
|
||||
#NOTE: by default, gate module and shared_expert(1) module will be involved into calculation in 1 kernel
|
||||
process_share_expert_weight(layer)
|
||||
|
||||
|
||||
@patch_to(UnquantizedLinearMethod)
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
if _should_skip_linear_post_process(
|
||||
layer, self.use_ds_mla,
|
||||
self.use_ds_mla_sparse) or self.weight_type != "NUMA":
|
||||
return
|
||||
still_need_process = True
|
||||
do_transpose = True
|
||||
parallel_type = "col_parallel"
|
||||
# NOTE: all process_weights func should done before process_weights_after_loading
|
||||
match layer:
|
||||
case ReplicatedLinear():
|
||||
process_weights_ReplicatedLinear(layer)
|
||||
still_need_process = not ("indexer" not in layer.prefix and (
|
||||
layer.output_size == 64 or layer.output_size == 160 # Glm4-Moe
|
||||
or layer.output_size == 128 or layer.output_size == 256))
|
||||
do_transpose = False
|
||||
case MergedColumnParallelLinear():
|
||||
process_weights_MergedColumnParallelLinear(layer)
|
||||
still_need_process = False
|
||||
do_transpose = False
|
||||
case RowParallelLinear():
|
||||
parallel_type = "row_parallel"
|
||||
case _:
|
||||
pass
|
||||
|
||||
if not still_need_process or self.weight_type != "NUMA":
|
||||
return
|
||||
|
||||
# process numa weight and bias
|
||||
if hasattr(layer, "weight") and len(layer.weight.shape) == 2:
|
||||
if 'vision' in layer.prefix and is_br166_device():
|
||||
layer.weight.data = _convert_to_numa_tensor_vit(
|
||||
layer.weight,
|
||||
envs.VLLM_BR_DEVICE_WARP_SIZE,
|
||||
"colmajor",
|
||||
torch.bfloat16,
|
||||
do_transpose=do_transpose,
|
||||
wk=(layer.weight.data.shape[1]
|
||||
if do_transpose else layer.weight.data.shape[0]),
|
||||
wn=(layer.weight.data.shape[0]
|
||||
if do_transpose else layer.weight.data.shape[1]),
|
||||
parallel_type=parallel_type) # noqa: SIM210
|
||||
else:
|
||||
layer.weight.data = _convert_to_numa_tensor(
|
||||
layer.weight,
|
||||
envs.VLLM_BR_DEVICE_WARP_SIZE,
|
||||
"colmajor",
|
||||
torch.bfloat16,
|
||||
do_transpose=do_transpose,
|
||||
wk=(layer.weight.data.shape[1]
|
||||
if do_transpose else layer.weight.data.shape[0]),
|
||||
wn=(layer.weight.data.shape[0]
|
||||
if do_transpose else layer.weight.data.shape[1]),
|
||||
parallel_type=parallel_type) # noqa: SIM210
|
||||
|
||||
if hasattr(layer, "bias") and layer.bias is not None:
|
||||
pad_zeros = (parallel_type == "row_parallel")
|
||||
if 'vision' in layer.prefix and is_br166_device():
|
||||
if (pad_zeros and layer.reduce_results):
|
||||
return
|
||||
layer.bias.data = _convert_to_numa_tensor_vit(
|
||||
layer.bias,
|
||||
envs.VLLM_BR_DEVICE_WARP_SIZE,
|
||||
"linear_bias",
|
||||
torch.float32,
|
||||
parallel_type=parallel_type,
|
||||
pad_zeros=pad_zeros)
|
||||
else:
|
||||
layer.bias.data = _convert_to_numa_tensor(
|
||||
layer.bias,
|
||||
envs.VLLM_BR_DEVICE_WARP_SIZE,
|
||||
"linear_bias",
|
||||
torch.float32,
|
||||
parallel_type=parallel_type,
|
||||
pad_zeros=pad_zeros)
|
||||
|
||||
|
||||
@patch_to(UnquantizedLinearMethod)
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
# numa weight is 3-dims
|
||||
if 'vision' in layer.prefix and is_br166_device():
|
||||
if len(layer.weight.shape) == 3:
|
||||
is_row = isinstance(layer, RowParallelLinear)
|
||||
output_size = (layer.output_size_per_partition if hasattr(
|
||||
layer, "output_size_per_partition") else layer.output_size)
|
||||
act_mode = "act_default"
|
||||
if isinstance(layer, MergedColumnParallelLinear) and not (hasattr(
|
||||
layer, "no_need_cross") and layer.no_need_cross):
|
||||
act_mode = "act_swiglu"
|
||||
output_size //= 2
|
||||
if bias is None or (is_row and layer.reduce_results):
|
||||
# return torch_br.br_matmul_infer(
|
||||
# x,
|
||||
# layer.weight,
|
||||
# bias=None,
|
||||
# output_w=output_size,
|
||||
# )
|
||||
return torch_br.br_fused_mlp_infer(x, [layer.weight],
|
||||
output_w=output_size,
|
||||
activation_mode=act_mode)
|
||||
else:
|
||||
return torch_br.br_matmul_infer(x, layer.weight, bias,
|
||||
output_size)
|
||||
supa_debug.set_enable_sublas_api(True)
|
||||
output = F.linear(x, layer.weight, bias)
|
||||
supa_debug.set_enable_sublas_api(False)
|
||||
return output
|
||||
if len(layer.weight.shape) == 3:
|
||||
output_size = (layer.output_size_per_partition if hasattr(
|
||||
layer, "output_size_per_partition") else layer.output_size)
|
||||
act_mode = "act_default"
|
||||
if isinstance(layer, MergedColumnParallelLinear) and not (hasattr(
|
||||
layer, "no_need_cross") and layer.no_need_cross):
|
||||
act_mode = "act_swiglu"
|
||||
output_size //= 2
|
||||
|
||||
bias = [bias] if bias is not None else None
|
||||
if isinstance(layer, RowParallelLinear):
|
||||
seq_len = x.shape[-2]
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
# TODO(CaoJun): This is WA, delete (16, 8) so that the test_vllm_model_accu_qwen25_72b_instruct can run through
|
||||
support_types = ((16, 4), (32, 2), (32, 4))
|
||||
# bypass tp8 and tp4pp2 allreduce
|
||||
pp_size = get_pp_group().world_size
|
||||
all_rank = tp_size * pp_size
|
||||
layer.reduce_results = not (
|
||||
all_rank <= envs.VLLM_BR_USE_FUSED_ALLREDUCE
|
||||
and seq_len <= envs.VLLM_BR_STATIC_MOE_DECODER_MAX_LEN and
|
||||
(envs.VLLM_BR_DEVICE_SPC_NUM, tp_size) in support_types)
|
||||
if layer.reduce_results:
|
||||
return torch_br.br_fused_mlp_infer(x, [layer.weight],
|
||||
output_w=output_size,
|
||||
bias=bias,
|
||||
activation_mode=act_mode)
|
||||
else:
|
||||
tp_rank = get_tp_group().rank_in_group
|
||||
global_rank = get_tp_group().rank
|
||||
rank_i = global_rank % tp_size
|
||||
assert rank_i == tp_rank
|
||||
return torch_br.supa_fused_linear_allreduce_opt(
|
||||
x, layer.weight, output_size, tp_rank, tp_size,
|
||||
global_rank, 0)
|
||||
else:
|
||||
return torch_br.br_fused_mlp_infer(x, [layer.weight],
|
||||
output_w=output_size,
|
||||
bias=bias,
|
||||
activation_mode=act_mode)
|
||||
supa_debug.set_enable_sublas_api(True)
|
||||
output = F.linear(x, layer.weight, bias)
|
||||
supa_debug.set_enable_sublas_api(False)
|
||||
return output
|
||||
|
||||
|
||||
@patch_to(LinearBase)
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
*,
|
||||
return_bias: bool = True,
|
||||
disable_tp: bool = False,
|
||||
):
|
||||
super(LinearBase, self).__init__()
|
||||
|
||||
# Keep input parameters
|
||||
self.input_size = input_size
|
||||
self.output_size = output_size
|
||||
self.skip_bias_add = skip_bias_add
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
self.params_dtype = params_dtype
|
||||
if quant_config is None:
|
||||
self.quant_method = UnquantizedLinearMethod()
|
||||
else:
|
||||
self.quant_method = quant_config.get_quant_method(self, prefix=prefix)
|
||||
self.return_bias = return_bias
|
||||
self.prefix = prefix
|
||||
self.tp_rank = (get_tensor_model_parallel_rank() if not disable_tp else 0)
|
||||
self.tp_size = (get_tensor_model_parallel_world_size()
|
||||
if not disable_tp else 1)
|
||||
|
||||
|
||||
@patch_to(RowParallelLinear)
|
||||
def forward(
|
||||
self, input_
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||
if envs.VLLM_BR_USE_CPU_ALL_REDUCE != 0 and not hasattr(
|
||||
self, "grandparent_pid"):
|
||||
self.grandparent_pid = get_grandparent_pid()
|
||||
if self.input_is_parallel:
|
||||
input_parallel = input_
|
||||
else:
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
splitted_input = split_tensor_along_last_dim(
|
||||
input_, num_partitions=self.tp_size)
|
||||
input_parallel = splitted_input[tp_rank].contiguous()
|
||||
|
||||
# Matrix multiply.
|
||||
assert self.quant_method is not None
|
||||
# Only fuse bias add into GEMM for rank 0 (this ensures that
|
||||
# bias will not get added more than once in TP>1 case)
|
||||
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
|
||||
output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_)
|
||||
if self.reduce_results and self.tp_size > 1:
|
||||
# CPU all reduce will be applied.
|
||||
if envs.VLLM_BR_USE_CPU_ALL_REDUCE != 0 and self.tp_size >= 4 and output_parallel.shape[
|
||||
1] <= 32:
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
output = torch_br.supa_allreduce_pcie_infer(
|
||||
output_parallel, tp_rank, self.tp_size, self.grandparent_pid)
|
||||
else:
|
||||
output = tensor_model_parallel_all_reduce(output_parallel)
|
||||
else:
|
||||
output = output_parallel
|
||||
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
|
||||
if not self.return_bias:
|
||||
return output
|
||||
return output, output_bias
|
||||
|
||||
|
||||
@patch_to(QKVParallelLinear)
|
||||
def weight_loader(self,
|
||||
param: Parameter,
|
||||
loaded_weight: torch.Tensor,
|
||||
loaded_shard_id: Optional[str] = None):
|
||||
|
||||
# Special case for GGUF
|
||||
# initialize GGUF param after we know the quantize type
|
||||
is_gguf_weight = getattr(param, "is_gguf_weight", False)
|
||||
is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
|
||||
if is_gguf_weight_type:
|
||||
idx_map = {"q": 0, "k": 1, "v": 2}
|
||||
if loaded_shard_id is not None:
|
||||
param.data[idx_map[loaded_shard_id]].copy_(loaded_weight)
|
||||
param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
|
||||
else:
|
||||
param.shard_weight_type = {
|
||||
k: loaded_weight.item()
|
||||
for k in idx_map
|
||||
}
|
||||
return
|
||||
|
||||
if is_gguf_weight:
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
shard_size = loaded_weight.size(output_dim) // tp_size
|
||||
start_idx = tp_rank * shard_size
|
||||
|
||||
if loaded_shard_id is not None:
|
||||
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
|
||||
shard_size)
|
||||
param.shard_id.append(loaded_shard_id)
|
||||
param.shard_id_map[loaded_shard_id] = len(param.data_container)
|
||||
param.data_container.append(loaded_weight)
|
||||
return
|
||||
|
||||
param_data = param.data
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
# Special case for AQLM codebooks.
|
||||
is_metadata = getattr(param, "is_metadata", False)
|
||||
|
||||
# Special case for per-tensor scales in fused case.
|
||||
needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
|
||||
|
||||
if loaded_shard_id is None:
|
||||
# Loaded weight is already fused on disk (qkv).
|
||||
# (e.g., Phi-3's qkv_proj).
|
||||
if output_dim is None:
|
||||
if needs_scalar_to_array:
|
||||
param_data, loaded_weight = adjust_scalar_to_fused_array(
|
||||
param_data, loaded_weight, 0)
|
||||
|
||||
assert param_data.shape == loaded_weight.shape
|
||||
param_data.copy_(loaded_weight)
|
||||
return
|
||||
shard_offsets = [
|
||||
# (shard_id, shard_offset, shard_size)
|
||||
("q", 0, self.total_num_heads * self.head_size),
|
||||
("k", self.total_num_heads * self.head_size,
|
||||
self.total_num_kv_heads * self.head_size),
|
||||
("v",
|
||||
(self.total_num_heads + self.total_num_kv_heads) * self.head_size,
|
||||
self.total_num_kv_heads * self.head_size),
|
||||
]
|
||||
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
|
||||
|
||||
packed_dim = getattr(param, "packed_dim", None)
|
||||
for shard_id, shard_offset, shard_size in shard_offsets:
|
||||
# Special case for Quantized Weights.
|
||||
# If quantized, we need to adjust the offset and size to account
|
||||
# for the packing.
|
||||
if packed_dim == output_dim:
|
||||
shard_size = shard_size // param.pack_factor
|
||||
shard_offset = shard_offset // param.pack_factor
|
||||
|
||||
# Special case for Marlin.
|
||||
shard_size, shard_offset = adjust_marlin_shard(
|
||||
param, shard_size, shard_offset)
|
||||
|
||||
if use_bitsandbytes_4bit:
|
||||
orig_qkv_offsets = {
|
||||
"q": (0, self.total_num_heads * self.head_size),
|
||||
"k": (self.total_num_heads * self.head_size,
|
||||
self.total_num_kv_heads * self.head_size),
|
||||
"v":
|
||||
((self.total_num_heads + self.total_num_kv_heads) *
|
||||
self.head_size, self.total_num_kv_heads * self.head_size),
|
||||
"total":
|
||||
((self.total_num_heads + 2 * self.total_num_kv_heads) *
|
||||
self.head_size, 0)
|
||||
}
|
||||
|
||||
shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
|
||||
param, orig_qkv_offsets, shard_id)
|
||||
|
||||
loaded_weight_shard = loaded_weight.narrow(output_dim,
|
||||
shard_offset,
|
||||
shard_size)
|
||||
self.weight_loader(param, loaded_weight_shard, shard_id)
|
||||
return
|
||||
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
assert loaded_shard_id in ["q", "k", "v"]
|
||||
|
||||
# If output dim is defined, use the default loading process.
|
||||
if output_dim is not None:
|
||||
if loaded_shard_id == "q":
|
||||
shard_offset = 0
|
||||
shard_size = self.num_heads * self.head_size
|
||||
elif loaded_shard_id == "k":
|
||||
shard_offset = self.num_heads * self.head_size
|
||||
shard_size = self.num_kv_heads * self.head_size
|
||||
elif loaded_shard_id == "v":
|
||||
shard_offset = (self.num_heads +
|
||||
self.num_kv_heads) * self.head_size
|
||||
shard_size = self.num_kv_heads * self.head_size
|
||||
# Special case for Quantized Weights.
|
||||
# If quantized, we need to adjust the offset and size to account
|
||||
# for the packing.
|
||||
packed_dim = getattr(param, "packed_dim", None)
|
||||
if packed_dim == output_dim:
|
||||
shard_size = shard_size // param.pack_factor
|
||||
shard_offset = shard_offset // param.pack_factor
|
||||
|
||||
# Special case for Marlin.
|
||||
shard_size, shard_offset = adjust_marlin_shard(
|
||||
param, shard_size, shard_offset)
|
||||
|
||||
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
|
||||
is_sharded_weight = getattr(param, "is_sharded_weight", False)
|
||||
# bitsandbytes loads the weights of the specific portion
|
||||
# no need to narrow
|
||||
is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
|
||||
|
||||
if use_bitsandbytes_4bit:
|
||||
orig_qkv_offsets = {
|
||||
"q": (0, self.num_heads * self.head_size),
|
||||
"k": (self.num_heads * self.head_size,
|
||||
self.num_kv_heads * self.head_size),
|
||||
"v": ((self.num_heads + self.num_kv_heads) * self.head_size,
|
||||
self.num_kv_heads * self.head_size),
|
||||
"total":
|
||||
((self.num_heads + 2 * self.num_kv_heads) * self.head_size, 0)
|
||||
}
|
||||
shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
|
||||
param, orig_qkv_offsets, loaded_shard_id)
|
||||
|
||||
if envs.VLLM_BR_DEVICE_SPC_NUM > 16:
|
||||
half_w = param_data.shape[output_dim] // 2
|
||||
param_data = (param_data.narrow(output_dim, shard_offset // 2,
|
||||
shard_size // 2),
|
||||
param_data.narrow(output_dim,
|
||||
shard_offset // 2 + half_w,
|
||||
shard_size // 2))
|
||||
else:
|
||||
param_data = param_data.narrow(output_dim, shard_offset,
|
||||
shard_size)
|
||||
|
||||
if loaded_shard_id == "q":
|
||||
shard_id = tp_rank
|
||||
else:
|
||||
shard_id = tp_rank // self.num_kv_head_replicas
|
||||
start_idx = shard_id * shard_size
|
||||
|
||||
if not is_sharded_weight:
|
||||
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
|
||||
shard_size)
|
||||
|
||||
# Special case for for AQLM codebooks.
|
||||
elif is_metadata:
|
||||
# metadata indicates fixed size concatenated along dim 0
|
||||
shard_size = loaded_weight.shape[0]
|
||||
shard_index = ["q", "k", "v"].index(loaded_shard_id)
|
||||
param_data = param_data.narrow(0, shard_index * shard_size, shard_size)
|
||||
# Special case for per-tensor scales in fused case.
|
||||
elif needs_scalar_to_array:
|
||||
param_data, loaded_weight = adjust_scalar_to_fused_array(
|
||||
param_data, loaded_weight, loaded_shard_id)
|
||||
else:
|
||||
ignore_warning = getattr(param, "ignore_warning", False)
|
||||
if not ignore_warning:
|
||||
logger.warning(
|
||||
"Loading a weight without `output_dim` attribute in "
|
||||
"QKVParallelLinear, assume the weight is the same "
|
||||
"for all partitions.")
|
||||
|
||||
if isinstance(param_data, tuple):
|
||||
half_w = loaded_weight.shape[output_dim] // 2
|
||||
param_data[0].copy_(loaded_weight.narrow(output_dim, 0, half_w))
|
||||
param_data[1].copy_(loaded_weight.narrow(output_dim, half_w, half_w))
|
||||
else:
|
||||
assert param_data.shape == loaded_weight.shape
|
||||
param_data.copy_(loaded_weight)
|
||||
72
vllm_br/model_executor/layers/logits_processor.py
Normal file
72
vllm_br/model_executor/layers/logits_processor.py
Normal file
@@ -0,0 +1,72 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch_br
|
||||
from fastcore.basics import patch_to
|
||||
|
||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm_br import envs
|
||||
|
||||
|
||||
# TODO(shouqing): need to open this patch when fix hang in mtp
|
||||
@patch_to(LogitsProcessor)
|
||||
def _get_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
lm_head: VocabParallelEmbedding,
|
||||
embedding_bias: Optional[torch.Tensor],
|
||||
) -> Optional[torch.Tensor]:
|
||||
# Get the logits for the next tokens.
|
||||
logits = lm_head.quant_method.apply(lm_head,
|
||||
hidden_states,
|
||||
bias=embedding_bias)
|
||||
|
||||
spc_num = envs.VLLM_BR_DEVICE_SPC_NUM
|
||||
if spc_num > 16:
|
||||
bb_input = torch_br._empty_ut_only(size=logits.shape,
|
||||
dtype=logits.dtype,
|
||||
is_numa=False,
|
||||
device=logits.device,
|
||||
tensor_type="colmajor")
|
||||
|
||||
# work around the hang in s1b copy to bb
|
||||
bb_input.copy_(logits)
|
||||
logits = bb_input
|
||||
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
logits_ = torch.zeros((logits.shape[0], logits.shape[-1] * tp_size),
|
||||
dtype=logits.dtype,
|
||||
device=logits.device)
|
||||
|
||||
start = logits.shape[-1] * tp_rank
|
||||
end = start + logits.shape[-1]
|
||||
logits_[:, start:end].copy_(logits)
|
||||
logits = tensor_model_parallel_all_reduce(logits_)
|
||||
|
||||
# Remove paddings in vocab (if any).
|
||||
if logits is not None:
|
||||
logits = logits[..., :self.org_vocab_size]
|
||||
return logits
|
||||
19
vllm_br/model_executor/layers/quantization/__init__.py
Normal file
19
vllm_br/model_executor/layers/quantization/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
from . import compressed_tensors, gptq
|
||||
|
||||
__all__ = ["gptq", 'compressed_tensors']
|
||||
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,18 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
from .compressed_tensors import *
|
||||
from .compressed_tensors_moe import *
|
||||
from .compressed_tensors_wNa16 import *
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,64 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
from functools import wraps
|
||||
from typing import Any, cast
|
||||
|
||||
from fastcore.basics import patch_to
|
||||
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import (
|
||||
CompressedTensorsConfig)
|
||||
|
||||
|
||||
@patch_to(CompressedTensorsConfig, cls_method=True)
|
||||
def from_config(cls, config: dict[str, Any]) -> "CompressedTensorsConfig":
|
||||
"""
|
||||
[PatchNote] add qkv_quantized param support
|
||||
"""
|
||||
ignore: list[str] = cast(list[str], config.get("ignore", []))
|
||||
quant_format = cast(str, config.get("format"))
|
||||
target_scheme_map = cls._quantization_scheme_map_from_config(config=config)
|
||||
sparsity_scheme_map, sparsity_ignore_list = cls._parse_sparsity_config(
|
||||
config=config)
|
||||
transform_config = config.get("transform_config")
|
||||
|
||||
qkv_quantized = cls.get_from_keys_or(config, ["qkv_quantized"],
|
||||
default=True)
|
||||
|
||||
return cls(target_scheme_map=target_scheme_map,
|
||||
ignore=ignore,
|
||||
quant_format=quant_format,
|
||||
sparsity_scheme_map=sparsity_scheme_map,
|
||||
sparsity_ignore_list=sparsity_ignore_list,
|
||||
config=config,
|
||||
transform_config=transform_config,
|
||||
qkv_quantized=qkv_quantized)
|
||||
|
||||
|
||||
def wrapper_CompressedTensorsConfig_init(fn):
|
||||
|
||||
@wraps(fn)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
qkv_quantized = kwargs.pop("qkv_quantized", True)
|
||||
fn(self, *args, **kwargs)
|
||||
|
||||
self.qkv_quantized = qkv_quantized
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
CompressedTensorsConfig.__init__ = wrapper_CompressedTensorsConfig_init(
|
||||
CompressedTensorsConfig.__init__) # noqa: E501
|
||||
@@ -0,0 +1,594 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch_br
|
||||
from compressed_tensors.compressors.quantized_compressors import (
|
||||
unpack_from_int32)
|
||||
from fastcore.basics import patch_to
|
||||
from torch_br.utils.tensor_methods import Sbp
|
||||
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group
|
||||
from vllm.logger import logger
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
|
||||
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import (
|
||||
WNA16_SUPPORTED_BITS, CompressedTensorsMoEMethod,
|
||||
CompressedTensorsWNA16MoEMethod, CompressionFormat)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm_br import envs
|
||||
from ...br_utils import (_convert_to_crossed_numa_tensor,
|
||||
_convert_to_numa_tensor, align_n, cross_weight_32)
|
||||
from ...fused_moe.supa_moe import fused_moe_quant_device, fused_moe_quant_dyn
|
||||
|
||||
|
||||
@patch_to(CompressedTensorsMoEMethod)
|
||||
def get_moe_method(
|
||||
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
|
||||
layer: torch.nn.Module,
|
||||
) -> "CompressedTensorsMoEMethod":
|
||||
"""NOTE:
|
||||
1. SUPA only supports CompressedTensorsWNA16MoEMethod without Marlin
|
||||
2. Only Linear targets are supported for MoE layers
|
||||
"""
|
||||
# TODO: @dsikka: refactor this to use schemes as other kernels
|
||||
# are supported + check if the layer is being ignored.
|
||||
keys = list(quant_config.target_scheme_map.keys())
|
||||
assert len(keys) > 0, ("No valid quant key!!!")
|
||||
# assert "Linear" in quant_config.target_scheme_map
|
||||
# [Patch]: Only Linear target is supported for MoE layers, for temporary compatibility, we change the key of target_scheme_map to the first one
|
||||
quant_config.target_scheme_map[
|
||||
"Linear"] = quant_config.target_scheme_map.pop(keys[0])
|
||||
target_key = "Linear"
|
||||
# target_key = keys[0] # normal only one key
|
||||
weight_quant = quant_config.target_scheme_map[target_key].get("weights")
|
||||
input_quant = quant_config.target_scheme_map[target_key].get(
|
||||
"input_activations")
|
||||
|
||||
if quant_config._is_wNa16_group_channel(weight_quant, input_quant):
|
||||
logger.info_once("Using CompressedTensorsWNA16MoEMethod")
|
||||
return CompressedTensorsWNA16MoEMethod(quant_config, layer.moe_config)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Unsupported FusedMoe scheme: {weight_quant}, {input_quant}")
|
||||
|
||||
|
||||
@patch_to(CompressedTensorsWNA16MoEMethod)
|
||||
def __init__(
|
||||
self,
|
||||
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
|
||||
moe: FusedMoEConfig,
|
||||
):
|
||||
super(CompressedTensorsWNA16MoEMethod, self).__init__(moe)
|
||||
self.quant_config = quant_config
|
||||
# TODO: @dsikka: refactor this to use schemes as other kernels
|
||||
# are supported + check if the layer is being ignored.
|
||||
config = self.quant_config.target_scheme_map["Linear"].get("weights")
|
||||
self.num_bits = config.num_bits
|
||||
self.packed_factor = 32 // config.num_bits
|
||||
self.strategy = config.strategy
|
||||
# channelwise is not supported by this kernel
|
||||
# [Patch]: SUPA use CompressedTensorsWNA16MoEMethod for both channel/group strategies
|
||||
# assert config.strategy == "group"
|
||||
self.group_size = config.group_size
|
||||
# grouped actorder isn't supported by this kernel
|
||||
# assert config.actorder != "group"
|
||||
assert config.symmetric, (
|
||||
"Only symmetric quantization is supported for MoE")
|
||||
|
||||
if not (self.quant_config.quant_format
|
||||
== CompressionFormat.pack_quantized.value
|
||||
and self.num_bits in WNA16_SUPPORTED_BITS):
|
||||
raise ValueError("For Fused MoE layers, only ",
|
||||
f"{CompressionFormat.pack_quantized.value} ",
|
||||
"is supported for the following bits: ",
|
||||
f"{WNA16_SUPPORTED_BITS}")
|
||||
|
||||
|
||||
@patch_to(CompressedTensorsWNA16MoEMethod)
|
||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||
hidden_size: int, intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||
|
||||
# Will transpose the loaded weight along the
|
||||
# intermediate and hidden dim sizes. Will
|
||||
# shard for TP along the transposed dims
|
||||
extra_weight_attrs.update({
|
||||
"is_transposed": True,
|
||||
"quant_method": self.strategy
|
||||
})
|
||||
w13_weight = torch.nn.Parameter(torch.empty(
|
||||
num_experts,
|
||||
hidden_size // self.packed_factor,
|
||||
2 * intermediate_size_per_partition,
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_weight_packed", w13_weight)
|
||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||
|
||||
w2_weight = torch.nn.Parameter(torch.empty(
|
||||
num_experts,
|
||||
intermediate_size_per_partition // self.packed_factor,
|
||||
hidden_size,
|
||||
dtype=torch.int32,
|
||||
device="cpu"),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_weight_packed", w2_weight)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
|
||||
w2_scales_size = intermediate_size_per_partition
|
||||
|
||||
if self.strategy == "channel":
|
||||
num_groups_w2 = num_groups_w13 = 1
|
||||
self.group_size = -1
|
||||
else:
|
||||
num_groups_w2 = w2_scales_size // self.group_size
|
||||
num_groups_w13 = hidden_size // self.group_size
|
||||
|
||||
w13_scale = torch.nn.Parameter(torch.ones(
|
||||
num_experts,
|
||||
num_groups_w13,
|
||||
2 * intermediate_size_per_partition,
|
||||
dtype=params_dtype,
|
||||
device="cpu",
|
||||
),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_weight_scale", w13_scale)
|
||||
set_weight_attrs(w13_scale, extra_weight_attrs)
|
||||
|
||||
w2_scale = torch.nn.Parameter(torch.ones(num_experts,
|
||||
num_groups_w2,
|
||||
hidden_size,
|
||||
dtype=params_dtype,
|
||||
device="cpu"),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_weight_scale", w2_scale)
|
||||
set_weight_attrs(w2_scale, extra_weight_attrs)
|
||||
set_weight_attrs(w2_scale, {"load_full_w2": False})
|
||||
|
||||
w2_weight_shape = torch.nn.Parameter(torch.empty(num_experts,
|
||||
2,
|
||||
device="cpu"),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_weight_shape", w2_weight_shape)
|
||||
set_weight_attrs(w2_weight_shape, extra_weight_attrs)
|
||||
w13_weight_shape = torch.nn.Parameter(torch.empty(num_experts,
|
||||
2,
|
||||
device="cpu"),
|
||||
requires_grad=False)
|
||||
|
||||
layer.register_parameter("w13_weight_shape", w13_weight_shape)
|
||||
set_weight_attrs(w13_weight_shape, extra_weight_attrs)
|
||||
|
||||
w13_g_idx = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
device="cpu",
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_weight_g_idx", w13_g_idx)
|
||||
set_weight_attrs(w13_g_idx, extra_weight_attrs)
|
||||
|
||||
w2_g_idx = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
intermediate_size_per_partition,
|
||||
device="cpu",
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_weight_g_idx", w2_g_idx)
|
||||
set_weight_attrs(w2_g_idx, extra_weight_attrs)
|
||||
|
||||
w13_g_idx_sort_indices = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
device="cpu",
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_g_idx_sort_indices", w13_g_idx_sort_indices)
|
||||
set_weight_attrs(w13_g_idx_sort_indices, extra_weight_attrs)
|
||||
|
||||
w2_g_idx_sort_indices = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
intermediate_size_per_partition,
|
||||
device="cpu",
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_g_idx_sort_indices", w2_g_idx_sort_indices)
|
||||
set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs)
|
||||
|
||||
layer.a13_scale = None
|
||||
layer.a2_scale = None
|
||||
|
||||
|
||||
@patch_to(CompressedTensorsWNA16MoEMethod)
|
||||
def process_weights_after_loading(self: CompressedTensorsWNA16MoEMethod,
|
||||
layer: FusedMoE) -> None:
|
||||
die_spc_num = envs.VLLM_BR_DEVICE_SPC_NUM
|
||||
die_num = 1 if die_spc_num <= 16 else 2
|
||||
spc_num = die_spc_num // die_num
|
||||
cur_device = torch.supa.current_device()
|
||||
is_dual_die = (die_spc_num > 16)
|
||||
|
||||
if self.num_bits == 8:
|
||||
# NOTE: w13_weight
|
||||
# after _load_w13, w13_weight is a colparallel weight, shape
|
||||
# [num_experts, hidden_size // 4, 2 * intermediate_size_per_partition] INT32
|
||||
# for SUPA, transform it to a NUMA colmajor weight, shape
|
||||
# [spc_num * num_experts, wk, wn_block] INT8
|
||||
wk = layer.hidden_size
|
||||
wn = layer.intermediate_size_per_partition * 2
|
||||
align_size = 64
|
||||
wn_block = align_n(wn // die_num,
|
||||
align_size=align_size,
|
||||
spc_num=spc_num)
|
||||
|
||||
supa_w13_weight_packed = torch_br._empty_ut_only(
|
||||
size=(die_spc_num * layer.local_num_experts, wk, wn_block),
|
||||
dtype=torch.int8,
|
||||
is_numa=True,
|
||||
device=cur_device,
|
||||
tensor_type="colmajor",
|
||||
axis=0,
|
||||
sbp="SS" if is_dual_die else None)
|
||||
|
||||
for expert_id in range(layer.local_num_experts):
|
||||
expert_w13 = layer.w13_weight_packed[
|
||||
expert_id] # hidden_size // 4, 2 * intermediate_size_per_partition
|
||||
expert_1, expert_3 = expert_w13.chunk(
|
||||
2, dim=1) # each is a packed int4 weight
|
||||
|
||||
unpacked_expert_1 = unpack_from_int32(
|
||||
expert_1, self.num_bits,
|
||||
torch.Size(
|
||||
[layer.hidden_size,
|
||||
layer.intermediate_size_per_partition]), 0)
|
||||
|
||||
unpacked_expert_3 = unpack_from_int32(
|
||||
expert_3, self.num_bits,
|
||||
torch.Size(
|
||||
[layer.hidden_size,
|
||||
layer.intermediate_size_per_partition]), 0)
|
||||
|
||||
pad_expert_w13 = _convert_to_crossed_numa_tensor(
|
||||
unpacked_expert_1,
|
||||
unpacked_expert_3,
|
||||
die_spc_num,
|
||||
dim=1,
|
||||
need_pad=True,
|
||||
layout='COLMAJOR',
|
||||
do_transpose=False)
|
||||
hw_size = pad_expert_w13.shape[-2] * pad_expert_w13.shape[-1]
|
||||
narrow_data = supa_w13_weight_packed.view_as_usharp(
|
||||
"COLMAJOR", pad_expert_w13.shape, Sbp.ss(0),
|
||||
expert_id * hw_size)
|
||||
narrow_data.copy_(pad_expert_w13)
|
||||
|
||||
layer.w13_weight_packed.data = supa_w13_weight_packed
|
||||
|
||||
# NOTE: w13_scale
|
||||
# after _load_w13, w13_weight is a colparallel weight, shape
|
||||
# S8: [num_experts, 1, 2 * intermediate_size_per_partition]
|
||||
# for SUPA, transform it to a NUMA colmajor weight, shape
|
||||
# [num_experts, wn]
|
||||
supa_w13_scales = torch_br._empty_ut_only(
|
||||
size=(layer.local_num_experts, wn),
|
||||
dtype=torch.float32,
|
||||
is_numa=False,
|
||||
device=cur_device,
|
||||
tensor_type="linear_bias",
|
||||
sbp="BB" if is_dual_die else None)
|
||||
|
||||
for expert_id in range(layer.local_num_experts):
|
||||
expert_w13_scales = layer.w13_weight_scale[expert_id]
|
||||
expert_1_scale, expert_3_scale = expert_w13_scales.chunk(
|
||||
2, dim=1) # each is a packed int4 weight
|
||||
crossed_expert_w13_scales = cross_weight_32(
|
||||
expert_1_scale.squeeze(),
|
||||
expert_3_scale.squeeze(),
|
||||
die_spc_num,
|
||||
dim=0,
|
||||
need_pad=False,
|
||||
)
|
||||
narrow_data = supa_w13_scales[expert_id]
|
||||
narrow_data.copy_(crossed_expert_w13_scales)
|
||||
|
||||
layer.w13_weight_scale.data = supa_w13_scales
|
||||
|
||||
# NOTE: w2_weight
|
||||
# after _load_w2, w2_weight is a colparallel weight, shape
|
||||
# [num_experts, intermediate_size_per_partition // 4, hidden_size] INT32
|
||||
# for SUPA, transform it to a NUMA colmajor weight, shape
|
||||
# [spc_num * num_experts, wk, wn_block] INT8
|
||||
wk = layer.intermediate_size_per_partition
|
||||
wn = layer.hidden_size
|
||||
align_size = 32
|
||||
wn_block = align_n(wn, align_size=align_size, spc_num=spc_num)
|
||||
|
||||
supa_w2_weight_packed = torch_br._empty_ut_only(
|
||||
size=(die_spc_num * layer.local_num_experts, wk // die_num,
|
||||
wn_block),
|
||||
dtype=torch.int8,
|
||||
is_numa=True,
|
||||
device=cur_device,
|
||||
tensor_type="colmajor",
|
||||
axis=0,
|
||||
sbp="SS" if is_dual_die else None)
|
||||
|
||||
for expert_id in range(layer.local_num_experts):
|
||||
expert_w2 = layer.w2_weight_packed[expert_id]
|
||||
|
||||
unpacked_expert_2 = unpack_from_int32(
|
||||
expert_w2, self.num_bits,
|
||||
torch.Size(
|
||||
[layer.intermediate_size_per_partition,
|
||||
layer.hidden_size]), 0)
|
||||
|
||||
pad_expert_w2 = _convert_to_numa_tensor(
|
||||
unpacked_expert_2,
|
||||
align_size,
|
||||
'COLMAJOR',
|
||||
expert_w2.dtype,
|
||||
do_transpose=False,
|
||||
parallel_type="row_parallel")
|
||||
pad_expert_w2_shape = pad_expert_w2.shape
|
||||
hw_size = pad_expert_w2_shape[-2] * pad_expert_w2_shape[-1]
|
||||
narrow_data = supa_w2_weight_packed.view_as_usharp(
|
||||
"COLMAJOR", pad_expert_w2_shape, Sbp.ss(0),
|
||||
expert_id * hw_size)
|
||||
narrow_data.copy_(pad_expert_w2)
|
||||
|
||||
layer.w2_weight_packed.data = supa_w2_weight_packed
|
||||
|
||||
# NOTE: w2_scale
|
||||
# after _load_w2, w2_weight is a colparallel weight, shape
|
||||
# S8: [num_experts, 1, hidden_size]
|
||||
# for SUPA, transform it to a NUMA colmajor weight, shape
|
||||
# [num_experts, wn]
|
||||
supa_w2_scales = torch_br._empty_ut_only(size=(layer.local_num_experts,
|
||||
wn),
|
||||
dtype=torch.float32,
|
||||
is_numa=False,
|
||||
device=cur_device,
|
||||
tensor_type="linear_bias")
|
||||
|
||||
for expert_id in range(layer.local_num_experts):
|
||||
expert_w2 = layer.w2_weight_scale[expert_id]
|
||||
narrow_data = supa_w2_scales[expert_id::layer.local_num_experts]
|
||||
narrow_data.copy_(expert_w2)
|
||||
|
||||
layer.w2_weight_scale.data = supa_w2_scales
|
||||
|
||||
elif self.num_bits == 4:
|
||||
# NOTE: w13_weight
|
||||
# after _load_w13, w13_weight is a colparallel weight, shape
|
||||
# [num_experts, hidden_size // 8, 2 * intermediate_size_per_partition] INT32
|
||||
# for SUPA, transform it to a NUMA colmajor weight, shape
|
||||
# [spc_num * num_experts, wk, wn_block] INT32
|
||||
wk = layer.hidden_size // 8
|
||||
wn = layer.intermediate_size_per_partition * 2
|
||||
wn_block = align_n(wn, align_size=64, spc_num=spc_num)
|
||||
|
||||
supa_w13_weight_packed = torch_br._empty_ut_only(
|
||||
size=(spc_num * layer.local_num_experts, wk, wn_block),
|
||||
dtype=torch.int32,
|
||||
is_numa=True,
|
||||
device=cur_device,
|
||||
tensor_type="colmajor")
|
||||
|
||||
for expert_id in range(layer.local_num_experts):
|
||||
expert_w13 = layer.w13_weight_packed[
|
||||
expert_id] # hidden_size // 4, 2 * intermediate_size_per_partition
|
||||
expert_1, expert_3 = expert_w13.chunk(
|
||||
2, dim=0) # each is a packed int4 weight
|
||||
|
||||
pad_expert_w13 = _convert_to_crossed_numa_tensor(expert_1,
|
||||
expert_3,
|
||||
spc_num,
|
||||
dim=1,
|
||||
need_pad=True,
|
||||
layout='COLMAJOR',
|
||||
do_transpose=True)
|
||||
hw_size = pad_expert_w13.shape[-2] * pad_expert_w13.shape[-1]
|
||||
narrow_data = supa_w13_weight_packed.view_as_usharp(
|
||||
"COLMAJOR", pad_expert_w13.shape, Sbp.ss(0),
|
||||
expert_id * hw_size)
|
||||
narrow_data.copy_(pad_expert_w13)
|
||||
|
||||
layer.w13_weight_packed.data = supa_w13_weight_packed
|
||||
|
||||
# NOTE: w13_scale
|
||||
# after _load_w13, w13_weight is a colparallel weight, shape
|
||||
# S4: [num_experts, hidden_size // 128, 2 * intermediate_size_per_partition]
|
||||
# for SUPA, transform it to a NUMA colmajor weight, shape
|
||||
# [num_experts, group_nums, wn]
|
||||
supa_w13_scales = torch_br._empty_ut_only(
|
||||
size=(layer.local_num_experts,
|
||||
layer.hidden_size // self.group_size, wn),
|
||||
dtype=torch.float32,
|
||||
is_numa=False,
|
||||
device=cur_device,
|
||||
tensor_type="colmajor")
|
||||
|
||||
for expert_id in range(layer.local_num_experts):
|
||||
expert_w13_scales = layer.w13_weight_scale[expert_id]
|
||||
expert_1_scale, expert_3_scale = expert_w13_scales.chunk(
|
||||
2, dim=0) # each is a packed int4 weight
|
||||
crossed_expert_w13_scales = cross_weight_32(
|
||||
expert_1_scale,
|
||||
expert_3_scale,
|
||||
spc_num,
|
||||
dim=1,
|
||||
need_pad=False,
|
||||
)
|
||||
narrow_data = supa_w13_scales[expert_id]
|
||||
narrow_data.copy_(crossed_expert_w13_scales)
|
||||
|
||||
layer.w13_weight_scale.data = supa_w13_scales
|
||||
|
||||
# NOTE: w2_weight
|
||||
# after _load_w2, w2_weight is a colparallel weight, shape
|
||||
# [num_experts, intermediate_size_per_partition // 8, hidden_size] INT32
|
||||
# for SUPA, transform it to a NUMA colmajor weight, shape
|
||||
# [spc_num * num_experts, wk, wn_block] INT32
|
||||
wk = layer.intermediate_size_per_partition // 8
|
||||
wn = layer.hidden_size
|
||||
wn_block = align_n(wn, align_size=32, spc_num=spc_num)
|
||||
|
||||
supa_w2_weight_packed = torch_br._empty_ut_only(
|
||||
size=(spc_num * layer.local_num_experts, wk, wn_block),
|
||||
dtype=torch.int32,
|
||||
is_numa=True,
|
||||
device=cur_device,
|
||||
tensor_type="colmajor")
|
||||
|
||||
for expert_id in range(layer.local_num_experts):
|
||||
expert_w2 = layer.w2_weight_packed[expert_id]
|
||||
pad_expert_w2 = _convert_to_numa_tensor(expert_w2,
|
||||
spc_num,
|
||||
'COLMAJOR',
|
||||
expert_w2.dtype,
|
||||
do_transpose=True)
|
||||
pad_expert_w2_shape = pad_expert_w2.shape
|
||||
hw_size = pad_expert_w2_shape[-2] * pad_expert_w2_shape[-1]
|
||||
narrow_data = supa_w2_weight_packed.view_as_usharp(
|
||||
"COLMAJOR", pad_expert_w2_shape, Sbp.ss(0),
|
||||
expert_id * hw_size)
|
||||
narrow_data.copy_(pad_expert_w2)
|
||||
|
||||
layer.w2_weight_packed.data = supa_w2_weight_packed
|
||||
|
||||
# NOTE: w2_scale
|
||||
# after _load_w2, w2_weight is a colparallel weight, shape
|
||||
# S4: [num_experts, intermediate_size_per_partition // 128, hidden_size]
|
||||
# for SUPA, transform it to a NUMA colmajor weight, shape
|
||||
# [num_experts, group_nums, wn]
|
||||
supa_w2_scales = torch_br._empty_ut_only(
|
||||
size=(layer.local_num_experts,
|
||||
layer.intermediate_size_per_partition // self.group_size,
|
||||
wn),
|
||||
dtype=torch.float32,
|
||||
is_numa=False,
|
||||
device=cur_device,
|
||||
tensor_type="colmajor")
|
||||
|
||||
for expert_id in range(layer.local_num_experts):
|
||||
expert_w2 = layer.w2_weight_scale[expert_id]
|
||||
narrow_data = supa_w2_scales[expert_id::layer.local_num_experts]
|
||||
narrow_data.copy_(expert_w2)
|
||||
|
||||
layer.w2_weight_scale.data = supa_w2_scales
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported num_bits: {self.num_bits}. Only 4 and 8 are supported."
|
||||
)
|
||||
|
||||
# remove other CompressedTensorsWNA16MoEMethod registied buffer to reduce memory usage
|
||||
layer.w13_weight_shape = None
|
||||
layer.w13_weight_g_idx = None
|
||||
layer.w13_g_idx_sort_indices = None
|
||||
|
||||
layer.w2_weight_shape = None
|
||||
layer.w2_weight_g_idx = None
|
||||
layer.w2_g_idx_sort_indices = None
|
||||
|
||||
|
||||
@patch_to(CompressedTensorsWNA16MoEMethod)
|
||||
def apply(
|
||||
self,
|
||||
layer: FusedMoE,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
) -> torch.Tensor:
|
||||
b_seq = x.shape[0]
|
||||
gating_weight, shared_gate_up_weight, shared_down_weight = router_logits
|
||||
if b_seq > envs.VLLM_BR_STATIC_MOE_DECODER_MAX_LEN:
|
||||
return fused_moe_quant_dyn(
|
||||
x,
|
||||
shared_gate_up_weight,
|
||||
shared_down_weight,
|
||||
layer.w13_weight_packed,
|
||||
layer.w2_weight_packed,
|
||||
layer.w13_weight_scale,
|
||||
layer.w2_weight_scale,
|
||||
gating_weight,
|
||||
top_k,
|
||||
layer.intermediate_size_per_partition,
|
||||
renormalize=renormalize,
|
||||
inplace=True,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
ep_rank=layer.ep_rank,
|
||||
ep_size=layer.ep_size,
|
||||
)
|
||||
else:
|
||||
return fused_moe_quant_device(
|
||||
x,
|
||||
shared_gate_up_weight,
|
||||
shared_down_weight,
|
||||
layer.w13_weight_packed,
|
||||
layer.w2_weight_packed,
|
||||
layer.w13_weight_scale,
|
||||
layer.w2_weight_scale,
|
||||
gating_weight,
|
||||
top_k,
|
||||
layer.intermediate_size_per_partition,
|
||||
renormalize=renormalize,
|
||||
inplace=True,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
tp_rank=get_tp_group().rank_in_group,
|
||||
global_rank=get_tp_group().rank,
|
||||
tp_size=get_tensor_model_parallel_world_size(),
|
||||
ep_rank=layer.ep_rank,
|
||||
ep_size=layer.ep_size,
|
||||
)
|
||||
@@ -0,0 +1,267 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch_br
|
||||
from compressed_tensors.compressors.quantized_compressors import (
|
||||
unpack_from_int32)
|
||||
from fastcore.basics import patch_to
|
||||
|
||||
from vllm.distributed import (get_pipeline_model_parallel_group,
|
||||
get_tensor_model_parallel_world_size,
|
||||
get_tp_group)
|
||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsWNA16)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
marlin_repeat_scales_on_all_ranks)
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||
ChannelQuantScaleParameter,
|
||||
GroupQuantScaleParameter,
|
||||
PackedColumnParameter,
|
||||
PackedvLLMParameter,
|
||||
RowvLLMParameter)
|
||||
# yapf: enable
|
||||
from vllm_br import envs
|
||||
from ...br_utils import _convert_to_numa_tensor
|
||||
|
||||
|
||||
@patch_to(CompressedTensorsWNA16)
|
||||
def create_weights(self, layer: torch.nn.Module, input_size: int,
|
||||
input_size_per_partition: int, output_size: int,
|
||||
output_partition_sizes: list[int],
|
||||
params_dtype: torch.dtype, weight_loader: Callable,
|
||||
**kwargs):
|
||||
self.output_size_per_partition = sum(output_partition_sizes)
|
||||
group_size = self.group_size if self.group_size != -1 else input_size
|
||||
row_parallel = (input_size != input_size_per_partition)
|
||||
partition_scales = not marlin_repeat_scales_on_all_ranks(
|
||||
self.has_g_idx, self.group_size, row_parallel)
|
||||
scales_and_zp_size = input_size // group_size
|
||||
if partition_scales:
|
||||
assert input_size_per_partition % group_size == 0
|
||||
scales_and_zp_size = input_size_per_partition // group_size
|
||||
weight = PackedvLLMParameter(
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
packed_factor=self.pack_factor,
|
||||
packed_dim=1,
|
||||
data=torch.empty(self.output_size_per_partition,
|
||||
input_size_per_partition // self.pack_factor,
|
||||
dtype=torch.int32,
|
||||
device="cpu"))
|
||||
|
||||
weight_scale_args = {
|
||||
"weight_loader":
|
||||
weight_loader,
|
||||
"data":
|
||||
torch.empty(
|
||||
self.output_size_per_partition,
|
||||
scales_and_zp_size,
|
||||
device="cpu",
|
||||
dtype=params_dtype,
|
||||
)
|
||||
}
|
||||
zeros_args = {
|
||||
"weight_loader":
|
||||
weight_loader,
|
||||
"data":
|
||||
torch.zeros(
|
||||
self.output_size_per_partition // self.pack_factor,
|
||||
scales_and_zp_size,
|
||||
device="cpu",
|
||||
dtype=torch.int32,
|
||||
)
|
||||
}
|
||||
if not partition_scales:
|
||||
weight_scale = ChannelQuantScaleParameter(output_dim=0,
|
||||
**weight_scale_args)
|
||||
if not self.symmetric:
|
||||
qzeros = PackedColumnParameter(output_dim=0,
|
||||
packed_dim=0,
|
||||
packed_factor=self.pack_factor,
|
||||
**zeros_args)
|
||||
else:
|
||||
weight_scale = GroupQuantScaleParameter(output_dim=0,
|
||||
input_dim=1,
|
||||
**weight_scale_args)
|
||||
if not self.symmetric:
|
||||
qzeros = PackedvLLMParameter(input_dim=1,
|
||||
output_dim=0,
|
||||
packed_dim=0,
|
||||
packed_factor=self.pack_factor,
|
||||
**zeros_args)
|
||||
# A 2D array defining the original shape of the weights
|
||||
# before packing
|
||||
weight_shape = BasevLLMParameter(data=torch.empty(2,
|
||||
dtype=torch.int64,
|
||||
device="cpu"),
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("weight_packed", weight)
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
layer.register_parameter("weight_shape", weight_shape)
|
||||
|
||||
if not self.symmetric:
|
||||
layer.register_parameter("weight_zero_point", qzeros)
|
||||
# group index (for activation reordering)
|
||||
if self.has_g_idx:
|
||||
weight_g_idx = RowvLLMParameter(data=torch.empty(
|
||||
input_size_per_partition,
|
||||
device="cpu",
|
||||
dtype=torch.int32,
|
||||
),
|
||||
input_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("weight_g_idx", weight_g_idx)
|
||||
self.input_size_per_partition = input_size_per_partition
|
||||
|
||||
|
||||
@patch_to(CompressedTensorsWNA16)
|
||||
def process_weights_after_loading(self: CompressedTensorsWNA16,
|
||||
layer: torch.nn.Module) -> None:
|
||||
# spc_num = envs.VLLM_BR_DEVICE_SPC_NUM
|
||||
# cur_device = torch.supa.current_device()
|
||||
self.num_bits = 32 // self.pack_factor
|
||||
layer.weight_packed.data = unpack_from_int32(
|
||||
layer.weight_packed.data, self.num_bits,
|
||||
torch.Size(
|
||||
[self.output_size_per_partition, self.input_size_per_partition]),
|
||||
1)
|
||||
layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data,
|
||||
requires_grad=False)
|
||||
|
||||
br_scales = layer.weight_scale.data.to(torch.float32)
|
||||
layer.weight_scale.data = br_scales
|
||||
|
||||
do_transpose = True
|
||||
parallel_type = "col_parallel"
|
||||
match layer:
|
||||
case RowParallelLinear():
|
||||
parallel_type = "row_parallel"
|
||||
case _:
|
||||
pass
|
||||
|
||||
if hasattr(layer, 'weight_packed') and len(layer.weight_packed.shape) == 2:
|
||||
weight_packed = layer.weight_packed.data
|
||||
layer.weight_packed.data = _convert_to_numa_tensor(
|
||||
weight_packed,
|
||||
envs.VLLM_BR_DEVICE_WARP_SIZE,
|
||||
"colmajor",
|
||||
torch.int8,
|
||||
do_transpose=do_transpose,
|
||||
wk=(weight_packed.shape[1]
|
||||
if do_transpose else weight_packed.shape[0]),
|
||||
wn=(weight_packed.shape[0]
|
||||
if do_transpose else weight_packed.shape[1]),
|
||||
parallel_type=parallel_type) # noqa: SIM210
|
||||
|
||||
if hasattr(layer, 'weight_scale') and layer.weight_scale is not None:
|
||||
pad_zeros = False
|
||||
layer.weight_scale.data = _convert_to_numa_tensor(
|
||||
layer.weight_scale.data.T,
|
||||
envs.VLLM_BR_DEVICE_WARP_SIZE,
|
||||
"linear_bias",
|
||||
torch.float32,
|
||||
parallel_type=parallel_type,
|
||||
pad_zeros=pad_zeros)
|
||||
|
||||
if hasattr(layer, 'bias') and layer.bias is not None:
|
||||
pad_zeros = (parallel_type == "row_parallel")
|
||||
layer.bias.data = _convert_to_numa_tensor(
|
||||
layer.bias.data,
|
||||
envs.VLLM_BR_DEVICE_WARP_SIZE,
|
||||
"linear_bias",
|
||||
torch.float32,
|
||||
parallel_type=parallel_type,
|
||||
pad_zeros=pad_zeros)
|
||||
|
||||
|
||||
@patch_to(CompressedTensorsWNA16)
|
||||
def apply_weights(self: CompressedTensorsWNA16,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
# numa weight is 3-dims
|
||||
|
||||
if len(layer.weight_packed.shape) == 3:
|
||||
output_size = (layer.output_size_per_partition if hasattr(
|
||||
layer, "output_size_per_partition") else layer.output_size)
|
||||
act_mode = "act_default"
|
||||
if isinstance(layer, MergedColumnParallelLinear):
|
||||
act_mode = "act_swiglu"
|
||||
output_size //= 2
|
||||
|
||||
if isinstance(layer, RowParallelLinear):
|
||||
seq_len = x.shape[-2]
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
# bypass tp8 and tp4pp2 allreduce
|
||||
pp_size = get_pipeline_model_parallel_group().world_size
|
||||
all_rank = tp_size * pp_size
|
||||
support_types = ((16, 4), (16, 8), (32, 2), (32, 4))
|
||||
layer.reduce_results = not (
|
||||
all_rank <= envs.VLLM_BR_USE_FUSED_ALLREDUCE
|
||||
and seq_len <= envs.VLLM_BR_STATIC_MOE_DECODER_MAX_LEN and
|
||||
(envs.VLLM_BR_DEVICE_SPC_NUM, tp_size) in support_types)
|
||||
|
||||
if layer.reduce_results:
|
||||
return torch_br.br_fused_mlp_infer(
|
||||
x, [layer.weight_packed.data],
|
||||
output_w=output_size,
|
||||
scales=[layer.weight_scale.data]
|
||||
if layer.weight_scale.data is not None else None,
|
||||
bias=[bias] if bias is not None else None,
|
||||
activaion_mode=act_mode)
|
||||
else:
|
||||
tp_rank = get_tp_group().rank_in_group
|
||||
global_rank = get_tp_group().rank
|
||||
rank_i = global_rank % tp_size
|
||||
assert rank_i == tp_rank
|
||||
return torch_br.supa_fused_linear_allreduce_opt(
|
||||
x,
|
||||
layer.weight_packed.data,
|
||||
output_size,
|
||||
tp_rank,
|
||||
tp_size,
|
||||
global_rank,
|
||||
0,
|
||||
scales=layer.weight_scale.data,
|
||||
bias=bias,
|
||||
act_mode=act_mode)
|
||||
else:
|
||||
return torch_br.br_fused_mlp_infer(
|
||||
x, [layer.weight_packed.data],
|
||||
output_w=output_size,
|
||||
scales=[layer.weight_scale.data]
|
||||
if layer.weight_scale.data is not None else None,
|
||||
bias=[bias] if bias is not None else None,
|
||||
activation_mode=act_mode)
|
||||
|
||||
xn = x.shape[0]
|
||||
xh = x.shape[1]
|
||||
ww = layer.weight_packed.shape[1]
|
||||
# TODO, hard code to skip dry_run stage
|
||||
if xh >= 4096:
|
||||
return torch.ones((xn, xh, ww), dtype=x.dtype, device=x.device)
|
||||
return torch_br.sudnn_qmatmul_infer(x,
|
||||
layer.weight_packed,
|
||||
layer.weight_scale,
|
||||
bias=bias)
|
||||
@@ -0,0 +1,34 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def get_compressed_tensors_cache_scale(name: str) -> Optional[str]:
|
||||
"""
|
||||
Check whether the param name matches the format for k/v cache scales
|
||||
in compressed-tensors. If this is the case, return its equivalent
|
||||
param name expected by vLLM
|
||||
|
||||
:param name: param name
|
||||
:return: matching param name for KV cache scale in vLLM
|
||||
"""
|
||||
if name.endswith(".output_scale") and ".k_proj" in name:
|
||||
return name.replace(".k_proj.output_scale", ".attn.k_scale")
|
||||
if name.endswith(".output_scale") and ".v_proj" in name:
|
||||
return name.replace(".v_proj.output_scale", ".attn.v_scale")
|
||||
# If no matches, return None
|
||||
return None
|
||||
244
vllm_br/model_executor/layers/quantization/gptq.py
Normal file
244
vllm_br/model_executor/layers/quantization/gptq.py
Normal file
@@ -0,0 +1,244 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
from functools import wraps
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
import torch_br
|
||||
from fastcore.basics import patch_to
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm.distributed import (get_pipeline_model_parallel_group,
|
||||
get_tensor_model_parallel_world_size,
|
||||
get_tp_group)
|
||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.quantization.gptq import (GPTQConfig,
|
||||
GPTQLinearMethod)
|
||||
from vllm.model_executor.layers.quantization.utils.gptq_utils import (
|
||||
get_linear_quant_method)
|
||||
from vllm_br import envs
|
||||
from ..br_utils import _br_qweight_cvt, _convert_to_numa_tensor
|
||||
from ..linear import (process_weights_MergedColumnParallelLinear,
|
||||
process_weights_QuantMergedColumnParallelLinear,
|
||||
process_weights_ReplicatedLinear)
|
||||
|
||||
|
||||
@patch_to(GPTQConfig, cls_method=True)
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.bfloat16]
|
||||
|
||||
|
||||
@patch_to(GPTQConfig)
|
||||
def get_quant_method(self: GPTQConfig, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["GPTQLinearMethod"]:
|
||||
quant_method = get_linear_quant_method(self, layer, prefix,
|
||||
GPTQLinearMethod)
|
||||
|
||||
return quant_method
|
||||
|
||||
|
||||
@patch_to(GPTQConfig, cls_method=True)
|
||||
def from_config(cls, config: Dict[str, Any]) -> GPTQConfig:
|
||||
"""
|
||||
[PatchNote] add qkv_quantized param support
|
||||
"""
|
||||
dynamic = cls.get_from_keys_or(config, ["dynamic"], default={})
|
||||
dynamic = {} if dynamic is None else dynamic
|
||||
|
||||
weight_bits = cls.get_from_keys(config, ["bits"])
|
||||
group_size = cls.get_from_keys(config, ["group_size"])
|
||||
desc_act = cls.get_from_keys(config, ["desc_act"])
|
||||
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
|
||||
default=False)
|
||||
autoround_version = cls.get_from_keys_or(config, ["autoround_version"],
|
||||
default="")
|
||||
modules_in_block_to_quantize = cls.get_from_keys_or(
|
||||
config, ["modules_in_block_to_quantize"], default=None)
|
||||
qkv_quantized = cls.get_from_keys_or(config, ["qkv_quantized"],
|
||||
default=True)
|
||||
return cls(weight_bits=weight_bits,
|
||||
group_size=group_size,
|
||||
desc_act=desc_act,
|
||||
lm_head_quantized=lm_head_quantized,
|
||||
dynamic=dynamic,
|
||||
autoround_version=autoround_version,
|
||||
modules_in_block_to_quantize=modules_in_block_to_quantize,
|
||||
qkv_quantized=qkv_quantized)
|
||||
|
||||
|
||||
def wrapper_GPTQConfig_init(fn):
|
||||
|
||||
@wraps(fn)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
qkv_quantized = kwargs.pop("qkv_quantized", True)
|
||||
fn(self, *args, **kwargs)
|
||||
|
||||
self.qkv_quantized = qkv_quantized
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
GPTQConfig.__init__ = wrapper_GPTQConfig_init(
|
||||
GPTQConfig.__init__) # noqa: E501
|
||||
|
||||
|
||||
@patch_to(GPTQLinearMethod)
|
||||
def process_weights_after_loading(self: GPTQLinearMethod,
|
||||
layer: torch.nn.Module) -> None:
|
||||
still_need_process = True
|
||||
merge_col_quant = False
|
||||
# NOTE: all process_weights func should done before process_weights_after_loading
|
||||
parallel_type = "col_parallel"
|
||||
match layer:
|
||||
case ReplicatedLinear():
|
||||
process_weights_ReplicatedLinear(layer)
|
||||
still_need_process = layer.output_size == 64 or layer.output_size == 256
|
||||
case MergedColumnParallelLinear():
|
||||
if hasattr(layer, "qweight"):
|
||||
merge_col_quant = True
|
||||
else:
|
||||
process_weights_MergedColumnParallelLinear(layer)
|
||||
still_need_process = False
|
||||
case RowParallelLinear():
|
||||
parallel_type = "row_parallel"
|
||||
|
||||
case _:
|
||||
pass
|
||||
|
||||
# NOTE: if use exllama, br gptq needs similar treatment
|
||||
# exllama needs to shuffle the weight after the weight is loaded
|
||||
# here we do the shuffle on first forward pass
|
||||
if layer.qweight.dtype == torch.int32:
|
||||
input_size = layer.input_size_per_partition if hasattr(
|
||||
layer, 'input_size_per_partition') else layer.input_size
|
||||
output_size = layer.output_size_per_partition if hasattr(
|
||||
layer, 'output_size_per_partition') else layer.output_size
|
||||
br_qweight = _br_qweight_cvt(self, layer.qweight, layer.qzeros,
|
||||
input_size, output_size)
|
||||
layer.qweight.data = br_qweight
|
||||
if merge_col_quant:
|
||||
process_weights_QuantMergedColumnParallelLinear(layer)
|
||||
still_need_process = False
|
||||
|
||||
br_scales = layer.scales.to(torch.float32)
|
||||
layer.scales.data = br_scales
|
||||
|
||||
# for torch.compile
|
||||
layer.qzeros = Parameter(layer.qzeros.data, requires_grad=False)
|
||||
layer.qweight = Parameter(layer.qweight.data, requires_grad=False)
|
||||
layer.g_idx = Parameter(layer.g_idx.data, requires_grad=False)
|
||||
layer.scales = Parameter(layer.scales.data, requires_grad=False)
|
||||
|
||||
if not still_need_process or self.weight_type != "NUMA":
|
||||
return
|
||||
|
||||
if hasattr(layer, 'qweight') and len(layer.qweight.shape) == 2:
|
||||
layer.qweight.data = _convert_to_numa_tensor(
|
||||
layer.qweight,
|
||||
envs.VLLM_BR_DEVICE_WARP_SIZE,
|
||||
"colmajor",
|
||||
torch.int8,
|
||||
parallel_type=parallel_type)
|
||||
|
||||
if hasattr(layer, 'scales') and layer.scales is not None:
|
||||
pad_zeros = False
|
||||
layer.scales.data = _convert_to_numa_tensor(
|
||||
layer.scales,
|
||||
envs.VLLM_BR_DEVICE_WARP_SIZE,
|
||||
"linear_bias",
|
||||
torch.float32,
|
||||
parallel_type=parallel_type,
|
||||
pad_zeros=pad_zeros)
|
||||
|
||||
if hasattr(layer, 'bias') and layer.bias is not None:
|
||||
pad_zeros = (parallel_type == "row_parallel")
|
||||
layer.bias.data = _convert_to_numa_tensor(
|
||||
layer.bias,
|
||||
envs.VLLM_BR_DEVICE_WARP_SIZE,
|
||||
"linear_bias",
|
||||
torch.float32,
|
||||
parallel_type=parallel_type,
|
||||
pad_zeros=pad_zeros)
|
||||
|
||||
|
||||
@patch_to(GPTQLinearMethod)
|
||||
def apply(self: GPTQLinearMethod,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
# numa weight is 3-dims
|
||||
if len(layer.qweight.shape) == 3:
|
||||
output_size = (layer.output_size_per_partition if hasattr(
|
||||
layer, "output_size_per_partition") else layer.output_size)
|
||||
act_mode = "act_default"
|
||||
if isinstance(layer, MergedColumnParallelLinear):
|
||||
act_mode = "act_swiglu"
|
||||
output_size //= 2
|
||||
if isinstance(layer, RowParallelLinear):
|
||||
seq_len = x.shape[-2]
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
# bypass tp8 and tp4pp2 allreduce
|
||||
pp_size = get_pipeline_model_parallel_group().world_size
|
||||
all_rank = tp_size * pp_size
|
||||
support_types = ((16, 4), (16, 8), (32, 2), (32, 4))
|
||||
layer.reduce_results = not (
|
||||
all_rank <= envs.VLLM_BR_USE_FUSED_ALLREDUCE
|
||||
and seq_len <= envs.VLLM_BR_STATIC_MOE_DECODER_MAX_LEN and
|
||||
(envs.VLLM_BR_DEVICE_SPC_NUM, tp_size) in support_types)
|
||||
if layer.reduce_results:
|
||||
return torch_br.br_fused_mlp_infer(
|
||||
x, [layer.qweight],
|
||||
output_w=output_size,
|
||||
scales=[layer.scales]
|
||||
if layer.scales is not None else None,
|
||||
bias=[bias] if bias is not None else None,
|
||||
activation_mode=act_mode)
|
||||
else:
|
||||
tp_rank = get_tp_group().rank_in_group
|
||||
global_rank = get_tp_group().rank
|
||||
rank_i = global_rank % tp_size
|
||||
assert rank_i == tp_rank
|
||||
return torch_br.supa_fused_linear_allreduce_opt(
|
||||
x,
|
||||
layer.qweight,
|
||||
output_size,
|
||||
tp_rank,
|
||||
tp_size,
|
||||
global_rank,
|
||||
0,
|
||||
scales=layer.scales,
|
||||
bias=bias,
|
||||
act_mode=act_mode)
|
||||
else:
|
||||
return torch_br.br_fused_mlp_infer(
|
||||
x, [layer.qweight],
|
||||
output_w=output_size,
|
||||
scales=[layer.scales] if layer.scales is not None else None,
|
||||
bias=[bias] if bias is not None else None,
|
||||
activation_mode=act_mode)
|
||||
xn = x.shape[0]
|
||||
xh = x.shape[1]
|
||||
ww = layer.qweight.shape[1]
|
||||
# TODO, hard code to skip dry_run stage
|
||||
if xh >= 4096:
|
||||
return torch.ones((xn, xh, ww), dtype=x.dtype, device=x.device)
|
||||
return torch_br.sudnn_qmatmul_infer(x,
|
||||
layer.qweight,
|
||||
layer.scales,
|
||||
bias=bias)
|
||||
924
vllm_br/model_executor/layers/rotary_embedding.py
Normal file
924
vllm_br/model_executor/layers/rotary_embedding.py
Normal file
@@ -0,0 +1,924 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
import itertools
|
||||
from typing import Any, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch_br
|
||||
from fastcore.basics import patch_to
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
import vllm.model_executor.layers.rotary_embedding
|
||||
import vllm.model_executor.models.chatglm
|
||||
import vllm.model_executor.models.deepseek_v2
|
||||
import vllm_br.envs as br_envs
|
||||
from vllm.logger import logger
|
||||
from vllm.model_executor.layers.rotary_embedding import (
|
||||
_ROPE_DICT, DeepseekScalingRotaryEmbedding, DualChunkRotaryEmbedding,
|
||||
DynamicNTKScalingRotaryEmbedding, LinearScalingRotaryEmbedding,
|
||||
Llama3RotaryEmbedding, Llama4VisionRotaryEmbedding, MRotaryEmbedding,
|
||||
NTKScalingRotaryEmbedding, Phi3LongRoPEScaledRotaryEmbedding,
|
||||
RotaryEmbedding, YaRNScalingRotaryEmbedding)
|
||||
from vllm.model_executor.layers.rotary_embedding.common import (
|
||||
rotate_gptj, rotate_neox, yarn_find_correction_range,
|
||||
yarn_linear_ramp_mask)
|
||||
from vllm.model_executor.layers.rotary_embedding.deepseek_scaling_rope import (
|
||||
yarn_get_mscale)
|
||||
from vllm.model_executor.layers.rotary_embedding.mrope import (
|
||||
apply_interleaved_rope)
|
||||
|
||||
|
||||
@patch_to(RotaryEmbedding)
|
||||
def __init__(
|
||||
self,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position_embeddings: int,
|
||||
base: int,
|
||||
is_neox_style: bool,
|
||||
dtype: torch.dtype,
|
||||
op_type: str = "Half", # FIXME: other op type not supported yet
|
||||
) -> None:
|
||||
logger.info('[Patch] RotaryEmbedding use SUPA RoPE')
|
||||
super(RotaryEmbedding, self).__init__() # type: ignore
|
||||
self.head_size = head_size
|
||||
self.rotary_dim = rotary_dim
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.base = base
|
||||
self.is_neox_style = is_neox_style
|
||||
self.dtype = dtype
|
||||
self.op_type = op_type # FIXME: other op type not supported yet
|
||||
|
||||
if isinstance(self, MRotaryEmbedding):
|
||||
cache = self._compute_cos_sin_cache()
|
||||
cache = cache.to(dtype)
|
||||
device = torch.cuda.current_device()
|
||||
cache = cache.to(device)
|
||||
self.cos_sin_cache: torch.Tensor # type: ignore
|
||||
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
||||
elif isinstance(self, DeepseekScalingRotaryEmbedding):
|
||||
self.head_size = head_size
|
||||
self.rotary_dim = rotary_dim
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.base = base
|
||||
self.is_neox_style = is_neox_style
|
||||
self.dtype = dtype
|
||||
|
||||
cache = self._compute_cos_sin_cache()
|
||||
cache = cache.to(dtype)
|
||||
device = torch.supa.current_device()
|
||||
cache = cache.to(device)
|
||||
self.cos_sin_cache: torch.Tensor # type: ignore
|
||||
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
||||
else:
|
||||
sin_cache, cos_cache = self._compute_cos_sin_cache()
|
||||
sin_cache = sin_cache.to(torch.float32)
|
||||
cos_cache = cos_cache.to(torch.float32)
|
||||
device = torch.cuda.current_device()
|
||||
sin_cache = sin_cache.to(device)
|
||||
cos_cache = cos_cache.to(device)
|
||||
self.register_buffer("sin_cache", sin_cache, persistent=False)
|
||||
self.register_buffer("cos_cache", cos_cache, persistent=False)
|
||||
|
||||
|
||||
@patch_to(RotaryEmbedding)
|
||||
def _compute_cos_sin_cache(self) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Compute the cos and sin cache."""
|
||||
with torch.device('cpu'):
|
||||
inv_freq = self._compute_inv_freq(self.base)
|
||||
t = torch.arange(self.max_position_embeddings, dtype=torch.float)
|
||||
|
||||
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
||||
if isinstance(self, MRotaryEmbedding):
|
||||
cos = freqs.cos()
|
||||
sin = freqs.sin()
|
||||
cache = torch.cat((cos, sin), dim=-1)
|
||||
return cache
|
||||
else:
|
||||
if self.op_type == "Half" or self.op_type == "TeleChat":
|
||||
freqs = freqs.repeat(1, 2)
|
||||
cos = freqs.cos()
|
||||
sin = freqs.sin()
|
||||
else:
|
||||
cos_freqs = freqs.repeat_interleave(2, dim=-1)
|
||||
cos = cos_freqs.cos()
|
||||
scales = torch.arange(cos_freqs.numel()) % 2 * 2 - 1
|
||||
sin_freqs = cos_freqs * scales.reshape_as(cos_freqs)
|
||||
sin = sin_freqs.sin()
|
||||
return sin, cos
|
||||
|
||||
|
||||
@patch_to(RotaryEmbedding)
|
||||
def forward_oot(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
query_, key_ = torch_br.supa_rope_infer_v2(query,
|
||||
key,
|
||||
self.sin_cache,
|
||||
self.cos_cache,
|
||||
positions,
|
||||
self.head_size,
|
||||
rope_type=self.op_type,
|
||||
rotary_size=self.rotary_dim)
|
||||
return query_, key_
|
||||
|
||||
|
||||
@patch_to(RotaryEmbedding)
|
||||
def enabled(cls) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
class SupaDeepseekScalingRotaryEmbedding(RotaryEmbedding):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position_embeddings: int,
|
||||
base: float,
|
||||
is_neox_style: bool,
|
||||
scaling_factor: float,
|
||||
dtype: torch.dtype,
|
||||
*,
|
||||
extrapolation_factor: float = 1,
|
||||
attn_factor: float = 1,
|
||||
beta_fast: int = 32,
|
||||
beta_slow: int = 1,
|
||||
mscale: float = 1,
|
||||
mscale_all_dim: float = 0,
|
||||
) -> None:
|
||||
self.scaling_factor = scaling_factor
|
||||
self.extrapolation_factor = extrapolation_factor
|
||||
self.attn_factor = attn_factor
|
||||
self.beta_fast = beta_fast
|
||||
self.beta_slow = beta_slow
|
||||
# Get n-d magnitude scaling corrected for interpolation.
|
||||
self.mscale = float(
|
||||
yarn_get_mscale(self.scaling_factor, float(mscale)) /
|
||||
yarn_get_mscale(self.scaling_factor, float(mscale_all_dim)) *
|
||||
attn_factor)
|
||||
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
|
||||
is_neox_style, dtype)
|
||||
|
||||
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
|
||||
with torch.device('cpu'):
|
||||
pos_freqs = self.base**(torch.arange(
|
||||
0, self.rotary_dim, 2, dtype=torch.float, device="cpu") /
|
||||
self.rotary_dim)
|
||||
inv_freq_extrapolation = 1.0 / pos_freqs
|
||||
inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)
|
||||
|
||||
low, high = yarn_find_correction_range(
|
||||
self.beta_fast, self.beta_slow, self.rotary_dim, self.base,
|
||||
self.max_position_embeddings)
|
||||
# Get n-d rotational scaling corrected for extrapolation
|
||||
inv_freq_mask = (1 - yarn_linear_ramp_mask(
|
||||
low, high, self.rotary_dim // 2,
|
||||
dtype=torch.float)) * self.extrapolation_factor
|
||||
inv_freq = inv_freq_interpolation * (
|
||||
1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
|
||||
return inv_freq
|
||||
|
||||
def _compute_cos_sin_cache(self) -> torch.Tensor:
|
||||
with torch.device('cpu'):
|
||||
inv_freq = self._compute_inv_freq(self.scaling_factor)
|
||||
t = torch.arange(self.max_position_embeddings *
|
||||
self.scaling_factor,
|
||||
dtype=torch.float)
|
||||
|
||||
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
||||
cos_freqs = freqs.repeat_interleave(2, dim=-1)
|
||||
cos = (cos_freqs.cos() * self.mscale)
|
||||
scales = torch.arange(cos_freqs.numel()) % 2 * 2 - 1
|
||||
sin_freqs = cos_freqs * scales.reshape_as(cos_freqs)
|
||||
sin = (sin_freqs.sin() * self.mscale)
|
||||
return sin, cos
|
||||
|
||||
|
||||
@patch_to(DeepseekScalingRotaryEmbedding)
|
||||
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
|
||||
with torch.device('cpu'):
|
||||
pos_freqs = self.base**(torch.arange(
|
||||
0, self.rotary_dim, 2, dtype=torch.float, device="cpu") /
|
||||
self.rotary_dim)
|
||||
inv_freq_extrapolation = 1.0 / pos_freqs
|
||||
inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)
|
||||
|
||||
low, high = yarn_find_correction_range(self.beta_fast, self.beta_slow,
|
||||
self.rotary_dim, self.base,
|
||||
self.max_position_embeddings)
|
||||
# Get n-d rotational scaling corrected for extrapolation
|
||||
inv_freq_mask = (1 - yarn_linear_ramp_mask(
|
||||
low, high, self.rotary_dim // 2,
|
||||
dtype=torch.float)) * self.extrapolation_factor
|
||||
inv_freq = inv_freq_interpolation * (
|
||||
1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
|
||||
return inv_freq
|
||||
|
||||
|
||||
@patch_to(DeepseekScalingRotaryEmbedding)
|
||||
def _compute_cos_sin_cache(self) -> torch.Tensor:
|
||||
with torch.device('cpu'):
|
||||
inv_freq = self._compute_inv_freq(self.scaling_factor)
|
||||
t = torch.arange(self.max_position_embeddings * self.scaling_factor,
|
||||
dtype=torch.float32)
|
||||
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
||||
cos = (freqs.cos() * self.mscale)
|
||||
sin = (freqs.sin() * self.mscale)
|
||||
cache = torch.cat((cos, sin), dim=-1)
|
||||
return cache
|
||||
|
||||
|
||||
@patch_to(DeepseekScalingRotaryEmbedding)
|
||||
def forward_native(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: Optional[torch.Tensor] = None,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
assert key is not None
|
||||
self._match_cos_sin_cache_dtype(query)
|
||||
query_rot = query[..., :self.rotary_dim]
|
||||
key_rot = key[..., :self.rotary_dim]
|
||||
if self.rotary_dim < self.head_size:
|
||||
query_pass = query[..., self.rotary_dim:]
|
||||
key_pass = key[..., self.rotary_dim:]
|
||||
|
||||
cos_sin = self.cos_sin_cache[
|
||||
torch.add(positions, offsets) if offsets is not None else positions]
|
||||
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||
if self.is_neox_style:
|
||||
# NOTE(woosuk): Here we assume that the positions tensor has the
|
||||
# shape [batch_size, seq_len].
|
||||
cos = cos.repeat(1, 1, 2).unsqueeze(-2)
|
||||
sin = sin.repeat(1, 1, 2).unsqueeze(-2)
|
||||
else:
|
||||
device = torch.supa.current_device()
|
||||
cos = cos.to('cpu')
|
||||
sin = sin.to('cpu')
|
||||
cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
|
||||
sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)
|
||||
cos = cos.to(device)
|
||||
sin = sin.to(device)
|
||||
|
||||
rotate_fn = rotate_neox if self.is_neox_style else rotate_gptj
|
||||
device = query_rot.device
|
||||
if query.shape[0] > 1024:
|
||||
query_rot = query_rot.to('cpu')
|
||||
key_rot = key_rot.to('cpu')
|
||||
cos = cos.to('cpu')
|
||||
sin = sin.to('cpu')
|
||||
query_rot = query_rot * cos + rotate_fn(query_rot) * sin
|
||||
key_rot = key_rot * cos + rotate_fn(key_rot) * sin
|
||||
if query.shape[0] > 1024:
|
||||
query_rot = query_rot.to(device)
|
||||
key_rot = key_rot.to(device)
|
||||
|
||||
if self.rotary_dim < self.head_size:
|
||||
query = torch.cat((query_rot, query_pass), dim=-1)
|
||||
key = torch.cat((key_rot, key_pass), dim=-1)
|
||||
else:
|
||||
query = query_rot
|
||||
key = key_rot
|
||||
return query, key
|
||||
|
||||
|
||||
@patch_to(DeepseekScalingRotaryEmbedding)
|
||||
def forward_oot(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
query, key = self.forward_native(positions, query, key, offsets)
|
||||
return query, key
|
||||
|
||||
|
||||
@patch_to(YaRNScalingRotaryEmbedding)
|
||||
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
|
||||
with torch.device('cpu'):
|
||||
pos_freqs = self.base**(
|
||||
torch.arange(0, self.rotary_dim, 2, dtype=torch.float) /
|
||||
self.rotary_dim)
|
||||
inv_freq_extrapolation = 1.0 / pos_freqs
|
||||
inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)
|
||||
|
||||
low, high = yarn_find_correction_range(self.beta_fast, self.beta_slow,
|
||||
self.rotary_dim, self.base,
|
||||
self.max_position_embeddings)
|
||||
# Get n-d rotational scaling corrected for extrapolation
|
||||
inv_freq_mask = (1 - yarn_linear_ramp_mask(
|
||||
low, high, self.rotary_dim // 2,
|
||||
dtype=torch.float)) * self.extrapolation_factor
|
||||
inv_freq = inv_freq_interpolation * (
|
||||
1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
|
||||
return inv_freq
|
||||
|
||||
|
||||
@patch_to(YaRNScalingRotaryEmbedding)
|
||||
def _compute_cos_sin_cache(self) -> torch.Tensor:
|
||||
with torch.device('cpu'):
|
||||
inv_freq = self._compute_inv_freq(self.scaling_factor)
|
||||
t = torch.arange(self.max_position_embeddings * self.scaling_factor,
|
||||
dtype=torch.float32)
|
||||
|
||||
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
||||
freqs = freqs.repeat(1, 2)
|
||||
cos = freqs.cos() * self.mscale
|
||||
sin = freqs.sin() * self.mscale
|
||||
return sin, cos
|
||||
|
||||
|
||||
def dtnamicNTK_compute_cos_sin_cache(
|
||||
self) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Compute the cos and sin cache."""
|
||||
with torch.device('cpu'):
|
||||
inv_freq = self._compute_inv_freq(self.base)
|
||||
t = torch.arange(self.max_position_embeddings, dtype=torch.float)
|
||||
|
||||
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
||||
if self.op_type == "Half" or self.op_type == "TeleChat":
|
||||
freqs = freqs.repeat(1, 2)
|
||||
cos = freqs.cos()
|
||||
sin = freqs.sin()
|
||||
else:
|
||||
cos_freqs = freqs.repeat_interleave(2, dim=-1)
|
||||
cos = cos_freqs.cos()
|
||||
scales = torch.arange(cos_freqs.numel()) % 2 * 2 - 1
|
||||
sin_freqs = cos_freqs * scales.reshape_as(cos_freqs)
|
||||
sin = sin_freqs.sin()
|
||||
return sin, cos
|
||||
|
||||
|
||||
def dynamicNTKScaling_rope_forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if query.shape[-1] != key.shape[-1]:
|
||||
query_, key_ = torch_br.supa_rope_infer_v2(query,
|
||||
key,
|
||||
self.sin_cache,
|
||||
self.cos_cache,
|
||||
positions,
|
||||
self.head_size,
|
||||
rope_type="MRope")
|
||||
else:
|
||||
query_, key_ = torch_br.supa_rope_infer_v2(query,
|
||||
key,
|
||||
self.sin_cache,
|
||||
self.cos_cache,
|
||||
positions,
|
||||
self.head_size,
|
||||
rope_type=self.op_type)
|
||||
return query_, key_
|
||||
|
||||
|
||||
DynamicNTKScalingRotaryEmbedding._compute_cos_sin_cache = dtnamicNTK_compute_cos_sin_cache
|
||||
DynamicNTKScalingRotaryEmbedding.forward = dynamicNTKScaling_rope_forward
|
||||
|
||||
|
||||
def _apply_rotary_emb_torch(
|
||||
x: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
is_neox_style: bool,
|
||||
) -> torch.Tensor:
|
||||
cos = cos.unsqueeze(-2).to(x.dtype)
|
||||
sin = sin.unsqueeze(-2).to(x.dtype)
|
||||
if is_neox_style:
|
||||
x1, x2 = torch.chunk(x, 2, dim=-1)
|
||||
else:
|
||||
x1 = x[..., ::2]
|
||||
x2 = x[..., 1::2]
|
||||
o1 = x1 * cos - x2 * sin
|
||||
o2 = x2 * cos + x1 * sin
|
||||
if is_neox_style:
|
||||
return torch.cat((o1, o2), dim=-1)
|
||||
else:
|
||||
return torch.stack((o1, o2), dim=-1).flatten(-2)
|
||||
|
||||
|
||||
def _apply_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor,
|
||||
is_neox_style: bool) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x: [num_tokens, num_heads, head_size]
|
||||
cos: [num_tokens, head_size // 2]
|
||||
sin: [num_tokens, head_size // 2]
|
||||
is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
|
||||
positional embeddings.
|
||||
"""
|
||||
return _apply_rotary_emb_torch(x, cos, sin, is_neox_style)
|
||||
|
||||
|
||||
def forward_MRotaryEmbedding_0_9_2(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: Optional[torch.Tensor] = None,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""PyTorch-native implementation equivalent to forward().
|
||||
|
||||
Args:
|
||||
positions:
|
||||
[num_tokens,] (text only) or
|
||||
[3, num_tokens] (T/H/W positions with multimodal inputs)
|
||||
query: [num_tokens, num_heads * head_size]
|
||||
key: [num_tokens, num_kv_heads * head_size]
|
||||
"""
|
||||
assert positions.ndim == 1 or positions.ndim == 2
|
||||
assert key is not None
|
||||
|
||||
num_tokens = positions.shape[-1]
|
||||
cos_sin = self.cos_sin_cache[positions]
|
||||
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||
if positions.ndim == 2:
|
||||
assert self.mrope_section
|
||||
|
||||
cos = torch.cat([
|
||||
m[i] for i, m in enumerate(cos.split(self.mrope_section, dim=-1))
|
||||
],
|
||||
dim=-1)
|
||||
sin = torch.cat([
|
||||
m[i] for i, m in enumerate(sin.split(self.mrope_section, dim=-1))
|
||||
],
|
||||
dim=-1)
|
||||
|
||||
query_shape = query.shape
|
||||
query = query.view(num_tokens, -1, self.head_size)
|
||||
query_rot = query[..., :self.rotary_dim]
|
||||
query_pass = query[..., self.rotary_dim:]
|
||||
query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style)
|
||||
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
|
||||
|
||||
key_shape = key.shape
|
||||
key = key.view(num_tokens, -1, self.head_size)
|
||||
key_rot = key[..., :self.rotary_dim]
|
||||
key_pass = key[..., self.rotary_dim:]
|
||||
key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style)
|
||||
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
||||
return query, key
|
||||
|
||||
|
||||
def forward_supa(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
positions:
|
||||
[num_tokens,] (text only) or
|
||||
[3, num_tokens] (T/H/W positions with multimodal inputs)
|
||||
query: [num_tokens, num_heads * head_size]
|
||||
key: [num_tokens, num_kv_heads * head_size]
|
||||
"""
|
||||
if br_envs.VLLM_BR_USE_MROPE_0_9_2:
|
||||
return forward_MRotaryEmbedding_0_9_2(self, positions, query, key)
|
||||
|
||||
assert positions.ndim == 1 or positions.ndim == 2
|
||||
data_in_supa = lambda t: str(t.device).startswith('supa')
|
||||
data_in_cpu = lambda t: t.device == torch.device('cpu')
|
||||
|
||||
if positions.ndim == 2:
|
||||
# use bypass for decode stage
|
||||
if (positions.shape[1] == 1):
|
||||
cos_sin = self.cos_sin_cache[positions]
|
||||
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||
cos = cos[0]
|
||||
sin = sin[0]
|
||||
else:
|
||||
cos_sin = self.cos_sin_cache[positions.to(torch.int64)]
|
||||
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||
assert self.mrope_section
|
||||
|
||||
if self.mrope_interleaved:
|
||||
cos = apply_interleaved_rope(cos, self.mrope_section)
|
||||
sin = apply_interleaved_rope(sin, self.mrope_section)
|
||||
else:
|
||||
cos = torch.cat([
|
||||
m[i] for i, m in enumerate(
|
||||
cos.split(self.mrope_section, dim=-1))
|
||||
],
|
||||
dim=-1)
|
||||
sin = torch.cat([
|
||||
m[i] for i, m in enumerate(
|
||||
sin.split(self.mrope_section, dim=-1))
|
||||
],
|
||||
dim=-1)
|
||||
if data_in_supa(query) and data_in_supa(key):
|
||||
sin = sin.supa() if data_in_cpu(sin) else sin
|
||||
cos = cos.supa() if data_in_cpu(cos) else cos
|
||||
positions = positions.supa() if data_in_cpu(positions) else positions
|
||||
|
||||
query, key = torch_br.supa_rope_infer_v2(query,
|
||||
key,
|
||||
sin.to(torch.float32),
|
||||
cos.to(torch.float32),
|
||||
positions.to(torch.int32),
|
||||
self.head_size,
|
||||
rope_type="MRope")
|
||||
return query, key
|
||||
|
||||
|
||||
MRotaryEmbedding.forward = forward_supa
|
||||
|
||||
|
||||
def get_rope(
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position: int,
|
||||
base: int,
|
||||
is_neox_style: bool = True,
|
||||
rope_scaling: Optional[dict[str, Any]] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
partial_rotary_factor: float = 1.0,
|
||||
dual_chunk_attention_config: Optional[dict[str, Any]] = None,
|
||||
op_type: str = "Half",
|
||||
) -> RotaryEmbedding:
|
||||
if dtype is None:
|
||||
dtype = torch.get_default_dtype()
|
||||
if rope_scaling is not None:
|
||||
# Transforms every value that is a list into a tuple for caching calls
|
||||
rope_scaling_tuple = {
|
||||
k: tuple(v) if isinstance(v, list) else v
|
||||
for k, v in rope_scaling.items()
|
||||
}
|
||||
rope_scaling_args = tuple(rope_scaling_tuple.items())
|
||||
else:
|
||||
rope_scaling_args = None
|
||||
|
||||
if dual_chunk_attention_config is not None:
|
||||
dual_chunk_attention_tuple = {
|
||||
k: tuple(v) if isinstance(v, list) else v
|
||||
for k, v in dual_chunk_attention_config.items()
|
||||
if k != "sparse_attention_config"
|
||||
}
|
||||
dual_chunk_attention_args = tuple(dual_chunk_attention_tuple.items())
|
||||
else:
|
||||
dual_chunk_attention_args = None
|
||||
|
||||
if partial_rotary_factor < 1.0:
|
||||
rotary_dim = int(rotary_dim * partial_rotary_factor)
|
||||
key = (head_size, rotary_dim, max_position, base, is_neox_style,
|
||||
rope_scaling_args, dual_chunk_attention_args, dtype)
|
||||
if key in _ROPE_DICT:
|
||||
return _ROPE_DICT[key]
|
||||
|
||||
if dual_chunk_attention_config is not None:
|
||||
extra_kwargs = {
|
||||
k: v
|
||||
for k, v in dual_chunk_attention_config.items()
|
||||
if k in ("chunk_size", "local_size")
|
||||
}
|
||||
rotary_emb = DualChunkRotaryEmbedding(head_size, rotary_dim,
|
||||
max_position, base,
|
||||
is_neox_style, dtype,
|
||||
**extra_kwargs)
|
||||
elif not rope_scaling:
|
||||
rotary_emb = RotaryEmbedding(head_size,
|
||||
rotary_dim,
|
||||
max_position,
|
||||
base,
|
||||
is_neox_style,
|
||||
dtype,
|
||||
op_type=op_type)
|
||||
else:
|
||||
scaling_type = rope_scaling["rope_type"]
|
||||
|
||||
if scaling_type == "llama3":
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
low_freq_factor = rope_scaling["low_freq_factor"]
|
||||
high_freq_factor = rope_scaling["high_freq_factor"]
|
||||
original_max_position = rope_scaling[
|
||||
"original_max_position_embeddings"]
|
||||
rotary_emb = Llama3RotaryEmbedding(head_size, rotary_dim,
|
||||
max_position, base,
|
||||
is_neox_style, dtype,
|
||||
scaling_factor, low_freq_factor,
|
||||
high_freq_factor,
|
||||
original_max_position)
|
||||
elif scaling_type == "mllama4":
|
||||
rotary_emb = Llama4VisionRotaryEmbedding(head_size, rotary_dim,
|
||||
max_position, base,
|
||||
is_neox_style, dtype)
|
||||
elif scaling_type == "default":
|
||||
if "mrope_section" in rope_scaling:
|
||||
rotary_emb = MRotaryEmbedding(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position,
|
||||
base,
|
||||
is_neox_style,
|
||||
dtype=torch.float32,
|
||||
mrope_section=rope_scaling["mrope_section"],
|
||||
mrope_interleaved=rope_scaling.get("mrope_interleaved",
|
||||
False),
|
||||
)
|
||||
else:
|
||||
rotary_emb = RotaryEmbedding(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position,
|
||||
base,
|
||||
is_neox_style,
|
||||
dtype,
|
||||
)
|
||||
elif scaling_type == "linear":
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim,
|
||||
max_position, base,
|
||||
is_neox_style,
|
||||
scaling_factor, dtype)
|
||||
elif scaling_type == "ntk":
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
mixed_b = rope_scaling.get('mixed_b', None)
|
||||
rotary_emb = NTKScalingRotaryEmbedding(head_size, rotary_dim,
|
||||
max_position, base,
|
||||
is_neox_style,
|
||||
scaling_factor, dtype,
|
||||
mixed_b)
|
||||
elif scaling_type == "dynamic":
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
rotary_emb = DynamicNTKScalingRotaryEmbedding(
|
||||
head_size, rotary_dim, max_position, base, is_neox_style,
|
||||
scaling_factor, dtype)
|
||||
elif scaling_type == "yarn":
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
original_max_position = rope_scaling[
|
||||
"original_max_position_embeddings"]
|
||||
extra_kwargs = {
|
||||
k: v
|
||||
for k, v in rope_scaling.items()
|
||||
if k in ("extrapolation_factor", "attn_factor", "beta_fast",
|
||||
"beta_slow")
|
||||
}
|
||||
rotary_emb = YaRNScalingRotaryEmbedding(head_size, rotary_dim,
|
||||
original_max_position,
|
||||
base, is_neox_style,
|
||||
scaling_factor, dtype,
|
||||
**extra_kwargs)
|
||||
elif scaling_type == "deepseek_yarn":
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
original_max_position = rope_scaling[
|
||||
"original_max_position_embeddings"]
|
||||
# assert max_position == original_max_position * scaling_factor
|
||||
extra_kwargs = {
|
||||
k: v
|
||||
for k, v in rope_scaling.items()
|
||||
if k in ("extrapolation_factor", "attn_factor", "beta_fast",
|
||||
"beta_slow", "mscale", "mscale_all_dim")
|
||||
}
|
||||
rotary_emb = DeepseekScalingRotaryEmbedding(
|
||||
head_size, rotary_dim, original_max_position, base,
|
||||
is_neox_style, scaling_factor, dtype, **extra_kwargs)
|
||||
elif scaling_type == "deepseek_yarn_supa":
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
original_max_position = rope_scaling[
|
||||
"original_max_position_embeddings"]
|
||||
# assert max_position == original_max_position * scaling_factor
|
||||
extra_kwargs = {
|
||||
k: v
|
||||
for k, v in rope_scaling.items()
|
||||
if k in ("extrapolation_factor", "attn_factor", "beta_fast",
|
||||
"beta_slow", "mscale", "mscale_all_dim")
|
||||
}
|
||||
rotary_emb = SupaDeepseekScalingRotaryEmbedding(
|
||||
head_size, rotary_dim, original_max_position, base,
|
||||
is_neox_style, scaling_factor, dtype, **extra_kwargs)
|
||||
elif scaling_type == "longrope":
|
||||
short_factor = rope_scaling["short_factor"]
|
||||
long_factor = rope_scaling["long_factor"]
|
||||
original_max_position = rope_scaling[
|
||||
"original_max_position_embeddings"]
|
||||
extra_kwargs = {
|
||||
k: v
|
||||
for k, v in rope_scaling.items()
|
||||
if k in ("short_mscale", "long_mscale")
|
||||
}
|
||||
rotary_emb = Phi3LongRoPEScaledRotaryEmbedding(
|
||||
head_size, rotary_dim, max_position, original_max_position,
|
||||
base, is_neox_style, dtype, short_factor, long_factor,
|
||||
**extra_kwargs)
|
||||
else:
|
||||
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
||||
_ROPE_DICT[key] = rotary_emb
|
||||
return rotary_emb
|
||||
|
||||
|
||||
def deepseek_get_rope(
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position: int,
|
||||
base: int,
|
||||
is_neox_style: bool = True,
|
||||
rope_scaling: Optional[dict[str, Any]] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
partial_rotary_factor: float = 1.0,
|
||||
dual_chunk_attention_config: Optional[dict[str, Any]] = None,
|
||||
) -> RotaryEmbedding:
|
||||
return get_rope(head_size, rotary_dim, max_position, base, is_neox_style,
|
||||
rope_scaling, dtype, partial_rotary_factor,
|
||||
dual_chunk_attention_config, "DeepSeek")
|
||||
|
||||
|
||||
def chatglm2_get_rope(
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position: int,
|
||||
base: int,
|
||||
is_neox_style: bool = True,
|
||||
rope_scaling: Optional[dict[str, Any]] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
partial_rotary_factor: float = 1.0,
|
||||
dual_chunk_attention_config: Optional[dict[str, Any]] = None,
|
||||
) -> RotaryEmbedding:
|
||||
return get_rope(head_size, rotary_dim, max_position, base, is_neox_style,
|
||||
rope_scaling, dtype, partial_rotary_factor,
|
||||
dual_chunk_attention_config, "DeepSeek")
|
||||
|
||||
|
||||
vllm.model_executor.layers.rotary_embedding.get_rope = get_rope
|
||||
vllm.model_executor.models.deepseek_v2.get_rope = deepseek_get_rope
|
||||
vllm.model_executor.models.chatglm.get_rope = chatglm2_get_rope
|
||||
|
||||
|
||||
@patch_to(MRotaryEmbedding)
|
||||
def _glm4v_get_input_positions_tensor(
|
||||
cls,
|
||||
input_tokens: list[int],
|
||||
hf_config: PretrainedConfig,
|
||||
image_grid_thw: Union[list[list[int]], torch.Tensor],
|
||||
video_grid_thw: Union[list[list[int]], torch.Tensor],
|
||||
context_len: int = 0,
|
||||
seq_len: Optional[int] = None,
|
||||
) -> tuple[torch.Tensor, int]:
|
||||
"""Get mrope input positions and delta value for GLM4V."""
|
||||
|
||||
image_token_id = hf_config.image_token_id
|
||||
video_start_token_id = hf_config.video_start_token_id
|
||||
video_end_token_id = hf_config.video_end_token_id
|
||||
spatial_merge_size = hf_config.vision_config.spatial_merge_size
|
||||
llm_pos_ids_list: list = []
|
||||
|
||||
if not (image_grid_thw is None and video_grid_thw is None):
|
||||
if isinstance(image_grid_thw, torch.Tensor):
|
||||
image_grid_thw = image_grid_thw.tolist()
|
||||
|
||||
input_token_type: list[str] = []
|
||||
video_check_flg = False
|
||||
for token in input_tokens:
|
||||
if token == video_start_token_id:
|
||||
video_check_flg = True
|
||||
elif token == video_end_token_id:
|
||||
video_check_flg = False
|
||||
|
||||
if (token == image_token_id) and (video_check_flg is False):
|
||||
input_token_type.append("image")
|
||||
elif (token == image_token_id) and (video_check_flg is True):
|
||||
input_token_type.append("video")
|
||||
else:
|
||||
input_token_type.append("text")
|
||||
|
||||
input_type_group: list[tuple[str, int, int]] = []
|
||||
for key, group_iter in itertools.groupby(enumerate(input_token_type),
|
||||
lambda x: x[1]):
|
||||
group_list = list(group_iter)
|
||||
start_index = group_list[0][0]
|
||||
end_index = group_list[-1][0] + 1
|
||||
input_type_group.append((key, start_index, end_index))
|
||||
|
||||
video_frame_num = 1
|
||||
mm_data_idx = 0
|
||||
for modality_type, start_idx, end_idx in input_type_group:
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if len(
|
||||
llm_pos_ids_list) > 0 else 0
|
||||
if modality_type == "image":
|
||||
t, h, w = (
|
||||
image_grid_thw[mm_data_idx][0],
|
||||
image_grid_thw[mm_data_idx][1],
|
||||
image_grid_thw[mm_data_idx][2],
|
||||
)
|
||||
llm_grid_t, llm_grid_h, llm_grid_w = \
|
||||
t, h // spatial_merge_size, w // spatial_merge_size
|
||||
|
||||
t_index = torch.arange(llm_grid_t).view(-1, 1).expand(
|
||||
-1, llm_grid_h * llm_grid_w).flatten()
|
||||
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(
|
||||
llm_grid_t, -1, llm_grid_w).flatten()
|
||||
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(
|
||||
llm_grid_t, llm_grid_h, -1).flatten()
|
||||
llm_pos_ids_list.append(
|
||||
torch.stack([t_index, h_index, w_index]) + st_idx)
|
||||
mm_data_idx += 1
|
||||
|
||||
elif modality_type == "video":
|
||||
t, h, w = (
|
||||
video_frame_num,
|
||||
image_grid_thw[mm_data_idx][1],
|
||||
image_grid_thw[mm_data_idx][2],
|
||||
)
|
||||
llm_grid_t, llm_grid_h, llm_grid_w = \
|
||||
t, h // spatial_merge_size, w // spatial_merge_size
|
||||
|
||||
for t_idx in range(llm_grid_t):
|
||||
t_index = torch.tensor(t_idx).view(-1, 1).expand(
|
||||
-1, llm_grid_h * llm_grid_w).flatten()
|
||||
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(
|
||||
1, -1, llm_grid_w).flatten()
|
||||
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(
|
||||
1, llm_grid_h, -1).flatten()
|
||||
llm_pos_ids_list.append(
|
||||
torch.stack([t_index, h_index, w_index]) + st_idx)
|
||||
|
||||
mm_data_idx += 1
|
||||
video_frame_num += 1
|
||||
|
||||
else:
|
||||
text_len = end_idx - start_idx
|
||||
llm_pos_ids_list.append(
|
||||
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
|
||||
video_frame_num = 1
|
||||
|
||||
else:
|
||||
text_len = len(input_tokens)
|
||||
llm_pos_ids_list.append(
|
||||
torch.arange(text_len).view(1, -1).expand(3, -1))
|
||||
|
||||
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
|
||||
llm_positions = llm_positions[:, context_len:seq_len]
|
||||
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
|
||||
return llm_positions, mrope_position_delta
|
||||
|
||||
|
||||
@patch_to(MRotaryEmbedding)
|
||||
def get_input_positions_tensor_for_glm(
|
||||
cls,
|
||||
input_tokens: list[int],
|
||||
hf_config: PretrainedConfig,
|
||||
image_grid_thw: Union[list[list[int]], torch.Tensor],
|
||||
video_grid_thw: Union[list[list[int]], torch.Tensor],
|
||||
second_per_grid_ts: list[float],
|
||||
context_len: int = 0,
|
||||
seq_len: Optional[int] = None,
|
||||
audio_feature_lengths: Optional[torch.Tensor] = None,
|
||||
use_audio_in_video: bool = False,
|
||||
) -> tuple[torch.Tensor, int]:
|
||||
from vllm.transformers_utils.config import thinker_uses_mrope
|
||||
if thinker_uses_mrope(hf_config):
|
||||
return cls._omni_get_input_positions_tensor(
|
||||
input_tokens=input_tokens,
|
||||
hf_config=hf_config,
|
||||
image_grid_thw=image_grid_thw,
|
||||
video_grid_thw=video_grid_thw,
|
||||
second_per_grid_ts=second_per_grid_ts,
|
||||
context_len=context_len,
|
||||
seq_len=seq_len,
|
||||
audio_feature_lengths=audio_feature_lengths,
|
||||
use_audio_in_video=use_audio_in_video,
|
||||
)
|
||||
elif "glm4v" in hf_config.model_type:
|
||||
return cls._glm4v_get_input_positions_tensor(
|
||||
cls,
|
||||
input_tokens=input_tokens,
|
||||
hf_config=hf_config,
|
||||
image_grid_thw=image_grid_thw,
|
||||
video_grid_thw=video_grid_thw,
|
||||
context_len=context_len,
|
||||
seq_len=seq_len,
|
||||
)
|
||||
else:
|
||||
return cls._vl_get_input_positions_tensor(
|
||||
input_tokens=input_tokens,
|
||||
hf_config=hf_config,
|
||||
image_grid_thw=image_grid_thw,
|
||||
video_grid_thw=video_grid_thw,
|
||||
second_per_grid_ts=second_per_grid_ts,
|
||||
context_len=context_len,
|
||||
seq_len=seq_len,
|
||||
)
|
||||
65
vllm_br/model_executor/layers/utils.py
Normal file
65
vllm_br/model_executor/layers/utils.py
Normal file
@@ -0,0 +1,65 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
import torch
|
||||
|
||||
import vllm
|
||||
from vllm.model_executor.layers.utils import get_token_bin_counts_and_mask
|
||||
|
||||
|
||||
def apply_penalties_fit(logits: torch.Tensor,
|
||||
prompt_tokens_tensor: torch.Tensor,
|
||||
output_tokens_tensor: torch.Tensor,
|
||||
presence_penalties: torch.Tensor,
|
||||
frequency_penalties: torch.Tensor,
|
||||
repetition_penalties: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Applies penalties in place to the logits tensor
|
||||
logits : The input logits tensor of shape [num_seqs, vocab_size]
|
||||
prompt_tokens_tensor: A tensor containing the prompt tokens. The prompts
|
||||
are padded to the maximum prompt length within the batch using
|
||||
`vocab_size` as the padding value. The value `vocab_size` is used
|
||||
for padding because it does not correspond to any valid token ID
|
||||
in the vocabulary.
|
||||
output_tokens_tensor: The output tokens tensor.
|
||||
presence_penalties: The presence penalties of shape (num_seqs, )
|
||||
frequency_penalties: The frequency penalties of shape (num_seqs, )
|
||||
repetition_penalties: The repetition penalties of shape (num_seqs, )
|
||||
"""
|
||||
num_seqs, vocab_size = logits.shape
|
||||
_, prompt_mask = get_token_bin_counts_and_mask(prompt_tokens_tensor,
|
||||
vocab_size, num_seqs)
|
||||
output_bin_counts, output_mask = get_token_bin_counts_and_mask(
|
||||
output_tokens_tensor, vocab_size, num_seqs)
|
||||
repetition_penalties = repetition_penalties.unsqueeze(dim=1).repeat(
|
||||
1, vocab_size)
|
||||
|
||||
# If token appears in prompt or output, apply, otherwise use 1.0 for no-op.
|
||||
penalties = torch.where(prompt_mask | output_mask, repetition_penalties,
|
||||
1.0)
|
||||
|
||||
# If logits are positive, divide by penalty, otherwise multiply by penalty.
|
||||
scaling = torch.where(logits > 0, 1.0 / penalties, penalties)
|
||||
logits *= scaling
|
||||
|
||||
# We follow the definition in OpenAI API.
|
||||
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
|
||||
logits -= frequency_penalties.unsqueeze(dim=1) * output_bin_counts
|
||||
logits -= presence_penalties.unsqueeze(dim=1) * output_mask
|
||||
return logits
|
||||
|
||||
|
||||
vllm.model_executor.layers.utils.apply_penalties = apply_penalties_fit
|
||||
139
vllm_br/model_executor/layers/vocab_parallel_embedding.py
Normal file
139
vllm_br/model_executor/layers/vocab_parallel_embedding.py
Normal file
@@ -0,0 +1,139 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch_br
|
||||
import torch_br.supa._debug as supa_debug
|
||||
from fastcore.basics import patch_to
|
||||
|
||||
import vllm
|
||||
from vllm.distributed import tensor_model_parallel_all_reduce
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
UnquantizedEmbeddingMethod, VocabParallelEmbedding)
|
||||
from vllm_br import envs
|
||||
|
||||
|
||||
@patch_to(UnquantizedEmbeddingMethod)
|
||||
def process_weights_after_loading(self, module):
|
||||
if envs.VLLM_BR_EMBEDDING_S0B:
|
||||
ori_weight = module.weight.data.cpu()
|
||||
module.weight.data = torch_br._empty_ut_only(module.weight.shape,
|
||||
"colmajor",
|
||||
False,
|
||||
0,
|
||||
dtype=module.weight.dtype,
|
||||
sbp='SB')
|
||||
module.weight.data.copy_(ori_weight)
|
||||
|
||||
|
||||
@patch_to(UnquantizedEmbeddingMethod)
|
||||
def embedding(self, layer: torch.nn.Module,
|
||||
input_: torch.Tensor) -> torch.Tensor:
|
||||
if envs.VLLM_BR_EMBEDDING_S0B:
|
||||
y_supa = torch_br._empty_ut_only(
|
||||
[1, input_.shape[0], layer.weight.shape[-1]],
|
||||
is_numa=False,
|
||||
dtype=layer.weight.dtype,
|
||||
sbp='BB',
|
||||
tensor_type="colmajor",
|
||||
)
|
||||
torch_br.out_embedding(y_supa, layer.weight.data, input_, -1, -1)
|
||||
y_supa.squeeze_(0)
|
||||
return y_supa
|
||||
return F.embedding(input_, layer.weight)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def get_masked_input_and_mask(
|
||||
input_: torch.Tensor, org_vocab_start_index: int,
|
||||
org_vocab_end_index: int, num_org_vocab_padding: int,
|
||||
added_vocab_start_index: int,
|
||||
added_vocab_end_index: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# torch.jit.script will fuse all of the pointwise ops below
|
||||
# into a single kernel, making it very fast
|
||||
spc_num = envs.VLLM_BR_DEVICE_SPC_NUM
|
||||
if spc_num > 16:
|
||||
org_vocab_mask = (input_ >= org_vocab_start_index) & (
|
||||
input_ < org_vocab_end_index)
|
||||
added_vocab_mask = (input_ >= added_vocab_start_index) & (
|
||||
input_ < added_vocab_end_index)
|
||||
added_offset = added_vocab_start_index - (
|
||||
org_vocab_end_index -
|
||||
org_vocab_start_index) - num_org_vocab_padding
|
||||
valid_offset = (org_vocab_start_index *
|
||||
org_vocab_mask) + (added_offset * added_vocab_mask)
|
||||
vocab_mask = org_vocab_mask | added_vocab_mask
|
||||
input_ = vocab_mask * (input_ - valid_offset)
|
||||
return input_, ~vocab_mask
|
||||
else:
|
||||
input_, inv_vocab_mask = torch_br.supa_embedding_mask_infer(
|
||||
input_, org_vocab_start_index, org_vocab_end_index,
|
||||
num_org_vocab_padding, added_vocab_start_index,
|
||||
added_vocab_end_index)
|
||||
return input_, inv_vocab_mask
|
||||
|
||||
|
||||
vllm.model_executor.layers.vocab_parallel_embedding.get_masked_input_and_mask = get_masked_input_and_mask
|
||||
|
||||
|
||||
def vocab_parallel_embedding_forward(self, input_) -> torch.Tensor:
|
||||
if self.tp_size > 1:
|
||||
# Build the mask.
|
||||
masked_input, input_mask = get_masked_input_and_mask(
|
||||
input_,
|
||||
self.shard_indices.org_vocab_start_index,
|
||||
self.shard_indices.org_vocab_end_index,
|
||||
self.shard_indices.num_org_vocab_padding,
|
||||
self.shard_indices.added_vocab_start_index,
|
||||
self.shard_indices.added_vocab_end_index,
|
||||
)
|
||||
else:
|
||||
masked_input = input_
|
||||
# Get the embeddings.
|
||||
output_parallel = self.quant_method.embedding(self, masked_input.long())
|
||||
# Mask the output embedding.
|
||||
if self.tp_size > 1:
|
||||
output_parallel.masked_fill_(input_mask.unsqueeze(-1),
|
||||
0) # type: ignore
|
||||
# Reduce across all the model parallel GPUs.
|
||||
output = tensor_model_parallel_all_reduce(output_parallel)
|
||||
return output
|
||||
|
||||
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
# numa weight is 3-dims
|
||||
if len(layer.weight.shape) == 3:
|
||||
output_size = (layer.output_size_per_partition if hasattr(
|
||||
layer, "output_size_per_partition") else layer.output_size)
|
||||
return torch_br.br_fused_mlp_infer(
|
||||
x, [layer.weight],
|
||||
output_w=output_size,
|
||||
bias=[bias] if bias is not None else None)
|
||||
supa_debug.set_enable_sublas_api(True)
|
||||
output = F.linear(x, layer.weight, bias)
|
||||
supa_debug.set_enable_sublas_api(False)
|
||||
return output
|
||||
|
||||
|
||||
UnquantizedEmbeddingMethod.apply = apply
|
||||
|
||||
VocabParallelEmbedding.forward = vocab_parallel_embedding_forward
|
||||
17
vllm_br/model_executor/model_loader/__init__.py
Normal file
17
vllm_br/model_executor/model_loader/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
from . import default_loader # noqa: F401
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
83
vllm_br/model_executor/model_loader/default_loader.py
Normal file
83
vllm_br/model_executor/model_loader/default_loader.py
Normal file
@@ -0,0 +1,83 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
import time
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.logger import logger
|
||||
from vllm.model_executor.model_loader import DefaultModelLoader
|
||||
from vllm.model_executor.model_loader.utils import (initialize_model,
|
||||
set_default_torch_dtype)
|
||||
from .utils import process_weights_after_loading
|
||||
|
||||
|
||||
def load_model(self, vllm_config: VllmConfig,
|
||||
model_config: ModelConfig) -> nn.Module:
|
||||
device_config = vllm_config.device_config
|
||||
target_device = torch.device(device_config.device)
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
model = initialize_model(vllm_config=vllm_config,
|
||||
model_config=model_config)
|
||||
# NOTE: on SUPA, with device context may not take effect, mamully to device
|
||||
# model = model.to(target_device)
|
||||
|
||||
# NOTE: move moe weight to cpu, reduce device memory usage, more layers can be moved to cpu if necessary
|
||||
moe_packed_weights = [
|
||||
"mlp.experts.w13_weight_packed",
|
||||
"mlp.experts.w2_weight_packed",
|
||||
"mlp.gate_up_proj",
|
||||
"mlp.down_proj",
|
||||
"mlp.experts",
|
||||
"self_attn.qkv_proj",
|
||||
"self_attn.o_proj",
|
||||
]
|
||||
for name, module in model.named_parameters():
|
||||
if any(s in name for s in moe_packed_weights):
|
||||
module.data = module.to("cpu")
|
||||
else:
|
||||
module.data = module.to(target_device)
|
||||
|
||||
torch.supa.empty_cache()
|
||||
|
||||
weights_to_load = {name for name, _ in model.named_parameters()}
|
||||
loaded_weights = model.load_weights(
|
||||
self.get_all_weights(model_config, model))
|
||||
|
||||
torch.supa.empty_cache()
|
||||
|
||||
self.counter_after_loading_weights = time.perf_counter()
|
||||
logger.info(
|
||||
"Loading weights took %.2f seconds",
|
||||
self.counter_after_loading_weights -
|
||||
self.counter_before_loading_weights)
|
||||
# We only enable strict check for non-quantized models
|
||||
# that have loaded weights tracking currently.
|
||||
if model_config.quantization is None and loaded_weights is not None:
|
||||
weights_not_loaded = weights_to_load - loaded_weights
|
||||
if weights_not_loaded:
|
||||
raise ValueError("Following weights were not initialized from "
|
||||
f"checkpoint: {weights_not_loaded}")
|
||||
|
||||
process_weights_after_loading(model, model_config, target_device)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
return model.eval()
|
||||
|
||||
|
||||
DefaultModelLoader.load_model = load_model
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user