mirror of
https://github.com/handsomezhuzhu/vllm-npu-plugin.git
synced 2026-02-20 19:50:15 +00:00
fix: pure pytorch reshape_and_cache + _npu_flash_attention prefill
This commit is contained in:
@@ -322,24 +322,20 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
Matches Huawei vllm-ascend: splits kv_cache[0]/[1] and writes via
|
||||
slot_mapping indices.
|
||||
"""
|
||||
import torch_npu # noqa: F401
|
||||
|
||||
if kv_cache.numel() > 0:
|
||||
if self._key_cache is None:
|
||||
self._key_cache, self._value_cache = kv_cache[0], kv_cache[1]
|
||||
|
||||
# Ensure contiguous tensors for the NPU op
|
||||
key = key.contiguous()
|
||||
value = value.contiguous()
|
||||
slots = attn_metadata.slot_mapping.long() # indices must be long
|
||||
slots = attn_metadata.slot_mapping
|
||||
key_to_cache = key[:attn_metadata.num_actual_tokens]
|
||||
val_to_cache = value[:attn_metadata.num_actual_tokens]
|
||||
|
||||
torch_npu._npu_reshape_and_cache(
|
||||
key,
|
||||
value,
|
||||
self._key_cache,
|
||||
self._value_cache,
|
||||
slots,
|
||||
)
|
||||
# Use pure-PyTorch indexing (ATB reshape_and_cache crashes on this env)
|
||||
block_size = self._key_cache.shape[1]
|
||||
block_idx = slots // block_size
|
||||
block_offset = slots % block_size
|
||||
self._key_cache[block_idx, block_offset] = key_to_cache
|
||||
self._value_cache[block_idx, block_offset] = val_to_cache
|
||||
|
||||
return key, value
|
||||
|
||||
|
||||
Reference in New Issue
Block a user