add qwen3

This commit is contained in:
Chranos
2026-02-04 17:22:39 +08:00
parent d1c0f68ab4
commit 8511fe8530
1932 changed files with 300426 additions and 0 deletions

View File

@@ -0,0 +1 @@
from .backends import mlu_attn

View File

@@ -0,0 +1,58 @@
from typing import Optional, Type
import torch
from vllm_mlu.mlu_hijack_utils import MluHijackObject
from vllm.attention.backends.abstract import AttentionType
from vllm.attention.backends.mlu_attn import MLUFlashAttentionMetadata
from vllm_mlu.attention.backends.mlu_attn import MLUFlashAttentionImpl_V2
from .ring_attn import zigzag_ring_attn
from examples.cambricon_custom_func.vllm.mlu_hijack.distributed.parallel_state import (
get_context_model_parallel_world_size)
vllm__attention__backends__flash_attn__MLUFlashAttentionImpl__forward_org = MLUFlashAttentionImpl_V2.forward
def vllm__attention__backends__flash_attn__MLUFlashAttentionImpl__forward_wraper(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: MLUFlashAttentionMetadata,
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER,
use_mla: bool = False,
) -> torch.Tensor:
'''
==========================
Modify by Context Parallel
==========================
@brief: use ring attn when context parallel
'''
if get_context_model_parallel_world_size() > 1 and attn_metadata.prefill_metadata:
return zigzag_ring_attn(self,
query=query.view(-1, self.num_heads, self.head_size),
key=key.view(-1, self.num_kv_heads, self.head_size),
value=value.view(-1, self.num_kv_heads, self.head_size),
kv_cache=kv_cache,
attn_metadata=attn_metadata)
'''
=======================
End of Context Parallel
=======================
'''
return vllm__attention__backends__flash_attn__MLUFlashAttentionImpl__forward_org(self,
query=query,
key=key,
value=value,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
k_scale=k_scale,
v_scale=v_scale,
attn_type=attn_type)
MluHijackObject.apply_hijack(MLUFlashAttentionImpl_V2,
MLUFlashAttentionImpl_V2.forward,
vllm__attention__backends__flash_attn__MLUFlashAttentionImpl__forward_wraper)

View File

@@ -0,0 +1,216 @@
from typing import List, Optional, Tuple
import torch
import torch.nn.functional as F
from vllm import _mlu_ops as mlu_ops
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.backends.mlu_attn import MLUFlashAttentionMetadata
from vllm.attention.ops.paged_attn import PagedAttention
from examples.cambricon_custom_func.vllm.mlu_hijack.distributed.parallel_state import get_context_model_parallel_group
from ...distributed.ring_comm import RingComm
# code references: https://github.com/zhuzilin/ring-flash-attention
def _update_out_and_lse(
out: torch.Tensor,
lse: torch.Tensor,
block_out: torch.Tensor,
block_lse: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
block_out = block_out.to(torch.float32)
block_lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1)
out = out - F.sigmoid(block_lse - lse) * (out - block_out)
lse = lse - F.logsigmoid(lse - block_lse)
return out, lse
def update_out_and_lse(
out: Optional[torch.Tensor],
lse: Optional[torch.Tensor],
block_out: torch.Tensor,
block_lse: torch.Tensor,
slice_=None,
) -> Tuple[torch.Tensor, torch.Tensor]:
if out is None:
if slice_ is not None:
raise RuntimeError("first update_out_and_lse should not pass slice_ args")
out = block_out.to(torch.float32)
lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1)
elif slice_ is not None:
slice_out, slice_lse = out[slice_], lse[slice_]
slice_out, slice_lse = _update_out_and_lse(
slice_out, slice_lse, block_out, block_lse
)
out[slice_], lse[slice_] = slice_out, slice_lse
else:
out, lse = _update_out_and_lse(out, lse, block_out, block_lse)
return out, lse
def get_half(pack_tensor, cu_seq_lens, first_half):
batch_num = cu_seq_lens.shape[0] - 1
half_list = []
for batch in range(batch_num):
if first_half:
start = cu_seq_lens[batch]
end = (cu_seq_lens[batch] + cu_seq_lens[batch + 1]) // 2
else:
start = (cu_seq_lens[batch] + cu_seq_lens[batch + 1]) // 2
end = cu_seq_lens[batch + 1]
half = pack_tensor[start: end]
half_list.append(half)
half = torch.cat(half_list, dim=0)
return half
def update_half(pack_tensor, half_tensor, cu_seq_lens, first_half):
half_cu_seq_lens = cu_seq_lens // 2
batch_num = cu_seq_lens.shape[0] - 1
for batch in range(batch_num):
if first_half:
start = cu_seq_lens[batch]
end = (cu_seq_lens[batch] + cu_seq_lens[batch + 1]) // 2
else:
start = (cu_seq_lens[batch] + cu_seq_lens[batch + 1]) // 2
end = cu_seq_lens[batch + 1]
pack_tensor[start: end] = half_tensor[half_cu_seq_lens[batch]: half_cu_seq_lens[batch + 1]]
def zigzag_ring_attn(self,
query: torch.Tensor, # [num_tokens, num_heads, head_size]
key: torch.Tensor, # [num_tokens, num_heads. head_size]
value: torch.Tensor, # [num_tokens, num_heads, head_size]
kv_cache: List[torch.Tensor],
attn_metadata: MLUFlashAttentionMetadata) -> torch.Tensor:
num_tokens, _, _ = query.shape
cu_seq_lens = attn_metadata.prefill_metadata.seq_start_loc
batch_num = cu_seq_lens.shape[0] - 1
block_seq_len = query.shape[0] // 2
process_group = get_context_model_parallel_group().device_group
comm = RingComm(process_group) # k
comm_ = RingComm(process_group) # v
comm__ = RingComm(process_group) # slot_mapping
q, k, v = query, key, value
if batch_num == 1:
q1 = q[block_seq_len:]
else:
q1 = get_half(q, cu_seq_lens, False)
slot_mapping = attn_metadata.slot_mapping
out = None
lse = None
next_k, next_v = None, None
next_slot_mapping = None
def forward(q, k, v, causal):
if batch_num == 1:
seq = q.shape[0]
seq_k = k.shape[0]
cu_seq_lens_q = torch.arange(0, seq+1, seq, dtype=torch.int32, device=q.device)
cu_seq_lens_kv = torch.arange(0, seq_k+1, seq_k, dtype=torch.int32, device=q.device)
max_seq_len_q = seq
max_seq_len_kv = seq_k
else:
max_seq_len_q = attn_metadata.prefill_metadata.max_seq_len
max_seq_len_kv = attn_metadata.prefill_metadata.max_seq_len
cu_seq_lens_q = cu_seq_lens
cu_seq_lens_kv = cu_seq_lens
if q.shape[0] != cu_seq_lens[-1]:
cu_seq_lens_q = cu_seq_lens // 2
max_seq_len_q = max_seq_len_q // 2
if k.shape[0] != cu_seq_lens[-1]:
cu_seq_lens_kv = cu_seq_lens // 2
max_seq_len_kv = max_seq_len_kv // 2
alibi_slopes = None if self.alibi_slopes is None else \
self.alibi_slopes.repeat(attn_metadata.num_prefills, 1)
ouptuts = mlu_ops.flash_attention(q,
k,
v,
None,
cu_seq_lens_q,
cu_seq_lens_kv,
alibi_slopes,
None,
max_seq_len_q,
max_seq_len_kv,
self.scale,
causal, -1, -1, torch.float, True)
block_out, block_lse = ouptuts[0], ouptuts[1]
if block_lse.shape[0] == 1:
block_lse = block_lse[0]
else:
# block_lse shape is [batch, head_num_q, max_seq_q] the empty part will set 0
# we need to modify the shape to [batch, head_num_q, total_seq_q]
block_lse_list = []
for batch in range(block_lse.shape[0]):
block_lse_ = block_lse[batch][:, : cu_seq_lens_q[batch + 1] - cu_seq_lens_q[batch]]
block_lse_list.append(block_lse_)
block_lse = torch.cat(block_lse_list, dim=-1)
return block_out, block_lse
for step in range(comm.world_size):
if step + 1 != comm.world_size:
next_k: torch.Tensor = comm.send_recv(k.contiguous())
next_v: torch.Tensor = comm_.send_recv(v.contiguous())
next_slot_mapping: torch.Tensor = comm__.send_recv(slot_mapping)
comm.commit()
comm_.commit()
comm__.commit()
# call mlu_ops.reshape_paged_cache
if kv_cache[0].numel() > 0:
kv_cache_, kv_cache_scale_ = kv_cache
key_cache, value_cache = kv_cache_[0], kv_cache_[1]
if isinstance(kv_cache[0], torch.Tensor) and kv_cache[0].dtype == torch.int8:
key_cache_scale, value_cache_scale = kv_cache_scale_[0], kv_cache_scale_[1]
mlu_ops.quant_to_paged_cache(k,
v,
key_cache,
value_cache,
key_cache_scale,
value_cache_scale,
slot_mapping.flatten())
else:
mlu_ops.reshape_paged_cache(k,
v,
key_cache,
value_cache,
slot_mapping.flatten())
if step == 0:
block_out, block_lse = forward(q, k, v, causal = True)
out, lse = update_out_and_lse(out, lse, block_out, block_lse)
elif step <= comm.rank:
if batch_num == 1:
k0 = k[:block_seq_len]
v0 = v[:block_seq_len]
else:
k0 = get_half(k, cu_seq_lens, True)
v0 = get_half(v, cu_seq_lens, True)
block_out, block_lse = forward(q, k0, v0, causal = False)
out, lse = update_out_and_lse(out, lse, block_out, block_lse)
else:
block_out, block_lse = forward(q1, k, v, causal = False)
if batch_num == 1:
out, lse = update_out_and_lse(out, lse, block_out, block_lse,
slice_=(slice(block_seq_len, None)),)
else:
slice_out = get_half(out, cu_seq_lens, False)
slice_lse = get_half(lse, cu_seq_lens, False)
slice_out, slice_lse = update_out_and_lse(
slice_out, slice_lse, block_out, block_lse
)
update_half(out, slice_out, cu_seq_lens, False)
update_half(lse, slice_lse, cu_seq_lens, False)
if step + 1 != comm.world_size:
comm.wait()
comm_.wait()
comm__.wait()
k = next_k
v = next_v
slot_mapping = next_slot_mapping
out = out.to(q.dtype)
return out.view(num_tokens, self.num_heads * self.head_size)

