You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
manifest/tests/test_array_cache.py

82 lines
2.2 KiB
Python

"""Array cache test."""
from pathlib import Path
import numpy as np
import pytest
from manifest.caches.array_cache import ArrayCache
def test_init(tmpdir: Path) -> None:
"""Test cache initialization."""
cache = ArrayCache(Path(tmpdir))
assert (tmpdir / "hash2arrloc.sqlite").exists()
assert cache.cur_file_idx == 0
assert cache.cur_offset == 0
def test_put_get(tmpdir: Path) -> None:
"""Test putting and getting."""
cache = ArrayCache(tmpdir)
cache.max_memmap_size = 5
arr = np.random.rand(10, 10)
with pytest.raises(ValueError) as exc_info:
cache.put("key", arr)
assert str(exc_info.value) == ("Array is too large to be cached. Max is 5")
cache.max_memmap_size = 120
cache.put("key", arr)
assert np.allclose(cache.get("key"), arr)
assert cache.get("key").dtype == arr.dtype
assert cache.cur_file_idx == 0
assert cache.cur_offset == 100
assert cache.hash2arrloc["key"] == {
"file_idx": 0,
"offset": 0,
"flatten_size": 100,
"shape": (10, 10),
"dtype": np.dtype("float64"),
}
arr2 = np.random.randint(0, 3, size=(10, 10))
cache.put("key2", arr2)
assert np.allclose(cache.get("key2"), arr2)
assert cache.get("key2").dtype == arr2.dtype
assert cache.cur_file_idx == 1
assert cache.cur_offset == 100
assert cache.hash2arrloc["key2"] == {
"file_idx": 1,
"offset": 0,
"flatten_size": 100,
"shape": (10, 10),
"dtype": np.dtype("int64"),
}
cache = ArrayCache(tmpdir)
assert cache.hash2arrloc["key"] == {
"file_idx": 0,
"offset": 0,
"flatten_size": 100,
"shape": (10, 10),
"dtype": np.dtype("float64"),
}
assert cache.hash2arrloc["key2"] == {
"file_idx": 1,
"offset": 0,
"flatten_size": 100,
"shape": (10, 10),
"dtype": np.dtype("int64"),
}
assert np.allclose(cache.get("key"), arr)
assert np.allclose(cache.get("key2"), arr2)
def test_contains_key(tmpdir: Path) -> None:
"""Test contains key."""
cache = ArrayCache(tmpdir)
assert not cache.contains_key("key")
arr = np.random.rand(10, 10)
cache.put("key", arr)
assert cache.contains_key("key")