mirror of
https://github.com/handsomezhuzhu/vllm-npu-plugin.git
synced 2026-02-20 19:50:15 +00:00
fix: pure pytorch reshape_and_cache + _npu_flash_attention prefill
This commit is contained in:
@@ -322,24 +322,20 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
Matches Huawei vllm-ascend: splits kv_cache[0]/[1] and writes via
|
Matches Huawei vllm-ascend: splits kv_cache[0]/[1] and writes via
|
||||||
slot_mapping indices.
|
slot_mapping indices.
|
||||||
"""
|
"""
|
||||||
import torch_npu # noqa: F401
|
|
||||||
|
|
||||||
if kv_cache.numel() > 0:
|
if kv_cache.numel() > 0:
|
||||||
if self._key_cache is None:
|
if self._key_cache is None:
|
||||||
self._key_cache, self._value_cache = kv_cache[0], kv_cache[1]
|
self._key_cache, self._value_cache = kv_cache[0], kv_cache[1]
|
||||||
|
|
||||||
# Ensure contiguous tensors for the NPU op
|
slots = attn_metadata.slot_mapping
|
||||||
key = key.contiguous()
|
key_to_cache = key[:attn_metadata.num_actual_tokens]
|
||||||
value = value.contiguous()
|
val_to_cache = value[:attn_metadata.num_actual_tokens]
|
||||||
slots = attn_metadata.slot_mapping.long() # indices must be long
|
|
||||||
|
|
||||||
torch_npu._npu_reshape_and_cache(
|
# Use pure-PyTorch indexing (ATB reshape_and_cache crashes on this env)
|
||||||
key,
|
block_size = self._key_cache.shape[1]
|
||||||
value,
|
block_idx = slots // block_size
|
||||||
self._key_cache,
|
block_offset = slots % block_size
|
||||||
self._value_cache,
|
self._key_cache[block_idx, block_offset] = key_to_cache
|
||||||
slots,
|
self._value_cache[block_idx, block_offset] = val_to_cache
|
||||||
)
|
|
||||||
|
|
||||||
return key, value
|
return key, value
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user