View File

@@ -0,0 +1 @@
from . import ring_comm

View File

@@ -0,0 +1,50 @@
from typing import Optional
import torch
import torch.distributed as dist
import torch.nn.functional as F
# code references: https://github.com/zhuzilin/ring-flash-attention
class RingComm:
def __init__(self, process_group: dist.ProcessGroup):
self._process_group = process_group
self._ops = []
self.rank = dist.get_rank(self._process_group)
self.world_size = dist.get_world_size(self._process_group)
self._reqs = None
self.send_rank = (self.rank + 1) % self.world_size
self.recv_rank = (self.rank - 1) % self.world_size
if process_group is not None:
self.send_rank = dist.get_global_rank(self._process_group, self.send_rank)
self.recv_rank = dist.get_global_rank(self._process_group, self.recv_rank)
def send_recv(
self, to_send: torch.Tensor, recv_tensor: Optional[torch.Tensor] = None
) -> torch.Tensor:
if recv_tensor is None:
res = torch.empty_like(to_send)
else:
res = recv_tensor
send_op = dist.P2POp(
dist.isend, to_send, self.send_rank, group=self._process_group
)
recv_op = dist.P2POp(dist.irecv, res, self.recv_rank, group=self._process_group)
self._ops.append(send_op)
self._ops.append(recv_op)
return res
def commit(self):
if self._reqs is not None:
raise RuntimeError("commit called twice")
self._reqs = dist.batch_isend_irecv(self._ops)
def wait(self):
if self._reqs is None:
raise RuntimeError("wait called before commit")
for req in self._reqs:
req.wait()
self._reqs = None
self._ops = []

View File

@@ -0,0 +1,2 @@
from . import gpu_executor
from . import ray_mlu_executor

View File

@@ -0,0 +1,40 @@
from typing import Any, Dict, Optional
from vllm.executor.gpu_executor import GPUExecutor
from vllm_mlu.mlu_hijack_utils import MluHijackObject
def vllm__executor__gpu_executor__GPUExecutor___get_worker_kwargs(
self,
local_rank: int = 0,
rank: int = 0,
distributed_init_method: Optional[str] = None,
) -> Dict[str, Any]:
"""Return worker init args for a given rank."""
if distributed_init_method is None:
distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port())
'''
==========================
Modify by Context Parallel
==========================
@brief: replace self.parallel_config.tensor_parallel_size with self.parallel_config.world_size.
'''
return dict(
vllm_config=self.vllm_config,
local_rank=local_rank,
rank=rank,
distributed_init_method=distributed_init_method,
is_driver_worker=(not self.parallel_config)
or (rank % self.parallel_config.world_size == 0),
)
'''
=======================
End of Context Parallel
=======================
'''
MluHijackObject.apply_hijack(
GPUExecutor,
GPUExecutor._get_worker_kwargs,
vllm__executor__gpu_executor__GPUExecutor___get_worker_kwargs)

View File

