[CI]Fixed the spell check function in typos.toml (#6753)
### What this PR does / why we need it?
The incorrect regular expression syntax `.*[UE4M3|ue4m3].*` actually
ignores all words containing any of the following characters: `u, e, 4,
m, 3, |`
```yaml
extend-ignore-identifiers-re = [".*Unc.*", ".*_thw",
".*UE8M0.*", ".*[UE4M3|ue4m3].*", ".*eles.*", ".*fo.*", ".*ba.*",
".*ot.*", ".*[Tt]h[rR].*"]
```
===fix===>
```yaml
extend-ignore-identifiers-re = [".*Unc.*", ".*_thw",
".*UE8M0.*", ".*(UE4M3|ue4m3]).*", ".*eles.*", ".*fo.*", ".*ba.*",
".*ot.*", ".*[Tt]h[rR].*"]
```
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.15.0
- vLLM main:
9562912cea
Signed-off-by: MrZ20 <2609716663@qq.com>
This commit is contained in:
4
.github/workflows/labled_doctest.yaml
vendored
4
.github/workflows/labled_doctest.yaml
vendored
@@ -44,11 +44,11 @@ jobs:
|
|||||||
# Each version should be tested
|
# Each version should be tested
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
vllm_verison: [releases-v0.13.0, releases-v0.13.0-openeuler, main, main-openeuler]
|
vllm_version: [releases-v0.13.0, releases-v0.13.0-openeuler, main, main-openeuler]
|
||||||
name: vLLM Ascend test
|
name: vLLM Ascend test
|
||||||
runs-on: linux-aarch64-a2b3-1
|
runs-on: linux-aarch64-a2b3-1
|
||||||
container:
|
container:
|
||||||
image: swr.cn-southwest-2.myhuaweicloud.com/base_image/ascend-ci/vllm-ascend:${{ matrix.vllm_verison }}
|
image: swr.cn-southwest-2.myhuaweicloud.com/base_image/ascend-ci/vllm-ascend:${{ matrix.vllm_version }}
|
||||||
steps:
|
steps:
|
||||||
- name: Check NPU/CANN and git info
|
- name: Check NPU/CANN and git info
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
@@ -146,10 +146,10 @@ jobs:
|
|||||||
# Each version should be tested
|
# Each version should be tested
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
vllm_verison: [releases-v0.13.0, releases-v0.13.0-openeuler, main, main-openeuler]
|
vllm_version: [releases-v0.13.0, releases-v0.13.0-openeuler, main, main-openeuler]
|
||||||
runs-on: linux-aarch64-a2b3-1
|
runs-on: linux-aarch64-a2b3-1
|
||||||
container:
|
container:
|
||||||
image: swr.cn-southwest-2.myhuaweicloud.com/base_image/ascend-ci/vllm-ascend:${{ matrix.vllm_verison }}
|
image: swr.cn-southwest-2.myhuaweicloud.com/base_image/ascend-ci/vllm-ascend:${{ matrix.vllm_version }}
|
||||||
steps:
|
steps:
|
||||||
- name: Check NPU/CANN and git info
|
- name: Check NPU/CANN and git info
|
||||||
run: |
|
run: |
|
||||||
@@ -183,4 +183,4 @@ jobs:
|
|||||||
|
|
||||||
# Run real test
|
# Run real test
|
||||||
echo "Test:"
|
echo "Test:"
|
||||||
/vllm-workspace/vllm-ascend/tests/e2e/run_doctests.sh
|
/vllm-workspace/vllm-ascend/tests/e2e/run_doctests.sh
|
||||||
|
|||||||
@@ -1,9 +1,11 @@
|
|||||||
default_install_hook_types:
|
default_install_hook_types:
|
||||||
- pre-commit
|
- pre-commit
|
||||||
- commit-msg
|
- commit-msg
|
||||||
|
|
||||||
default_stages:
|
default_stages:
|
||||||
- pre-commit # Run locally
|
- pre-commit # Run locally
|
||||||
- manual # Run in CI
|
- manual # Run in CI
|
||||||
|
|
||||||
repos:
|
repos:
|
||||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||||
rev: v0.14.0
|
rev: v0.14.0
|
||||||
@@ -11,6 +13,7 @@ repos:
|
|||||||
- id: ruff-check
|
- id: ruff-check
|
||||||
args: [--output-format, github, --fix]
|
args: [--output-format, github, --fix]
|
||||||
- id: ruff-format
|
- id: ruff-format
|
||||||
|
|
||||||
- repo: https://github.com/codespell-project/codespell
|
- repo: https://github.com/codespell-project/codespell
|
||||||
rev: v2.4.1
|
rev: v2.4.1
|
||||||
hooks:
|
hooks:
|
||||||
@@ -22,6 +25,7 @@ repos:
|
|||||||
]
|
]
|
||||||
additional_dependencies:
|
additional_dependencies:
|
||||||
- tomli
|
- tomli
|
||||||
|
|
||||||
- repo: https://github.com/crate-ci/typos
|
- repo: https://github.com/crate-ci/typos
|
||||||
rev: v1.32.0
|
rev: v1.32.0
|
||||||
hooks:
|
hooks:
|
||||||
@@ -30,6 +34,7 @@ repos:
|
|||||||
"--force-exclude",
|
"--force-exclude",
|
||||||
"--exclude", "csrc/**"
|
"--exclude", "csrc/**"
|
||||||
]
|
]
|
||||||
|
|
||||||
# - repo: https://github.com/pre-commit/mirrors-clang-format
|
# - repo: https://github.com/pre-commit/mirrors-clang-format
|
||||||
# rev: v20.1.3
|
# rev: v20.1.3
|
||||||
# hooks:
|
# hooks:
|
||||||
@@ -37,17 +42,20 @@ repos:
|
|||||||
# files: ^csrc/.*\.(cpp|hpp|cc|hh|cxx|hxx)$
|
# files: ^csrc/.*\.(cpp|hpp|cc|hh|cxx|hxx)$
|
||||||
# types_or: [c++]
|
# types_or: [c++]
|
||||||
# args: [--style=google, --verbose]
|
# args: [--style=google, --verbose]
|
||||||
|
|
||||||
- repo: https://github.com/igorshubovych/markdownlint-cli
|
- repo: https://github.com/igorshubovych/markdownlint-cli
|
||||||
rev: v0.45.0
|
rev: v0.45.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: markdownlint
|
- id: markdownlint
|
||||||
exclude: '.*\.inc\.md$|.*report_template\.md$|.*contributors\.md$|.*PULL_REQUEST_TEMPLATE\.md$'
|
exclude: '.*\.inc\.md$|.*report_template\.md$|.*contributors\.md$|.*PULL_REQUEST_TEMPLATE\.md$'
|
||||||
stages: [manual] # Only run in CI
|
stages: [manual] # Only run in CI
|
||||||
|
|
||||||
- repo: https://github.com/rhysd/actionlint
|
- repo: https://github.com/rhysd/actionlint
|
||||||
rev: v1.7.7
|
rev: v1.7.7
|
||||||
hooks:
|
hooks:
|
||||||
- id: actionlint
|
- id: actionlint
|
||||||
exclude: '.*\.github/workflows/scripts/.*\.ya?ml$'
|
exclude: '.*\.github/workflows/scripts/.*\.ya?ml$'
|
||||||
|
|
||||||
- repo: local
|
- repo: local
|
||||||
hooks:
|
hooks:
|
||||||
- id: png-lint
|
- id: png-lint
|
||||||
|
|||||||
@@ -25,5 +25,5 @@ msgid "Adding a New Multi-Modal Model"
|
|||||||
msgstr "添加新的多模态模型"
|
msgstr "添加新的多模态模型"
|
||||||
|
|
||||||
#: ../../developer_guide/modeling/adding_a_new_multimodal_model.md:3
|
#: ../../developer_guide/modeling/adding_a_new_multimodal_model.md:3
|
||||||
msgid "**_Comming soon ..._**"
|
msgid "**_Coming soon ..._**"
|
||||||
msgstr "**_敬请期待 ..._**"
|
msgstr "**_敬请期待 ..._**"
|
||||||
|
|||||||
@@ -636,7 +636,7 @@ This is the 3rd release candidate of v0.9.1 for vLLM Ascend. Please follow the [
|
|||||||
- Fix incorrect req block length in ascend scheduler [#2394](https://github.com/vllm-project/vllm-ascend/pull/2394)
|
- Fix incorrect req block length in ascend scheduler [#2394](https://github.com/vllm-project/vllm-ascend/pull/2394)
|
||||||
- Fix header include issue in rope [#2398](https://github.com/vllm-project/vllm-ascend/pull/2398)
|
- Fix header include issue in rope [#2398](https://github.com/vllm-project/vllm-ascend/pull/2398)
|
||||||
- Fix mtp config bug [#2412](https://github.com/vllm-project/vllm-ascend/pull/2412)
|
- Fix mtp config bug [#2412](https://github.com/vllm-project/vllm-ascend/pull/2412)
|
||||||
- Fix error info and adapt `attn_metedata` refactor [#2402](https://github.com/vllm-project/vllm-ascend/pull/2402)
|
- Fix error info and adapt `attn_metadata` refactor [#2402](https://github.com/vllm-project/vllm-ascend/pull/2402)
|
||||||
- Fix torchair runtime error caused by configuration mismatches and `.kv_cache_bytes` file missing [#2312](https://github.com/vllm-project/vllm-ascend/pull/2312)
|
- Fix torchair runtime error caused by configuration mismatches and `.kv_cache_bytes` file missing [#2312](https://github.com/vllm-project/vllm-ascend/pull/2312)
|
||||||
- Move `with_prefill` allreduce from cpu to npu [#2230](https://github.com/vllm-project/vllm-ascend/pull/2230)
|
- Move `with_prefill` allreduce from cpu to npu [#2230](https://github.com/vllm-project/vllm-ascend/pull/2230)
|
||||||
|
|
||||||
|
|||||||
@@ -178,7 +178,7 @@ class ProxyState:
|
|||||||
# No lock needed - atomic operation
|
# No lock needed - atomic operation
|
||||||
self.prefillers[server_idx].aborted_requests.add(request_id)
|
self.prefillers[server_idx].aborted_requests.add(request_id)
|
||||||
|
|
||||||
def aquire_aborted_prefiller_requests(self, server_idx: int): # Changed to synchronous
|
def acquire_aborted_prefiller_requests(self, server_idx: int): # Changed to synchronous
|
||||||
"""
|
"""
|
||||||
Get the set of aborted requests and clear it.
|
Get the set of aborted requests and clear it.
|
||||||
This is used to release kv cache in prefiller node.
|
This is used to release kv cache in prefiller node.
|
||||||
@@ -325,7 +325,7 @@ async def send_request_to_service(
|
|||||||
max_retries: int = 3,
|
max_retries: int = 3,
|
||||||
base_delay: float = 0.2,
|
base_delay: float = 0.2,
|
||||||
):
|
):
|
||||||
proxy_state.aquire_aborted_prefiller_requests(prefiller_id)
|
proxy_state.acquire_aborted_prefiller_requests(prefiller_id)
|
||||||
req_data = req_data.copy()
|
req_data = req_data.copy()
|
||||||
req_data["stream"] = False
|
req_data["stream"] = False
|
||||||
req_data["max_tokens"] = 1
|
req_data["max_tokens"] = 1
|
||||||
|
|||||||
@@ -241,7 +241,7 @@ class ProxyState:
|
|||||||
return
|
return
|
||||||
self.prefillers[server_idx].aborted_requests.add(request_id)
|
self.prefillers[server_idx].aborted_requests.add(request_id)
|
||||||
|
|
||||||
def aquire_aborted_prefiller_requests(self, server_idx: int): # Changed to synchronous
|
def acquire_aborted_prefiller_requests(self, server_idx: int): # Changed to synchronous
|
||||||
"""
|
"""
|
||||||
Get the set of aborted requests and clear it.
|
Get the set of aborted requests and clear it.
|
||||||
This is used to release kv cache in prefiller node.
|
This is used to release kv cache in prefiller node.
|
||||||
@@ -582,7 +582,7 @@ async def send_request_to_service(
|
|||||||
max_retries: int = 3,
|
max_retries: int = 3,
|
||||||
base_delay: float = 0.2,
|
base_delay: float = 0.2,
|
||||||
):
|
):
|
||||||
aborted_requests = proxy_state.aquire_aborted_prefiller_requests(prefiller_id)
|
aborted_requests = proxy_state.acquire_aborted_prefiller_requests(prefiller_id)
|
||||||
req_data = req_data.copy()
|
req_data = req_data.copy()
|
||||||
req_data["kv_transfer_params"] = {
|
req_data["kv_transfer_params"] = {
|
||||||
"do_remote_decode": True,
|
"do_remote_decode": True,
|
||||||
|
|||||||
@@ -59,7 +59,7 @@ def calculate_average(lst):
|
|||||||
return total / count
|
return total / count
|
||||||
|
|
||||||
|
|
||||||
def layer_imblance_polt(y_list, label_names, device_num, output_path, file_name):
|
def layer_imbalance_plot(y_list, label_names, device_num, output_path, file_name):
|
||||||
plt.rcParams["font.sans-serif"] = ["Arial"]
|
plt.rcParams["font.sans-serif"] = ["Arial"]
|
||||||
plt.rcParams["axes.unicode_minus"] = False
|
plt.rcParams["axes.unicode_minus"] = False
|
||||||
x = [i for i in range(58)]
|
x = [i for i in range(58)]
|
||||||
@@ -160,4 +160,4 @@ if __name__ == "__main__":
|
|||||||
save_matrix_to_json(output_path, file_name, np.array(global_deployment))
|
save_matrix_to_json(output_path, file_name, np.array(global_deployment))
|
||||||
label_names = ["default deployment max load", "balanced load max load", "balanced load avg load"]
|
label_names = ["default deployment max load", "balanced load max load", "balanced load avg load"]
|
||||||
new_file_name = f"{exp_name}_{num_devices}_{num_redundancy_expert}.png"
|
new_file_name = f"{exp_name}_{num_devices}_{num_redundancy_expert}.png"
|
||||||
layer_imblance_polt(y_list, label_names, num_devices, output_path, new_file_name)
|
layer_imbalance_plot(y_list, label_names, num_devices, output_path, new_file_name)
|
||||||
|
|||||||
@@ -283,10 +283,10 @@ async def _select_instance(api: str, req_data: Any, request_length: int):
|
|||||||
request_id = await proxy_state.next_req_id()
|
request_id = await proxy_state.next_req_id()
|
||||||
# Select dp server based on priority score
|
# Select dp server based on priority score
|
||||||
server_idx = proxy_state.select_server(priority_score)
|
server_idx = proxy_state.select_server(priority_score)
|
||||||
choosen_server = proxy_state.dp_servers[server_idx]
|
chosen_server = proxy_state.dp_servers[server_idx]
|
||||||
logger.debug(f"Choose server {choosen_server.url} to process request {request_id}")
|
logger.debug(f"Choose server {chosen_server.url} to process request {request_id}")
|
||||||
return InstanceInfo(
|
return InstanceInfo(
|
||||||
request_id=request_id, server_idx=server_idx, priority_score=priority_score, server_state=choosen_server
|
request_id=request_id, server_idx=server_idx, priority_score=priority_score, server_state=chosen_server
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -29,11 +29,11 @@ dp_rpc_port = args.dp_rpc_port
|
|||||||
vllm_start_port = args.vllm_start_port
|
vllm_start_port = args.vllm_start_port
|
||||||
|
|
||||||
|
|
||||||
def run_command(visiable_devices, dp_rank, vllm_engine_port):
|
def run_command(visible_devices, dp_rank, vllm_engine_port):
|
||||||
command = [
|
command = [
|
||||||
"bash",
|
"bash",
|
||||||
"./run_dp_template.sh",
|
"./run_dp_template.sh",
|
||||||
visiable_devices,
|
visible_devices,
|
||||||
str(vllm_engine_port),
|
str(vllm_engine_port),
|
||||||
str(dp_size),
|
str(dp_size),
|
||||||
str(dp_rank),
|
str(dp_rank),
|
||||||
@@ -55,8 +55,8 @@ if __name__ == "__main__":
|
|||||||
for i in range(dp_size_local):
|
for i in range(dp_size_local):
|
||||||
dp_rank = dp_rank_start + i
|
dp_rank = dp_rank_start + i
|
||||||
vllm_engine_port = vllm_start_port + i
|
vllm_engine_port = vllm_start_port + i
|
||||||
visiable_devices = ",".join(str(x) for x in range(i * tp_size, (i + 1) * tp_size))
|
visible_devices = ",".join(str(x) for x in range(i * tp_size, (i + 1) * tp_size))
|
||||||
process = multiprocessing.Process(target=run_command, args=(visiable_devices, dp_rank, vllm_engine_port))
|
process = multiprocessing.Process(target=run_command, args=(visible_devices, dp_rank, vllm_engine_port))
|
||||||
processes.append(process)
|
processes.append(process)
|
||||||
process.start()
|
process.start()
|
||||||
|
|
||||||
|
|||||||
2
setup.py
2
setup.py
@@ -198,7 +198,7 @@ class build_and_install_aclnn(Command):
|
|||||||
try:
|
try:
|
||||||
print("Running bash build_aclnn.sh ...")
|
print("Running bash build_aclnn.sh ...")
|
||||||
subprocess.check_call(["bash", "csrc/build_aclnn.sh", ROOT_DIR, envs.SOC_VERSION])
|
subprocess.check_call(["bash", "csrc/build_aclnn.sh", ROOT_DIR, envs.SOC_VERSION])
|
||||||
print("buid_aclnn.sh executed successfully!")
|
print("build_aclnn.sh executed successfully!")
|
||||||
except subprocess.CalledProcessError as e:
|
except subprocess.CalledProcessError as e:
|
||||||
print(f"Error running build_aclnn.sh: {e}")
|
print(f"Error running build_aclnn.sh: {e}")
|
||||||
raise SystemExit(e.returncode)
|
raise SystemExit(e.returncode)
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import yaml
|
|||||||
# isort: off
|
# isort: off
|
||||||
from tests.e2e.nightly.multi_node.scripts.utils import (
|
from tests.e2e.nightly.multi_node.scripts.utils import (
|
||||||
CONFIG_BASE_PATH, DEFAULT_SERVER_PORT, get_all_ipv4, get_cluster_ips,
|
CONFIG_BASE_PATH, DEFAULT_SERVER_PORT, get_all_ipv4, get_cluster_ips,
|
||||||
get_net_interface, setup_logger, get_avaliable_port)
|
get_net_interface, setup_logger, get_available_port)
|
||||||
# isort: on
|
# isort: on
|
||||||
setup_logger()
|
setup_logger()
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -202,7 +202,7 @@ class MultiNodeConfig:
|
|||||||
master_ip = (self.disagg_cfg.master_ip_for_node(
|
master_ip = (self.disagg_cfg.master_ip_for_node(
|
||||||
self.cur_index, self.nodes)
|
self.cur_index, self.nodes)
|
||||||
if self.disagg_cfg else self.nodes[0].ip)
|
if self.disagg_cfg else self.nodes[0].ip)
|
||||||
self.proxy_port = get_avaliable_port()
|
self.proxy_port = get_available_port()
|
||||||
|
|
||||||
self.envs = DistEnvBuilder(
|
self.envs = DistEnvBuilder(
|
||||||
cur_node=self.cur_node,
|
cur_node=self.cur_node,
|
||||||
|
|||||||
@@ -75,7 +75,7 @@ def get_cluster_ips(word_size: int = 2) -> list[str]:
|
|||||||
return [resolver(dns) for dns in get_cluster_dns_list(word_size)]
|
return [resolver(dns) for dns in get_cluster_dns_list(word_size)]
|
||||||
|
|
||||||
|
|
||||||
def get_avaliable_port(start_port: int = 6000, end_port: int = 7000) -> int:
|
def get_available_port(start_port: int = 6000, end_port: int = 7000) -> int:
|
||||||
import socket
|
import socket
|
||||||
for port in range(start_port, end_port):
|
for port in range(start_port, end_port):
|
||||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ from vllm_ascend.utils import enable_custom_op
|
|||||||
enable_custom_op()
|
enable_custom_op()
|
||||||
|
|
||||||
|
|
||||||
class TestDisptachFFNCombine:
|
class TestDispatchFFNCombine:
|
||||||
|
|
||||||
def __init__(self, rank, world_size, port):
|
def __init__(self, rank, world_size, port):
|
||||||
self.rank = rank
|
self.rank = rank
|
||||||
@@ -208,7 +208,7 @@ class TestDisptachFFNCombine:
|
|||||||
|
|
||||||
|
|
||||||
def worker(rank: int, world_size: int, port: int, q: mp.SimpleQueue):
|
def worker(rank: int, world_size: int, port: int, q: mp.SimpleQueue):
|
||||||
op = TestDisptachFFNCombine(rank, world_size, port)
|
op = TestDispatchFFNCombine(rank, world_size, port)
|
||||||
op.generate_hcom()
|
op.generate_hcom()
|
||||||
out1 = op.run_tensor_list()
|
out1 = op.run_tensor_list()
|
||||||
q.put(out1)
|
q.put(out1)
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ from vllm_ascend.utils import enable_custom_op
|
|||||||
enable_custom_op()
|
enable_custom_op()
|
||||||
|
|
||||||
|
|
||||||
class TestDisptachFFNCombine:
|
class TestDispatchFFNCombine:
|
||||||
|
|
||||||
def __init__(self, rank, world_size, port):
|
def __init__(self, rank, world_size, port):
|
||||||
self.rank = rank
|
self.rank = rank
|
||||||
@@ -208,7 +208,7 @@ class TestDisptachFFNCombine:
|
|||||||
|
|
||||||
|
|
||||||
def worker(rank: int, world_size: int, port: int, q: mp.SimpleQueue):
|
def worker(rank: int, world_size: int, port: int, q: mp.SimpleQueue):
|
||||||
op = TestDisptachFFNCombine(rank, world_size, port)
|
op = TestDispatchFFNCombine(rank, world_size, port)
|
||||||
op.generate_hcom()
|
op.generate_hcom()
|
||||||
out1 = op.run_tensor_list()
|
out1 = op.run_tensor_list()
|
||||||
q.put(out1)
|
q.put(out1)
|
||||||
|
|||||||
@@ -124,10 +124,10 @@ def create_test_data(
|
|||||||
|
|
||||||
logits = torch.randn(num_reqs, vocab_size, device=device, dtype=dtype)
|
logits = torch.randn(num_reqs, vocab_size, device=device, dtype=dtype)
|
||||||
|
|
||||||
repetiton_penalty = torch.ones(num_reqs, device=device, dtype=torch.float32)
|
repetition_penalty = torch.ones(num_reqs, device=device, dtype=torch.float32)
|
||||||
for i in range(num_reqs):
|
for i in range(num_reqs):
|
||||||
if torch.rand(1) > 0.3:
|
if torch.rand(1) > 0.3:
|
||||||
repetiton_penalty[i] = torch.rand(1, device=device).item() * 0.8 + 0.6
|
repetition_penalty[i] = torch.rand(1, device=device).item() * 0.8 + 0.6
|
||||||
|
|
||||||
frequency_penalty = torch.zeros(num_reqs, device=device, dtype=torch.float32)
|
frequency_penalty = torch.zeros(num_reqs, device=device, dtype=torch.float32)
|
||||||
for i in range(num_reqs):
|
for i in range(num_reqs):
|
||||||
@@ -168,7 +168,7 @@ def create_test_data(
|
|||||||
output_bin_counts[state_idx, token] = count
|
output_bin_counts[state_idx, token] = count
|
||||||
|
|
||||||
sampling_metadata = SamplingMetadata(
|
sampling_metadata = SamplingMetadata(
|
||||||
repetition_penalty=repetiton_penalty,
|
repetition_penalty=repetition_penalty,
|
||||||
frequency_penalty=frequency_penalty,
|
frequency_penalty=frequency_penalty,
|
||||||
presence_penalty=presence_penalty,
|
presence_penalty=presence_penalty,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
@@ -217,4 +217,3 @@ def test_apply_penalties_and_temperature(
|
|||||||
atol = 1e-02
|
atol = 1e-02
|
||||||
rtol = 1e-02
|
rtol = 1e-02
|
||||||
assert torch.allclose(logits_triton, logits_pytorch_result, atol=atol, rtol=rtol)
|
assert torch.allclose(logits_triton, logits_pytorch_result, atol=atol, rtol=rtol)
|
||||||
|
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ class TestCumsumGroupList(unittest.TestCase):
|
|||||||
}
|
}
|
||||||
|
|
||||||
support_combine = [(0, 0), (1, 0), (0, 1)]
|
support_combine = [(0, 0), (1, 0), (0, 1)]
|
||||||
unsupport_combine = [(0, 2), (2, 1), (1, 2)]
|
unsupported_combine = [(0, 2), (2, 1), (1, 2)]
|
||||||
|
|
||||||
def test_cumsum_group_list_supported_conversion(self):
|
def test_cumsum_group_list_supported_conversion(self):
|
||||||
for src_list_type, dst_list_type in self.support_combine:
|
for src_list_type, dst_list_type in self.support_combine:
|
||||||
@@ -38,7 +38,7 @@ class TestCumsumGroupList(unittest.TestCase):
|
|||||||
|
|
||||||
def test_cumsum_group_list_unsupported_conversion_notimplementederror(
|
def test_cumsum_group_list_unsupported_conversion_notimplementederror(
|
||||||
self):
|
self):
|
||||||
for src_list_type, dst_list_type in self.unsupport_combine:
|
for src_list_type, dst_list_type in self.unsupported_combine:
|
||||||
with self.subTest(src=src_list_type, dst=dst_list_type):
|
with self.subTest(src=src_list_type, dst=dst_list_type):
|
||||||
with self.assertRaises(NotImplementedError) as excinfo:
|
with self.assertRaises(NotImplementedError) as excinfo:
|
||||||
cumsum_group_list(self.glist_dict[0], src_list_type,
|
cumsum_group_list(self.glist_dict[0], src_list_type,
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
[files]
|
[files]
|
||||||
# these files may be written in non english words
|
# these files may be written in non english words
|
||||||
extend-exclude = []
|
extend-exclude = [".pre-commit-config.yaml",]
|
||||||
ignore-hidden = true
|
ignore-hidden = true
|
||||||
ignore-files = true
|
ignore-files = true
|
||||||
ignore-dot = true
|
ignore-dot = true
|
||||||
@@ -17,9 +17,9 @@ ignore-hex = true
|
|||||||
identifier-leading-digits = false
|
identifier-leading-digits = false
|
||||||
locale = "en"
|
locale = "en"
|
||||||
extend-ignore-identifiers-re = [".*Unc.*", ".*_thw",
|
extend-ignore-identifiers-re = [".*Unc.*", ".*_thw",
|
||||||
".*UE8M0.*", ".*[UE4M3|ue4m3].*", ".*eles.*", ".*fo.*", ".*ba.*",
|
".*UE8M0.*", ".*(UE4M3|ue4m3).*", ".*eles.*", ".*fo.*", ".*ba.*",
|
||||||
".*ot.*", ".*[Tt]h[rR].*"]
|
".*ot.*", ".*[Tt]h[rR].*"]
|
||||||
extend-ignore-words-re = ["CANN", "cann","ND","alog"]
|
extend-ignore-words-re = ["CANN", "cann","ND","alog","nd","BA","datas","ful","udo",]
|
||||||
extend-ignore-re = []
|
extend-ignore-re = []
|
||||||
|
|
||||||
[default.extend-identifiers]
|
[default.extend-identifiers]
|
||||||
|
|||||||
@@ -144,14 +144,14 @@ class AscendConfig:
|
|||||||
if os.getenv("VLLM_ASCEND_ENABLE_PREFETCH_MLP", "0") == "1":
|
if os.getenv("VLLM_ASCEND_ENABLE_PREFETCH_MLP", "0") == "1":
|
||||||
MAX_PREFETCH_WEIGHT_SIZE: int = 18 * 1024 * 1024
|
MAX_PREFETCH_WEIGHT_SIZE: int = 18 * 1024 * 1024
|
||||||
gate_up_prefetch_size = int(os.getenv("VLLM_ASCEND_MLP_GATE_UP_PREFETCH_SIZE", MAX_PREFETCH_WEIGHT_SIZE))
|
gate_up_prefetch_size = int(os.getenv("VLLM_ASCEND_MLP_GATE_UP_PREFETCH_SIZE", MAX_PREFETCH_WEIGHT_SIZE))
|
||||||
down_prefetch_szie = int(os.getenv("VLLM_ASCEND_MLP_DOWN_PREFETCH_SIZE", MAX_PREFETCH_WEIGHT_SIZE))
|
down_prefetch_size = int(os.getenv("VLLM_ASCEND_MLP_DOWN_PREFETCH_SIZE", MAX_PREFETCH_WEIGHT_SIZE))
|
||||||
self.weight_prefetch_config.set_mlp_pre_version_compatibale_config(
|
self.weight_prefetch_config.set_mlp_pre_version_compatibale_config(
|
||||||
gate_up_prefetch_size, down_prefetch_szie
|
gate_up_prefetch_size, down_prefetch_size
|
||||||
)
|
)
|
||||||
logger.info_once(
|
logger.info_once(
|
||||||
f"MLP weight prefetch enabled from env variable VLLM_ASCEND_ENABLE_PREFETCH_MLP."
|
f"MLP weight prefetch enabled from env variable VLLM_ASCEND_ENABLE_PREFETCH_MLP."
|
||||||
f"gate_up_prefetch_size={gate_up_prefetch_size}, "
|
f"gate_up_prefetch_size={gate_up_prefetch_size}, "
|
||||||
f"down_prefetch_szie={down_prefetch_szie}."
|
f"down_prefetch_size={down_prefetch_size}."
|
||||||
)
|
)
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"VLLM_ASCEND_ENABLE_PREFETCH_MLP is deprecated and will be removed in a v0.16.0 version. "
|
"VLLM_ASCEND_ENABLE_PREFETCH_MLP is deprecated and will be removed in a v0.16.0 version. "
|
||||||
|
|||||||
@@ -34,13 +34,13 @@ else:
|
|||||||
from vllm.compilation.passes.vllm_inductor_pass import VllmInductorPass
|
from vllm.compilation.passes.vllm_inductor_pass import VllmInductorPass
|
||||||
|
|
||||||
# computation-communication tiling block is 512
|
# computation-communication tiling block is 512
|
||||||
ALLREDUCE_NORM_FUSE_THREHOLD = 512
|
ALLREDUCE_NORM_FUSE_THRESHOLD = 512
|
||||||
|
|
||||||
|
|
||||||
def get_compile_range_and_extra_stream_check():
|
def get_compile_range_and_extra_stream_check():
|
||||||
def check_func(match: Match) -> bool:
|
def check_func(match: Match) -> bool:
|
||||||
compile_range = get_pass_context().compile_range
|
compile_range = get_pass_context().compile_range
|
||||||
return extra_stream_scope_check(match) and compile_range.start > ALLREDUCE_NORM_FUSE_THREHOLD
|
return extra_stream_scope_check(match) and compile_range.start > ALLREDUCE_NORM_FUSE_THRESHOLD
|
||||||
|
|
||||||
return check_func
|
return check_func
|
||||||
|
|
||||||
@@ -176,5 +176,5 @@ class MatmulAllReduceAddRMSNormPass(VllmInductorPass):
|
|||||||
"""
|
"""
|
||||||
Check if the pass is applicable for the current configuration.
|
Check if the pass is applicable for the current configuration.
|
||||||
"""
|
"""
|
||||||
applicable = compile_range.start > ALLREDUCE_NORM_FUSE_THREHOLD
|
applicable = compile_range.start > ALLREDUCE_NORM_FUSE_THRESHOLD
|
||||||
return applicable
|
return applicable
|
||||||
|
|||||||
@@ -86,9 +86,9 @@ class BudgetRefiner:
|
|||||||
return k
|
return k
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _get_max_budget(self, num_deocde_tokens, num_decode):
|
def _get_max_budget(self, num_decode_tokens, num_decode):
|
||||||
"""Get the maximum budget according to the number of decoding tokens and the decoding requests."""
|
"""Get the maximum budget according to the number of decoding tokens and the decoding requests."""
|
||||||
aligned_ctx = self._align_key(num_deocde_tokens, self.context_keys)
|
aligned_ctx = self._align_key(num_decode_tokens, self.context_keys)
|
||||||
aligned_dnum = self._align_key(num_decode, self.dnum_keys)
|
aligned_dnum = self._align_key(num_decode, self.dnum_keys)
|
||||||
if aligned_ctx is None or aligned_dnum is None:
|
if aligned_ctx is None or aligned_dnum is None:
|
||||||
return self.default_budget
|
return self.default_budget
|
||||||
@@ -99,7 +99,7 @@ class BudgetRefiner:
|
|||||||
# For debug.
|
# For debug.
|
||||||
# logger.info(
|
# logger.info(
|
||||||
# f"budget {budget}, ctx,dnum {aligned_ctx, aligned_dnum}, "
|
# f"budget {budget}, ctx,dnum {aligned_ctx, aligned_dnum}, "
|
||||||
# f"raw ctx,dnum {num_deocde_tokens, num_decode}"
|
# f"raw ctx,dnum {num_decode_tokens, num_decode}"
|
||||||
# )
|
# )
|
||||||
return budget
|
return budget
|
||||||
|
|
||||||
@@ -114,8 +114,8 @@ class BudgetRefiner:
|
|||||||
num_decode = len(num_decode_token_lst)
|
num_decode = len(num_decode_token_lst)
|
||||||
if num_decode <= 0:
|
if num_decode <= 0:
|
||||||
return budget
|
return budget
|
||||||
num_deocde_tokens = sum(num_decode_token_lst) / num_decode
|
num_decode_tokens = sum(num_decode_token_lst) / num_decode
|
||||||
return self._get_max_budget(num_deocde_tokens, num_decode)
|
return self._get_max_budget(num_decode_tokens, num_decode)
|
||||||
|
|
||||||
|
|
||||||
class SchedulerDynamicBatch(Scheduler):
|
class SchedulerDynamicBatch(Scheduler):
|
||||||
|
|||||||
@@ -171,7 +171,7 @@ class HCCLLibrary:
|
|||||||
path_to_library_cache: dict[str, Any] = {}
|
path_to_library_cache: dict[str, Any] = {}
|
||||||
|
|
||||||
# class attribute to store the mapping from library path
|
# class attribute to store the mapping from library path
|
||||||
# to the correspongding directory
|
# to the corresponding directory
|
||||||
path_to_dict_mapping: dict[str, dict[str, Any]] = {}
|
path_to_dict_mapping: dict[str, dict[str, Any]] = {}
|
||||||
|
|
||||||
def __init__(self, so_file: str | None = None):
|
def __init__(self, so_file: str | None = None):
|
||||||
|
|||||||
@@ -1316,8 +1316,8 @@ class MooncakeConnectorWorker:
|
|||||||
"""
|
"""
|
||||||
prefill_tp_size = meta.remote_ptp_size if getattr(meta, "remote_ptp_size", None) else self._prefill_tp_size
|
prefill_tp_size = meta.remote_ptp_size if getattr(meta, "remote_ptp_size", None) else self._prefill_tp_size
|
||||||
if meta.remote_pcp_size * meta.remote_dcp_size * self.pcp_size * self.dcp_size == 1:
|
if meta.remote_pcp_size * meta.remote_dcp_size * self.pcp_size * self.dcp_size == 1:
|
||||||
choosen_rank_list = self._get_remote_rank(req_id, prefill_tp_size)
|
chosen_rank_list = self._get_remote_rank(req_id, prefill_tp_size)
|
||||||
remote_handshake_port_list = [[x + meta.remote_port for x in choosen_rank_list]]
|
remote_handshake_port_list = [[x + meta.remote_port for x in chosen_rank_list]]
|
||||||
local_block_ids_list, remote_block_ids_list = [meta.local_block_ids], [meta.remote_block_ids]
|
local_block_ids_list, remote_block_ids_list = [meta.local_block_ids], [meta.remote_block_ids]
|
||||||
return remote_handshake_port_list, local_block_ids_list, remote_block_ids_list
|
return remote_handshake_port_list, local_block_ids_list, remote_block_ids_list
|
||||||
|
|
||||||
@@ -1563,8 +1563,8 @@ class MooncakeConnectorWorker:
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
else: # TODO: support prefill context parallel and pipeline parallel open at the same time
|
else: # TODO: support prefill context parallel and pipeline parallel open at the same time
|
||||||
choosen_rank_list = self._get_remote_rank(remote_req_id, prefill_tp_size)
|
chosen_rank_list = self._get_remote_rank(remote_req_id, prefill_tp_size)
|
||||||
remote_handshake_port_list = [[x + meta.remote_port] for x in choosen_rank_list]
|
remote_handshake_port_list = [[x + meta.remote_port] for x in chosen_rank_list]
|
||||||
for i in range(tp_num_need_pulls * self._prefill_pp_size):
|
for i in range(tp_num_need_pulls * self._prefill_pp_size):
|
||||||
assert self.kv_recv_thread is not None
|
assert self.kv_recv_thread is not None
|
||||||
remote_host, remote_engine_id = self._get_remote_host_info_by_port(
|
remote_host, remote_engine_id = self._get_remote_host_info_by_port(
|
||||||
@@ -1651,8 +1651,8 @@ class MooncakeConnectorWorker:
|
|||||||
or self.use_sparse
|
or self.use_sparse
|
||||||
):
|
):
|
||||||
tp_ori_data = tp_ori_data.reshape(-1, num_groups)
|
tp_ori_data = tp_ori_data.reshape(-1, num_groups)
|
||||||
choosen_group = tp_ori_data[:, [rand_group_index]]
|
chosen_group = tp_ori_data[:, [rand_group_index]]
|
||||||
flattened = choosen_group.reshape(-1).tolist()
|
flattened = chosen_group.reshape(-1).tolist()
|
||||||
tp_sampled_nums = [
|
tp_sampled_nums = [
|
||||||
flattened[i : i + tp_num_need_pulls] for i in range(0, len(flattened), tp_num_need_pulls)
|
flattened[i : i + tp_num_need_pulls] for i in range(0, len(flattened), tp_num_need_pulls)
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -741,7 +741,7 @@ class MooncakeLayerwiseConnectorScheduler:
|
|||||||
computed_tokens.get(req_id, 0) + scheduled_tokens - spec_decode_tokens
|
computed_tokens.get(req_id, 0) + scheduled_tokens - spec_decode_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
def add_tranfer_task(req_id, send_req_info: SendReqInfo, chunk_finish=False):
|
def add_transfer_task(req_id, send_req_info: SendReqInfo, chunk_finish=False):
|
||||||
(
|
(
|
||||||
local_block_ids,
|
local_block_ids,
|
||||||
local_transed_tokens,
|
local_transed_tokens,
|
||||||
@@ -771,7 +771,7 @@ class MooncakeLayerwiseConnectorScheduler:
|
|||||||
# whether chunk finish
|
# whether chunk finish
|
||||||
chunk_finish = send_req_info.local_computed_tokens >= len(send_req_info.request.all_token_ids)
|
chunk_finish = send_req_info.local_computed_tokens >= len(send_req_info.request.all_token_ids)
|
||||||
|
|
||||||
add_tranfer_task(req_id, send_req_info, chunk_finish=chunk_finish)
|
add_transfer_task(req_id, send_req_info, chunk_finish=chunk_finish)
|
||||||
if chunk_finish:
|
if chunk_finish:
|
||||||
self._reqs_need_send_layerwise.pop(req_id)
|
self._reqs_need_send_layerwise.pop(req_id)
|
||||||
return meta
|
return meta
|
||||||
|
|||||||
@@ -68,7 +68,7 @@ class D2DExpertWeightLoader:
|
|||||||
self.updated_log2phy_map = log2phy_map
|
self.updated_log2phy_map = log2phy_map
|
||||||
|
|
||||||
def asyn_expert_weight_transfer(self, reqs):
|
def asyn_expert_weight_transfer(self, reqs):
|
||||||
# Only when send/recv tasks are parsed into self.comm_op_list, d2d send/recv tasks can be luanched
|
# Only when send/recv tasks are parsed into self.comm_op_list, d2d send/recv tasks can be launched
|
||||||
if self.state != ExpertWeightUpdateState.READY:
|
if self.state != ExpertWeightUpdateState.READY:
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -80,7 +80,7 @@ class D2DExpertWeightLoader:
|
|||||||
self.state = ExpertWeightUpdateState.TRANSFERRING
|
self.state = ExpertWeightUpdateState.TRANSFERRING
|
||||||
|
|
||||||
def update_expert_map_and_weight(self, reqs):
|
def update_expert_map_and_weight(self, reqs):
|
||||||
# Only after send/recv tasks have been luanched, expert_map and weight can be updated
|
# Only after send/recv tasks have been launched, expert_map and weight can be updated
|
||||||
if self.state != ExpertWeightUpdateState.TRANSFERRING:
|
if self.state != ExpertWeightUpdateState.TRANSFERRING:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|||||||
@@ -130,8 +130,8 @@ def jsq_placement(X, pieces, M, stage_weights):
|
|||||||
score = 0.0
|
score = 0.0
|
||||||
for s in range(n_stage):
|
for s in range(n_stage):
|
||||||
tmp_sj = loads[s, j] + w[s]
|
tmp_sj = loads[s, j] + w[s]
|
||||||
numer_sj = tmp_sj if tmp_sj > stage_max[s] else stage_max[s]
|
number_sj = tmp_sj if tmp_sj > stage_max[s] else stage_max[s]
|
||||||
score += stage_weights[s] * (numer_sj / denom[s])
|
score += stage_weights[s] * (number_sj / denom[s])
|
||||||
if score < best_val:
|
if score < best_val:
|
||||||
best_val = score
|
best_val = score
|
||||||
best_j = j
|
best_j = j
|
||||||
|
|||||||
@@ -195,10 +195,10 @@ class NPUPlatform(Platform):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if vllm_config.additional_config.get("ascend_compilation_config", {}).get("fuse_allreduce_rms", True):
|
if vllm_config.additional_config.get("ascend_compilation_config", {}).get("fuse_allreduce_rms", True):
|
||||||
from vllm_ascend.compilation.passes.allreduce_rmsnorm_fusion_pass import ALLREDUCE_NORM_FUSE_THREHOLD
|
from vllm_ascend.compilation.passes.allreduce_rmsnorm_fusion_pass import ALLREDUCE_NORM_FUSE_THRESHOLD
|
||||||
|
|
||||||
new_compile_ranges_split_points = vllm_config.compilation_config.compile_ranges_split_points
|
new_compile_ranges_split_points = vllm_config.compilation_config.compile_ranges_split_points
|
||||||
new_compile_ranges_split_points.append(ALLREDUCE_NORM_FUSE_THREHOLD)
|
new_compile_ranges_split_points.append(ALLREDUCE_NORM_FUSE_THRESHOLD)
|
||||||
new_compile_ranges_split_points = sorted(new_compile_ranges_split_points)
|
new_compile_ranges_split_points = sorted(new_compile_ranges_split_points)
|
||||||
vllm_config.compilation_config.compile_ranges_split_points = new_compile_ranges_split_points
|
vllm_config.compilation_config.compile_ranges_split_points = new_compile_ranges_split_points
|
||||||
logger.debug(
|
logger.debug(
|
||||||
@@ -208,10 +208,10 @@ class NPUPlatform(Platform):
|
|||||||
|
|
||||||
npugraph_ex_config = ascend_config.npugraph_ex_config
|
npugraph_ex_config = ascend_config.npugraph_ex_config
|
||||||
if npugraph_ex_config and npugraph_ex_config.fuse_allreduce_rms:
|
if npugraph_ex_config and npugraph_ex_config.fuse_allreduce_rms:
|
||||||
from vllm_ascend.compilation.passes.allreduce_rmsnorm_fusion_pass import ALLREDUCE_NORM_FUSE_THREHOLD
|
from vllm_ascend.compilation.passes.allreduce_rmsnorm_fusion_pass import ALLREDUCE_NORM_FUSE_THRESHOLD
|
||||||
|
|
||||||
new_compile_ranges_split_points = vllm_config.compilation_config.compile_ranges_split_points
|
new_compile_ranges_split_points = vllm_config.compilation_config.compile_ranges_split_points
|
||||||
new_compile_ranges_split_points.append(ALLREDUCE_NORM_FUSE_THREHOLD)
|
new_compile_ranges_split_points.append(ALLREDUCE_NORM_FUSE_THRESHOLD)
|
||||||
new_compile_ranges_split_points = sorted(new_compile_ranges_split_points)
|
new_compile_ranges_split_points = sorted(new_compile_ranges_split_points)
|
||||||
vllm_config.compilation_config.compile_ranges_split_points = new_compile_ranges_split_points
|
vllm_config.compilation_config.compile_ranges_split_points = new_compile_ranges_split_points
|
||||||
logger.debug(
|
logger.debug(
|
||||||
@@ -558,7 +558,7 @@ class NPUPlatform(Platform):
|
|||||||
Args:
|
Args:
|
||||||
attn_metadata (dict[str, Any]): attention metadata for all layers.
|
attn_metadata (dict[str, Any]): attention metadata for all layers.
|
||||||
vllm_config (VllmConfig): configuration of vllm.
|
vllm_config (VllmConfig): configuration of vllm.
|
||||||
dp_metadata (DpMetada): metadata for data parallelism.
|
dp_metadata (Dpmetadata): metadata for data parallelism.
|
||||||
lack of typehint because of circular import.
|
lack of typehint because of circular import.
|
||||||
virtual_engine (int, optional): index of virtual engine. Defaults to 0.
|
virtual_engine (int, optional): index of virtual engine. Defaults to 0.
|
||||||
num_tokens (int | None, optional): number of tokens. Defaults to None.
|
num_tokens (int | None, optional): number of tokens. Defaults to None.
|
||||||
|
|||||||
@@ -941,7 +941,7 @@ class EagleProposer(VllmEagleProposer):
|
|||||||
# [0, 1, 2, 3, 4, 5, 6, 7, 8] ->
|
# [0, 1, 2, 3, 4, 5, 6, 7, 8] ->
|
||||||
# [0, 1, 0, 1, 2, 3, 0, 1, 2]
|
# [0, 1, 0, 1, 2, 3, 0, 1, 2]
|
||||||
# _r1_ ____r2____ ___r3__
|
# _r1_ ____r2____ ___r3__
|
||||||
token_offests = self.token_arange_np[:total_num_tokens] - new_query_start_locs_expanded
|
token_offsets = self.token_arange_np[:total_num_tokens] - new_query_start_locs_expanded
|
||||||
|
|
||||||
# Expand starting positions to match token pattern
|
# Expand starting positions to match token pattern
|
||||||
# [0, q1, q1 + q2] ->
|
# [0, q1, q1 + q2] ->
|
||||||
@@ -952,7 +952,7 @@ class EagleProposer(VllmEagleProposer):
|
|||||||
# [0, 1, // req 1
|
# [0, 1, // req 1
|
||||||
# q1 + 0, q1 + 1, q1 + 2, q1 + 3, // req 2
|
# q1 + 0, q1 + 1, q1 + 2, q1 + 3, // req 2
|
||||||
# q1 + q2 + 0, q1 + q2 + 1, q1 + q2 + 2] // req 3
|
# q1 + q2 + 0, q1 + q2 + 1, q1 + q2 + 2] // req 3
|
||||||
token_indices_np = token_offests + old_query_start_locs_expanded
|
token_indices_np = token_offsets + old_query_start_locs_expanded
|
||||||
token_indices = torch.from_numpy(token_indices_np).to(device, non_blocking=True)
|
token_indices = torch.from_numpy(token_indices_np).to(device, non_blocking=True)
|
||||||
|
|
||||||
common_attn_metadata.slot_mapping[: token_indices.shape[0]].copy_(
|
common_attn_metadata.slot_mapping[: token_indices.shape[0]].copy_(
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ class Proposer:
|
|||||||
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
||||||
batch_descriptor=None,
|
batch_descriptor=None,
|
||||||
):
|
):
|
||||||
"""Called by dummy_run in modle_runner"""
|
"""Called by dummy_run in model_runner"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def generate_token_ids(
|
def generate_token_ids(
|
||||||
|
|||||||
@@ -2390,7 +2390,7 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
to be reshaped to the desired shape before being used by the models.
|
to be reshaped to the desired shape before being used by the models.
|
||||||
|
|
||||||
NOTE: To support prefill disaggregation, we need to split kvcache tensor into
|
NOTE: To support prefill disaggregation, we need to split kvcache tensor into
|
||||||
k_cahce and v cache, and the addr of both are aligned by 2M
|
k_cache and v cache, and the addr of both are aligned by 2M
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
kv_cache_config: The KV cache config
|
kv_cache_config: The KV cache config
|
||||||
|
|||||||
@@ -459,9 +459,9 @@ class PCPManager:
|
|||||||
# draft_len of each request [1, 2, 1]
|
# draft_len of each request [1, 2, 1]
|
||||||
# then prev_draft_token_indices is [0, 2, 3, 4]
|
# then prev_draft_token_indices is [0, 2, 3, 4]
|
||||||
prev_draft_token_indices.extend(range(start, start + draft_len))
|
prev_draft_token_indices.extend(range(start, start + draft_len))
|
||||||
num_commmon_tokens = len(sample_flattened_indices)
|
num_common_tokens = len(sample_flattened_indices)
|
||||||
|
|
||||||
if num_commmon_tokens == 0:
|
if num_common_tokens == 0:
|
||||||
# No requests in common with the previous iteration
|
# No requests in common with the previous iteration
|
||||||
# So input_ids.cpu will have all the input ids.
|
# So input_ids.cpu will have all the input ids.
|
||||||
return
|
return
|
||||||
|
|||||||
Reference in New Issue
Block a user