Add note for deepseek related docs and remove unnecessary comments (#590)
### What this PR does / why we need it? Add notes for deepseek's patch and remove some of the unnecessary comments --------- Signed-off-by: ganyi <pleaplusone.gy@gmail.com>
This commit is contained in:
@@ -1,5 +1,5 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||||
# Copyright 2023 The vLLM team.
|
# Copyright 2023 The vLLM team.
|
||||||
# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved.
|
# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved.
|
||||||
#
|
#
|
||||||
@@ -19,31 +19,11 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# <<<<<<< HEAD
|
|
||||||
# # Adapted from
|
# # Adapted from
|
||||||
# # vllm-project/vllm/blob/main/vllm/model_executor/models/deepseek_v2.py
|
# # vllm-project/vllm/blob/main/vllm/model_executor/models/deepseek_v2.py
|
||||||
# # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
|
# # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
|
||||||
# # vllm-project/vllm/vllm/model_executor/models/deepseek_v2.py
|
# # vllm-project/vllm/vllm/model_executor/models/deepseek_v2.py
|
||||||
# """Inference-only DeepseekV2/DeepseekV3 model."""
|
# """Inference-only DeepseekV2/DeepseekV3 model."""
|
||||||
# from typing import Optional, Union
|
|
||||||
|
|
||||||
# import torch
|
|
||||||
# from torch import nn
|
|
||||||
# from transformers import PretrainedConfig
|
|
||||||
# from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
|
||||||
# from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
|
||||||
# from vllm.model_executor.layers.fused_moe import FusedMoE
|
|
||||||
# from vllm.model_executor.layers.layernorm import RMSNorm
|
|
||||||
# from vllm.model_executor.layers.linear import ReplicatedLinear
|
|
||||||
# from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
|
||||||
# from vllm.model_executor.layers.quantization import QuantizationConfig
|
|
||||||
# from vllm.model_executor.layers.sampler import get_sampler
|
|
||||||
# from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|
||||||
# ParallelLMHead, VocabParallelEmbedding)
|
|
||||||
# from vllm.model_executor.models.deepseek_v2 import ( # noqa
|
|
||||||
# DeepseekV2Attention, DeepseekV2DecoderLayer, DeepseekV2ForCausalLM,
|
|
||||||
# DeepseekV2MLAAttention, DeepseekV2MLP, DeepseekV2MoE)
|
|
||||||
# =======
|
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from typing import Any, Dict, Optional, Union
|
from typing import Any, Dict, Optional, Union
|
||||||
@@ -173,9 +153,6 @@ class CustomDeepseekV2MoE(nn.Module):
|
|||||||
|
|
||||||
if (self.tp_size > 1 and self.enable_mc2
|
if (self.tp_size > 1 and self.enable_mc2
|
||||||
and attn_metadata.num_prefills == 0):
|
and attn_metadata.num_prefills == 0):
|
||||||
# hidden_states = dist._functional_collectives.reduce_scatter_tensor(
|
|
||||||
# hidden_states, "sum", scatter_dim=0, group=self.tp_group
|
|
||||||
# )
|
|
||||||
chunks = torch.chunk(hidden_states,
|
chunks = torch.chunk(hidden_states,
|
||||||
get_tp_group().world_size,
|
get_tp_group().world_size,
|
||||||
dim=0)
|
dim=0)
|
||||||
@@ -365,29 +342,6 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):
|
|||||||
k_pe,
|
k_pe,
|
||||||
output_shape=hidden_states.shape)
|
output_shape=hidden_states.shape)
|
||||||
|
|
||||||
# def forward(
|
|
||||||
# self,
|
|
||||||
# positions: torch.Tensor,
|
|
||||||
# hidden_states: torch.Tensor,
|
|
||||||
# # torchair should pass below two parameters
|
|
||||||
# kv_cache: torch.Tensor = None,
|
|
||||||
# attn_metadata: AttentionMetadata = None,
|
|
||||||
# ) -> torch.Tensor:
|
|
||||||
# if self.q_lora_rank is not None:
|
|
||||||
# ckq = self.q_a_proj(hidden_states)[0]
|
|
||||||
# hidden_states_or_q_c = self.q_a_layernorm(ckq)
|
|
||||||
# else:
|
|
||||||
# hidden_states_or_q_c = hidden_states
|
|
||||||
# if VLLM_ENABLE_GRAPH_MODE == '1':
|
|
||||||
# return self.mla_attn(hidden_states_or_q_c, hidden_states, None,
|
|
||||||
# kv_cache, attn_metadata)
|
|
||||||
# else:
|
|
||||||
# kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split(
|
|
||||||
# [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
|
||||||
# kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
|
|
||||||
# return self.mla_attn(hidden_states_or_q_c, kv_c_normed, k_pe, output_shape=hidden_states.shape)
|
|
||||||
# kv_cache, attn_metadata)
|
|
||||||
|
|
||||||
|
|
||||||
class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
|
class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
|
||||||
|
|
||||||
|
|||||||
@@ -364,23 +364,6 @@ def select_experts(
|
|||||||
Raises:
|
Raises:
|
||||||
ValueError: If an unsupported scoring function is provided.
|
ValueError: If an unsupported scoring function is provided.
|
||||||
"""
|
"""
|
||||||
# assert hidden_states.shape[0] == router_logits.shape[0], (
|
|
||||||
# "Number of tokens mismatch")
|
|
||||||
# if os.environ.get("VLLM_ENABLE_GRAPH_MODE") == "1" and not is_prefill:
|
|
||||||
# topk_weight, topk_idx, _ = torch.ops.npu_inference.npu_moe_gating_top_k(
|
|
||||||
# router_logits,
|
|
||||||
# k=top_k, # topk当前写8
|
|
||||||
# bias=e_score_correction_bias,
|
|
||||||
# k_group=topk_group, # fix: 4
|
|
||||||
# group_count=num_expert_group, # fix 8
|
|
||||||
# group_select_mode=1, # 0: group中的最大; 1: topk2.sum(fix)
|
|
||||||
# renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
|
|
||||||
# norm_type=1, # 0: softmax; 1: sigmoid(fix)
|
|
||||||
# # out_flag=False, # todo new api; 第三个输出是否输出
|
|
||||||
# # y2_flag=False, # old api; 第三个输出是否输出
|
|
||||||
# routed_scaling_factor=1,
|
|
||||||
# eps=float(1e-20))
|
|
||||||
# return topk_weight, topk_idx
|
|
||||||
|
|
||||||
if custom_routing_function is not None:
|
if custom_routing_function is not None:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
@@ -483,8 +466,6 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
|||||||
is_prefill=False,
|
is_prefill=False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
# assert router_logits.shape[
|
|
||||||
# 1] == global_num_experts, "Number of global experts mismatch"
|
|
||||||
# set prefill as false always, should fix this
|
# set prefill as false always, should fix this
|
||||||
topk_weights, topk_ids = select_experts(
|
topk_weights, topk_ids = select_experts(
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
@@ -670,7 +651,6 @@ class AscendFusedMoE(FusedMoE):
|
|||||||
scatter_dim=0,
|
scatter_dim=0,
|
||||||
group=get_dp_group().device_group)
|
group=get_dp_group().device_group)
|
||||||
|
|
||||||
# if self.reduce_results and self.tp_size > 1:
|
|
||||||
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
|
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
|
||||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||||
final_hidden_states)
|
final_hidden_states)
|
||||||
|
|||||||
@@ -229,7 +229,6 @@ def _set_cos_sin_cache(self, seq_len, device, dtype):
|
|||||||
|
|
||||||
# TODO: Patch when aclnn ops avaiable
|
# TODO: Patch when aclnn ops avaiable
|
||||||
RotaryEmbedding.forward_oot = rope_forward_oot
|
RotaryEmbedding.forward_oot = rope_forward_oot
|
||||||
# DeepseekScalingRotaryEmbedding.forward = rope_deepseek_forward_oot
|
|
||||||
DeepseekScalingRotaryEmbedding.forward = native_rope_deepseek_forward
|
DeepseekScalingRotaryEmbedding.forward = native_rope_deepseek_forward
|
||||||
DeepseekScalingRotaryEmbedding._set_cos_sin_cache = _set_cos_sin_cache
|
DeepseekScalingRotaryEmbedding._set_cos_sin_cache = _set_cos_sin_cache
|
||||||
DeepseekScalingRotaryEmbedding.max_seq_len_cached = None
|
DeepseekScalingRotaryEmbedding.max_seq_len_cached = None
|
||||||
|
|||||||
@@ -1,3 +1,22 @@
|
|||||||
|
#
|
||||||
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||||
|
# Copyright 2023 The vLLM team.
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# Adapted from vllm/model_executor/models/qwen2_vl.py
|
||||||
|
# This file is a part of the vllm-ascend project.
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import vllm
|
import vllm
|
||||||
import vllm.distributed
|
import vllm.distributed
|
||||||
@@ -8,6 +27,40 @@ from torch.distributed.distributed_c10d import (Backend, PrefixStore,
|
|||||||
from torch.distributed.rendezvous import rendezvous
|
from torch.distributed.rendezvous import rendezvous
|
||||||
from vllm.config import ParallelConfig
|
from vllm.config import ParallelConfig
|
||||||
|
|
||||||
|
# What's Patched and how it works:
|
||||||
|
# ** File: platform/patch_0_8_4/patch_distributed.py**
|
||||||
|
# 1. `vllm.distributed.parallel_state.destroy_model_parallel()`
|
||||||
|
# Why:
|
||||||
|
# vllm dose not support outside platform maintain its own `CoordinatorGroup`, vllm-ascend maintain EP and ETP
|
||||||
|
# inside of the repo, and needs a common interface to destroy them, this patch add the interface of destroy
|
||||||
|
# platform owned `CoordinatorGroup` to make sure all the CoordinateGroup can be properly destroyed
|
||||||
|
# How:
|
||||||
|
# Call platform method `destroy_platform_model_parallel` to destroy all the `CoordinateGroup`
|
||||||
|
# Related PR (if no, explain why): no related PR, we want add this ability into vllm
|
||||||
|
# Future Plan:
|
||||||
|
# Remove those patch when vllm merged them
|
||||||
|
# 2. `vllm.distributed.stateless_init_torch_distributed_process_group()`
|
||||||
|
# Why:
|
||||||
|
# The stateless process group can not be initialized except from gloo and nccl backend, vllm-ascend
|
||||||
|
# needs to initialize its own stateless process group for communication, so we add the platform related
|
||||||
|
# call to the `stateless_init_torch_distributed_process_group`, to enable other platform which may support
|
||||||
|
# stateless process group initialize method
|
||||||
|
# How:
|
||||||
|
# Call platform method `platform_has_backend_register` to judge if there is a stateless process group initialize
|
||||||
|
# method and call platform method `platform_register_backend` to initialize them
|
||||||
|
# Related PR (if no, explain why): no related PR, we want add this ability into vllm
|
||||||
|
# Future Plan:
|
||||||
|
# Remove those patch when vllm merged them
|
||||||
|
# 3. `ParallelConfig.get_next_dp_init_port`
|
||||||
|
# Why:
|
||||||
|
# We want to get dp port from env variable, so the multi-node inference can be properly initialized and run.
|
||||||
|
# How:
|
||||||
|
# Get the dp port from env variable enable multi-mode dp inference
|
||||||
|
# Related PR (if no, explain why): no related PR, we want add this ability into vllm
|
||||||
|
# Future Plan:
|
||||||
|
# Its a workaround in vllm-ascend to enable multi-node dp inference, maybe removed if vllm have better plan
|
||||||
|
# on multi-node dp inference implementation
|
||||||
|
|
||||||
|
|
||||||
def ascend_destroy_model_parallel():
|
def ascend_destroy_model_parallel():
|
||||||
"""Set the groups to none and destroy them."""
|
"""Set the groups to none and destroy them."""
|
||||||
|
|||||||
@@ -1,138 +1,32 @@
|
|||||||
import torch
|
#
|
||||||
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||||
|
# Copyright 2023 The vLLM team.
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# Adapted from vllm/model_executor/models/qwen2_vl.py
|
||||||
|
# This file is a part of the vllm-ascend project.
|
||||||
|
|
||||||
import vllm
|
import vllm
|
||||||
import vllm.distributed
|
import vllm.distributed
|
||||||
from torch.distributed import ProcessGroup
|
|
||||||
from torch.distributed.distributed_c10d import (Backend, PrefixStore,
|
|
||||||
_get_default_timeout,
|
|
||||||
is_nccl_available)
|
|
||||||
from torch.distributed.rendezvous import rendezvous
|
|
||||||
from vllm.config import ParallelConfig
|
from vllm.config import ParallelConfig
|
||||||
|
|
||||||
|
from vllm_ascend.patch.platform.patch_0_8_4.patch_distributed import (
|
||||||
|
ascend_destroy_model_parallel,
|
||||||
|
ascend_stateless_init_torch_distributed_process_group,
|
||||||
|
parallel_config_get_dp_port)
|
||||||
|
|
||||||
def ascend_destroy_model_parallel():
|
# All details of those patch please refer to vllm_ascend/patch/platform/patch_0_8_4/patch_distributed.py
|
||||||
"""Set the groups to none and destroy them."""
|
|
||||||
from vllm.distributed.parallel_state import _DP, _PP, _TP
|
|
||||||
if _TP:
|
|
||||||
_TP.destroy()
|
|
||||||
_TP = None
|
|
||||||
|
|
||||||
if _PP:
|
|
||||||
_PP.destroy()
|
|
||||||
_PP = None
|
|
||||||
|
|
||||||
if _DP:
|
|
||||||
_DP.destroy()
|
|
||||||
_DP = None
|
|
||||||
from vllm.platforms import current_platform
|
|
||||||
current_platform.destroy_platform_model_parallel()
|
|
||||||
|
|
||||||
|
|
||||||
def ascend_stateless_init_torch_distributed_process_group(
|
|
||||||
host: str, port: int, rank: int, world_size: int,
|
|
||||||
backend: str) -> ProcessGroup:
|
|
||||||
"""
|
|
||||||
A replacement for `torch.distributed.init_process_group` that does not
|
|
||||||
pollute the global state. The created ProcessGroup object can be used for
|
|
||||||
some operations such as `allreduce`, because it does not depend on the
|
|
||||||
global rank. However, some operations such as `broadcast` cannot be used
|
|
||||||
because it depends on the global rank.
|
|
||||||
|
|
||||||
# TODO: ask for help from PyTorch team if we need the `broadcast` operation.
|
|
||||||
|
|
||||||
This function is useful when we are not sure about the total number of
|
|
||||||
processes in the process group. For example, we may have process
|
|
||||||
1, 2, ..., 8 who want to communicate, and process 9 might be the same
|
|
||||||
process as process 1, or it might be a different process; process 10
|
|
||||||
might be the same process as process 5, or it might be a different process.
|
|
||||||
In this case, how can we reliably form a communication channel within
|
|
||||||
process 9 and 10, without affecting the communication channel within
|
|
||||||
process 1, 2, ..., 8?
|
|
||||||
|
|
||||||
One possible solution is to figure out if process 9 and 10 are the same
|
|
||||||
as process 1 and 5 beforehand, and then form a communication channel
|
|
||||||
based on the information, adjusting the ranks and world_size etc. However,
|
|
||||||
figuring out the information is not always easy, and it will interfere
|
|
||||||
with the main communication channel.
|
|
||||||
|
|
||||||
Our solution is to always form a communication channel with process 1, 2,
|
|
||||||
..., 8, and then use this function to form another communication channel
|
|
||||||
with process 9 and 10. This way, regardless of whether process 9 and 10
|
|
||||||
are the same as process 1 and 5, the main communication channel is
|
|
||||||
always formed with process 1, 2, ..., 8, and the additional communication
|
|
||||||
channel is formed with process 9 and 10.
|
|
||||||
"""
|
|
||||||
init_method = f"tcp://{host}:{port}"
|
|
||||||
backend = Backend(backend) # it is basically string
|
|
||||||
timeout = _get_default_timeout(backend)
|
|
||||||
|
|
||||||
store, rank, world_size = next(
|
|
||||||
rendezvous(init_method, rank, world_size, timeout=timeout))
|
|
||||||
store.set_timeout(timeout)
|
|
||||||
|
|
||||||
group_rank = rank
|
|
||||||
group_size = world_size
|
|
||||||
|
|
||||||
# Use a PrefixStore to avoid accidental overrides of keys used by
|
|
||||||
# different systems (e.g. RPC) in case the store is multi-tenant.
|
|
||||||
prefix_store = PrefixStore(init_method, store)
|
|
||||||
|
|
||||||
pg: ProcessGroup = ProcessGroup(
|
|
||||||
prefix_store,
|
|
||||||
group_rank,
|
|
||||||
group_size,
|
|
||||||
)
|
|
||||||
from vllm.platforms import current_platform
|
|
||||||
if backend == "gloo":
|
|
||||||
from torch.distributed.distributed_c10d import ProcessGroupGloo
|
|
||||||
backend_class = ProcessGroupGloo(prefix_store,
|
|
||||||
group_rank,
|
|
||||||
group_size,
|
|
||||||
timeout=timeout)
|
|
||||||
backend_type = ProcessGroup.BackendType.GLOO
|
|
||||||
device = torch.device("cpu")
|
|
||||||
elif backend == "nccl":
|
|
||||||
assert is_nccl_available()
|
|
||||||
from torch.distributed.distributed_c10d import ProcessGroupNCCL
|
|
||||||
|
|
||||||
backend_options = ProcessGroupNCCL.Options()
|
|
||||||
backend_options._timeout = timeout
|
|
||||||
|
|
||||||
backend_class = ProcessGroupNCCL(prefix_store, group_rank, group_size,
|
|
||||||
backend_options)
|
|
||||||
backend_type = ProcessGroup.BackendType.NCCL
|
|
||||||
device = torch.device("cuda")
|
|
||||||
elif current_platform.platform_has_backend_register():
|
|
||||||
current_platform.platform_register_backend()
|
|
||||||
return pg
|
|
||||||
else:
|
|
||||||
raise RuntimeError(f"Unsupported torch distributed backend: {backend}")
|
|
||||||
|
|
||||||
pg._set_default_backend(backend_type)
|
|
||||||
backend_class._set_sequence_number_for_group()
|
|
||||||
|
|
||||||
pg._register_backend(device, backend_type, backend_class)
|
|
||||||
|
|
||||||
return pg
|
|
||||||
|
|
||||||
|
|
||||||
def parallel_config_get_dp_port(self) -> int:
|
|
||||||
"""
|
|
||||||
We might need to initialize process groups in multiple
|
|
||||||
processes that is related to data parallelism,
|
|
||||||
e.g. both in the worker and in the engine, which
|
|
||||||
can live in different processes. To avoid port conflicts, we
|
|
||||||
increment the port number each time we need to initialize a
|
|
||||||
new process group related to data parallelism.
|
|
||||||
"""
|
|
||||||
answer = self.data_parallel_master_port
|
|
||||||
self.data_parallel_master_port += 1
|
|
||||||
import os
|
|
||||||
|
|
||||||
# NOTE: Get port from envs directly when using torchrun
|
|
||||||
port = int(os.environ.get("MASTER_PORT", answer)) # type: ignore
|
|
||||||
return port
|
|
||||||
|
|
||||||
|
|
||||||
vllm.distributed.parallel_state.destroy_model_parallel = ascend_destroy_model_parallel
|
vllm.distributed.parallel_state.destroy_model_parallel = ascend_destroy_model_parallel
|
||||||
vllm.distributed.stateless_init_torch_distributed_process_group = ascend_stateless_init_torch_distributed_process_group
|
vllm.distributed.stateless_init_torch_distributed_process_group = ascend_stateless_init_torch_distributed_process_group
|
||||||
ParallelConfig.get_next_dp_init_port = parallel_config_get_dp_port
|
ParallelConfig.get_next_dp_init_port = parallel_config_get_dp_port
|
||||||
|
|||||||
@@ -835,7 +835,6 @@ class NPUModelRunner:
|
|||||||
assert num_blocks >= kv_cache_config.num_blocks
|
assert num_blocks >= kv_cache_config.num_blocks
|
||||||
# TODO: remove this after the OOM issue is located and fixed, otherwise, some model may
|
# TODO: remove this after the OOM issue is located and fixed, otherwise, some model may
|
||||||
# encounter OOM issue
|
# encounter OOM issue
|
||||||
num_blocks = num_blocks // 4
|
|
||||||
if isinstance(kv_cache_spec, FullAttentionSpec):
|
if isinstance(kv_cache_spec, FullAttentionSpec):
|
||||||
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
|
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
|
||||||
num_blocks, kv_cache_spec.block_size,
|
num_blocks, kv_cache_spec.block_size,
|
||||||
|
|||||||
Reference in New Issue
Block a user