@@ -0,0 +1,246 @@
from collections import defaultdict
from typing import TYPE_CHECKING, Dict, List, Optional
import vllm.envs as envs
from vllm.executor.ray_utils import RayWorkerWrapper, ray
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
get_vllm_instance_id)
from vllm_mlu._mlu_utils import VLLM_LATENCY_DEBUG, VLLM_LATENCY_DEBUG_NO_DEVICE
from examples.cambricon_custom_func.vllm.mlu_hijack.common import init_logger
from vllm.executor.ray_mlu_executor import RayMLUExecutor
from vllm_mlu.mlu_hijack_utils import MluHijackObject
if ray is not None:
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup
logger = init_logger(__name__)
def vllm__executor__ray_mlu_executor__RayMLUExecutor___init_workers_ray(
self, placement_group: "PlacementGroup",
**ray_remote_kwargs):
if (self.parallel_config.tensor_parallel_size == 1
and self.parallel_config.pipeline_parallel_size == 1):
# For single GPU case, we use a ray worker with constrained memory.
num_gpus = self.cache_config.gpu_memory_utilization
else:
# Otherwise, the ray workers are allocated with a full GPU.
num_gpus = 1
# The driver dummy worker does not actually use any resources.
# It holds the resource for the driver worker.
self.driver_dummy_worker: Optional[RayWorkerWrapper] = None
# The remaining workers are the actual ray actors.
self.workers: List[RayWorkerWrapper] = []
# Used in ray compiled DAG: indexed first by PP rank,
# and then TP rank. In other words, the inner list is
# the TP group of workers for a PP rank.
self.pp_tp_workers: List[List[RayWorkerWrapper]] = []
if self.parallel_config.ray_workers_use_nsight:
ray_remote_kwargs = self._configure_ray_workers_use_nsight(
ray_remote_kwargs)
logger.info("use_ray_spmd_worker: %s", self.use_ray_spmd_worker)
# Create the workers.
driver_ip = get_ip()
worker_wrapper_kwargs = self._get_worker_wrapper_args()
for bundle_id, bundle in enumerate(placement_group.bundle_specs):
if not bundle.get("GPU", 0):
continue
scheduling_strategy = PlacementGroupSchedulingStrategy(
placement_group=placement_group,
placement_group_capture_child_tasks=True,
placement_group_bundle_index=bundle_id,
)
worker = ray.remote(
num_cpus=0,
num_gpus=num_gpus,
scheduling_strategy=scheduling_strategy,
**ray_remote_kwargs,
)(RayWorkerWrapper).remote(**worker_wrapper_kwargs)
if self.use_ray_spmd_worker:
self.workers.append(worker)
else:
worker_ip = ray.get(worker.get_node_ip.remote())
if worker_ip == driver_ip and self.driver_dummy_worker is None:
# If the worker is on the same node as the driver, we use it
# as the resource holder for the driver process.
self.driver_dummy_worker = worker
self.driver_worker = RayWorkerWrapper(
**worker_wrapper_kwargs)
else:
# Else, added to the list of workers.
self.workers.append(worker)
logger.debug("workers: %s", self.workers)
logger.debug("driver_dummy_worker: %s", self.driver_dummy_worker)
if not self.use_ray_spmd_worker and self.driver_dummy_worker is None:
raise ValueError(
"Ray does not allocate any GPUs on the driver node. Consider "
"adjusting the Ray placement group or running the driver on a "
"GPU node.")
worker_ips = [
ray.get(worker.get_node_ip.remote()) # type: ignore[attr-defined]
for worker in self.workers
]
ip_counts: Dict[str, int] = {}
for ip in worker_ips:
ip_counts[ip] = ip_counts.get(ip, 0) + 1
def sort_by_driver_then_worker_ip(worker):
"""
Sort the workers based on 3 properties:
1. If the worker is on the same node as the driver (vllm engine),
it should be placed first.
2. Then, if the worker is on a node with fewer workers, it should
be placed first.
3. Finally, if the work is on a node with smaller IP address, it
should be placed first.
"""
ip = ray.get(worker.get_node_ip.remote())
return (ip != driver_ip, ip_counts[ip], ip)
# After sorting, the workers on the same node will be
# close to each other, and the workers on the driver
# node will be placed first.
self.workers = sorted(self.workers, key=sort_by_driver_then_worker_ip)
# Get the set of GPU IDs used on each node.
worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids",
use_dummy_driver=True)
node_workers = defaultdict(list) # node id -> list of worker ranks
node_gpus = defaultdict(list) # node id -> list of gpu ids
for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids):
node_workers[node_id].append(i)
# `gpu_ids` can be a list of strings or integers.
# convert them to integers for consistency.
# NOTE: gpu_ids can be larger than 9 (e.g. 16 GPUs),
# string sorting is not sufficient.
# see https://github.com/vllm-project/vllm/issues/5590
gpu_ids = [int(x) for x in gpu_ids]
node_gpus[node_id].extend(gpu_ids)
for node_id, gpu_ids in node_gpus.items():
node_gpus[node_id] = sorted(gpu_ids)
all_ips = set(worker_ips + [driver_ip])
n_ips = len(all_ips)
n_nodes = len(node_workers)
if n_nodes != n_ips:
raise RuntimeError(
f"Every node should have a unique IP address. Got {n_nodes}"
f" nodes with node ids {list(node_workers.keys())} and "
f"{n_ips} unique IP addresses {all_ips}. Please check your"
" network configuration. If you set `VLLM_HOST_IP` or "
"`HOST_IP` environment variable, make sure it is unique for"
" each node.")
VLLM_INSTANCE_ID = get_vllm_instance_id()
# Set environment variables for the driver and workers.
all_args_to_update_environment_variables = [({
"MLU_VISIBLE_DEVICES":
",".join(map(str, node_gpus[node_id])),
"VLLM_INSTANCE_ID":
VLLM_INSTANCE_ID,
"VLLM_TRACE_FUNCTION":
str(envs.VLLM_TRACE_FUNCTION),
**({
"VLLM_ATTENTION_BACKEND": envs.VLLM_ATTENTION_BACKEND
} if envs.VLLM_ATTENTION_BACKEND is not None else {}),
"VLLM_LATENCY_DEBUG":
'1' if VLLM_LATENCY_DEBUG else '0',
"VLLM_LATENCY_DEBUG_NO_DEVICE":
'1' if VLLM_LATENCY_DEBUG_NO_DEVICE else '0',
}, ) for (node_id, _) in worker_node_and_gpu_ids]
self._env_vars_for_all_workers = (
all_args_to_update_environment_variables)
self._run_workers("update_environment_variables",
all_args=self._get_env_vars_to_be_updated())
if len(node_gpus) == 1:
# in single node case, we don't need to get the IP address.
# the loopback address is sufficient
# NOTE: a node may have several IP addresses, one for each
# network interface. `get_ip()` might return any of them,
# while they might not work for communication inside the node
# if the network setup is complicated. Using the loopback address
# solves this issue, as it always works for communication inside
# the node.
driver_ip = "127.0.0.1"
distributed_init_method = get_distributed_init_method(
driver_ip, get_open_port())
# Initialize the actual workers inside worker wrapper.
init_worker_all_kwargs = [
self._get_worker_kwargs(
local_rank=node_workers[node_id].index(rank),
rank=rank,
distributed_init_method=distributed_init_method,
) for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids)
]
self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs)
self._run_workers("init_device")
self._run_workers("load_model",
max_concurrent_workers=self.parallel_config.
max_parallel_loading_workers)
if self.use_ray_spmd_worker:
for pp_rank in range(self.parallel_config.pipeline_parallel_size):
self.pp_tp_workers.append([])
for tp_rank in range(
self.parallel_config.tensor_parallel_size):
# PP=2, TP=4
# pp_tp_workers = [[0, 1, 2, 3], [4, 5, 6, 7]]
rank = (pp_rank * self.parallel_config.tensor_parallel_size
) + tp_rank
assert len(self.pp_tp_workers[pp_rank]) == tp_rank
assert pp_rank < len(self.pp_tp_workers)
self.pp_tp_workers[pp_rank].append(self.workers[rank])
# This is the list of workers that are rank 0 of each TP group EXCEPT
# global rank 0. These are the workers that will broadcast to the
# rest of the workers.
self.tp_driver_workers: List[RayWorkerWrapper] = []
# This is the list of workers that are not drivers and not the first
# worker in a TP group. These are the workers that will be
# broadcasted to.
self.non_driver_workers: List[RayWorkerWrapper] = []
# Enforce rank order for correct rank to return final output.
for index, worker in enumerate(self.workers):
# The driver worker is rank 0 and not in self.workers.
rank = index + 1
'''
==========================
Modify by Context Parallel
==========================
@brief: replace tp size with world_size.
'''
if rank % self.parallel_config.world_size == 0:
self.tp_driver_workers.append(worker)
else:
self.non_driver_workers.append(worker)
'''
=======================
End of Context Parallel
=======================
'''
MluHijackObject.apply_hijack(RayMLUExecutor,
RayMLUExecutor._init_workers_ray,
vllm__executor__ray_mlu_executor__RayMLUExecutor___init_workers_ray)

