mirror of
https://github.com/handsomezhuzhu/vllm-npu-plugin.git
synced 2026-02-20 11:42:30 +00:00
23 lines
709 B
Python
23 lines
709 B
Python
from .context import (get_multistream_layer_context,
|
|
get_multistream_microbatch_context)
|
|
|
|
|
|
# vllm v1 use get_forward_context to get the attn_metadata,
|
|
# we can use this decorator to update the attn metadata
|
|
def set_multistream_support():
|
|
|
|
def decorator(func):
|
|
|
|
def wrapper():
|
|
context = func()
|
|
layer_index, ms_metadata, attn_metadata = get_multistream_layer_context(
|
|
)
|
|
micro_batch_num = get_multistream_microbatch_context()
|
|
if layer_index != -1 and micro_batch_num != -1:
|
|
context.attn_metadata = attn_metadata[micro_batch_num]
|
|
return context
|
|
|
|
return wrapper
|
|
|
|
return decorator
|