[BugFix][310P][v0.18.0] Use CPU generator cache for sampling (#8624)
### What this PR does / why we need it? This PR introduces a caching mechanism for CPU-based `torch.Generator` objects in the `_random_sample_310p` function to optimize sampling performance. It includes unit tests for cache persistence and state recovery. Feedback highlights a critical bug where keying the cache by batch index instead of generator ID can break RNG reproducibility during request re-scheduling, and notes a potential memory leak in the global cache. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Tested via new unit tests in `tests/ut/_310p/sample/test_sampler_310.py` verifying cache logic and error handling. --------- Signed-off-by: csoulnd <daidaicurry@foxmail.com>
This commit is contained in:
201
tests/ut/_310p/sample/test_sampler_310.py
Normal file
201
tests/ut/_310p/sample/test_sampler_310.py
Normal file
@@ -0,0 +1,201 @@
|
|||||||
|
#
|
||||||
|
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from contextlib import nullcontext
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from tests.ut.base import TestBase
|
||||||
|
from vllm_ascend._310p.sample import sampler as sampler_310p
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeRow:
|
||||||
|
def __init__(self):
|
||||||
|
self.generators = []
|
||||||
|
|
||||||
|
def exponential_(self, generator=None):
|
||||||
|
self.generators.append(generator)
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeQ:
|
||||||
|
def __init__(self, batch_size):
|
||||||
|
self.shape = (batch_size, 4)
|
||||||
|
self.default_exponential_called = False
|
||||||
|
self.rows = {idx: _FakeRow() for idx in range(batch_size)}
|
||||||
|
|
||||||
|
def cpu(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def npu(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def exponential_(self):
|
||||||
|
self.default_exponential_called = True
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
return self.rows[idx]
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeCPUGenerator:
|
||||||
|
def __init__(self, device=None):
|
||||||
|
self.device = device
|
||||||
|
self.state = None
|
||||||
|
self.seed = None
|
||||||
|
|
||||||
|
def set_state(self, state):
|
||||||
|
self.state = state
|
||||||
|
|
||||||
|
def manual_seed(self, seed):
|
||||||
|
self.seed = seed
|
||||||
|
|
||||||
|
|
||||||
|
class TestSampler310pGeneratorCache(TestBase):
|
||||||
|
@patch.object(sampler_310p.torch, "npu", create=True)
|
||||||
|
@patch.object(sampler_310p.torch, "Generator")
|
||||||
|
@patch.object(sampler_310p.torch, "empty_like")
|
||||||
|
@patch.object(sampler_310p, "global_stream", return_value=MagicMock())
|
||||||
|
@patch.object(sampler_310p, "npu_stream_switch", return_value=nullcontext())
|
||||||
|
def test_random_sample_310p_reuse_cpu_generator_cache(
|
||||||
|
self,
|
||||||
|
_mock_npu_stream_switch,
|
||||||
|
_mock_global_stream,
|
||||||
|
mock_empty_like,
|
||||||
|
mock_generator_ctor,
|
||||||
|
mock_npu,
|
||||||
|
):
|
||||||
|
# Same source generator should reuse one cached CPU generator.
|
||||||
|
sampler_310p._CPU_GENERATOR_CACHE_310P.clear()
|
||||||
|
probs = MagicMock()
|
||||||
|
probs.div_.return_value = probs
|
||||||
|
probs.argmax.return_value = probs
|
||||||
|
probs.view.return_value = torch.tensor([0])
|
||||||
|
|
||||||
|
fake_q_first = _FakeQ(batch_size=2)
|
||||||
|
fake_q_second = _FakeQ(batch_size=2)
|
||||||
|
|
||||||
|
npu_stream = MagicMock()
|
||||||
|
generator = MagicMock()
|
||||||
|
generator.get_state.return_value = b"state"
|
||||||
|
generator.initial_seed.return_value = 7
|
||||||
|
generators = {1: generator}
|
||||||
|
mock_empty_like.side_effect = [fake_q_first, fake_q_second]
|
||||||
|
mock_generator_ctor.side_effect = _FakeCPUGenerator
|
||||||
|
|
||||||
|
mock_npu.current_stream.return_value = npu_stream
|
||||||
|
sampler_310p._random_sample_310p(probs, generators)
|
||||||
|
sampler_310p._random_sample_310p(probs, generators)
|
||||||
|
|
||||||
|
self.assertEqual(mock_generator_ctor.call_count, 1)
|
||||||
|
self.assertIn(1, sampler_310p._CPU_GENERATOR_CACHE_310P)
|
||||||
|
cached_cpu_generator, source_generator_id = sampler_310p._CPU_GENERATOR_CACHE_310P[1]
|
||||||
|
self.assertIs(fake_q_first.rows[1].generators[0], cached_cpu_generator)
|
||||||
|
self.assertIs(fake_q_second.rows[1].generators[0], cached_cpu_generator)
|
||||||
|
self.assertEqual(source_generator_id, id(generator))
|
||||||
|
self.assertEqual(cached_cpu_generator.state, b"state")
|
||||||
|
self.assertIsNone(cached_cpu_generator.seed)
|
||||||
|
self.assertEqual(npu_stream.wait_stream.call_count, 2)
|
||||||
|
|
||||||
|
@patch.object(sampler_310p.torch, "npu", create=True)
|
||||||
|
@patch.object(sampler_310p.torch, "Generator")
|
||||||
|
@patch.object(sampler_310p.torch, "empty_like")
|
||||||
|
@patch.object(sampler_310p, "global_stream", return_value=MagicMock())
|
||||||
|
@patch.object(sampler_310p, "npu_stream_switch", return_value=nullcontext())
|
||||||
|
def test_random_sample_310p_fallback_to_initial_seed_when_set_state_failed(
|
||||||
|
self,
|
||||||
|
_mock_npu_stream_switch,
|
||||||
|
_mock_global_stream,
|
||||||
|
mock_empty_like,
|
||||||
|
mock_generator_ctor,
|
||||||
|
mock_npu,
|
||||||
|
):
|
||||||
|
# If syncing generator state fails, fallback to initial seed.
|
||||||
|
sampler_310p._CPU_GENERATOR_CACHE_310P.clear()
|
||||||
|
probs = MagicMock()
|
||||||
|
probs.div_.return_value = probs
|
||||||
|
probs.argmax.return_value = probs
|
||||||
|
probs.view.return_value = torch.tensor([1])
|
||||||
|
|
||||||
|
fake_q = _FakeQ(batch_size=1)
|
||||||
|
npu_stream = MagicMock()
|
||||||
|
generator = MagicMock()
|
||||||
|
generator.get_state.side_effect = RuntimeError("state read failed")
|
||||||
|
generator.initial_seed.return_value = 1234
|
||||||
|
generators = {0: generator}
|
||||||
|
|
||||||
|
class _FailSetStateCPUGenerator(_FakeCPUGenerator):
|
||||||
|
def set_state(self, state):
|
||||||
|
raise RuntimeError("state set failed")
|
||||||
|
mock_empty_like.return_value = fake_q
|
||||||
|
mock_generator_ctor.side_effect = _FailSetStateCPUGenerator
|
||||||
|
|
||||||
|
mock_npu.current_stream.return_value = npu_stream
|
||||||
|
sampler_310p._random_sample_310p(probs, generators)
|
||||||
|
|
||||||
|
cached_cpu_generator, source_generator_id = sampler_310p._CPU_GENERATOR_CACHE_310P[0]
|
||||||
|
self.assertEqual(source_generator_id, id(generator))
|
||||||
|
self.assertEqual(cached_cpu_generator.seed, 1234)
|
||||||
|
self.assertIs(fake_q.rows[0].generators[0], cached_cpu_generator)
|
||||||
|
self.assertEqual(npu_stream.wait_stream.call_count, 1)
|
||||||
|
|
||||||
|
@patch.object(sampler_310p.torch, "npu", create=True)
|
||||||
|
@patch.object(sampler_310p.torch, "Generator")
|
||||||
|
@patch.object(sampler_310p.torch, "empty_like")
|
||||||
|
@patch.object(sampler_310p, "global_stream", return_value=MagicMock())
|
||||||
|
@patch.object(sampler_310p, "npu_stream_switch", return_value=nullcontext())
|
||||||
|
def test_random_sample_310p_rebuild_cache_when_generator_identity_changes(
|
||||||
|
self,
|
||||||
|
_mock_npu_stream_switch,
|
||||||
|
_mock_global_stream,
|
||||||
|
mock_empty_like,
|
||||||
|
mock_generator_ctor,
|
||||||
|
mock_npu,
|
||||||
|
):
|
||||||
|
# A new source generator object should rebuild cache entry.
|
||||||
|
sampler_310p._CPU_GENERATOR_CACHE_310P.clear()
|
||||||
|
probs = MagicMock()
|
||||||
|
probs.div_.return_value = probs
|
||||||
|
probs.argmax.return_value = probs
|
||||||
|
probs.view.return_value = torch.tensor([0])
|
||||||
|
|
||||||
|
fake_q_first = _FakeQ(batch_size=1)
|
||||||
|
fake_q_second = _FakeQ(batch_size=1)
|
||||||
|
npu_stream = MagicMock()
|
||||||
|
|
||||||
|
generator_first = MagicMock()
|
||||||
|
generator_first.get_state.return_value = b"state-1"
|
||||||
|
generator_first.initial_seed.return_value = 11
|
||||||
|
|
||||||
|
generator_second = MagicMock()
|
||||||
|
generator_second.get_state.return_value = b"state-2"
|
||||||
|
generator_second.initial_seed.return_value = 22
|
||||||
|
mock_empty_like.side_effect = [fake_q_first, fake_q_second]
|
||||||
|
mock_generator_ctor.side_effect = _FakeCPUGenerator
|
||||||
|
|
||||||
|
mock_npu.current_stream.return_value = npu_stream
|
||||||
|
sampler_310p._random_sample_310p(probs, {0: generator_first})
|
||||||
|
sampler_310p._random_sample_310p(probs, {0: generator_second})
|
||||||
|
|
||||||
|
self.assertEqual(mock_generator_ctor.call_count, 2)
|
||||||
|
first_cpu_generator = fake_q_first.rows[0].generators[0]
|
||||||
|
second_cpu_generator = fake_q_second.rows[0].generators[0]
|
||||||
|
self.assertIsNot(first_cpu_generator, second_cpu_generator)
|
||||||
|
self.assertEqual(first_cpu_generator.state, b"state-1")
|
||||||
|
self.assertEqual(second_cpu_generator.state, b"state-2")
|
||||||
|
cached_cpu_generator, source_generator_id = sampler_310p._CPU_GENERATOR_CACHE_310P[0]
|
||||||
|
self.assertIs(cached_cpu_generator, second_cpu_generator)
|
||||||
|
self.assertEqual(source_generator_id, id(generator_second))
|
||||||
@@ -25,6 +25,8 @@ from vllm_ascend.sample.sampler import (
|
|||||||
)
|
)
|
||||||
from vllm_ascend.utils import global_stream, npu_stream_switch
|
from vllm_ascend.utils import global_stream, npu_stream_switch
|
||||||
|
|
||||||
|
_CPU_GENERATOR_CACHE_310P: dict[int, tuple[torch.Generator, int]] = {}
|
||||||
|
|
||||||
|
|
||||||
def _random_sample_310p(
|
def _random_sample_310p(
|
||||||
probs: torch.Tensor,
|
probs: torch.Tensor,
|
||||||
@@ -38,7 +40,18 @@ def _random_sample_310p(
|
|||||||
q.exponential_()
|
q.exponential_()
|
||||||
if generators:
|
if generators:
|
||||||
for i, generator in generators.items():
|
for i, generator in generators.items():
|
||||||
q[i].exponential_(generator=generator)
|
cache_entry = _CPU_GENERATOR_CACHE_310P.get(i)
|
||||||
|
if cache_entry is None or cache_entry[1] != id(generator):
|
||||||
|
cpu_generator = torch.Generator(device="cpu")
|
||||||
|
try:
|
||||||
|
# Keep RNG stream consistent with the original generator.
|
||||||
|
cpu_generator.set_state(generator.get_state())
|
||||||
|
except Exception:
|
||||||
|
cpu_generator.manual_seed(generator.initial_seed())
|
||||||
|
cache_entry = (cpu_generator, id(generator))
|
||||||
|
_CPU_GENERATOR_CACHE_310P[i] = cache_entry
|
||||||
|
cpu_generator, _ = cache_entry
|
||||||
|
q[i].exponential_(generator=cpu_generator)
|
||||||
q = q.npu()
|
q = q.npu()
|
||||||
torch.npu.current_stream().wait_stream(global_stream())
|
torch.npu.current_stream().wait_stream(global_stream())
|
||||||
return probs.div_(q).argmax(dim=-1).view(-1)
|
return probs.div_(q).argmax(dim=-1).view(-1)
|
||||||
|
|||||||
Reference in New Issue
Block a user