View File

@@ -0,0 +1,6 @@
print("Apply Context Parallel Demo!")
from . import distributed
from . import attention
from . import model_executor
from . import worker
from . import executor

View File

@@ -0,0 +1,2 @@
from .layers import rotary_embedding
from .layers import logits_processor

View File

@@ -0,0 +1,110 @@
from typing import Optional
import torch
import vllm
from vllm_mlu.mlu_hijack_utils import MluHijackObject
from vllm.distributed import get_world_group
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.attention import AttentionMetadata
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.layers.logits_processor import LogitsProcessor, _prune_hidden_states, _apply_logits_processors
from examples.cambricon_custom_func.vllm.mlu_hijack.distributed.parallel_state import (
get_context_model_parallel_world_size, get_context_model_parallel_rank, get_tensor_model_parallel_world_size)
def vllm__module_executor__layers__logits_processor__LogitsProcessor__forward_wraper(
self,
lm_head: VocabParallelEmbedding,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
embedding_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if self.logits_as_input:
logits = hidden_states
else:
'''
==========================
Modify by Context Parallel
==========================
@brief: context parallel requires special handling of hidden_states and logits
'''
if self.attn_metadata and get_context_model_parallel_world_size() > 1:
hidden_states = _prune_hidden_states_context_parallel(hidden_states, sampling_metadata, self.attn_metadata)
else:
hidden_states = _prune_hidden_states(hidden_states, sampling_metadata)
'''
=======================
End of Context Parallel
=======================
'''
# Get the logits for the next tokens.
logits = self._get_logits(hidden_states, lm_head, embedding_bias)
if logits is not None:
if self.soft_cap is not None:
logits = logits / self.soft_cap
logits = torch.tanh(logits)
logits = logits * self.soft_cap
if self.scale != 1.0:
logits *= self.scale
# Apply logits processors (if any).
if sampling_metadata is not None:
logits = _apply_logits_processors(logits, sampling_metadata)
return logits
'''
==========================
Modify by Context Parallel
==========================
@brief: token num can be divisible by context_parallel_size * 2 after padding,
and then split to context parallel groups with zigzag method, now we
need to find the last valid tokens, and get the logits for the next tokens.
'''
def _prune_hidden_states_context_parallel(
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
attn_metadata: AttentionMetadata
) -> torch.Tensor:
select_hidden_states_list = []
seq_start_loc = attn_metadata.prefill_metadata.seq_start_loc
batch_num = seq_start_loc.shape[0] - 1
for batch in range(batch_num):
start = seq_start_loc[batch]
end = seq_start_loc[batch + 1]
hidden_states_ = hidden_states[start : end]
split_seq_len = hidden_states_.shape[0] // 2
seq_len = attn_metadata.prefill_metadata.seq_lens[batch]
last_id = seq_len - 1
idx = last_id // split_seq_len
select_hidden_states = torch.zeros((1, hidden_states.shape[-1]), dtype = hidden_states.dtype, device = hidden_states.device)
if idx < get_context_model_parallel_world_size():
target_cp_id = idx
src_rank = get_tensor_model_parallel_world_size() * target_cp_id
if get_context_model_parallel_rank() == target_cp_id:
selected_token_indices = last_id - idx * split_seq_len
select_hidden_states = hidden_states_[selected_token_indices].unsqueeze(0)
else:
target_cp_id = get_context_model_parallel_world_size() * 2 - 1 - idx
src_rank = get_tensor_model_parallel_world_size() * target_cp_id
if get_context_model_parallel_rank() == target_cp_id:
selected_token_indices = last_id - idx * split_seq_len + split_seq_len
select_hidden_states = hidden_states_[selected_token_indices].unsqueeze(0)
select_hidden_states = get_world_group().broadcast(select_hidden_states, src = src_rank)
select_hidden_states_list.append(select_hidden_states)
select_hidden_states = torch.cat(select_hidden_states_list, dim=0)
return select_hidden_states
'''
=======================
End of Context Parallel
=======================
'''
MluHijackObject.apply_hijack(LogitsProcessor,
LogitsProcessor.forward,
vllm__module_executor__layers__logits_processor__LogitsProcessor__forward_wraper)

View File

@@ -0,0 +1,62 @@
from typing import Optional, Tuple
import torch
import vllm
from vllm_mlu.mlu_hijack_utils import MluHijackObject
from vllm_mlu.model_executor.layers.rotary_embedding import MLURotaryEmbedding
from examples.cambricon_custom_func.vllm.mlu_hijack.distributed.parallel_state import (
get_context_model_parallel_world_size)
def vllm__module_executor__layers__rotary_embedding__MLURotaryEmbedding__forward_mlu_wraper(
self,
positions: torch.Tensor,
x: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
from vllm import _mlu_ops as mlu_ops
# ops.rotary_embedding()/batched_rotary_embedding()
# are in-place operations that update the query and key tensors.
if offsets is not None:
raise ValueError(f"tmo.apply_rotary not support offsets yet.")
else:
if MLURotaryEmbedding.set_cos_sin == False:
MLURotaryEmbedding.cos_, MLURotaryEmbedding.sin_ = self._get_cos_sin()
MLURotaryEmbedding.set_cos_sin = True
interleaved = True
if self.is_neox_style:
interleaved = False
if MLURotaryEmbedding.is_chunked or not MLURotaryEmbedding.is_prompt:
position_ids = positions
discrete = True
else :
position_ids = None
discrete = False
'''
==========================
Modify by Context Parallel
==========================
@brief: context parallel need discrete = True
'''
position_ids = None if (MLURotaryEmbedding.is_prompt and get_context_model_parallel_world_size == 1) else positions
discrete = False if (MLURotaryEmbedding.is_prompt and get_context_model_parallel_world_size == 1) else True
'''
=======================
End of Context Parallel
=======================
'''
x = mlu_ops.rotary_embedding(x,
MLURotaryEmbedding.sin_,
MLURotaryEmbedding.cos_,
position_ids,
MLURotaryEmbedding.cu_seq_lens,
interleaved,
discrete,
False,
MLURotaryEmbedding.max_seq_len)
return x
MluHijackObject.apply_hijack(MLURotaryEmbedding,
MLURotaryEmbedding.forward_mlu,
vllm__module_executor__layers__rotary_embedding__MLURotaryEmbedding__forward_mlu_wraper)

View File

@@ -0,0 +1,5 @@
from . import mlu_model_runner
from . import model_runner
from . import model_runner_base
from . import worker
from . import worker_base

View File

@@ -0,0 +1,256 @@
import torch
from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Set,
Tuple, Type, TypeVar, Union)
from vllm.forward_context import set_forward_context
from vllm.multimodal.inputs import MultiModalKwargs
from vllm_mlu.mlu_hijack_utils import MluHijackObject
from vllm_mlu._mlu_utils import *
from vllm.worker.model_runner import (
TModelInputForGPU, ModelInputForGPU,
ModelInputForGPUWithSamplingMetadata,
ModelInputForGPUBuilder, GPUModelRunnerBase,
ModelRunner, CUDAGraphRunner,
LORA_WARMUP_RANK, _get_graph_batch_size,
_BATCH_SIZES_TO_CAPTURE, _NUM_WARMUP_ITERS
)
from vllm.worker.mlu_model_runner import MLUModelRunner
from vllm.sequence import (IntermediateTensors, SequenceGroupMetadata)
from vllm.distributed import get_pp_group
from vllm.model_executor.layers.sampler import SamplerOutput
from ..zigzag_utils import get_context_model_parallel_world_size, zigzag_split
import vllm.envs as envs
try:
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper
from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper
FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
except ImportError:
BatchDecodeWithPagedKVCacheWrapper = None
CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None
BatchPrefillWithPagedKVCacheWrapper = None
FLASHINFER_WORKSPACE_BUFFER_SIZE = 0
_PAD_SLOT_ID = -1
@torch.inference_mode()
def vllm__worker__mlu_model_runner__MLUModelRunner__execute_model(
self,
model_input: ModelInputForGPUWithSamplingMetadata,
kv_caches: List[torch.Tensor],
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1,
) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]:
if num_steps > 1:
raise ValueError("num_steps > 1 is not supported in ModelRunner")
if self.lora_config:
assert model_input.lora_requests is not None
assert model_input.lora_mapping is not None
self.set_active_loras(model_input.lora_requests,
model_input.lora_mapping)
if self.prompt_adapter_config:
assert model_input.prompt_adapter_requests is not None
assert model_input.prompt_adapter_mapping is not None
self.set_active_prompt_adapters(
model_input.prompt_adapter_requests,
model_input.prompt_adapter_mapping)
self.attn_state.begin_forward(model_input)
# Currently cuda graph is only supported by the decode phase.
assert model_input.attn_metadata is not None
prefill_meta = model_input.attn_metadata.prefill_metadata
decode_meta = model_input.attn_metadata.decode_metadata
# TODO(andoorve): We can remove this once all
# virtual engines share the same kv cache.
virtual_engine = model_input.virtual_engine
if prefill_meta is None and decode_meta.use_cuda_graph:
assert model_input.input_tokens is not None
graph_batch_size = model_input.input_tokens.shape[0]
model_executable = self.graph_runners[virtual_engine][
graph_batch_size]
else:
model_executable = self.model
multi_modal_kwargs = model_input.multi_modal_kwargs or {}
seqlen_agnostic_kwargs = {
"finished_requests_ids": model_input.finished_requests_ids,
"request_ids_to_seq_ids": model_input.request_ids_to_seq_ids,
} if self.has_inner_state else {}
if (self.observability_config is not None
and self.observability_config.collect_model_forward_time):
model_forward_start = torch.mlu.Event(enable_timing=True)
model_forward_end = torch.mlu.Event(enable_timing=True)
model_forward_start.record()
'''
=============================
Modify by vllm_mlu
=============================
@brief: add mlu metrics
'''
# Add time markers for model_executable+compute_logits
if VLLM_LATENCY_DEBUG_WITH_DEVICE_EN:
use_cuda_graph = ((prefill_meta is None and decode_meta.use_cuda_graph)
or use_context_mlugraph)
# if use_cuda_graph, the start timestamp will be inserted inside MLUGraphRunner.forward()
if not use_cuda_graph:
start = torch.mlu.Event(enable_timing=True)
start.record()
'''
==========================
Modify by Context Parallel
==========================
@brief: context parallel split input for model with zigzag method
'''
if get_context_model_parallel_world_size() > 1 and model_input.attn_metadata.prefill_metadata:
with set_forward_context(model_input.attn_metadata):
zigzag_input_ids, zigzag_positions, zigzag_attn_metadata = zigzag_split(model_input.input_tokens,
model_input.input_positions,
model_input.attn_metadata, _PAD_SLOT_ID)
hidden_or_intermediate_states = model_executable(
input_ids=zigzag_input_ids,
positions=zigzag_positions,
kv_caches=kv_caches,
attn_metadata=zigzag_attn_metadata,
intermediate_tensors=intermediate_tensors,
**multi_modal_kwargs,
**seqlen_agnostic_kwargs)
else:
with set_forward_context(model_input.attn_metadata):
hidden_or_intermediate_states = model_executable(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
kv_caches=kv_caches,
attn_metadata=model_input.attn_metadata,
intermediate_tensors=intermediate_tensors,
**MultiModalKwargs.as_kwargs(multi_modal_kwargs,
device=self.device),
**seqlen_agnostic_kwargs)
#################################################################################################
# DEBUG #
#################################################################################################
# import os
# from vllm.distributed import get_tensor_model_parallel_world_size, get_tensor_model_parallel_rank
# from from examples.cambricon_custom_funcvllm.mlu_hijack.distributed.parallel_state import (
# get_context_model_parallel_rank)
# from ..zigzag_utils import context_parallel_tensor_all_gather, diff1
# if get_context_model_parallel_world_size() > 1 and attn_metadata.prefill_metadata:
# hidden_states = context_parallel_tensor_all_gather(hidden_states, zigzag_attn_metadata, dim=0)
# if attn_metadata.prefill_metadata and (kv_caches[0] is not None):
# file_path = '/workspace/output_base_' + str(hidden_states.shape) + \
# '_tp_' + str(get_tensor_model_parallel_world_size()) + '.pth'
# if get_context_model_parallel_rank() == 0 and get_tensor_model_parallel_rank() == 0:
# if os.path.exists(file_path):
# print("##################compare################")
# hidden_states_base = torch.load(file_path)
# print("########output_diff1: ", diff1(hidden_states, hidden_states_base))
# else:
# print("##################save base################")
# torch.save(hidden_states, file_path)
'''
@brief: logits_processor in context parallel need attn_metadata param
'''
if get_context_model_parallel_world_size() > 1 and model_input.attn_metadata.prefill_metadata:
setattr(self.model.logits_processor, 'attn_metadata', zigzag_attn_metadata)
else:
setattr(self.model.logits_processor, 'attn_metadata', None)
'''
=======================
End of Context Parallel
=======================
'''
if (self.observability_config is not None
and self.observability_config.collect_model_forward_time):
model_forward_end.record()
# Compute the logits in the last pipeline stage.
if not get_pp_group().is_last_rank:
if (self.is_driver_worker
and hidden_or_intermediate_states is not None
and isinstance(hidden_or_intermediate_states,
IntermediateTensors)
and self.observability_config is not None
and self.observability_config.collect_model_forward_time):
model_forward_end.synchronize()
model_forward_time = model_forward_start.elapsed_time(
model_forward_end)
orig_model_forward_time = 0.0
if intermediate_tensors is not None:
orig_model_forward_time = intermediate_tensors.tensors.get(
"model_forward_time", torch.tensor(0.0)).item()
hidden_or_intermediate_states.tensors["model_forward_time"] = (
torch.tensor(model_forward_time + orig_model_forward_time))
return hidden_or_intermediate_states
logits = self.model.compute_logits(hidden_or_intermediate_states,
model_input.sampling_metadata)
# Add time markers for model_executable+compute_logits
if VLLM_LATENCY_DEBUG_WITH_DEVICE_EN:
end_marker = torch.mlu.Event(enable_timing=True)
end_marker.record()
if use_cuda_graph:
self.time_markers = (model_executable.start, end_marker)
else:
self.time_markers = (start, end_marker)
'''
==================
End of MLU Hijack
==================
'''
if not self.is_driver_worker:
return []
if model_input.async_callback is not None:
model_input.async_callback()
# Sample the next token.
output: SamplerOutput = self.model.sample(
logits=logits,
sampling_metadata=model_input.sampling_metadata,
)
if (self.observability_config is not None
and self.observability_config.collect_model_forward_time
and output is not None):
model_forward_end.synchronize()
model_forward_time = model_forward_start.elapsed_time(
model_forward_end)
orig_model_forward_time = 0.0
if intermediate_tensors is not None:
orig_model_forward_time = intermediate_tensors.tensors.get(
"model_forward_time", torch.tensor(0.0)).item()
# If there are multiple workers, we are still tracking the latency
# from the start time of the driver worker to the end time of the
# driver worker. The model forward time will then end up covering
# the communication time as well.
output.model_forward_time = (orig_model_forward_time +
model_forward_time)
if self.return_hidden_states:
# we only need to pass hidden states of most recent token
assert model_input.sampling_metadata is not None
indices = model_input.sampling_metadata.selected_token_indices
if model_input.is_prompt:
hidden_states = hidden_or_intermediate_states.index_select(
0, indices)
elif decode_meta.use_cuda_graph:
hidden_states = hidden_or_intermediate_states[:len(indices)]
else:
hidden_states = hidden_or_intermediate_states
output.hidden_states = hidden_states
return [output]
MluHijackObject.apply_hijack(MLUModelRunner,
MLUModelRunner.execute_model,
vllm__worker__mlu_model_runner__MLUModelRunner__execute_model)

