Signed-off-by: simpx <simpxx@gmail.com>
This commit is contained in:
@ -23,9 +23,9 @@ from vllm.transformers_utils.detokenizer_utils import (
|
||||
from vllm.utils import (CacheInfo, FlexibleArgumentParser, LRUCache,
|
||||
MemorySnapshot, PlaceholderModule, StoreBoolean,
|
||||
bind_kv_cache, common_broadcastable_dtype,
|
||||
deprecate_kwargs, get_open_port, get_tcp_uri,
|
||||
is_lossless_cast, join_host_port, make_zmq_path,
|
||||
make_zmq_socket, memory_profiling,
|
||||
current_stream, deprecate_kwargs, get_open_port,
|
||||
get_tcp_uri, is_lossless_cast, join_host_port,
|
||||
make_zmq_path, make_zmq_socket, memory_profiling,
|
||||
merge_async_iterators, sha256, split_host_port,
|
||||
split_zmq_path, supports_kw, swap_dict_values)
|
||||
|
||||
@ -957,3 +957,41 @@ def test_convert_ids_list_to_tokens():
|
||||
]
|
||||
tokens = convert_ids_list_to_tokens(tokenizer, token_ids)
|
||||
assert tokens == ['Hello', ',', ' world', '!']
|
||||
|
||||
|
||||
def test_current_stream_multithread():
|
||||
import threading
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("CUDA not available")
|
||||
|
||||
main_default_stream = torch.cuda.current_stream()
|
||||
child_stream = torch.cuda.Stream()
|
||||
|
||||
thread_stream_ready = threading.Event()
|
||||
thread_can_exit = threading.Event()
|
||||
|
||||
def child_thread_func():
|
||||
with torch.cuda.stream(child_stream):
|
||||
thread_stream_ready.set()
|
||||
thread_can_exit.wait(timeout=10)
|
||||
|
||||
child_thread = threading.Thread(target=child_thread_func)
|
||||
child_thread.start()
|
||||
|
||||
try:
|
||||
assert thread_stream_ready.wait(
|
||||
timeout=5), "Child thread failed to enter stream context in time"
|
||||
|
||||
main_current_stream = current_stream()
|
||||
|
||||
assert main_current_stream != child_stream, "Main thread's current_stream was contaminated by child thread"
|
||||
assert main_current_stream == main_default_stream, "Main thread's current_stream is not the default stream"
|
||||
|
||||
# Notify child thread it can exit
|
||||
thread_can_exit.set()
|
||||
|
||||
finally:
|
||||
# Ensure child thread exits properly
|
||||
child_thread.join(timeout=5)
|
||||
if child_thread.is_alive():
|
||||
pytest.fail("Child thread failed to exit properly")
|
||||
|
||||
Reference in New Issue
Block a user