View File

@@ -0,0 +1,35 @@
from typing import (Any, Dict, Optional)
from vllm_mlu.mlu_hijack_utils import MluHijackObject
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
from examples.cambricon_custom_func.context_parallel.mlu_hijack.worker.model_runner_base import vllm__worker__model_runner_base___init_sampling_metadata_from_tensor_dict
from vllm.worker.model_runner_base import _init_attn_metadata_from_tensor_dict
@classmethod
def vllm__worker__model_runner__ModelInputForGPUWithSamplingMetadata__from_broadcasted_tensor_dict(
cls,
tensor_dict: Dict[str, Any],
attn_backend: Optional["AttentionBackend"] = None,
) -> "ModelInputForGPUWithSamplingMetadata":
'''
==========================
Modify by Context Parallel
==========================
@brief: force apply hijacked function.
'''
tensor_dict = vllm__worker__model_runner_base___init_sampling_metadata_from_tensor_dict(tensor_dict)
'''
=======================
End of Context Parallel
=======================
'''
if attn_backend is not None:
tensor_dict = _init_attn_metadata_from_tensor_dict(
attn_backend, tensor_dict)
return cls(**tensor_dict)
MluHijackObject.apply_hijack(
ModelInputForGPUWithSamplingMetadata,
ModelInputForGPUWithSamplingMetadata.from_broadcasted_tensor_dict,
vllm__worker__model_runner__ModelInputForGPUWithSamplingMetadata__from_broadcasted_tensor_dict
)

View File

@@ -0,0 +1,74 @@
from typing import (Any, Dict)
from vllm.model_executor.sampling_metadata import SequenceGroupToSample
from vllm.worker import model_runner_base
from vllm_mlu.mlu_hijack_utils import MluHijackObject
def vllm__worker__model_runner_base___init_sampling_metadata_from_tensor_dict( # type: ignore
tensor_dict: Dict[str, Any]) -> Dict[str, Any]:
"""
Helper method to initialize SamplingMetadata based on broadcastable
SamplingMetadata fields.
"""
from vllm.model_executor import SamplingMetadata
selected_token_indices = tensor_dict.pop("selected_token_indices", None)
if selected_token_indices is not None:
if 'seq_group_metadata' in tensor_dict.keys() and len(tensor_dict['seq_group_metadata']) > 0:
'''
==========================
Modify by Context Parallel
==========================
@brief: construct sampling metadata.
'''
sequence_group_to_sample_list = []
for seq_group_metadata in tensor_dict['seq_group_metadata']:
seq_ids = list(seq_group_metadata.seq_data.keys())
sampling_params = seq_group_metadata.sampling_params
seq_data = seq_group_metadata.seq_data
is_prompt = seq_group_metadata.is_prompt
if is_prompt:
seq_len = query_len = list(seq_data.values())[0].get_prompt_len()
else:
seq_len = None
query_len = 1
prompt_logprob_indices = []
sample_indices = seq_ids
sequence_group_to_sample = SequenceGroupToSample(seq_ids,
sampling_params,
seq_data,
seq_len,
query_len,
None, # Generator
is_prompt,
prompt_logprob_indices,
sample_indices)
sequence_group_to_sample_list.append(sequence_group_to_sample)
tensor_dict["sampling_metadata"] = SamplingMetadata(
seq_groups=sequence_group_to_sample_list,
selected_token_indices=selected_token_indices,
categorized_sample_indices=None,
num_prompts=len(sequence_group_to_sample_list),
)
del tensor_dict['seq_group_metadata']
'''
=======================
End of Context Parallel
=======================
'''
else:
# An empty SamplingMetadata to signal that the worker should skip
# sampling.
tensor_dict["sampling_metadata"] = SamplingMetadata(
seq_groups=None,
selected_token_indices=selected_token_indices,
categorized_sample_indices=None,
num_prompts=0,
)
return tensor_dict
MluHijackObject.apply_hijack(
model_runner_base,
model_runner_base._init_sampling_metadata_from_tensor_dict,
vllm__worker__model_runner_base___init_sampling_metadata_from_tensor_dict
)

View File

@@ -0,0 +1,23 @@
from vllm.worker.worker import Worker
from vllm_mlu.mlu_hijack_utils import MluHijackObject
@property
def vllm__worker__worker__Worker__do_metadata_broadcast(self) -> bool:
'''
=============================
Modify by Context Parallel
=============================
@brief: do metadata broadcast if cp or tp > 1.
'''
return self.parallel_config.world_size > 1
'''
==========================
End of Context Parallel
==========================
'''
MluHijackObject.apply_hijack(
Worker,
Worker.do_metadata_broadcast,
vllm__worker__worker__Worker__do_metadata_broadcast)

View File

@@ -0,0 +1,121 @@
import dataclasses
from typing import Any, Dict, Optional, Tuple, Union
import torch
from vllm.config import ObservabilityConfig, VllmConfig
from vllm.distributed.parallel_state import get_world_group
from vllm.sequence import ExecuteModelRequest
from vllm.worker.model_runner_base import (BroadcastableModelInput,
ModelRunnerInputBase)
from vllm.worker.worker_base import (extract_previous_hidden_states,
LocalOrDistributedWorkerBase,
WorkerInput)
from vllm_mlu.mlu_hijack_utils import MluHijackObject
def broadcast_tensor_dict(
tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None,
src: int = 0
):
if not torch.distributed.is_initialized():
return tensor_dict
return get_world_group().broadcast_tensor_dict(tensor_dict, src)
def vllm__worker__worker_base__LocalOrDistributedWorkerBase___get_driver_input_and_broadcast(
self, execute_model_req: ExecuteModelRequest
) -> Tuple[BroadcastableModelInput, WorkerInput, Dict[str, torch.Tensor]]:
""" Get the driver input and broadcast it to other workers. """
assert self.is_driver_worker
worker_input: WorkerInput = self.prepare_worker_input(
execute_model_req=execute_model_req)
model_input: ModelRunnerInputBase = (
self.model_runner.prepare_model_input(
execute_model_req.seq_group_metadata_list,
execute_model_req.virtual_engine,
execute_model_req.finished_requests_ids))
kwargs = extract_previous_hidden_states(execute_model_req)
if self.do_metadata_broadcast:
broadcast_data = worker_input.as_broadcastable_tensor_dict()
broadcast_data.update(model_input.as_broadcastable_tensor_dict())
broadcast_data.update(kwargs)
'''
==========================
Modify by Context Parallel
==========================
@brief: add seq_group metadata to broadcast.
'''
broadcast_data['seq_group_metadata'] = execute_model_req.seq_group_metadata_list
'''
=======================
End of Context Parallel
=======================
'''
broadcast_tensor_dict(broadcast_data, src=0)
if execute_model_req.async_callback:
model_input = dataclasses.replace( # type: ignore
model_input,
async_callback=execute_model_req.async_callback)
return model_input, worker_input, kwargs
def vllm__worker__worker_base__LocalOrDistributedWorkerBase___get_worker_input_from_broadcast(
self
) -> Optional[Tuple[BroadcastableModelInput, WorkerInput, Dict[
str, torch.Tensor]]]:
""" Get the worker input from the broadcasted tensor dict. """
assert self.do_metadata_broadcast
assert not self.is_driver_worker
broadcast_data = broadcast_tensor_dict(src=0)
if not broadcast_data:
return None
worker_input = WorkerInput.from_broadcasted_tensor_dict(broadcast_data)
model_input = (
self.model_runner.make_model_input_from_broadcasted_tensor_dict(
broadcast_data))
kwargs = extract_previous_hidden_states(broadcast_data)
return model_input, worker_input, kwargs
def vllm__worker__worker_base__LocalOrDistributedWorkerBase__prepare_input(
self,
execute_model_req: Optional[ExecuteModelRequest] = None
) -> Optional[Tuple[BroadcastableModelInput, WorkerInput, Dict[str, torch.Tensor]]]:
"""
Prepare the inputs to ModelRunner and workers.
"""
if self.is_driver_worker:
if execute_model_req is None:
if self.do_metadata_broadcast:
# This signals that there's no more requests to process for
# now. All workers are running infinite loop with
# broadcast_tensor_dict, and it stops the loop when the
# driver broadcasts an empty input. Send an empty input to
# notify all other workers to stop their execution loop.
broadcast_tensor_dict({}, src=0)
return None
return self._get_driver_input_and_broadcast(execute_model_req)
else:
return self._get_worker_input_from_broadcast()
MluHijackObject.apply_hijack(
LocalOrDistributedWorkerBase,
LocalOrDistributedWorkerBase._get_driver_input_and_broadcast,
vllm__worker__worker_base__LocalOrDistributedWorkerBase___get_driver_input_and_broadcast)
MluHijackObject.apply_hijack(
LocalOrDistributedWorkerBase,
LocalOrDistributedWorkerBase._get_worker_input_from_broadcast,
vllm__worker__worker_base__LocalOrDistributedWorkerBase___get_worker_input_from_broadcast)
MluHijackObject.apply_hijack(
LocalOrDistributedWorkerBase,
LocalOrDistributedWorkerBase.prepare_input,
vllm__worker__worker_base__LocalOrDistributedWorkerBase__prepare_input)

View File

@@ -0,0 +1,149 @@
from typing import Dict, Optional, Sequence, List
import torch
import torch.distributed as dist
from torch import nn
from torch.nn import functional as F
from examples.cambricon_custom_func.vllm.mlu_hijack.distributed.parallel_state import (
get_context_model_parallel_rank, get_context_model_parallel_world_size, get_context_model_parallel_group)
from vllm.distributed.utils import divide
from vllm.attention import AttentionMetadata
import copy
def diff1(result: torch.Tensor, baseline: torch.Tensor):
result = result.flatten().float().to('cpu')
baseline = baseline.flatten().float().to('cpu')
assert result.shape == baseline.shape
error = torch.abs(baseline - result)
denominator = torch.sum(torch.abs(baseline)).item()
eps = 0.0 if denominator > 0 else 1e-9
diff1 = torch.sum(error) / (denominator + eps)
return diff1.item()
def get_pad_seq(seq_len: int, pad: int):
return (seq_len // pad + (int)((seq_len) % (pad) > 0)) * pad
# Gather the partial results of a batch on context parallel groups
# together and place them in the order before zigzag splitting
def context_parallel_tensor_all_gather_(input_, dim=-1):
world_size = get_context_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
assert -input_.dim() <= dim < input_.dim(), (
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()
input_size = input_.size()
assert input_size[dim] % 2 == 0, (f"input tensor split dim % 2 != 0")
gather_list = [torch.empty(input_.shape, dtype=input_.dtype, device=input_.device) for _ in range(world_size)]
torch.distributed.all_gather(
gather_list, input_, group=get_context_model_parallel_group())
first = []
second = []
for i in range(world_size):
first_second = torch.split(gather_list[i], gather_list[i].shape[dim] // 2, dim=dim)
first.append(first_second[0])
second.insert(0, first_second[1])
tensor_list = first + second
output_tensor = torch.cat(tensor_list, dim = dim).contiguous()
return output_tensor
# Gather the partial results of each batch on the context parallel groups together,
# place them in the order before zigzag splitting, and remove the pad part.
# This function is used for debugging
def context_parallel_tensor_all_gather(input, attn_metadata, dim=-1):
if dim < 0:
dim += input.dim()
slice_ = ()
for i in range(dim):
slice_ + (slice(None))
select_list = []
seq_start_loc = attn_metadata.prefill_metadata.seq_start_loc
batch_num = seq_start_loc.shape[0] - 1
for batch in range(batch_num):
start = seq_start_loc[batch].item()
end = seq_start_loc[batch + 1].item()
slice1 = slice_ + (slice(start, end), )
input_ = input[slice1]
gather_ = context_parallel_tensor_all_gather_(input_, dim=dim)
slice2 = slice_ + (slice(None, attn_metadata.prefill_metadata.seq_lens[batch]), )
select = gather_[slice2]
select_list.append(select)
output = torch.cat(select_list, dim=dim)
return output
# Pad one dimension of a tensor so that it is divisible by context_parallel_size * 2,
# and then use zigzag method to split it into different context parallel groups
def zigzag_split_(tensor: torch.Tensor, dim = -1, pad_value=0):
if dim < 0:
dim = tensor.dim() + dim
split_num = get_context_model_parallel_world_size() * 2
pad_num = get_pad_seq(tensor.shape[dim], split_num) - tensor.shape[dim]
pad_param = (0, 0) * (tensor.dim() - dim - 1) + (0, pad_num) + (0, 0) * dim
tensor_pad = F.pad(tensor, pad_param, value = pad_value)
split_size = divide(tensor_pad.size()[dim], split_num)
# Split.
tensor_list = torch.split(tensor_pad, split_size, dim = dim)
first = tensor_list[get_context_model_parallel_rank()]
second = tensor_list[split_num - get_context_model_parallel_rank() - 1]
output_tensor = torch.cat((first, second), dim=dim).contiguous()
return output_tensor
# Split each batch of input_ids, positions, attn_metadata.slot_mapping with zigzag method,
# and update prefill_metadata.seq_start_loc and prefill_metadata.max_seq_len
def zigzag_split(input_ids: torch.Tensor,
positions: torch.Tensor,
attn_metadata: AttentionMetadata,
pad_slot_id: int):
zigzag_input_ids: List[int] = []
zigzag_positions: List[int] = []
zigzag_slot_mapping: List[int] = []
zigzag_attn_metadata = copy.deepcopy(attn_metadata)
seq_lens: List[int] = []
seq_start_loc = attn_metadata.prefill_metadata.seq_start_loc
batch_num = seq_start_loc.shape[0] - 1
for batch in range(batch_num):
start, end = seq_start_loc[batch], seq_start_loc[batch + 1]
input_ids_ = input_ids[start : end]
positions_ = positions[start : end]
zigzag_input_ids_ = zigzag_split_(input_ids_)
zigzag_positions_ = zigzag_split_(positions_)
zigzag_input_ids.append(zigzag_input_ids_)
zigzag_positions.append(zigzag_positions_)
seq_lens.append(zigzag_input_ids_.shape[0])
slot_mapping_ = attn_metadata.slot_mapping[start : end]
zigzag_slot_mapping_ = zigzag_split_(slot_mapping_, pad_value=pad_slot_id)
zigzag_slot_mapping.append(zigzag_slot_mapping_)
zigzag_input_ids = torch.cat(zigzag_input_ids, dim=0)
zigzag_positions = torch.cat(zigzag_positions, dim=0)
zigzag_slot_mapping = torch.cat(zigzag_slot_mapping, dim=0)
max_seq_len = max(seq_lens)
seq_lens_tensor = torch.tensor(seq_lens,
dtype=torch.int,
device=input_ids.device)
seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
dtype=torch.int32,
device=input_ids.device)
torch.cumsum(seq_lens_tensor,
dim=0,
dtype=seq_start_loc.dtype,
out=seq_start_loc[1:])
zigzag_attn_metadata.prefill_metadata.seq_start_loc = seq_start_loc
zigzag_attn_metadata.prefill_metadata.query_start_loc = seq_start_loc
zigzag_attn_metadata.prefill_metadata.max_seq_len = max_seq_len
zigzag_attn_metadata.slot_mapping = zigzag_slot_mapping
return zigzag_input_ids, zigzag_positions, zigzag_attn_metadata