diff --git a/Cargo.lock b/Cargo.lock index b65475315..3eec2a64f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -691,6 +691,7 @@ dependencies = [ name = "diskann-utils" version = "0.46.0" dependencies = [ + "bytemuck", "cfg-if", "diskann-vector", "diskann-wide", diff --git a/diskann-benchmark/src/inputs/async_.rs b/diskann-benchmark/src/inputs/async_.rs index 66cd65d18..19230977d 100644 --- a/diskann-benchmark/src/inputs/async_.rs +++ b/diskann-benchmark/src/inputs/async_.rs @@ -384,8 +384,8 @@ impl IndexLoad { let index_config = IndexConfiguration::new( self.distance.into(), - metadata.ndims, - metadata.npoints, + metadata.ndims(), + metadata.npoints(), num_frozen_pts, 1, config, diff --git a/diskann-benchmark/src/utils/datafiles.rs b/diskann-benchmark/src/utils/datafiles.rs index c4cf8b881..9c5057488 100644 --- a/diskann-benchmark/src/utils/datafiles.rs +++ b/diskann-benchmark/src/utils/datafiles.rs @@ -22,12 +22,11 @@ pub(crate) fn load_dataset(path: BinFile<'_>) -> anyhow::Result> where T: Copy + bytemuck::Pod, { - let (data, num_data, data_dim) = diskann_providers::utils::file_util::load_bin::( - &diskann_providers::storage::FileStorageProvider, - &path.0.to_string_lossy(), - 0, + let data = diskann_utils::io::read_bin::( + &mut diskann_providers::storage::FileStorageProvider + .open_reader(&path.0.to_string_lossy())?, )?; - Ok(Matrix::try_from(data.into(), num_data, data_dim).map_err(|err| err.as_static())?) + Ok(data) } /// Helper trait to load a `Matrix` from source files that potentially have a different diff --git a/diskann-disk/src/build/builder/build.rs b/diskann-disk/src/build/builder/build.rs index ca23daee7..8eabad038 100644 --- a/diskann-disk/src/build/builder/build.rs +++ b/diskann-disk/src/build/builder/build.rs @@ -26,10 +26,12 @@ use diskann_providers::{ }, storage::{AsyncIndexMetadata, DiskGraphOnly, PQStorage}, utils::{ - create_thread_pool, find_medoid_with_sampling, load_bin, save_bin_u32, RayonThreadPool, - VectorDataIterator, MAX_MEDOID_SAMPLE_SIZE, + create_thread_pool, find_medoid_with_sampling, RayonThreadPool, VectorDataIterator, + MAX_MEDOID_SAMPLE_SIZE, }, }; +use diskann_utils::io::{read_bin, write_bin}; +use diskann_utils::views::MatrixView; use tokio::task::JoinSet; use tracing::{debug, info}; @@ -880,9 +882,9 @@ impl StartPoint { path ))); } - let (data, _, _) = load_bin::(&mut reader.open_reader(path)?, 0)?; + let data = read_bin::(&mut reader.open_reader(path)?)?; - let start_point_id = data.first().ok_or_else(|| { + let start_point_id = data.try_get(0, 0).ok_or_else(|| { ANNError::log_invalid_file_format(format!("Start point ID file {} is empty", path)) })?; @@ -894,12 +896,9 @@ impl StartPoint { where StorageWriter: StorageWriteProvider, { - save_bin_u32( + write_bin( + MatrixView::row_vector(std::slice::from_ref(&self.0)), &mut storage_provider.create_for_write(path)?, - std::slice::from_ref(&self.0), - 1, - 1, - 0, )?; debug!("Saved start point ID {} to {}", self.0, path); Ok(()) @@ -915,7 +914,7 @@ mod start_point_tests { use std::io::Write; use diskann_providers::storage::VirtualStorageProvider; - use diskann_providers::utils::write_metadata; + use diskann_utils::io::Metadata; use super::*; @@ -976,7 +975,7 @@ mod start_point_tests { let mut file = storage_provider.create_for_write(file_path).unwrap(); let npts = 0; let dim = 1; - write_metadata(&mut file, npts, dim).unwrap(); + Metadata::new(npts, dim).unwrap().write(&mut file).unwrap(); } let result = StartPoint::load(file_path, &storage_provider); diff --git a/diskann-disk/src/build/builder/core.rs b/diskann-disk/src/build/builder/core.rs index d27182e83..c7f21b682 100644 --- a/diskann-disk/src/build/builder/core.rs +++ b/diskann-disk/src/build/builder/core.rs @@ -13,10 +13,11 @@ use diskann_providers::{ }, storage::PQStorage, utils::{ - load_bin, load_metadata_from_file, RayonThreadPool, SampleVectorReader, SamplingDensity, + load_metadata_from_file, RayonThreadPool, SampleVectorReader, SamplingDensity, READ_WRITE_BLOCK_SIZE, }, }; +use diskann_utils::io::read_bin; use rand::{seq::SliceRandom, Rng}; use tracing::info; @@ -129,7 +130,7 @@ where let metadata = load_metadata_from_file(storage_provider, shard_base_file)?; let mut index_config = base_config.clone(); - index_config.max_points = metadata.npoints; + index_config.max_points = metadata.npoints(); index_config.config = low_degree_params; Ok(index_config) @@ -145,13 +146,11 @@ where T: Default + bytemuck::Pod, { let storage_provider = self.storage_provider; - let (shard_ids, shard_size, _) = load_bin::( - &mut storage_provider.open_reader(shard_ids_file)?, - 0, - )?; + let shard_ids = read_bin::(&mut storage_provider.open_reader(shard_ids_file)?)?; + let shard_size = shard_ids.nrows(); info!("Loaded {} shard ids from {}", shard_size, shard_ids_file); - let max_id = shard_ids.iter().max().copied().unwrap_or(0); - let sampling_rate = shard_ids.len() as f64 / (max_id + 1) as f64; + let max_id = shard_ids.as_slice().iter().max().copied().unwrap_or(0); + let sampling_rate = shard_ids.as_slice().len() as f64 / (max_id + 1) as f64; let mut dataset_reader: SampleVectorReader = SampleVectorReader::new( dataset_file, @@ -172,7 +171,7 @@ where shard_base_cached_writer.write(&dim.to_le_bytes())?; let mut num_written: u32 = 0; - dataset_reader.read_vectors(shard_ids.iter().copied(), |vector_t| { + dataset_reader.read_vectors(shard_ids.as_slice().iter().copied(), |vector_t| { // Casting Pod type to bytes always succeeds (u8 has alignment of 1) let vector_bytes: &[u8] = bytemuck::must_cast_slice(vector_t); shard_base_cached_writer.write(vector_bytes)?; @@ -384,12 +383,9 @@ where Ok(()) } - fn read_idmap(&self, idmaps_path: String) -> std::io::Result> { - let (data, _npts, _dim) = load_bin::( - &mut self.storage_provider.open_reader(&idmaps_path)?, - 0, - )?; - Ok(data) + fn read_idmap(&self, idmaps_path: String) -> Result, diskann_utils::io::ReadBinError> { + let data = read_bin::(&mut self.storage_provider.open_reader(&idmaps_path)?)?; + Ok(data.into_inner().into_vec()) } fn merge_shards_and_cleanup( @@ -644,7 +640,7 @@ pub(crate) mod disk_index_builder_tests { test_utils::graph_data_type_utils::{ GraphDataF32VectorU32Data, GraphDataF32VectorUnitData, }, - utils::{file_util, BridgeErr, Timer}, + utils::Timer, }; use diskann_utils::test_data_root; use diskann_vector::{ @@ -787,15 +783,17 @@ pub(crate) mod disk_index_builder_tests { .unwrap(); assert_eq!( - self.params.dim, metadata.ndims, + self.params.dim, + metadata.ndims(), "Parameters dimension {} and data dimension {} are not equal", - self.params.dim, metadata.ndims + self.params.dim, + metadata.ndims(), ); let config = IndexConfiguration::new( self.params.metric, self.params.dim, - metadata.npoints, + metadata.npoints(), ONE, self.params.num_threads, config, @@ -1067,13 +1065,9 @@ pub(crate) mod disk_index_builder_tests { None, )?; - let (data, npoints, dim) = file_util::load_bin::( - storage_provider.as_ref(), - ¶ms.data_path, - 0, - )?; let data = - diskann_utils::views::Matrix::try_from(data.into(), npoints, dim).bridge_err()?; + read_bin::(&mut storage_provider.open_reader(¶ms.data_path)?)?; + let dim = data.ncols(); let distance = ::distance(params.metric, Some(dim)); // Here, we use elements of the dataset to search the dataset itself. diff --git a/diskann-disk/src/search/provider/disk_provider.rs b/diskann-disk/src/search/provider/disk_provider.rs index 00c17f8c9..25916df4f 100644 --- a/diskann-disk/src/search/provider/disk_provider.rs +++ b/diskann-disk/src/search/provider/disk_provider.rs @@ -1059,11 +1059,9 @@ mod disk_provider_tests { test_utils::graph_data_type_utils::{ GraphDataF32VectorU32Data, GraphDataF32VectorUnitData, }, - utils::{ - create_thread_pool, file_util, load_aligned_bin, PQPathNames, ParallelIteratorInPool, - }, + utils::{create_thread_pool, load_aligned_bin, PQPathNames, ParallelIteratorInPool}, }; - use diskann_utils::test_data_root; + use diskann_utils::{io::read_bin, test_data_root}; use diskann_vector::distance::Metric; use rayon::prelude::{IndexedParallelIterator, IntoParallelRefIterator}; use rstest::rstest; @@ -1189,10 +1187,10 @@ mod disk_provider_tests { ) -> Vec { const ASSOCIATED_DATA_FILE: &str = "/sift/siftsmall_learn_256pts_u32_associated_data.fbin"; - let (data, _npts, _dim) = - file_util::load_bin::(storage_provider, ASSOCIATED_DATA_FILE, 0) + let data = + read_bin::(&mut storage_provider.open_reader(ASSOCIATED_DATA_FILE).unwrap()) .unwrap(); - data + data.into_inner().into_vec() } #[test] @@ -1339,10 +1337,9 @@ mod disk_provider_tests { storage_provider: &StorageReader, query_result_path: &str, ) -> Vec { - let (result, _, _) = - file_util::load_bin::(storage_provider, query_result_path, 0) - .unwrap(); - result + let result = + read_bin::(&mut storage_provider.open_reader(query_result_path).unwrap()).unwrap(); + result.into_inner().into_vec() } struct TestDiskSearchParams<'a, StorageType> { diff --git a/diskann-disk/src/search/provider/disk_vertex_provider.rs b/diskann-disk/src/search/provider/disk_vertex_provider.rs index 3351012ab..d59fd7760 100644 --- a/diskann-disk/src/search/provider/disk_vertex_provider.rs +++ b/diskann-disk/src/search/provider/disk_vertex_provider.rs @@ -307,7 +307,7 @@ mod disk_vertex_provider_tests { let metadata = load_metadata_from_file(storage_provider, data_path).unwrap(); let memory_budget = MemoryBudget::try_from_gb(1.0).unwrap(); - let num_pq_chunks = NumPQChunks::new_with(128, metadata.ndims).unwrap(); + let num_pq_chunks = NumPQChunks::new_with(128, metadata.ndims()).unwrap(); let disk_index_build_parameters = DiskIndexBuildParameters::new(memory_budget, QuantizationType::FP, num_pq_chunks); @@ -326,8 +326,8 @@ mod disk_vertex_provider_tests { let config = IndexConfiguration::new( diskann_vector::distance::Metric::L2, - metadata.ndims, - metadata.npoints, + metadata.ndims(), + metadata.npoints(), ONE, 1, config, diff --git a/diskann-disk/src/storage/disk_index_reader.rs b/diskann-disk/src/storage/disk_index_reader.rs index 0669878be..319207a2c 100644 --- a/diskann-disk/src/storage/disk_index_reader.rs +++ b/diskann-disk/src/storage/disk_index_reader.rs @@ -40,20 +40,20 @@ impl DiskIndexReader { let pq_compressed_data = PQStorage::load_pq_compressed_vectors_bin::( &pq_compressed_data_path, - metadata.npoints, + metadata.npoints(), pq_pivot_table.get_num_chunks(), storage_provider, )?; info!( "Loaded PQ centroids and in-memory compressed vectors. #points:{} #pq_chunks: {}", - metadata.npoints, + metadata.npoints(), pq_pivot_table.get_num_chunks() ); Ok(DiskIndexReader { phantom: PhantomData, pq_data: Arc::::new(PQData::new(pq_pivot_table, pq_compressed_data)?), - num_points: metadata.npoints, + num_points: metadata.npoints(), }) } diff --git a/diskann-disk/src/storage/quant/generator.rs b/diskann-disk/src/storage/quant/generator.rs index 89af3547d..8e5006d6f 100644 --- a/diskann-disk/src/storage/quant/generator.rs +++ b/diskann-disk/src/storage/quant/generator.rs @@ -12,12 +12,9 @@ use diskann::{error::IntoANNResult, utils::VectorRepr, ANNError, ANNResult}; use diskann_providers::storage::{StorageReadProvider, StorageWriteProvider}; use diskann_providers::{ forward_threadpool, - utils::{ - load_metadata_from_file, write_metadata, AsThreadPool, BridgeErr, ParallelIteratorInPool, - Timer, - }, + utils::{load_metadata_from_file, AsThreadPool, BridgeErr, ParallelIteratorInPool, Timer}, }; -use diskann_utils::views::{self}; +use diskann_utils::{io::Metadata, views}; use rayon::iter::IndexedParallelIterator; use tracing::info; @@ -115,7 +112,7 @@ where let timer = Timer::new(); let metadata = load_metadata_from_file(storage_provider, &self.data_path)?; - let (num_points, dim) = (metadata.npoints, metadata.ndims); + let (num_points, dim) = metadata.into_dims(); self.validate_params(num_points, storage_provider)?; @@ -135,8 +132,8 @@ where storage_provider.open_writer(compressed_path)? } else { let mut sp = storage_provider.create_for_write(compressed_path)?; - // write meatadata to header - write_metadata(&mut sp, num_points, self.quantizer.compressed_bytes())?; + // write metadata to header + Metadata::new(num_points, self.quantizer.compressed_bytes())?.write(&mut sp)?; sp }; @@ -279,8 +276,10 @@ mod generator_tests { use diskann::utils::read_exact_into; use diskann_providers::storage::VirtualStorageProvider; - use diskann_providers::utils::{ - create_thread_pool_for_test, read_metadata, save_bin_f32, save_bytes, + use diskann_providers::utils::{create_thread_pool_for_test, save_bytes}; + use diskann_utils::{ + io::{write_bin, Metadata}, + views::MatrixView, }; use rstest::rstest; use vfs::{FileSystem, MemoryFS}; @@ -384,12 +383,11 @@ mod generator_tests { let compressed_path = "/test_data/test_compressed.bin".to_string(); // Setup test data - let _ = save_bin_f32( + let data = create_test_data(num_points, dim); + let view = MatrixView::try_from(data.as_slice(), num_points, dim).unwrap(); + write_bin( + view, &mut storage_provider.create_for_write(data_path.as_str())?, - &create_test_data(num_points, dim), - num_points, - dim, - 0, )?; if offset > 0 { @@ -483,13 +481,13 @@ mod generator_tests { let mut r = storage_provider.open_reader(compressed_path.as_str())?; let mut reader = BufReader::new(&mut r); - let metadata = read_metadata(&mut reader)?; + let metadata = Metadata::read(&mut reader)?; let data: Vec = read_exact_into(&mut reader, expected_size)?; // Check header - assert_eq!(metadata.ndims as u32, output_dim); - assert_eq!(metadata.npoints, num_points); + assert_eq!(metadata.ndims_u32(), output_dim); + assert_eq!(metadata.npoints(), num_points); // Check compressed data content data.chunks_exact(output_dim as usize) @@ -542,14 +540,14 @@ mod generator_tests { let mut r = storage_provider.open_reader(compressed_path.as_str())?; let mut reader = BufReader::new(&mut r); - let metadata = read_metadata(&mut reader)?; + let metadata = Metadata::read(&mut reader)?; let data: Vec = read_exact_into(&mut reader, expected_size - 2 * std::mem::size_of::())?; // Check header - assert_eq!(metadata.ndims as u32, output_dim); - assert_eq!(metadata.npoints, num_points); + assert_eq!(metadata.ndims_u32(), output_dim); + assert_eq!(metadata.npoints(), num_points); // Check compressed data content data.chunks_exact(output_dim as usize) diff --git a/diskann-disk/src/storage/quant/pq/pq_generation.rs b/diskann-disk/src/storage/quant/pq/pq_generation.rs index 3fb2be293..32c97becc 100644 --- a/diskann-disk/src/storage/quant/pq/pq_generation.rs +++ b/diskann-disk/src/storage/quant/pq/pq_generation.rs @@ -186,12 +186,15 @@ mod pq_generation_tests { use diskann::ANNError; use diskann_providers::model::pq::generate_pq_pivots; use diskann_providers::model::GeneratePivotArguments; - use diskann_providers::storage::{PQStorage, StorageWriteProvider, VirtualStorageProvider}; - use diskann_providers::utils::{ - create_thread_pool_for_test, file_util::load_bin, save_bin_f32, AsThreadPool, + use diskann_providers::storage::{ + PQStorage, StorageReadProvider, StorageWriteProvider, VirtualStorageProvider, + }; + use diskann_providers::utils::{create_thread_pool_for_test, AsThreadPool}; + use diskann_utils::{ + io::{read_bin, write_bin}, + test_data_root, + views::{MatrixView, MutMatrixView}, }; - use diskann_utils::test_data_root; - use diskann_utils::views::{MatrixView, MutMatrixView}; use diskann_vector::distance::Metric; use rstest::rstest; use vfs::FileSystem; @@ -257,13 +260,11 @@ mod pq_generation_tests { let (ndata, dim, num_centers, num_chunks, max_k_means_reps) = (5, 8, 2, 2, 5); let mut train_data: Vec = VALIDATION_DATA.to_vec(); - let _ = save_bin_f32( + write_bin( + MatrixView::try_from(train_data.as_slice(), ndata, dim).unwrap(), &mut storage_provider.create_for_write(data_path).unwrap(), - &train_data, - ndata, - dim, - 0, - ); + ) + .unwrap(); let pool = create_thread_pool_for_test(); generate_pq_pivots( @@ -309,12 +310,14 @@ mod pq_generation_tests { assert_eq!(compressor.table.nchunks(), num_chunks); assert!(&storage_provider.exists(pivot_file_name_compressor)); - let (compressor_pivots, cn, cd) = - load_bin::(&storage_provider, pivot_file_name_compressor, 0).unwrap(); - let (true_pivots, n, d) = load_bin::(&storage_provider, pivot_file_name, 0).unwrap(); - - assert_eq!(cn, n); - assert_eq!(cd, d); + let compressor_pivots = read_bin::( + &mut storage_provider + .open_reader(pivot_file_name_compressor) + .unwrap(), + ) + .unwrap(); + let true_pivots = + read_bin::(&mut storage_provider.open_reader(pivot_file_name).unwrap()).unwrap(); assert_eq!(compressor_pivots, true_pivots); } @@ -332,13 +335,11 @@ mod pq_generation_tests { let (ndata, dim, num_centers, num_chunks, max_k_means_reps) = (5, 8, 2, 2, 5); - let _ = save_bin_f32( + write_bin( + MatrixView::try_from(VALIDATION_DATA.as_slice(), ndata, dim).unwrap(), &mut storage_provider.create_for_write(data_path).unwrap(), - &VALIDATION_DATA, - ndata, - dim, - 0, - ); + ) + .unwrap(); let pool = create_thread_pool_for_test(); let compressor = create_new_compressor( @@ -387,18 +388,23 @@ mod pq_generation_tests { assert!(compressor.is_ok()); - let (data, npts, dim) = - load_bin::(&storage_provider, TEST_PQ_DATA_PATH, 0).unwrap(); + let data_matrix = + read_bin::(&mut storage_provider.open_reader(TEST_PQ_DATA_PATH).unwrap()).unwrap(); + let npts = data_matrix.nrows(); let mut compressed_mat = vec![0_u8; num_chunks * npts]; let result = compressor.unwrap().compress( - MatrixView::try_from(&data, npts, dim).unwrap(), + data_matrix.as_view(), MutMatrixView::try_from(&mut compressed_mat, npts, num_chunks).unwrap(), ); assert!(result.is_ok()); - let (compressed_gt, _, _) = - load_bin::(&storage_provider, TEST_PQ_COMPRESSED_PATH, 0).unwrap(); - assert_eq!(compressed_gt, compressed_mat); + let compressed_gt = read_bin::( + &mut storage_provider + .open_reader(TEST_PQ_COMPRESSED_PATH) + .unwrap(), + ) + .unwrap(); + assert_eq!(compressed_gt.as_slice(), &compressed_mat); } #[rstest] diff --git a/diskann-providers/benches/benchmarks/copy_aligned_data_bench.rs b/diskann-providers/benches/benchmarks/copy_aligned_data_bench.rs index 818ba594d..e8ffae784 100644 --- a/diskann-providers/benches/benchmarks/copy_aligned_data_bench.rs +++ b/diskann-providers/benches/benchmarks/copy_aligned_data_bench.rs @@ -10,8 +10,9 @@ use diskann::ANNResult; use diskann_providers::{ model::PQCompressedData, storage::{StorageReadProvider, StorageWriteProvider, VirtualStorageProvider}, - utils::{copy_aligned_data, write_metadata}, + utils::copy_aligned_data, }; +use diskann_utils::io::Metadata; use rand::Rng; use tempfile::TempDir; @@ -53,7 +54,7 @@ fn generate_random_data( dims: usize, ) -> ANNResult<()> { let mut writer = BufWriter::new(writer); - write_metadata(&mut writer, npts, dims)?; + Metadata::new(npts, dims)?.write(&mut writer)?; let mut rng = diskann_providers::utils::create_rnd_in_tests(); let data: Vec = (0..dims).map(|_| rng.random()).collect(); diff --git a/diskann-providers/benches/benchmarks/diskann_bench.rs b/diskann-providers/benches/benchmarks/diskann_bench.rs index fee192d28..16446b823 100644 --- a/diskann-providers/benches/benchmarks/diskann_bench.rs +++ b/diskann-providers/benches/benchmarks/diskann_bench.rs @@ -6,6 +6,7 @@ use std::time::Duration; use criterion::Criterion; use diskann::provider::DefaultContext; +use diskann_providers::storage::StorageReadProvider; use diskann_providers::{ index::diskann_async, model::graph::{ @@ -16,9 +17,9 @@ use diskann_providers::{ traits::AdHoc, }, storage::FileStorageProvider, - utils::{VectorDataIterator, create_thread_pool_for_bench, file_util::load_bin}, + utils::{VectorDataIterator, create_thread_pool_for_bench}, }; -use diskann_utils::views::MatrixView; +use diskann_utils::io::read_bin; use diskann_vector::distance::Metric; use tokio::runtime::Runtime; @@ -52,12 +53,16 @@ async fn test_sift_256_vectors_with_quant_vectors() { ) .unwrap(); - let (train_data, num_points, num_dim) = - load_bin(&storage_provider, get_test_file_path(file_path).as_str(), 0).unwrap(); + let train_data = read_bin::( + &mut storage_provider + .open_reader(get_test_file_path(file_path).as_str()) + .unwrap(), + ) + .unwrap(); let pool = create_thread_pool_for_bench(); let pq_chunk_table = diskann_async::train_pq( - MatrixView::try_from(&train_data, num_points, num_dim).unwrap(), + train_data.as_view(), 32, &mut diskann_providers::utils::create_rnd_in_tests(), &pool, @@ -74,8 +79,8 @@ async fn test_sift_256_vectors_with_quant_vectors() { .unwrap(); let provider_params = DefaultProviderParameters::simple( - num_points, - num_dim, + train_data.nrows(), + train_data.ncols(), Metric::L2, conf.max_degree_u32().get(), ); diff --git a/diskann-providers/benches/benchmarks_iai/copy_aligned_data_bench_iai.rs b/diskann-providers/benches/benchmarks_iai/copy_aligned_data_bench_iai.rs index 20dd2f76a..4d85f2b4b 100644 --- a/diskann-providers/benches/benchmarks_iai/copy_aligned_data_bench_iai.rs +++ b/diskann-providers/benches/benchmarks_iai/copy_aligned_data_bench_iai.rs @@ -9,8 +9,9 @@ use diskann::ANNResult; use diskann_providers::{ model::PQCompressedData, storage::{StorageReadProvider, StorageWriteProvider, VirtualStorageProvider}, - utils::{copy_aligned_data, write_metadata}, + utils::copy_aligned_data, }; +use diskann_utils::io::Metadata; use iai_callgrind::black_box; use rand::Rng; use tempfile::TempDir; @@ -55,7 +56,7 @@ fn generate_random_data( dims: usize, ) -> ANNResult<()> { let mut writer = BufWriter::new(writer); - write_metadata(&mut writer, npts, dims)?; + Metadata::new(npts, dims)?.write(&mut writer)?; let mut rng = diskann_providers::utils::create_rnd_in_tests(); let data: Vec = (0..dims).map(|_| rng.random()).collect(); diff --git a/diskann-providers/benches/benchmarks_iai/diskann_iai.rs b/diskann-providers/benches/benchmarks_iai/diskann_iai.rs index b691021a2..e29e85525 100644 --- a/diskann-providers/benches/benchmarks_iai/diskann_iai.rs +++ b/diskann-providers/benches/benchmarks_iai/diskann_iai.rs @@ -4,6 +4,7 @@ */ use diskann::provider::DefaultContext; +use diskann_providers::storage::StorageReadProvider; use diskann_providers::{ index::diskann_async, model::graph::{ @@ -14,9 +15,9 @@ use diskann_providers::{ traits::AdHoc, }, storage::FileStorageProvider, - utils::{VectorDataIterator, create_thread_pool_for_bench, file_util::load_bin}, + utils::{VectorDataIterator, create_thread_pool_for_bench}, }; -use diskann_utils::views::MatrixView; +use diskann_utils::io::read_bin; use diskann_vector::distance::Metric; use tokio::runtime::Runtime; @@ -48,12 +49,16 @@ async fn test_sift_256_vectors_with_quant_vectors() { ) .unwrap(); - let (train_data, num_points, num_dim) = - load_bin(&storage_provider, get_test_file_path(file_path).as_str(), 0).unwrap(); + let train_data = read_bin::( + &mut storage_provider + .open_reader(get_test_file_path(file_path).as_str()) + .unwrap(), + ) + .unwrap(); let pool = create_thread_pool_for_bench(); let pq_chunk_table = diskann_async::train_pq( - MatrixView::try_from(&train_data, num_points, num_dim).unwrap(), + train_data.as_view(), 32, &mut diskann_providers::utils::create_rnd_in_tests(), &pool, @@ -70,8 +75,8 @@ async fn test_sift_256_vectors_with_quant_vectors() { .unwrap(); let provider_params = DefaultProviderParameters::simple( - num_points, - num_dim, + train_data.nrows(), + train_data.ncols(), Metric::L2, conf.max_degree_u32().get(), ); diff --git a/diskann-providers/src/index/diskann_async.rs b/diskann-providers/src/index/diskann_async.rs index a7472a6ce..f3f15749e 100644 --- a/diskann-providers/src/index/diskann_async.rs +++ b/diskann-providers/src/index/diskann_async.rs @@ -206,10 +206,11 @@ pub(crate) mod tests { }, layers::BetaFilter, }, + storage::StorageReadProvider, test_utils::{ assert_range_results_exactly_match, assert_top_k_exactly_match, groundtruth, is_match, }, - utils::{self, VectorDataIterator, create_rnd_from_seed_in_tests, file_util}, + utils::{self, VectorDataIterator, create_rnd_from_seed_in_tests}, }; // Callbacks for use with `simplified_builder`. @@ -2354,9 +2355,8 @@ pub(crate) mod tests { + diskann::provider::SetElement<[f32]>, { let storage = VirtualStorageProvider::new_overlay(test_data_root()); - let (data_vec, npoints, dim) = file_util::load_bin(&storage, file, 0).unwrap(); - let data = - Arc::new(Matrix::::try_from(data_vec.into_boxed_slice(), npoints, dim).unwrap()); + let mut reader = storage.open_reader(file).unwrap(); + let data = Arc::new(diskann_utils::io::read_bin::(&mut reader).unwrap()); let rng = &mut create_rnd_from_seed_in_tests(0xe058c9c57864dd1e); let random_index = rand::Rng::random_range(rng, 0..data.nrows()); @@ -3021,13 +3021,12 @@ pub(crate) mod tests { S::PruneStrategy: Clone, { let storage = VirtualStorageProvider::new_overlay(test_data_root()); - let (train_data, npoints, dim) = file_util::load_bin(&storage, file, 0).unwrap(); - - let train_data_view = - diskann_utils::views::MatrixView::try_from(&train_data, npoints, dim).unwrap(); + let mut reader = storage.open_reader(file).unwrap(); + let train_data = diskann_utils::io::read_bin::(&mut reader).unwrap(); + let (npoints, dim) = (train_data.nrows(), train_data.ncols()); let table = train_pq( - train_data_view, + train_data.as_view(), num_pq_chunks, &mut create_rnd_from_seed_in_tests(0xe3c52ef001bc7ade), 1, @@ -3043,11 +3042,11 @@ pub(crate) mod tests { parameters, file, startpoint, - train_data_view, + train_data.as_view(), ) .await; - (index, train_data_view.to_owned()) + (index, train_data) } #[rstest] diff --git a/diskann-providers/src/model/pq/fixed_chunk_pq_table.rs b/diskann-providers/src/model/pq/fixed_chunk_pq_table.rs index ed903d8e7..0800f0574 100644 --- a/diskann-providers/src/model/pq/fixed_chunk_pq_table.rs +++ b/diskann-providers/src/model/pq/fixed_chunk_pq_table.rs @@ -763,6 +763,7 @@ mod fixed_chunk_pq_table_test { use crate::storage::{StorageReadProvider, VirtualStorageProvider}; use approx::assert_relative_eq; + use diskann::error::ErrorContext; use diskann_utils::test_data_root; use diskann_vector::{ PureDistanceFunction, @@ -771,11 +772,7 @@ mod fixed_chunk_pq_table_test { use itertools::iproduct; use super::*; - use crate::{ - common::AlignedBoxWithSlice, - model::{NUM_PQ_CENTROIDS, pq::convert_types}, - utils::{file_exists, load_bin}, - }; + use crate::{common::AlignedBoxWithSlice, model::NUM_PQ_CENTROIDS, utils::read_bin_from}; const DIM: usize = 128; @@ -1287,68 +1284,64 @@ mod fixed_chunk_pq_table_test { num_pq_chunks: &usize, storage_provider: &StorageProvider, ) -> ANNResult { - if !file_exists(storage_provider, pq_pivots_path) { - return Err(ANNError::log_pq_error( - "ERROR: PQ k-means pivot file not found.", - )); - } - - let (data, offset_num, offset_dim) = - load_bin::(&mut storage_provider.open_reader(pq_pivots_path)?, 0)?; + let mut reader = storage_provider + .open_reader(pq_pivots_path) + .with_context(|| format!("ERROR: Opening PQ k-means pivot file {}", pq_pivots_path))?; - let file_offset_data = - convert_types(&data, offset_num * offset_dim, |x: u64| x.into_usize()); - - if offset_num != 4 { + let offsets = read_bin_from::(&mut reader, 0)?; + if offsets.nrows() != 4 { return Err(ANNError::log_pq_error(format_args!( "Error reading pq_pivots file {}. \ Offsets don't contain correct metadata, \ # offsets = {}, but expecting 4.", - pq_pivots_path, offset_num + pq_pivots_path, + offsets.nrows() ))); } + let file_offset_data = offsets.map(|x| x.into_usize()); + + let pivots = read_bin_from::(&mut reader, file_offset_data[(0, 0)])?; - let (data, pq_center_num, dim) = load_bin::( - &mut storage_provider.open_reader(pq_pivots_path).unwrap(), - file_offset_data[0], - )?; - let pq_table = data.to_vec(); - if pq_center_num != NUM_PQ_CENTROIDS { + if pivots.nrows() != NUM_PQ_CENTROIDS { return Err(ANNError::log_pq_error(format_args!( "Error reading pq_pivots file {}. file_num_centers = {}, but expecting {} centers.", - pq_pivots_path, pq_center_num, NUM_PQ_CENTROIDS + pq_pivots_path, + pivots.nrows(), + NUM_PQ_CENTROIDS ))); } + let dim = pivots.ncols(); - let (data, centroid_dim, nc) = load_bin::( - &mut storage_provider.open_reader(pq_pivots_path).unwrap(), - file_offset_data[1], - )?; - let centroids = data.to_vec(); - if centroid_dim != dim || nc != 1 { + let centroids = read_bin_from::(&mut reader, file_offset_data[(1, 0)])?; + if centroids.nrows() != dim || centroids.ncols() != 1 { return Err(ANNError::log_pq_error(format_args!( "Error reading pq_pivots file {}. file_dim = {}, \ file_cols = {} but expecting {} entries in 1 dimension.", - pq_pivots_path, centroid_dim, nc, dim + pq_pivots_path, + centroids.nrows(), + centroids.ncols(), + dim ))); } - let (data, chunk_offset_num, nc) = load_bin::( - &mut storage_provider.open_reader(pq_pivots_path).unwrap(), - file_offset_data[2], - )?; - let chunk_offsets = convert_types(&data, chunk_offset_num * nc, |x: u32| x.into_usize()); - if chunk_offset_num != num_pq_chunks + 1 || nc != 1 { + let chunk_offsets_m = read_bin_from::(&mut reader, file_offset_data[(2, 0)])?; + if chunk_offsets_m.nrows() != num_pq_chunks + 1 || chunk_offsets_m.ncols() != 1 { return Err(ANNError::log_pq_error(format_args!( "Error reading pq_pivots file at chunk offsets; \ file has nr={}, nc={} but expecting nr={} and nc=1.", - chunk_offset_num, - nc, + chunk_offsets_m.nrows(), + chunk_offsets_m.ncols(), num_pq_chunks + 1 ))); } + let chunk_offsets = chunk_offsets_m.map(|x| x.into_usize()); - Ok((dim, pq_table, centroids, chunk_offsets)) + Ok(( + dim, + pivots.into_inner().into_vec(), + centroids.into_inner().into_vec(), + chunk_offsets.into_inner().into_vec(), + )) } #[test] diff --git a/diskann-providers/src/model/pq/mod.rs b/diskann-providers/src/model/pq/mod.rs index 863dcffef..010a7fef3 100644 --- a/diskann-providers/src/model/pq/mod.rs +++ b/diskann-providers/src/model/pq/mod.rs @@ -36,12 +36,3 @@ pub use generate_pivot_arguments::{GeneratePivotArguments, GeneratePivotArgument pub mod quantizer_preprocess; pub use quantizer_preprocess::quantizer_preprocess; - -/// Convert all types within `src` using the provided closure. -pub(crate) fn convert_types(src: &[T], max: usize, f: F) -> Vec -where - T: Copy, - F: Fn(T) -> U, -{ - src.iter().copied().take(max).map(f).collect() -} diff --git a/diskann-providers/src/model/pq/pq_construction.rs b/diskann-providers/src/model/pq/pq_construction.rs index a79f95fc1..77d9f9333 100644 --- a/diskann-providers/src/model/pq/pq_construction.rs +++ b/diskann-providers/src/model/pq/pq_construction.rs @@ -21,7 +21,10 @@ use diskann_quantization::{ CompressInto, product::{BasicTableView, TransposedTable, train::TrainQuantizer}, }; -use diskann_utils::views::{MatrixView, MutMatrixView}; +use diskann_utils::{ + io::Metadata, + views::{MatrixView, MutMatrixView}, +}; use rand::{Rng, distr::Distribution}; use rayon::prelude::*; use tracing::info; @@ -32,7 +35,7 @@ use crate::{ storage::PQStorage, utils::{ AsThreadPool, BridgeErr, ParallelIteratorInPool, RandomProvider, Timer, - create_rnd_provider_from_seed, k_means_clustering, read_metadata, run_lloyds, + create_rnd_provider_from_seed, k_means_clustering, run_lloyds, }, }; @@ -702,8 +705,7 @@ where storage_provider.create_for_write(pq_storage.get_compressed_data_path())? }; - let metadata = read_metadata(uncompressed_data_reader)?; - let (num_points, dim) = (metadata.npoints, metadata.ndims); + let (num_points, dim) = Metadata::read(uncompressed_data_reader)?.into_dims(); let mut full_pivot_data: Vec; let centroid: Vec; @@ -1010,15 +1012,15 @@ mod pq_test { use diskann_utils::test_data_root; use rand_distr::{Distribution, Uniform}; use rstest::rstest; - use vfs::{MemoryFS, OverlayFS}; + use vfs::OverlayFS; use super::*; use crate::{ model::{ FixedChunkPQTable, - pq::{METADATA_SIZE, convert_types, debug}, + pq::{METADATA_SIZE, debug}, }, - utils::{ParallelIteratorInPool, create_thread_pool_for_test, load_bin}, + utils::{ParallelIteratorInPool, create_thread_pool_for_test, read_bin_from}, }; #[test] @@ -1037,7 +1039,6 @@ mod pq_test { #[test] fn generate_pq_pivots_test() { let storage_provider = VirtualStorageProvider::new_memory(); - type ReaderType = as StorageReadProvider>::Reader; let pivot_file_name = "/generate_pq_pivots_test3.bin"; let compressed_file_name = "/compressed2.bin"; @@ -1047,6 +1048,7 @@ mod pq_test { compressed_file_name, Some(pq_training_file_name), ); + let mut train_data: Vec = vec![ 1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 2.0f32, 2.0f32, 2.0f32, 2.0f32, 2.0f32, 2.0f32, 2.0f32, 2.0f32, 2.1f32, 2.1f32, 2.1f32, 2.1f32, 2.1f32, 2.1f32, @@ -1064,57 +1066,40 @@ mod pq_test { ) .unwrap(); - let (data, nr, nc) = load_bin::( - &mut storage_provider.open_reader(pivot_file_name).unwrap(), - 0, - ) - .unwrap(); - let file_offset_data = convert_types(&data, nr * nc, |x: u64| x.into_usize()); - assert_eq!(file_offset_data[0], METADATA_SIZE); - assert_eq!(nr, 4); - assert_eq!(nc, 1); - - let (data, nr, nc) = load_bin::( - &mut storage_provider.open_reader(pivot_file_name).unwrap(), - file_offset_data[0], - ) - .unwrap(); + let mut reader = storage_provider.open_reader(pivot_file_name).unwrap(); + let offsets = read_bin_from::(&mut reader, 0).unwrap(); + let file_offset_data = offsets.map(|x| x.into_usize()); + assert_eq!(file_offset_data[(0, 0)], METADATA_SIZE); + assert_eq!(offsets.nrows(), 4); + assert_eq!(offsets.ncols(), 1); - let full_pivot_data = data.to_vec(); - assert_eq!(full_pivot_data.len(), 16); - assert_eq!(nr, 2); - assert_eq!(nc, 8); + let pivots = read_bin_from::(&mut reader, file_offset_data[(0, 0)]).unwrap(); - let (data, nr, nc) = load_bin::( - &mut storage_provider.open_reader(pivot_file_name).unwrap(), - file_offset_data[1], - ) - .unwrap(); - let centroid = data.to_vec(); + assert_eq!(pivots.as_slice().len(), 16); + assert_eq!(pivots.nrows(), 2); + assert_eq!(pivots.ncols(), 8); + + let centroid = read_bin_from::(&mut reader, file_offset_data[(1, 0)]).unwrap(); assert_eq!( - centroid[0], + centroid[(0, 0)], (1.0f32 + 2.0f32 + 2.1f32 + 2.2f32 + 100.0f32) / 5.0f32 ); - assert_eq!(nr, 8); - assert_eq!(nc, 1); - - let (data, nr, nc) = load_bin::( - &mut storage_provider.open_reader(pivot_file_name).unwrap(), - file_offset_data[2], - ) - .unwrap(); - let chunk_offsets = convert_types(&data, nr * nc, |x: u32| x.into_usize()); - assert_eq!(chunk_offsets[0], 0); - assert_eq!(chunk_offsets[1], 4); - assert_eq!(chunk_offsets[2], 8); - assert_eq!(nr, 3); - assert_eq!(nc, 1); + assert_eq!(centroid.nrows(), 8); + assert_eq!(centroid.ncols(), 1); + + let chunk_offsets = read_bin_from::(&mut reader, file_offset_data[(2, 0)]) + .unwrap() + .map(|x| x.into_usize()); + assert_eq!(chunk_offsets[(0, 0)], 0); + assert_eq!(chunk_offsets[(1, 0)], 4); + assert_eq!(chunk_offsets[(2, 0)], 8); + assert_eq!(chunk_offsets.nrows(), 3); + assert_eq!(chunk_offsets.ncols(), 1); } #[test] fn generate_optimized_pq_pivots_test() { let storage_provider = VirtualStorageProvider::new_memory(); - type ReaderType = as StorageReadProvider>::Reader; let pivot_file_name = "/generate_pq_pivots_test3.bin"; let compressed_file_name = "/compressed2.bin"; @@ -1124,6 +1109,7 @@ mod pq_test { compressed_file_name, Some(pq_training_file_name), ); + let mut train_data: Vec = vec![ 1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 2.0f32, 2.0f32, 2.0f32, 2.0f32, 2.0f32, 2.0f32, 2.0f32, 2.0f32, 2.1f32, 2.1f32, 2.1f32, 2.1f32, 2.1f32, 2.1f32, @@ -1141,51 +1127,35 @@ mod pq_test { ) .unwrap(); - let (data, nr, nc) = load_bin::( - &mut storage_provider.open_reader(pivot_file_name).unwrap(), - 0, - ) - .unwrap(); - let file_offset_data = convert_types(&data, nr * nc, |x: u64| x.into_usize()); - assert_eq!(file_offset_data[0], METADATA_SIZE); - assert_eq!(nr, 4); - assert_eq!(nc, 1); - - let (data, nr, nc) = load_bin::( - &mut storage_provider.open_reader(pivot_file_name).unwrap(), - file_offset_data[0], - ) - .unwrap(); + let mut reader = storage_provider.open_reader(pivot_file_name).unwrap(); + let offsets = read_bin_from::(&mut reader, 0).unwrap(); + let file_offset_data = offsets.map(|x| x.into_usize()); + assert_eq!(file_offset_data[(0, 0)], METADATA_SIZE); + assert_eq!(offsets.nrows(), 4); + assert_eq!(offsets.ncols(), 1); - let full_pivot_data = data.to_vec(); - assert_eq!(full_pivot_data.len(), 16); - assert_eq!(nr, 2); - assert_eq!(nc, 8); + let pivots = read_bin_from::(&mut reader, file_offset_data[(0, 0)]).unwrap(); - let (data, nr, nc) = load_bin::( - &mut storage_provider.open_reader(pivot_file_name).unwrap(), - file_offset_data[1], - ) - .unwrap(); - let centroid = data.to_vec(); + assert_eq!(pivots.as_slice().len(), 16); + assert_eq!(pivots.nrows(), 2); + assert_eq!(pivots.ncols(), 8); + + let centroid = read_bin_from::(&mut reader, file_offset_data[(1, 0)]).unwrap(); assert_eq!( - centroid[0], + centroid[(0, 0)], (1.0f32 + 2.0f32 + 2.1f32 + 2.2f32 + 100.0f32) / 5.0f32 ); - assert_eq!(nr, 8); - assert_eq!(nc, 1); - - let (data, nr, nc) = load_bin::( - &mut storage_provider.open_reader(pivot_file_name).unwrap(), - file_offset_data[2], - ) - .unwrap(); - let chunk_offsets = convert_types(&data, nr * nc, |x: u32| x.into_usize()); - assert_eq!(chunk_offsets[0], 0); - assert_eq!(chunk_offsets[1], 4); - assert_eq!(chunk_offsets[2], 8); - assert_eq!(nr, 3); - assert_eq!(nc, 1); + assert_eq!(centroid.nrows(), 8); + assert_eq!(centroid.ncols(), 1); + + let chunk_offsets = read_bin_from::(&mut reader, file_offset_data[(2, 0)]) + .unwrap() + .map(|x| x.into_usize()); + assert_eq!(chunk_offsets[(0, 0)], 0); + assert_eq!(chunk_offsets[(1, 0)], 4); + assert_eq!(chunk_offsets[(2, 0)], 8); + assert_eq!(chunk_offsets.nrows(), 3); + assert_eq!(chunk_offsets.ncols(), 1); } #[rstest] @@ -1317,17 +1287,17 @@ mod pq_test { &pool, ) .unwrap(); - let (data, nr, nc) = load_bin::( + let compressed = read_bin_from::( &mut storage_provider .open_reader(pq_compressed_vectors_path) .unwrap(), 0, ) .unwrap(); - assert_eq!(nr, 5); - assert_eq!(nc, 2); - assert_eq!(data[0], data[2]); - assert_ne!(data[0], data[8]); + assert_eq!(compressed.nrows(), 5); + assert_eq!(compressed.ncols(), 2); + assert_eq!(compressed[(0, 0)], compressed[(1, 0)]); + assert_ne!(compressed[(0, 0)], compressed[(4, 0)]); storage_provider.delete(data_file).unwrap(); storage_provider.delete(pq_pivots_path).unwrap(); @@ -1506,14 +1476,13 @@ mod pq_test { }); // use pq generated by original function as the gt - let (original_pq_data, _nr, _nc) = - load_bin:: as StorageReadProvider>::Reader>( - &mut storage_provider - .open_reader(pq_compressed_vectors_path) - .unwrap(), - 0, - ) - .unwrap(); + let original_pq_data = read_bin_from::( + &mut storage_provider + .open_reader(pq_compressed_vectors_path) + .unwrap(), + 0, + ) + .unwrap(); let membuf_view = MatrixView::try_from(membuf_pq_data.as_slice(), num_train, num_pq_chunks).unwrap(); @@ -1672,7 +1641,6 @@ mod pq_test { fn pq_end_to_end_validation_with_codebook_test() { // Creates a new filesystem using a read/write MemoryFS with PhysicalFS as a fall-back read-only filesystem. let storage_provider = VirtualStorageProvider::new_overlay(test_data_root()); - type ReaderType = as StorageReadProvider>::Reader; let data_file = "/sift/siftsmall_learn.bin"; let pq_pivots_path = "/sift/siftsmall_learn_pq_pivots.bin"; @@ -1694,23 +1662,19 @@ mod pq_test { ) .expect("Failed to generate quantized data"); - let (data, nr, nc) = load_bin::( + let data = read_bin_from::( &mut storage_provider .open_reader(pq_compressed_vectors_path) .unwrap(), 0, ) .unwrap(); - let (gt_data, gt_nr, gt_nc) = load_bin::( + let gt_data = read_bin_from::( &mut storage_provider.open_reader(ground_truth_path).unwrap(), 0, ) .unwrap(); - assert_eq!(nr, gt_nr); - assert_eq!(nc, gt_nc); - for i in 0..data.len() { - assert_eq!(data[i], gt_data[i]); - } + assert_eq!(data, gt_data); } #[test] diff --git a/diskann-providers/src/storage/bin.rs b/diskann-providers/src/storage/bin.rs index 8f528c159..bf1f89d14 100644 --- a/diskann-providers/src/storage/bin.rs +++ b/diskann-providers/src/storage/bin.rs @@ -11,11 +11,9 @@ use diskann::{ ANNError, ANNResult, utils::{IntoUsize, VectorRepr}, }; +use diskann_utils::io::Metadata; -use crate::{ - model::graph::traits::AdHoc, - utils::{load_metadata_from_file, write_metadata}, -}; +use crate::{model::graph::traits::AdHoc, utils::load_metadata_from_file}; /// An simplified adaptor interface for allowing providers to use and [`load_graph`]. /// @@ -146,12 +144,12 @@ where tracing::info!( "Loading {} vectors with dimension {} from storage system {} into dataset...", - metadata.npoints, - metadata.ndims, + metadata.npoints(), + metadata.ndims(), path ); - let mut data = create(metadata.npoints, metadata.ndims)?; + let mut data = create(metadata.npoints(), metadata.ndims())?; let itr = crate::utils::VectorDataIterator::<_, AdHoc>::new(path, None, provider)?; for (i, (vector, _)) in itr.enumerate() { data.set_data(i.into_usize(), &vector)?; @@ -186,10 +184,9 @@ where let dim = data.dim(); let mut writer = provider.create_for_write(path)?; - let mut points_written: u32 = 0; + Metadata::new(points_written, dim)?.write(&mut writer)?; - write_metadata(&mut writer, points_written, dim)?; for i in 0..total { // The binding provides a stable address for the return item of `get_data`, // regardless of if `get_data` returns a borrowed slice or a copy. diff --git a/diskann-providers/src/storage/index_storage.rs b/diskann-providers/src/storage/index_storage.rs index 924e3eb34..b310e43a0 100644 --- a/diskann-providers/src/storage/index_storage.rs +++ b/diskann-providers/src/storage/index_storage.rs @@ -222,10 +222,7 @@ mod tests { provider::{Accessor, SetElement}, utils::{IntoUsize, ONE}, }; - use diskann_utils::{ - Reborrow, test_data_root, - views::{Matrix, MatrixView}, - }; + use diskann_utils::{Reborrow, test_data_root, views::MatrixView}; use diskann_vector::distance::Metric; use super::*; @@ -236,7 +233,7 @@ mod tests { common::{FullPrecision, NoDeletes, NoStore, TableBasedDeletes}, inmem::{self}, }, - utils::{create_rnd_from_seed_in_tests, file_util}, + utils::create_rnd_from_seed_in_tests, }; async fn build_index( @@ -271,8 +268,8 @@ mod tests { let file_path = "/sift/siftsmall_learn_256pts.fbin"; let train_data = { let storage = VirtualStorageProvider::new_overlay(test_data_root()); - let (train_data, npoints, dim) = file_util::load_bin(&storage, file_path, 0).unwrap(); - Matrix::::try_from(train_data.into(), npoints, dim).unwrap() + let mut reader = storage.open_reader(file_path).unwrap(); + diskann_utils::io::read_bin::(&mut reader).unwrap() }; let pq_bytes = 8; diff --git a/diskann-providers/src/storage/pq_storage.rs b/diskann-providers/src/storage/pq_storage.rs index 3445fabc1..e4ab781f6 100644 --- a/diskann-providers/src/storage/pq_storage.rs +++ b/diskann-providers/src/storage/pq_storage.rs @@ -2,25 +2,23 @@ * Copyright (c) Microsoft Corporation. * Licensed under the MIT license. */ -use std::io::Write; +use std::io::{Seek, SeekFrom, Write}; use super::{StorageReadProvider, StorageWriteProvider}; use diskann::{ ANNError, ANNResult, utils::{IntoUsize, VectorRepr}, }; +use diskann_utils::{ + io::{Metadata, write_bin}, + views::MatrixView, +}; use rand::Rng; use tracing::info; use crate::{ - model::{ - FixedChunkPQTable, NUM_PQ_CENTROIDS, PQCompressedData, - pq::{METADATA_SIZE, convert_types}, - }, - utils::{ - copy_aligned_data, gen_random_slice, load_bin, save_bin_f32, save_bin_u32, save_bin_u64, - write_metadata, - }, + model::{FixedChunkPQTable, NUM_PQ_CENTROIDS, PQCompressedData, pq::METADATA_SIZE}, + utils::{copy_aligned_data, gen_random_slice, read_bin_from, write_bin_from}, }; // Create types to make return values easier to understand @@ -60,7 +58,7 @@ impl PQStorage { where Storage: StorageWriteProvider, { - write_metadata(writer, npts, pq_chunk)?; + Metadata::new(npts, pq_chunk)?.write(writer)?; Ok(()) } @@ -99,30 +97,32 @@ impl PQStorage { let mut cumul_bytes: Vec = vec![0; 4]; cumul_bytes[0] = METADATA_SIZE; let writer = &mut storage_provider.create_for_write(&self.pivot_data_path)?; - // Write Pq centroids vectors at offset METADATA_SIZE(4096) - cumul_bytes[1] = cumul_bytes[0] - + save_bin_f32(writer, full_pivot_data, num_centers, dim, cumul_bytes[0])?; - // Write THE CENTROID of PQ CENTROID vectors at offset METADATA_SIZE(4096) + size of centroids vectors. - cumul_bytes[2] = cumul_bytes[1] + save_bin_f32(writer, centroid, dim, 1, cumul_bytes[1])?; + // Skip past the offset table — we'll write it last once we know all offsets. + writer.seek(SeekFrom::Start(cumul_bytes[0] as u64))?; + + // Write PQ centroid vectors + let pivot_view = MatrixView::try_from(full_pivot_data, num_centers, dim)?; + cumul_bytes[1] = cumul_bytes[0] + write_bin(pivot_view, writer)?; - // Because the writer only can write u32, u64 but not usize, so we need to convert the type first. - let chunk_offsets_u32 = - convert_types(chunk_offsets, chunk_offsets.len(), |x: usize| x as u32); + // Write the centroid of PQ centroid vectors + cumul_bytes[2] = cumul_bytes[1] + write_bin(MatrixView::column_vector(centroid), writer)?; - // Write PQ chunk offsets at offset METADATA_SIZE(4096) + size of PQ centroids vectors + size of the centroid vector. + // Write PQ chunk offsets + let chunk_offsets_u32: Vec = chunk_offsets.iter().map(|&x| x as u32).collect(); cumul_bytes[3] = cumul_bytes[2] - + save_bin_u32( + + write_bin( + MatrixView::column_vector(chunk_offsets_u32.as_slice()), writer, - &chunk_offsets_u32, - chunk_offsets_u32.len(), - 1, - cumul_bytes[2], )?; - // Write metadata at offset 0. - let cumul_bytes_u64 = convert_types(&cumul_bytes, cumul_bytes.len(), |x: usize| x as u64); - save_bin_u64(writer, &cumul_bytes_u64, cumul_bytes_u64.len(), 1, 0)?; + // Seek back to offset 0 and write the offset table. + let cumul_bytes_u64: Vec = cumul_bytes.iter().map(|&x| x as u64).collect(); + write_bin_from( + MatrixView::column_vector(cumul_bytes_u64.as_slice()), + writer, + 0, + )?; writer.flush()?; Ok(()) @@ -151,7 +151,8 @@ impl PQStorage { let writer = &mut storage_provider.create_for_write(&self.get_rotation_matrix_path())?; // Save the rotation matrix - save_bin_f32(writer, rotation_matrix, dimension, dimension, 0)?; + let view = MatrixView::try_from(rotation_matrix, dimension, dimension)?; + write_bin(view, writer)?; Ok(()) } @@ -171,9 +172,8 @@ impl PQStorage { Storage: StorageReadProvider, { let reader = &mut storage_provider.open_reader(&self.pivot_data_path)?; - let (_, file_num_centers, file_dim) = - load_bin::(reader, METADATA_SIZE)?; - Ok((file_num_centers, file_dim)) + reader.seek(SeekFrom::Start(METADATA_SIZE as u64))?; + Ok(Metadata::read(reader)?.into_dims()) } pub fn load_existing_pivot_data( @@ -192,56 +192,58 @@ impl PQStorage { where Storage: StorageReadProvider, { - // Load file offset data. File saved as offset data(4*1) -> pivot data(centroid num*dim) -> centroid of dim data(dim*1) -> chunk offset data(chunksize+1*1) - // Because we only can write u64 rather than usize, so the file stored as u64 type. Need to convert to usize when use. + // Load file offset data. File layout: offset table(4*1) -> pivot data(num_centers*dim) -> centroid(dim*1) -> chunk offsets(num_chunks+1*1) let reader = &mut storage_provider.open_reader(&self.pivot_data_path)?; - let (data, offset_num, nc) = load_bin::(reader, 0)?; - let file_offset_data = convert_types(&data, offset_num * nc, |x: u64| x.into_usize()); - if offset_num != 4 { + let offsets = read_bin_from::(reader, 0)?; + if offsets.nrows() != 4 { return Err(ANNError::log_pq_error(format_args!( "Error reading pq_pivots file {}. Offsets don't contain correct \ metadata, # offsets = {}, but expecting 4.", - &self.pivot_data_path, offset_num + &self.pivot_data_path, + offsets.nrows() ))); } + let file_offset_data = offsets.map(|x| x.into_usize()); - info!(" Offset data: {:?}", file_offset_data); + info!(" Offset data: {:?}", file_offset_data.as_slice()); - let (data, pivot_num, pivot_dim) = - load_bin::(reader, file_offset_data[0])?; - let full_pivot_data = data; - if pivot_num != *num_centers || pivot_dim != *dim { + let pivots = read_bin_from::(reader, file_offset_data[(0, 0)])?; + if pivots.nrows() != *num_centers || pivots.ncols() != *dim { return Err(ANNError::log_pq_error(format_args!( "Error reading pq_pivots file {}. file_num_centers = {}, \ file_dim = {} but expecting {} centers in {} dimensions.", - &self.pivot_data_path, pivot_num, pivot_dim, num_centers, dim + &self.pivot_data_path, + pivots.nrows(), + pivots.ncols(), + num_centers, + dim ))); } - let (data, centroid_dim, nc) = - load_bin::(reader, file_offset_data[1])?; - let centroid = data; - if centroid_dim != *dim || nc != 1 { + let centroid_m = read_bin_from::(reader, file_offset_data[(1, 0)])?; + if centroid_m.nrows() != *dim || centroid_m.ncols() != 1 { return Err(ANNError::log_pq_error(format_args!( "Error reading pq_pivots file {}. file_dim = {}, \ file_cols = {} but expecting {} entries in 1 dimension.", - &self.pivot_data_path, centroid_dim, nc, dim + &self.pivot_data_path, + centroid_m.nrows(), + centroid_m.ncols(), + dim ))); } - let (data, chunk_offset_number, nc) = - load_bin::(reader, file_offset_data[2])?; - let chunk_offsets = convert_types(&data, chunk_offset_number * nc, |x: u32| x.into_usize()); - if chunk_offset_number != *num_pq_chunks + 1 || nc != 1 { + let chunk_offsets_m = read_bin_from::(reader, file_offset_data[(2, 0)])?; + if chunk_offsets_m.nrows() != *num_pq_chunks + 1 || chunk_offsets_m.ncols() != 1 { return Err(ANNError::log_pq_error(format_args!( "Error reading pq_pivots file at chunk offsets; \ - file has nr={}, nc={} but expecting nr={} and nc=2.", - chunk_offset_number, - nc, + file has nr={}, nc={} but expecting nr={} and nc=1.", + chunk_offsets_m.nrows(), + chunk_offsets_m.ncols(), num_pq_chunks + 1 ))); } + let chunk_offsets = chunk_offsets_m.map(|x| x.into_usize()); let opq_rotation_matrix = if use_opq { self.read_opq_rotation_matrix(storage_provider)? @@ -251,9 +253,9 @@ impl PQStorage { }; Ok(( - full_pivot_data, - centroid, - chunk_offsets, + pivots.into_inner().into_vec(), + centroid_m.into_inner().into_vec(), + chunk_offsets.into_inner().into_vec(), opq_rotation_matrix, )) } @@ -280,11 +282,11 @@ impl PQStorage { } }; - let (data, _, _) = load_bin::(rotation_matrix_reader, 0)?; + let data = read_bin_from::(rotation_matrix_reader, 0)?; info!("OPQ rotation matrix load complete"); - Ok(data) + Ok(data.into_inner().into_vec()) } /// Load the compressed pq dataset from file @@ -331,51 +333,54 @@ impl PQStorage { info!("Loading PQ pivots from {}...", pq_pivots); let mut reader = storage_provider.open_reader(pq_pivots)?; - let (data, offset_num, offset_dim) = load_bin::(&mut reader, 0)?; - let file_offset_data = - convert_types(&data, offset_num * offset_dim, |x: u64| x.into_usize()); - if offset_num != 4 { + let offsets = read_bin_from::(&mut reader, 0)?; + if offsets.nrows() != 4 { return Err(ANNError::log_pq_error(format_args!( "Error reading pq_pivots file {}. Offsets don't contain correct metadata, \ # offsets = {}, but expecting 4.", - pq_pivots, offset_num + pq_pivots, + offsets.nrows() ))); } + let file_offset_data = offsets.map(|x| x.into_usize()); - let (data, pivot_num, dim) = - load_bin::(&mut reader, file_offset_data[0])?; - let pq_table = data.to_vec(); - if pivot_num > NUM_PQ_CENTROIDS { + let pivots = read_bin_from::(&mut reader, file_offset_data[(0, 0)])?; + if pivots.nrows() > NUM_PQ_CENTROIDS { return Err(ANNError::log_pq_error(format_args!( "Error reading pq_pivots file {}. file_num_centers = {}, but expecting {} centers.", - pq_pivots, pivot_num, NUM_PQ_CENTROIDS + pq_pivots, + pivots.nrows(), + NUM_PQ_CENTROIDS ))); } + let dim = pivots.ncols(); - let (data, centroid_dim, nc) = - load_bin::(&mut reader, file_offset_data[1])?; - let centroids = data.to_vec(); - if centroid_dim != dim || nc != 1 { + let centroids = read_bin_from::(&mut reader, file_offset_data[(1, 0)])?; + if centroids.nrows() != dim || centroids.ncols() != 1 { return Err(ANNError::log_pq_error(format_args!( "Error reading pq_pivots file {}. file_dim = {}, file_cols = {} \ but expecting {} entries in 1 dimension.", - pq_pivots, centroid_dim, nc, dim + pq_pivots, + centroids.nrows(), + centroids.ncols(), + dim ))); } - let (data, chunk_offset_num, nc) = - load_bin::(&mut reader, file_offset_data[2])?; - let chunk_offsets = convert_types(&data, chunk_offset_num * nc, |x: u32| x.into_usize()); - if (chunk_offset_num != num_pq_chunks + 1 && num_pq_chunks as u32 != 0) || nc != 1 { + let chunk_offsets_m = read_bin_from::(&mut reader, file_offset_data[(2, 0)])?; + if (chunk_offsets_m.nrows() != num_pq_chunks + 1 && num_pq_chunks as u32 != 0) + || chunk_offsets_m.ncols() != 1 + { return Err(ANNError::log_pq_error(format_args!( "Error reading pq_pivots file at chunk offsets; file has nr={}, nc={} \ but expecting nr={} and nc=1. The expected num_pq_chunks should be \ passed as 0 if we want to infer.", - chunk_offset_num, - nc, + chunk_offsets_m.nrows(), + chunk_offsets_m.ncols(), num_pq_chunks + 1 ))); } + let chunk_offsets = chunk_offsets_m.map(|x| x.into_usize()); let opq_rotation_matrix: Option> = if storage_provider.exists(&self.get_rotation_matrix_path()) { @@ -387,9 +392,9 @@ impl PQStorage { FixedChunkPQTable::new( dim, - pq_table.into(), - centroids.into(), - chunk_offsets.into(), + pivots.into_inner(), + centroids.into_inner(), + chunk_offsets.into_inner(), opq_rotation_matrix, ) } @@ -440,7 +445,7 @@ mod pq_storage_tests { use vfs::MemoryFS; use super::*; - use crate::utils::{gen_random_slice, read_metadata}; + use crate::utils::gen_random_slice; const DATA_FILE: &str = "/sift/siftsmall_learn.bin"; const PQ_PIVOT_PATH: &str = "/sift/siftsmall_learn_pq_pivots.bin"; @@ -471,10 +476,10 @@ mod pq_storage_tests { } let mut result_reader = storage_provider.open_reader(compress_pivot_path).unwrap(); - let metadata = read_metadata(&mut result_reader).unwrap(); + let metadata = Metadata::read(&mut result_reader).unwrap(); - assert_eq!(metadata.npoints, 100); - assert_eq!(metadata.ndims, 20); + assert_eq!(metadata.npoints(), 100); + assert_eq!(metadata.ndims(), 20); storage_provider.delete(compress_pivot_path).unwrap(); } diff --git a/diskann-providers/src/utils/file_util.rs b/diskann-providers/src/utils/file_util.rs index d5bf54784..7a4a053aa 100644 --- a/diskann-providers/src/utils/file_util.rs +++ b/diskann-providers/src/utils/file_util.rs @@ -8,22 +8,19 @@ use std::{ io, - io::{BufReader, Read, Seek}, + io::{BufReader, Read}, mem::size_of, }; use crate::storage::{StorageReadProvider, StorageWriteProvider}; use byteorder::{LittleEndian, ReadBytesExt}; use diskann::{ANNError, ANNResult, utils::IntoUsize}; -use diskann_utils::views::Matrix; +use diskann_utils::{io::Metadata, views::Matrix}; use tracing::info; use crate::{ common::AlignedBoxWithSlice, - utils::{ - DatasetDto, copy_aligned_data, - storage_utils::{Metadata, read_metadata}, - }, + utils::{DatasetDto, copy_aligned_data}, }; /// Read metadata of data file. @@ -32,13 +29,13 @@ pub fn load_metadata_from_file( file_name: &str, ) -> std::io::Result { let mut file = storage_provider.open_reader(file_name)?; - read_metadata(&mut file) + Metadata::read(&mut file) } /// Read metadata from data content. Use include_bytes! marco to get reference of a byte array. pub fn load_metadata_from_bytes(bytes: &[u8]) -> std::io::Result { let mut cursor = std::io::Cursor::new(bytes); - read_metadata(&mut cursor) + Metadata::read(&mut cursor) } /// Read the deleted vertex ids from file. @@ -112,8 +109,8 @@ pub fn load_aligned_bin( file_size = storage_provider.get_length(bin_file)? as usize; let mut file = storage_provider.open_reader(bin_file)?; - let metadata = read_metadata(&mut file)?; - (npts, dim) = (metadata.npoints, metadata.ndims); + let metadata = Metadata::read(&mut file)?; + (npts, dim) = (metadata.npoints(), metadata.ndims()); } let rounded_dim = dim.next_multiple_of(8); @@ -175,32 +172,6 @@ pub fn file_exists( storage_provider.exists(filename) } -/// Read data file -/// # Arguments -/// * `bin_file` - filename where the data is -/// * `file_offset` - data offset in file -/// * `data` - information data -/// * `npts` - number of points -/// * `ndims` - point dimension -pub fn load_bin( - storage_read_provider: &StorageReader, - bin_file: &str, - file_offset: usize, -) -> std::io::Result<(Vec, usize, usize)> { - let mut reader = storage_read_provider.open_reader(bin_file)?; - reader.seek(std::io::SeekFrom::Start(file_offset as u64))?; - let metadata = read_metadata(&mut reader)?; - let (npts, dim) = (metadata.npoints, metadata.ndims); - - let size = npts * dim * std::mem::size_of::(); - let mut buf = vec![0u8; size]; - reader.read_exact(&mut buf)?; - - let data: &[T] = bytemuck::cast_slice(&buf); - - Ok((data.to_vec(), npts, dim)) -} - /// Read data file /// # Arguments /// * `bin_file` - filename where the data is @@ -286,7 +257,6 @@ mod file_util_test { use vfs::{FileSystem, MemoryFS, SeekAndWrite}; use super::*; - use crate::utils::save_bin_u64; #[test] fn get_file_size_test() { @@ -315,8 +285,8 @@ mod file_util_test { let result = load_metadata_from_bytes(&data); assert!(result.is_ok()); let metadata = result.unwrap(); - assert_eq!(metadata.npoints, 200); - assert_eq!(metadata.ndims, 128); + assert_eq!(metadata.npoints(), 200); + assert_eq!(metadata.ndims(), 128); } #[test] @@ -332,8 +302,8 @@ mod file_util_test { } match load_metadata_from_file(&storage_provider, file_name) { Ok(metadata) => { - assert!(metadata.npoints == 200); - assert!(metadata.ndims == 128); + assert!(metadata.npoints() == 200); + assert!(metadata.ndims() == 128); } Err(_e) => {} } @@ -422,60 +392,6 @@ mod file_util_test { .expect("Failed to delete file"); } - #[test] - fn load_bin_test() { - let storage_provider = VirtualStorageProvider::new_memory(); - - let file_name = "/load_bin_test"; - let data = vec![0u64, 1u64, 2u64]; - let num_pts = data.len(); - let dims = 1; - { - let mut file_write = storage_provider.create_for_write(file_name).unwrap(); - let bytes_written = save_bin_u64(&mut file_write, &data, num_pts, dims, 0).unwrap(); - assert_eq!(bytes_written, 32); - } - - let (load_data, load_num_pts, load_dims) = - load_bin::>(&storage_provider, file_name, 0) - .unwrap(); - assert_eq!(load_num_pts, num_pts); - assert_eq!(load_dims, dims); - assert_eq!(load_data, data); - storage_provider - .filesystem() - .remove_file(file_name) - .unwrap(); - } - - #[test] - fn load_bin_offset_test() { - let storage_provider = VirtualStorageProvider::new_memory(); - - let offset: usize = 32; - let file_name = "/load_bin_offset_test"; - let data = vec![0u64, 1u64, 2u64]; - let num_pts = data.len(); - let dims = 1; - { - let mut file_write = storage_provider.create_for_write(file_name).unwrap(); - let bytes_written = - save_bin_u64(&mut file_write, &data, num_pts, dims, offset).unwrap(); - assert_eq!(bytes_written, 32); - } - - let (load_data, load_num_pts, load_dims) = - load_bin::>(&storage_provider, file_name, offset) - .unwrap(); - assert_eq!(load_num_pts, num_pts); - assert_eq!(load_dims, dims); - assert_eq!(load_data, data); - storage_provider - .filesystem() - .remove_file(file_name) - .unwrap(); - } - #[test] fn load_multivec_bin_test() { let storage_provider = VirtualStorageProvider::new_memory(); diff --git a/diskann-providers/src/utils/medoid.rs b/diskann-providers/src/utils/medoid.rs index 342dcb691..e9c0508d3 100644 --- a/diskann-providers/src/utils/medoid.rs +++ b/diskann-providers/src/utils/medoid.rs @@ -226,15 +226,17 @@ where let metadata = load_metadata_from_file(reader, path)?; // Calculate sampling rate based on max_sample_size - let sampling_rate = if max_sample_size == 0 || max_sample_size >= metadata.npoints { + let sampling_rate = if max_sample_size == 0 || max_sample_size >= metadata.npoints() { 1.0 // Use all points } else { - max_sample_size as f64 / metadata.npoints as f64 + max_sample_size as f64 / metadata.npoints() as f64 }; info!( "Finding medoid from {} points with max max_sample_size: {}, sampling_rate: {:.2}", - metadata.npoints, max_sample_size, sampling_rate + metadata.npoints(), + max_sample_size, + sampling_rate ); let centroid = calculate_centroid_with_sampling::(path, reader, sampling_rate, rng)?; @@ -252,7 +254,6 @@ mod tests { use std::{io::Write, num::NonZeroUsize}; use crate::storage::VirtualStorageProvider; - use crate::utils::write_metadata; use diskann::utils::VectorRepr; use diskann_quantization::{ CompressInto, @@ -260,7 +261,7 @@ mod tests { minmax::{DataMutRef, MinMaxQuantizer}, num::Positive, }; - use diskann_utils::ReborrowMut; + use diskann_utils::{ReborrowMut, io::Metadata}; use rand::{SeedableRng, rngs::StdRng}; use vfs::{FileSystem, MemoryFS}; @@ -283,7 +284,9 @@ mod tests { vectors[0].len() }; - write_metadata(&mut file, num_points, dimension)?; + Metadata::new(num_points, dimension) + .unwrap() + .write(&mut file)?; // Write vectors for vector in vectors { diff --git a/diskann-providers/src/utils/mod.rs b/diskann-providers/src/utils/mod.rs index 5c84f156c..87ff9df0b 100644 --- a/diskann-providers/src/utils/mod.rs +++ b/diskann-providers/src/utils/mod.rs @@ -67,9 +67,8 @@ pub mod generate_synthetic_labels_utils; mod storage_utils; pub use storage_utils::{ - Metadata, MetadataError, copy_aligned_data, load_bin, load_vector_ids, read_metadata, - save_bin_f32, save_bin_u32, save_bin_u64, save_bytes, save_data_in_base_dimensions, - write_metadata, + copy_aligned_data, load_vector_ids, read_bin_from, save_bytes, save_data_in_base_dimensions, + write_bin_from, }; mod sampling; diff --git a/diskann-providers/src/utils/normalizing_util.rs b/diskann-providers/src/utils/normalizing_util.rs index 8cbad1103..0ff98c9e6 100644 --- a/diskann-providers/src/utils/normalizing_util.rs +++ b/diskann-providers/src/utils/normalizing_util.rs @@ -9,11 +9,12 @@ use std::{ use crate::storage::{StorageReadProvider, StorageWriteProvider}; use diskann::{ANNError, ANNResult}; +use diskann_utils::io::Metadata; use diskann_vector::{Norm, norm::FastL2Norm}; use rayon::prelude::*; use tracing::info; -use super::{AsThreadPool, ParallelIteratorInPool, RayonThreadPool, read_metadata, write_metadata}; +use super::{AsThreadPool, ParallelIteratorInPool, RayonThreadPool}; use crate::forward_threadpool; /// The normalizing_utils derives from the DiskANN c++ utils. @@ -30,10 +31,9 @@ where let mut reader = BufReader::new(storage_provider.open_reader(in_file_name)?); let mut writer = BufWriter::new(storage_provider.create_for_write(out_file_name)?); - let metadata = read_metadata(&mut reader)?; - let (npts, ndims) = (metadata.npoints, metadata.ndims); - - write_metadata(&mut writer, npts, ndims)?; + let metadata = Metadata::read(&mut reader)?; + let (npts, ndims) = (metadata.npoints(), metadata.ndims()); + metadata.write(&mut writer)?; info!("Normalizing FLOAT vectors in file: {}", in_file_name); info!("Dataset: #pts = {}, # dims = {}", npts, ndims); @@ -139,10 +139,10 @@ pub fn normalize_data_internal( #[cfg(test)] mod normalizing_utils_test { use crate::storage::{StorageReadProvider, VirtualStorageProvider}; - use diskann_utils::test_data_root; + use diskann_utils::{io::read_bin, test_data_root}; use super::*; - use crate::utils::{create_thread_pool_for_test, storage_utils::*}; + use crate::utils::create_thread_pool_for_test; #[test] fn test_normalize_data_file() { @@ -154,21 +154,15 @@ mod normalizing_utils_test { let pool = create_thread_pool_for_test(); normalize_data_file(in_file_name, out_file_name, &storage_provider, &pool).unwrap(); - let (load_data, load_num_pts, load_dims) = - load_bin::(&mut storage_provider.open_reader(out_file_name).unwrap(), 0) - .unwrap(); + let load_data = + read_bin::(&mut storage_provider.open_reader(out_file_name).unwrap()).unwrap(); storage_provider .delete(out_file_name) .expect("Should be able to delete temp file"); - let (norm_data, norm_num_pts, norm_dims) = load_bin::( - &mut storage_provider.open_reader(norm_file_name).unwrap(), - 0, - ) - .unwrap(); + let norm_data = + read_bin::(&mut storage_provider.open_reader(norm_file_name).unwrap()).unwrap(); - assert_eq!(load_num_pts, norm_num_pts); - assert_eq!(load_dims, norm_dims); assert_eq!(load_data, norm_data); } } diff --git a/diskann-providers/src/utils/storage_utils.rs b/diskann-providers/src/utils/storage_utils.rs index 1bba7c3ce..ddb0f432d 100644 --- a/diskann-providers/src/utils/storage_utils.rs +++ b/diskann-providers/src/utils/storage_utils.rs @@ -6,106 +6,20 @@ //! Utilities for reading and writing data from the storage layer with generic reader/writer. //! This is a replacement for the functions file_util.rs with generic reader/writer. -use std::{ - convert::TryInto, - io::{BufReader, Read, Seek, SeekFrom, Write}, - mem, -}; +use std::io::{BufReader, Read, Seek, Write}; use bytemuck::Pod; -use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; -use diskann::{ANNError, ANNErrorKind, ANNResult, utils::read_exact_into}; -use diskann_wide::{LoHi, SplitJoin}; -use thiserror::Error; -use tracing::info; +use byteorder::{LittleEndian, ReadBytesExt}; +use diskann::ANNResult; +use diskann_utils::{ + io::{Metadata, ReadBinError, SaveBinError, read_bin, write_bin}, + views::{Matrix, MatrixView}, +}; use crate::utils::DatasetDto; const DEFAULT_BUF_SIZE: usize = 1024 * 1024; -/// Metadata containing number of points and dimensions -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub struct Metadata { - pub npoints: usize, - pub ndims: usize, -} - -/// Error type for metadata I/O operations -#[derive(Debug, Error)] -pub enum MetadataError { - #[error("num points conversion")] - NumPoints(#[source] T), - #[error("dim conversion")] - Dim(#[source] U), - #[error("writing binary results")] - Write(#[source] std::io::Error), -} - -impl From> for ANNError -where - T: std::error::Error + Send + Sync + 'static, - U: std::error::Error + Send + Sync + 'static, -{ - #[track_caller] - fn from(err: MetadataError) -> Self { - ANNError::new(ANNErrorKind::IOError, err) - } -} - -/// Read binary metadata header (number of points and dimension) from a reader. -/// -/// Reads 8 bytes total: -/// - First 4 bytes: number of points (u32, little-endian) -/// - Next 4 bytes: number of dimensions (u32, little-endian) -/// -/// # Returns -/// * `Ok(Metadata)` - Metadata containing number of points and dimensions -/// * `Err(io::Error)` - If reading fails -pub fn read_metadata(reader: &mut Reader) -> std::io::Result { - let raw = reader.read_u64::()?; - let bytes: [u8; 8] = bytemuck::cast(raw); - let LoHi { - lo: npts_bytes, - hi: ndims_bytes, - } = bytes.split(); - let npoints = u32::from_le_bytes(npts_bytes) as usize; - let ndims = u32::from_le_bytes(ndims_bytes) as usize; - Ok(Metadata { npoints, ndims }) -} - -/// Write binary metadata header (number of points and dimension) to a writer. -/// -/// Writes 8 bytes total: -/// - First 4 bytes: number of points (u32, little-endian) -/// - Next 4 bytes: number of dimensions (u32, little-endian) -/// -/// This unified function accepts both `u32` and `usize` values, handling conversion appropriately: -/// - `u32` values are written directly (no conversion overhead) -/// - `usize` values are safely converted using `TryInto` (returns error on overflow) -/// -/// # Returns -/// * `Ok(usize)` - Number of bytes written (always 8) -/// * `Err(MetadataError)` - If writing fails or conversion fails (usize > u32::MAX) -pub fn write_metadata( - writer: &mut Writer, - npts: N, - ndims: D, -) -> Result> -where - N: TryInto, - D: TryInto, - N::Error: std::error::Error + 'static, - D::Error: std::error::Error + 'static, -{ - let npts_u32 = npts.try_into().map_err(MetadataError::NumPoints)?; - let ndims_u32 = ndims.try_into().map_err(MetadataError::Dim)?; - - let bytes: [u8; 8] = LoHi::new(npts_u32.to_le_bytes(), ndims_u32.to_le_bytes()).join(); - writer.write_all(&bytes).map_err(MetadataError::Write)?; - - Ok(2 * std::mem::size_of::()) -} - /// Load a list of vector ids from the stream. pub fn load_vector_ids(reader: &mut Reader) -> std::io::Result<(usize, Vec)> { // The first 4 bytes are the number of vector ids. @@ -142,8 +56,8 @@ pub fn copy_aligned_data( ) -> std::io::Result<(usize, usize)> { let mut reader = BufReader::with_capacity(DEFAULT_BUF_SIZE, reader); - let metadata = read_metadata(&mut reader)?; - let (npts, dim) = (metadata.npoints, metadata.ndims); + let metadata = Metadata::read(&mut reader)?; + let (npts, dim) = metadata.into_dims(); let rounded_dim = dataset_dto.rounded_dim; let offset = pts_offset * rounded_dim; @@ -167,24 +81,22 @@ pub fn copy_aligned_data( /// # Arguments /// * `reader` - a stream reader. /// * `offset` - start offset of the data. -pub fn load_bin( - reader: &mut Reader, +pub fn read_bin_from( + reader: &mut (impl Read + Seek), offset: usize, -) -> std::io::Result<(Vec, usize, usize)> { - let mut reader = BufReader::new(reader); +) -> Result, ReadBinError> { reader.seek(std::io::SeekFrom::Start(offset as u64))?; - let metadata = read_metadata(&mut reader)?; - let (npts, dim) = (metadata.npoints, metadata.ndims); - - let size = npts * dim * std::mem::size_of::(); - - let buf: Vec = read_exact_into(&mut reader, npts * dim)?; - info!( - "bin: #pts = {}, #dims = {}, offset = {} size = {}B", - npts, dim, offset, size - ); + read_bin(reader) +} - Ok((buf, npts, dim)) +/// Write a matrix at the given byte offset. +pub fn write_bin_from( + data: MatrixView<'_, T>, + writer: &mut (impl Write + Seek), + offset: usize, +) -> Result { + writer.seek(std::io::SeekFrom::Start(offset as u64))?; + write_bin(data, writer) } /// Save the byte array to storage. @@ -196,7 +108,7 @@ pub fn save_bytes( offset: usize, ) -> ANNResult { writer.seek(std::io::SeekFrom::Start(offset as u64))?; - write_metadata(writer, npts, ndims)?; + Metadata::new(npts, ndims)?.write(writer)?; writer.write_all(data)?; writer.flush()?; @@ -222,7 +134,7 @@ pub fn save_data_in_base_dimensions() + npts * ndims * (std::mem::size_of::()); writer.seek(std::io::SeekFrom::Start(offset as u64))?; - write_metadata(writer, npts, ndims)?; + Metadata::new(npts, ndims)?.write(writer)?; for i in 0..npts { let start = i * aligned_dim; @@ -236,103 +148,16 @@ pub fn save_data_in_base_dimensions { - /// Write data into the storage system. - pub fn $name( - writer: &mut W, - data: &[$t], - num_pts: usize, - dims: usize, - offset: usize, - ) -> ANNResult { - writer.seek(SeekFrom::Start(offset as u64))?; - let bytes_written = num_pts * dims * mem::size_of::<$t>() + 2 * mem::size_of::(); - - write_metadata(writer, num_pts, dims)?; - info!( - "bin: #pts = {}, #dims = {}, size = {}B", - num_pts, dims, bytes_written - ); - - for item in data.iter() { - writer.$write_func::(*item)?; - } - - writer.flush()?; - - info!("Finished writing bin."); - Ok(bytes_written) - } - }; -} - -save_bin!(save_bin_f32, f32, write_f32); -save_bin!(save_bin_u64, u64, write_u64); -save_bin!(save_bin_u32, u32, write_u32); - #[cfg(test)] mod storage_util_test { use crate::storage::{StorageReadProvider, StorageWriteProvider, VirtualStorageProvider}; + use byteorder::{LittleEndian, WriteBytesExt}; + use std::io::SeekFrom; use tempfile::tempfile; use super::*; pub const DIM_8: usize = 8; - #[test] - fn read_metadata_test() { - let file_name = "/test_read_metadata_test.bin"; - let data = [200, 0, 0, 0, 128, 0, 0, 0]; // 200 and 128 in little endian bytes (u32) - let storage_provider = VirtualStorageProvider::new_memory(); - { - let mut file = storage_provider - .create_for_write(file_name) - .expect("Could not create file"); - file.write_all(&data) - .expect("Should be able to write sample file"); - } - - let mut reader = storage_provider.open_reader(file_name).unwrap(); - match read_metadata(&mut reader) { - Ok(metadata) => { - assert_eq!(metadata.npoints, 200); - assert_eq!(metadata.ndims, 128); - } - Err(_e) => {} - } - storage_provider - .delete(file_name) - .expect("Should be able to delete sample file"); - } - - #[test] - fn read_metadata_i32_compatibility_test() { - // Test that read_metadata (u32) can read data written as i32 - let file_name = "/test_read_metadata_i32_compat.bin"; - let npts = 200i32; - let dims = 128i32; - let storage_provider = VirtualStorageProvider::new_memory(); - { - let mut file = storage_provider - .create_for_write(file_name) - .expect("Could not create file"); - // Write as i32 (old format) - file.write_i32::(npts).unwrap(); - file.write_i32::(dims).unwrap(); - } - - // Read as u32 (new format) - let mut reader = storage_provider.open_reader(file_name).unwrap(); - let metadata = read_metadata(&mut reader).unwrap(); - - assert_eq!(metadata.npoints, 200); - assert_eq!(metadata.ndims, 128); - - storage_provider - .delete(file_name) - .expect("Should be able to delete sample file"); - } - #[test] fn load_vector_ids_test() { let file_name = "/load_vector_ids_test"; @@ -356,56 +181,46 @@ mod storage_util_test { } #[test] - fn load_bin_test() { - let file_name = "/load_bin_test"; + fn test_read_bin_from() { + let file_name = "/read_bin_from"; let data = vec![0u64, 1u64, 2u64]; - let num_pts = data.len(); - let dims = 1; let storage_provider = VirtualStorageProvider::new_memory(); - let bytes_written = save_bin_u64( + let view = MatrixView::column_vector(data.as_slice()); + let bytes_written = write_bin_from( + view, &mut storage_provider.create_for_write(file_name).unwrap(), - &data, - num_pts, - dims, 0, ) .unwrap(); assert_eq!(bytes_written, 32); - let (load_data, load_num_pts, load_dims) = - load_bin::(&mut storage_provider.open_reader(file_name).unwrap(), 0).unwrap(); - assert_eq!(load_num_pts, num_pts); - assert_eq!(load_dims, dims); - assert_eq!(load_data, data); + let loaded = + read_bin_from::(&mut storage_provider.open_reader(file_name).unwrap(), 0).unwrap(); + assert_eq!(loaded.as_view(), view); storage_provider.delete(file_name).unwrap(); } #[test] - fn load_bin_offset_test() { + fn test_read_bin_from_offset_test() { let offset: usize = 32; - let file_name = "/load_bin_offset_test"; + let file_name = "/read_bin_from_offset_test"; let data = vec![0u64, 1u64, 2u64]; - let num_pts = data.len(); - let dims = 1; let storage_provider = VirtualStorageProvider::new_memory(); - let bytes_written = save_bin_u64( + let view = MatrixView::column_vector(data.as_slice()); + let bytes_written = write_bin_from( + view, &mut storage_provider.create_for_write(file_name).unwrap(), - &data, - num_pts, - dims, offset, ) .unwrap(); assert_eq!(bytes_written, 32); - let (load_data, load_num_pts, load_dims) = load_bin::( + let loaded = read_bin_from::( &mut storage_provider.open_reader(file_name).unwrap(), offset, ) .unwrap(); - assert_eq!(load_num_pts, num_pts); - assert_eq!(load_dims, dims); - assert_eq!(load_data, data); + assert_eq!(loaded.as_view(), view); storage_provider.delete(file_name).unwrap(); } @@ -451,17 +266,18 @@ mod storage_util_test { } #[test] - fn save_bin_test() { + fn write_bin_from_test() { let data = vec![0u64, 1u64, 2u64]; let num_pts = data.len(); let dims = 1; let mut file = tempfile().unwrap(); - let bytes_written = save_bin_u64::<_>(&mut file, &data, num_pts, dims, 0).unwrap(); + let view = MatrixView::column_vector(data.as_slice()); + let bytes_written = write_bin_from(view, &mut file, 0).unwrap(); assert_eq!(bytes_written, 32); let mut buffer = vec![]; file.seek(SeekFrom::Start(0)).unwrap(); - let metadata = read_metadata(&mut file).unwrap(); + let metadata = Metadata::read(&mut file).unwrap(); file.read_to_end(&mut buffer).unwrap(); let data_read: Vec = buffer @@ -469,83 +285,10 @@ mod storage_util_test { .map(|b| u64::from_le_bytes([b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7]])) .collect(); - assert_eq!(num_pts, metadata.npoints); - assert_eq!(dims, metadata.ndims); + assert_eq!(num_pts, metadata.npoints()); + assert_eq!(dims, metadata.ndims()); assert_eq!(data, data_read); } - - #[test] - fn write_metadata_unified_test() { - let mut buffer = Vec::new(); - - // Test with u32 values (no conversion) - let result = write_metadata(&mut buffer, 200u32, 128u32); - assert!(result.is_ok()); - assert_eq!(result.unwrap(), 8); - - // Test with usize values (safe conversion) - buffer.clear(); - let result = write_metadata(&mut buffer, 200usize, 128usize); - assert!(result.is_ok()); - assert_eq!(result.unwrap(), 8); - - // Test mixed types - buffer.clear(); - let result = write_metadata(&mut buffer, 200usize, 128u32); - assert!(result.is_ok()); - - // Verify the written data - let mut cursor = std::io::Cursor::new(&buffer); - let metadata = read_metadata(&mut cursor).unwrap(); - assert_eq!(metadata.npoints, 200); - assert_eq!(metadata.ndims, 128); - } - - #[test] - fn metadata_error_types_test() { - // Test NumPoints error - let large_value = u32::MAX as usize + 1; - let result = write_metadata(&mut Vec::new(), large_value, 128usize); - assert!(matches!(result, Err(MetadataError::NumPoints(_)))); - - // Test Dim error - let result = write_metadata(&mut Vec::new(), 128usize, large_value); - assert!(matches!(result, Err(MetadataError::Dim(_)))); - - // Test Write error - struct FailingWriter; - impl std::io::Write for FailingWriter { - fn write(&mut self, _: &[u8]) -> std::io::Result { - Err(std::io::Error::new( - std::io::ErrorKind::PermissionDenied, - "fail", - )) - } - fn flush(&mut self) -> std::io::Result<()> { - Ok(()) - } - } - - let result = write_metadata(&mut FailingWriter, 200u32, 128u32); - assert!(matches!(result, Err(MetadataError::Write(_)))); - } - - #[test] - fn metadata_error_to_ann_error_test() { - use diskann::{ANNError, ANNErrorKind}; - - // Test MetadataError -> ANNError conversion - let large_value = u32::MAX as usize + 1; - let result = write_metadata(&mut Vec::new(), large_value, 128usize); - let metadata_err = result.unwrap_err(); - let ann_error: ANNError = metadata_err.into(); - - assert_eq!(ann_error.kind(), ANNErrorKind::IOError); - - // Check that the error message contains information about the conversion - let error_str = ann_error.to_string(); - assert!(error_str.contains("num points conversion")); - } } #[cfg(test)] diff --git a/diskann-providers/src/utils/vector_data_iterator.rs b/diskann-providers/src/utils/vector_data_iterator.rs index a34ba283a..fe70d8f30 100644 --- a/diskann-providers/src/utils/vector_data_iterator.rs +++ b/diskann-providers/src/utils/vector_data_iterator.rs @@ -11,9 +11,10 @@ use std::{ use crate::storage::StorageReadProvider; use diskann::{ANNError, ANNErrorKind, utils::read_exact_into}; +use diskann_utils::io::Metadata; use thiserror::Error; -use crate::{model::graph::traits::GraphDataType, utils::read_metadata}; +use crate::model::graph::traits::GraphDataType; /// An iterator over the vector and associated data pairs in a dataset loaded from the storage provider. pub struct VectorDataIterator { @@ -39,16 +40,14 @@ impl ) -> std::io::Result> { let mut dataset_reader = read_provider.open_reader(vector_stream)?; - let vector_metadata = read_metadata(&mut dataset_reader)?; - let (vector_npts, vector_dim) = (vector_metadata.npoints, vector_metadata.ndims); + let (vector_npts, vector_dim) = Metadata::read(&mut dataset_reader)?.into_dims(); let (associated_data_reader, associated_data_length) = if let Some(associated_data_stream) = associated_data_stream { let mut associated_data_reader = read_provider.open_reader(&associated_data_stream)?; - let associated_metadata = read_metadata(&mut associated_data_reader)?; - let (num_pts, length) = (associated_metadata.npoints, associated_metadata.ndims); + let (num_pts, length) = Metadata::read(&mut associated_data_reader)?.into_dims(); if num_pts != vector_npts { return Err(std::io::Error::new( diff --git a/diskann-tools/src/bin/generate_minmax.rs b/diskann-tools/src/bin/generate_minmax.rs index 87c0b68dd..7ced9e2ef 100644 --- a/diskann-tools/src/bin/generate_minmax.rs +++ b/diskann-tools/src/bin/generate_minmax.rs @@ -11,7 +11,7 @@ use std::{ use anyhow::{Context, Result}; use clap::Parser; -use diskann_providers::utils::write_metadata; +use diskann_providers::storage::StorageReadProvider; use diskann_quantization::{ algorithms::transforms::{DoubleHadamard, TargetDim}, alloc::GlobalAllocator, @@ -19,6 +19,7 @@ use diskann_quantization::{ num::Positive, CompressInto, }; +use diskann_utils::io::Metadata; use half::f16; use rand::{rngs::StdRng, SeedableRng}; @@ -101,20 +102,15 @@ where diskann_quantization::bits::Unsigned: diskann_quantization::bits::Representation, { // Load input data - let (input_data, num_points, dim) = diskann_providers::utils::file_util::load_bin::( - &diskann_providers::storage::FileStorageProvider, - input_path, - 0, + let input_data = diskann_utils::io::read_bin::( + &mut diskann_providers::storage::FileStorageProvider + .open_reader(input_path) + .with_context(|| format!("Failed to open {}", input_path))?, ) .with_context(|| format!("Failed to load data from {}", input_path))?; - if input_data.len() != num_points * dim { - anyhow::bail!( - "Data size mismatch: expected {} elements, got {}", - num_points * dim, - input_data.len() - ); - } + let num_points = input_data.nrows(); + let dim = input_data.ncols(); println!("Input file: {} points, {} dimensions", num_points, dim); @@ -142,7 +138,8 @@ where let mut writer = BufWriter::new(output_file); // Write output header: num_points (u32) and bytes_per_vector (u32) - write_metadata(&mut writer, num_points, bytes_per_vector) + Metadata::new(num_points, bytes_per_vector)? + .write(&mut writer) .context("Failed to write metadata header")?; println!("Processing {} vectors...", num_points); @@ -152,12 +149,7 @@ where // Process vectors one by one for i in 0..num_points { // Get input vector - let start_idx = i * dim; - let end_idx = start_idx + dim; - let input_vector: Vec = input_data[start_idx..end_idx] - .iter() - .map(|x| (*x).into()) - .collect(); + let input_vector = input_data.row(i); // Create buffer for quantized data with proper alignment let mut quantized_buffer = vec![0u8; bytes_per_vector]; @@ -169,7 +161,7 @@ where // Compress the vector let loss_x = quantizer - .compress_into(input_vector.as_slice(), quantized_data) + .compress_into(input_vector, quantized_data) .with_context(|| format!("Failed to compress vector {}", i))?; loss += loss_x.as_f32(); diff --git a/diskann-tools/src/bin/subsample_bin.rs b/diskann-tools/src/bin/subsample_bin.rs index 891d71ec6..6612ea91b 100644 --- a/diskann-tools/src/bin/subsample_bin.rs +++ b/diskann-tools/src/bin/subsample_bin.rs @@ -15,8 +15,9 @@ use rand_distr::{Distribution, StandardUniform}; use diskann::utils::VectorRepr; use diskann_providers::storage::FileStorageProvider; use diskann_providers::storage::StorageWriteProvider; -use diskann_providers::utils::{random, write_metadata, SampleVectorReader, SamplingDensity}; +use diskann_providers::utils::{random, SampleVectorReader, SamplingDensity}; use diskann_tools::utils::DataType; +use diskann_utils::io::Metadata; /// Subsamples vectors from a DiskANN style binary file. #[derive(Parser, Debug)] @@ -83,7 +84,7 @@ where let mut writer = storage_provider.create_for_write(&output_file)?; // Write metadata with a temporary count, then fix it after sampling. - write_metadata(&mut writer, npts, dims)?; + Metadata::new(npts, dims)?.write(&mut writer)?; let mut sampled_count: u32 = 0; reader.read_vectors(sampled_indices, |vec_t| { @@ -94,7 +95,7 @@ where // Rewrite metadata at the start of the file with the actual sampled count. writer.seek(SeekFrom::Start(0))?; - write_metadata(&mut writer, sampled_count, dims)?; + Metadata::new(sampled_count, dims)?.write(&mut writer)?; println!( "Wrote {} points to sample file {}", diff --git a/diskann-tools/src/utils/build_disk_index.rs b/diskann-tools/src/utils/build_disk_index.rs index 9fe248f5b..916c24c8a 100644 --- a/diskann-tools/src/utils/build_disk_index.rs +++ b/diskann-tools/src/utils/build_disk_index.rs @@ -113,17 +113,17 @@ where let metadata = load_metadata_from_file(storage_provider, parameters.data_path)?; - if metadata.ndims != parameters.dim_values.dim() { + if metadata.ndims() != parameters.dim_values.dim() { return Err(ANNError::log_index_config_error( format!("{:?}", parameters.dim_values), - format!("dim_values must match with data_dim {}", metadata.ndims), + format!("dim_values must match with data_dim {}", metadata.ndims()), )); } let index_configuration = IndexConfiguration::new( parameters.metric, - metadata.ndims, - metadata.npoints, + metadata.ndims(), + metadata.npoints(), ONE, parameters.num_threads, config, diff --git a/diskann-tools/src/utils/build_pq.rs b/diskann-tools/src/utils/build_pq.rs index e753f6589..becdeb1d6 100644 --- a/diskann-tools/src/utils/build_pq.rs +++ b/diskann-tools/src/utils/build_pq.rs @@ -48,10 +48,11 @@ pub fn build_pq( let metadata = load_metadata_from_file(storage_provider, parameters.data_path)?; info!( "Compressing dim-{} data into {} chunks(bytes) for PQ", - metadata.ndims, num_pq_chunks + metadata.ndims(), + num_pq_chunks ); - let p_val = MAX_PQ_TRAINING_SET_SIZE / (metadata.npoints as f64); + let p_val = MAX_PQ_TRAINING_SET_SIZE / (metadata.npoints() as f64); let timer = Timer::new(); let storage_provider = FileStorageProvider; diff --git a/diskann-tools/src/utils/cmd_tool_error.rs b/diskann-tools/src/utils/cmd_tool_error.rs index 9479c1f37..f84527a76 100644 --- a/diskann-tools/src/utils/cmd_tool_error.rs +++ b/diskann-tools/src/utils/cmd_tool_error.rs @@ -53,6 +53,20 @@ impl From for CMDToolError { } } } +impl From for CMDToolError { + fn from(err: diskann_utils::io::ReadBinError) -> Self { + CMDToolError { + details: err.to_string(), + } + } +} +impl From for CMDToolError { + fn from(err: diskann_utils::io::SaveBinError) -> Self { + CMDToolError { + details: err.to_string(), + } + } +} impl From for CMDToolError { fn from(err: diskann::graph::config::ConfigError) -> Self { CMDToolError { @@ -69,12 +83,12 @@ impl From for CMDToolError { } } -impl From> for CMDToolError +impl From> for CMDToolError where T: std::error::Error + Send + Sync + 'static, U: std::error::Error + Send + Sync + 'static, { - fn from(err: diskann_providers::utils::MetadataError) -> Self { + fn from(err: diskann_utils::io::MetadataError) -> Self { // Leverage the existing conversion chain: MetadataError -> ANNError -> CMDToolError let ann_error: diskann::ANNError = err.into(); ann_error.into() diff --git a/diskann-tools/src/utils/gen_associated_data_from_range.rs b/diskann-tools/src/utils/gen_associated_data_from_range.rs index 4752a892e..7d91ec865 100644 --- a/diskann-tools/src/utils/gen_associated_data_from_range.rs +++ b/diskann-tools/src/utils/gen_associated_data_from_range.rs @@ -6,7 +6,7 @@ use std::io::Write; use diskann_providers::storage::StorageWriteProvider; -use diskann_providers::utils::write_metadata; +use diskann_utils::io::Metadata; use super::CMDResult; @@ -23,7 +23,7 @@ pub fn gen_associated_data_from_range( let int_length: u32 = 1; // Write the number of integers and the length of each integer as little endian - write_metadata(&mut file, num_ints, int_length)?; + Metadata::new(num_ints, int_length)?.write(&mut file)?; // Write the integers from the range as little endian for i in start..=end { diff --git a/diskann-tools/src/utils/ground_truth.rs b/diskann-tools/src/utils/ground_truth.rs index ad7da38f2..bea41b1e6 100644 --- a/diskann-tools/src/utils/ground_truth.rs +++ b/diskann-tools/src/utils/ground_truth.rs @@ -17,11 +17,12 @@ use diskann_providers::storage::{StorageReadProvider, StorageWriteProvider}; use diskann_providers::{ common::AlignedBoxWithSlice, model::graph::traits::GraphDataType, - utils::{ - create_thread_pool, file_util, write_metadata, ParallelIteratorInPool, VectorDataIterator, - }, + utils::{create_thread_pool, file_util, ParallelIteratorInPool, VectorDataIterator}, +}; +use diskann_utils::{ + io::{read_bin, Metadata}, + views::Matrix, }; -use diskann_utils::views::Matrix; use diskann_vector::{distance::Metric, DistanceFunction}; use itertools::Itertools; use rayon::prelude::*; @@ -121,10 +122,10 @@ pub fn compute_ground_truth_from_datafiles< }; // Load the query file - let (raw_query_data, query_num, query_dim) = file_util::load_bin::< - Data::VectorDataType, - StorageProvider, - >(storage_provider, query_file, 0)?; + let query_data = + read_bin::(&mut storage_provider.open_reader(query_file)?)?; + let query_num = query_data.nrows(); + let query_dim = query_data.ncols(); let mut query_bitmaps: Option> = None; if let (Some(base_file_labels), Some(query_file_labels)) = (base_file_labels, query_file_labels) @@ -135,7 +136,7 @@ pub fn compute_ground_truth_from_datafiles< )?); } - let queries: Vec<_> = raw_query_data.chunks(query_dim).collect(); + let queries: Vec<_> = query_data.row_iter().collect(); // Load the vector filters let vector_filters = match vector_filters_file { @@ -375,11 +376,11 @@ pub fn compute_range_search_ground_truth_from_datafiles< )?; // Load the query file - let (raw_query_data, query_num, query_dim) = file_util::load_bin::< - Data::VectorDataType, - StorageProvider, - >(storage_provider, query_file, 0)?; - let queries: Vec<_> = raw_query_data.chunks(query_dim).collect(); + let query_data = + read_bin::(&mut storage_provider.open_reader(query_file)?)?; + let query_num = query_data.nrows(); + let query_dim = query_data.ncols(); + let queries: Vec<_> = query_data.row_iter().collect(); let query_aligned_dim = query_dim.next_multiple_of(8); let ground_truth_result = compute_range_search_ground_truth_from_data::< @@ -427,7 +428,7 @@ fn write_range_search_ground_truth() as usize; // Metadata - write_metadata(&mut file, number_of_queries, total_number_of_neighbors)?; + Metadata::new(number_of_queries, total_number_of_neighbors)?.write(&mut file)?; // Write queue sizes array. let mut queue_sizes_buffer = vec![0; queue_sizes.len() * size_of::()]; @@ -467,7 +468,7 @@ fn write_ground_truth( ) -> CMDResult<()> { let mut file = storage_provider.create_for_write(ground_truth_file)?; - write_metadata(&mut file, number_of_queries, number_of_neighbors)?; + Metadata::new(number_of_queries, number_of_neighbors)?.write(&mut file)?; let mut gt_ids: Vec = Vec::with_capacity(number_of_neighbors * number_of_queries); let mut gt_distances: Vec = Vec::with_capacity(number_of_neighbors * number_of_queries); diff --git a/diskann-tools/src/utils/random_data_generator.rs b/diskann-tools/src/utils/random_data_generator.rs index 8798d2533..a3c3863e0 100644 --- a/diskann-tools/src/utils/random_data_generator.rs +++ b/diskann-tools/src/utils/random_data_generator.rs @@ -6,10 +6,8 @@ use std::io::{BufWriter, Write}; use byteorder::{LittleEndian, WriteBytesExt}; -use diskann_providers::{ - storage::StorageWriteProvider, - utils::{math_util, write_metadata}, -}; +use diskann_providers::{storage::StorageWriteProvider, utils::math_util}; +use diskann_utils::io::Metadata; use diskann_vector::Half; use crate::utils::{CMDResult, CMDToolError, DataType}; @@ -75,7 +73,7 @@ pub fn write_random_data_writer( }); } - write_metadata(&mut writer, number_of_vectors, number_of_dimensions)?; + Metadata::new(number_of_vectors, number_of_dimensions)?.write(&mut writer)?; let block_size = 131072; let nblks = u64::div_ceil(number_of_vectors, block_size); diff --git a/diskann-tools/src/utils/relative_contrast.rs b/diskann-tools/src/utils/relative_contrast.rs index 7e3c9b0b9..3b521e4e6 100644 --- a/diskann-tools/src/utils/relative_contrast.rs +++ b/diskann-tools/src/utils/relative_contrast.rs @@ -4,8 +4,9 @@ */ use diskann::{utils::VectorRepr, ANNError}; +use diskann_providers::model::graph::traits::GraphDataType; use diskann_providers::storage::StorageReadProvider; -use diskann_providers::{model::graph::traits::GraphDataType, utils::file_util::load_bin}; +use diskann_utils::io::read_bin; use rand::Rng; use crate::utils::{CMDResult, CMDToolError}; @@ -57,9 +58,16 @@ pub fn compute_relative_contrast< rng: &mut R, ) -> CMDResult { // Load base, query, and ground truth data - let (base_flat, nb, dim) = load_bin::(storage_provider, base_file, 0)?; - let (query_flat, nq, _) = load_bin::(storage_provider, query_file, 0)?; - let (gt_flat, _, ngt) = load_bin::(storage_provider, gt_file, 0)?; + let base_data = + read_bin::(&mut storage_provider.open_reader(base_file)?)?; + let query_data = + read_bin::(&mut storage_provider.open_reader(query_file)?)?; + let gt_data = read_bin::(&mut storage_provider.open_reader(gt_file)?)?; + + let nb = base_data.nrows(); + let dim = base_data.ncols(); + let nq = query_data.nrows(); + let ngt = gt_data.ncols(); tracing::info!( "Loaded base: {} points, query: {} points, dimension: {}, ground truth neighbors: {}", @@ -70,10 +78,9 @@ pub fn compute_relative_contrast< ); // Reshape flat vectors into 2D vectors - let base: Vec> = base_flat.chunks(dim).map(|x| x.to_vec()).collect(); - let query: Vec> = - query_flat.chunks(dim).map(|x| x.to_vec()).collect(); - let gt: Vec> = gt_flat.chunks(ngt).map(|x| x.to_vec()).collect(); + let base: Vec> = base_data.row_iter().map(|x| x.to_vec()).collect(); + let query: Vec> = query_data.row_iter().map(|x| x.to_vec()).collect(); + let gt: Vec> = gt_data.row_iter().map(|x| x.to_vec()).collect(); let mut mean_rc = 0.0; @@ -111,7 +118,7 @@ pub fn compute_relative_contrast< mod relative_contrast_tests { use diskann_providers::storage::{StorageWriteProvider, VirtualStorageProvider}; use diskann_providers::utils::random; - use diskann_providers::utils::write_metadata; + use diskann_utils::io::Metadata; use diskann_vector::distance::Metric; use half::f16; use rand::Rng; @@ -144,7 +151,10 @@ mod relative_contrast_tests { let base_file_path = "/base.bin"; { let mut base_writer = storage_provider.create_for_write(base_file_path).unwrap(); - write_metadata(&mut base_writer, num_vectors, dim).unwrap(); + Metadata::new(num_vectors, dim) + .unwrap() + .write(&mut base_writer) + .unwrap(); for value in &base { base_writer.write_all(&value.to_le_bytes()).unwrap(); } @@ -154,7 +164,10 @@ mod relative_contrast_tests { let query_file_path = "/query.bin"; { let mut query_writer = storage_provider.create_for_write(query_file_path).unwrap(); - write_metadata(&mut query_writer, num_queries, dim).unwrap(); + Metadata::new(num_queries, dim) + .unwrap() + .write(&mut query_writer) + .unwrap(); for value in &query { query_writer.write_all(&value.to_le_bytes()).unwrap(); } @@ -228,7 +241,10 @@ mod relative_contrast_tests { let mut query_writer = storage_provider .create_for_write(query_file_path) .expect("Failed to create query file in memory"); - write_metadata(&mut query_writer, num_queries, dim).expect("Failed to write metadata"); + Metadata::new(num_queries, dim) + .expect("Failed to create metadata") + .write(&mut query_writer) + .expect("Failed to write metadata"); for value in &query { query_writer .write_all(&value.to_le_bytes()) diff --git a/diskann-tools/src/utils/search_disk_index.rs b/diskann-tools/src/utils/search_disk_index.rs index 1a2936f5e..888988ca0 100644 --- a/diskann-tools/src/utils/search_disk_index.rs +++ b/diskann-tools/src/utils/search_disk_index.rs @@ -21,8 +21,9 @@ use diskann_providers::storage::{StorageReadProvider, StorageWriteProvider}; use diskann_providers::{ model::graph::traits::GraphDataType, storage::{get_compressed_pq_file, get_pq_pivot_file}, - utils::{create_thread_pool, load_aligned_bin, save_bin_u32, ParallelIteratorInPool}, + utils::{create_thread_pool, load_aligned_bin, ParallelIteratorInPool}, }; +use diskann_utils::{io::write_bin, views::MatrixView}; use diskann_vector::distance::Metric; use opentelemetry::global::BoxedSpan; #[cfg(feature = "perf_test")] @@ -412,12 +413,17 @@ where "{}_{}_idx_uint32.bin", parameters.result_output_prefix, l_value ); - save_bin_u32( - &mut storage_provider.create_for_write(&cur_result_path)?, + let view = MatrixView::try_from( query_result_ids[test_id].as_slice(), query_num, parameters.recall_at as usize, - 0, + ) + .map_err(|e| CMDToolError { + details: e.to_string(), + })?; + write_bin( + view, + &mut storage_provider.create_for_write(&cur_result_path)?, )?; } diff --git a/diskann-tools/src/utils/search_index_utils.rs b/diskann-tools/src/utils/search_index_utils.rs index d2b0751ea..ddbb0c062 100644 --- a/diskann-tools/src/utils/search_index_utils.rs +++ b/diskann-tools/src/utils/search_index_utils.rs @@ -8,7 +8,7 @@ use bytemuck::cast_slice; use diskann::{ANNError, ANNResult}; use diskann_providers::model::graph::traits::GraphDataType; use diskann_providers::storage::StorageReadProvider; -use diskann_providers::utils::read_metadata; +use diskann_utils::io::Metadata; use tracing::{error, info}; use crate::utils::CMDToolError; @@ -366,8 +366,8 @@ pub fn load_truthset( let actual_file_size = storage_provider.get_length(bin_file)? as usize; let mut file = storage_provider.open_reader(bin_file)?; - let metadata = read_metadata(&mut file)?; - let (npts, dim) = (metadata.npoints, metadata.ndims); + let metadata = Metadata::read(&mut file)?; + let (npts, dim) = metadata.into_dims(); info!("Metadata: #pts = {npts}, #dims = {dim}... "); @@ -420,8 +420,8 @@ pub fn load_truthset_with_associated_data( ) -> ANNResult> { let mut file = storage_provider.open_reader(bin_file)?; - let metadata = read_metadata(&mut file)?; - let (npts, dim) = (metadata.npoints, metadata.ndims); + let metadata = Metadata::read(&mut file)?; + let (npts, dim) = metadata.into_dims(); info!("Metadata: #pts = {}, #dims = {}...", npts, dim); @@ -469,8 +469,8 @@ pub fn load_range_truthset( ) -> ANNResult { let mut file = storage_provider.open_reader(bin_file)?; - let metadata = read_metadata(&mut file)?; - let (npts, total_ids) = (metadata.npoints, metadata.ndims); + let metadata = Metadata::read(&mut file)?; + let (npts, total_ids) = metadata.into_dims(); let mut buffer = [0; size_of::()]; info!("Metadata: #pts = {}, #totalIds = {}", npts, total_ids); diff --git a/diskann-utils/Cargo.toml b/diskann-utils/Cargo.toml index 06c161978..a72c22e77 100644 --- a/diskann-utils/Cargo.toml +++ b/diskann-utils/Cargo.toml @@ -15,6 +15,7 @@ thiserror.workspace = true diskann-vector = { workspace = true } rand = { workspace = true } diskann-wide = { workspace = true } +bytemuck = { workspace = true, features = ["must_cast"] } [lints] workspace = true diff --git a/diskann-utils/src/io.rs b/diskann-utils/src/io.rs new file mode 100644 index 000000000..0ac45b03f --- /dev/null +++ b/diskann-utils/src/io.rs @@ -0,0 +1,328 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! Read and write vectors in the DiskANN binary format. +//! +//! The binary format is: +//! - 8-byte header +//! - `npoints` (u32 LE) +//! - `ndims` (u32 LE) +//! - Payload: `npoints × ndims` elements of `T`, tightly packed in row-major order + +use std::io::{Read, Seek, Write}; + +use diskann_wide::{LoHi, SplitJoin}; +use thiserror::Error; + +use crate::views::{Matrix, MatrixView}; + +/// Read a matrix of `T` from the DiskANN binary format (see [module docs](self)). +/// +/// Validates that the reader contains enough data before allocating. +pub fn read_bin(reader: &mut (impl Read + Seek)) -> Result, ReadBinError> +where + T: bytemuck::Pod, +{ + let metadata = Metadata::read(reader)?; + let (npoints, ndims) = (metadata.npoints(), metadata.ndims()); + let type_size = std::mem::size_of::(); + + let expected_bytes = npoints + .checked_mul(ndims) + .and_then(|n| n.checked_mul(type_size)) + .ok_or(ReadBinError::Overflow { + npoints: metadata.npoints_u32(), + ndims: metadata.ndims_u32(), + type_size, + })?; + + let data_start = reader.stream_position()?; + let end = reader.seek(std::io::SeekFrom::End(0))?; + let available = end - data_start; + reader.seek(std::io::SeekFrom::Start(data_start))?; + + if available < expected_bytes as u64 { + return Err(ReadBinError::SizeMismatch { + expected: expected_bytes as u64, + available, + npoints: metadata.npoints_u32(), + ndims: metadata.ndims_u32(), + type_size, + }); + } + + let mut data = Matrix::new(::zeroed(), npoints, ndims); + + reader.read_exact(bytemuck::must_cast_slice_mut::(data.as_mut_slice()))?; + Ok(data) +} + +/// Write a matrix of `T` in the DiskANN binary format (see [module docs](self)). +/// +/// Returns the total number of bytes written. +pub fn write_bin(data: MatrixView<'_, T>, writer: &mut impl Write) -> Result +where + T: bytemuck::Pod, +{ + let metadata = + Metadata::new(data.nrows(), data.ncols()).map_err(|_| SaveBinError::DimensionOverflow { + nrows: data.nrows(), + ncols: data.ncols(), + })?; + let bytes = metadata.write(writer)?; + writer.write_all(bytemuck::must_cast_slice::(data.as_slice()))?; + Ok(bytes + std::mem::size_of_val(data.as_slice())) +} + +/// 8-byte header at the start of a DiskANN binary file: `npoints` and `ndims` as little-endian u32. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct Metadata { + npoints: u32, + ndims: u32, +} + +impl Metadata { + /// Construct from any integer types that fit in `u32`. + pub fn new(npoints: T, ndims: U) -> Result> + where + T: TryInto, + U: TryInto, + { + Ok(Self { + npoints: npoints.try_into().map_err(MetadataError::NumPoints)?, + ndims: ndims.try_into().map_err(MetadataError::Dim)?, + }) + } + + /// Number of points as `usize`. + pub fn npoints(&self) -> usize { + self.npoints as usize + } + + /// Number of points as `u32`. + pub fn npoints_u32(&self) -> u32 { + self.npoints + } + + /// Number of dimensions as `usize`. + pub fn ndims(&self) -> usize { + self.ndims as usize + } + + /// Number of dimensions as `u32`. + pub fn ndims_u32(&self) -> u32 { + self.ndims + } + + /// Destructure into (`npoints`, `ndims`) as `usize`. + pub fn into_dims(&self) -> (usize, usize) { + (self.npoints(), self.ndims()) + } + + /// Deserialize the 8-byte header from a reader. + pub fn read(reader: &mut R) -> std::io::Result + where + R: Read, + { + let mut bytes = [0u8; 8]; + reader.read_exact(&mut bytes)?; + + let LoHi { + lo: npts_bytes, + hi: ndims_bytes, + } = bytes.split(); + + let npoints = u32::from_le_bytes(npts_bytes); + let ndims = u32::from_le_bytes(ndims_bytes); + Ok(Metadata { npoints, ndims }) + } + + /// Serialize the 8-byte header to a writer. Returns the number of bytes written (always 8). + pub fn write(&self, writer: &mut W) -> std::io::Result + where + W: Write, + { + let bytes: [u8; 8] = LoHi::new(self.npoints.to_le_bytes(), self.ndims.to_le_bytes()).join(); + writer.write_all(&bytes)?; + Ok(2 * std::mem::size_of::()) + } +} + +#[derive(Debug, Error)] +pub enum MetadataError { + #[error("num points conversion")] + NumPoints(#[source] T), + #[error("dim conversion")] + Dim(#[source] U), +} + +/// Error type for [`read_bin`]. +#[derive(Debug, Error)] +pub enum ReadBinError { + /// The reader has fewer bytes remaining than the header declares. + #[error( + "binary data too short: header declares {npoints} points × {ndims} dims × {type_size} bytes = \ + {expected} bytes, but only {available} bytes available" + )] + SizeMismatch { + expected: u64, + available: u64, + npoints: u32, + ndims: u32, + type_size: usize, + }, + + /// `npoints * ndims` overflows `usize` (corrupt or malicious header). + #[error( + "header dimensions overflow: {npoints} points × {ndims} dims × {type_size} bytes overflows" + )] + Overflow { + npoints: u32, + ndims: u32, + type_size: usize, + }, + + /// Underlying IO failure. + #[error(transparent)] + Io(#[from] std::io::Error), +} + +/// Error type for [`write_bin`]. +#[derive(Debug, Error)] +pub enum SaveBinError { + /// Matrix dimensions exceed `u32::MAX` and cannot be represented in the binary header. + #[error("dimensions overflow u32: {nrows} rows × {ncols} cols")] + DimensionOverflow { nrows: usize, ncols: usize }, + + /// Underlying IO failure. + #[error(transparent)] + Io(#[from] std::io::Error), +} + +/////////// +// Tests // +/////////// + +#[cfg(test)] +mod tests { + use std::io::Cursor; + + use crate::views::Init; + + use super::*; + + #[test] + fn round_trip_f32() { + let mut counter = 1.0f32; + let matrix = Matrix::::new( + Init(|| { + let v = counter; + counter += 1.0; + v + }), + 3, + 4, + ); + + assert_eq!( + matrix.as_slice(), + &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0] + ); + + let mut buf = Vec::new(); + let written = write_bin(matrix.as_view(), &mut buf).unwrap(); + assert_eq!(written, 8 + 3 * 4 * 4); + + let mut cursor = Cursor::new(&buf); + let loaded = read_bin::(&mut cursor).unwrap(); + assert_eq!(loaded.nrows(), 3); + assert_eq!(loaded.ncols(), 4); + assert_eq!(loaded.as_slice(), matrix.as_slice()); + } + + #[test] + fn read_bin_size_mismatch() { + // Header says 10 points × 4 dims of f32, but only provide 8 bytes of payload + let mut buf = Vec::new(); + let metadata = Metadata::new(10u32, 4u32).unwrap(); + metadata.write(&mut buf).unwrap(); + buf.extend_from_slice(&[0u8; 8]); + + let mut cursor = Cursor::new(&buf); + let err = read_bin::(&mut cursor).unwrap_err(); + + match err { + ReadBinError::SizeMismatch { + expected, + available, + npoints, + ndims, + type_size, + } => { + assert_eq!(expected, 10 * 4 * 4); + assert_eq!(available, 8); + assert_eq!(npoints, 10); + assert_eq!(ndims, 4); + assert_eq!(type_size, 4); + } + other => panic!("expected SizeMismatch, got: {other}"), + } + } + + #[test] + fn read_bin_overflow() { + // Header with huge values that overflow usize multiplication + let mut buf = Vec::new(); + buf.extend_from_slice(&u32::MAX.to_le_bytes()); + buf.extend_from_slice(&u32::MAX.to_le_bytes()); + + let mut cursor = Cursor::new(&buf); + let err = read_bin::(&mut cursor).unwrap_err(); + + match err { + ReadBinError::Overflow { + npoints, + ndims, + type_size, + } => { + assert_eq!(npoints, u32::MAX); + assert_eq!(ndims, u32::MAX); + assert_eq!(type_size, 4); + } + other => panic!("expected Overflow, got: {other}"), + } + } + + #[test] + fn read_bin_error_message_is_informative() { + let mut buf = Vec::new(); + let metadata = Metadata::new(100u32, 32u32).unwrap(); + metadata.write(&mut buf).unwrap(); + // no payload + + let mut cursor = Cursor::new(&buf); + let err = read_bin::(&mut cursor).unwrap_err(); + let msg = err.to_string(); + + assert!(msg.contains("100 points"), "missing npoints: {msg}"); + assert!(msg.contains("32 dims"), "missing ndims: {msg}"); + assert!(msg.contains("12800 bytes"), "missing expected: {msg}"); + assert!( + msg.contains("0 bytes available"), + "missing available: {msg}" + ); + } + + #[test] + fn metadata_read_write_round_trip() { + let mut buf = Vec::new(); + let metadata = Metadata::new(200u32, 128u32).unwrap(); + metadata.write(&mut buf).unwrap(); + + let mut cursor = Cursor::new(&buf); + let loaded = Metadata::read(&mut cursor).unwrap(); + assert_eq!(loaded, metadata); + } +} diff --git a/diskann-utils/src/lib.rs b/diskann-utils/src/lib.rs index f399d3ac1..cd8c1b84d 100644 --- a/diskann-utils/src/lib.rs +++ b/diskann-utils/src/lib.rs @@ -3,6 +3,9 @@ * Licensed under the MIT license. */ +#[cfg(not(target_endian = "little"))] +compile_error!("diskann-utils assumes little-endian targets"); + pub mod reborrow; pub use reborrow::{Reborrow, ReborrowMut}; @@ -11,6 +14,7 @@ pub use lifetime::WithLifetime; pub mod future; +pub mod io; pub mod sampling; // Views diff --git a/diskann-utils/src/views.rs b/diskann-utils/src/views.rs index 149ff7dd5..a9352918c 100644 --- a/diskann-utils/src/views.rs +++ b/diskann-utils/src/views.rs @@ -228,6 +228,21 @@ where self.nrows } + /// Create a new [`Matrix`] by applying the closure `f` to each element. + /// + /// The returned matrix has the same shape as `self`. + pub fn map(&self, f: F) -> Matrix + where + F: FnMut(&T::Elem) -> R, + { + let data: Box<[_]> = self.as_slice().iter().map(f).collect(); + Matrix { + data, + nrows: self.nrows(), + ncols: self.ncols(), + } + } + /// Return the underlying data as a slice. pub fn as_slice(&self) -> &[T::Elem] { self.data.as_slice() @@ -268,6 +283,19 @@ where ncols, } } + + /// Construct a new `MatrixBase` over the raw data. + /// + /// The returned `MatrixBase` will only have a single column with contents equal to `data`. + pub fn column_vector(data: T) -> Self { + let nrows = data.as_slice().len(); + Self { + data, + nrows, + ncols: 1, + } + } + /// Return row `row` if `row < self.nrows()`. Otherwise, return `None`. pub fn get_row(&self, row: usize) -> Option<&[T::Elem]> { if row < self.nrows() { @@ -557,6 +585,18 @@ where self.as_mut_slice().as_mut_ptr() } + /// Return the value at the specified `row` and `col`. + /// + /// If either index is out-of-bounds, return `None`. + pub fn try_get(&self, row: usize, col: usize) -> Option<&T::Elem> { + if row >= self.nrows() || col >= self.ncols() { + None + } else { + // SAFETY: We just verified that `row` and `col` are in-bounds. + Some(unsafe { self.get_unchecked(row, col) }) + } + } + /// Returns a reference to an element without boundschecking. /// /// # Safety @@ -1169,6 +1209,7 @@ mod tests { #[should_panic(expected = "row 3 is out of bounds (max: 3)")] fn test_index_panics_row() { let m = Matrix::::new(0, 3, 7); + assert!(m.try_get(3, 2).is_none()); let _ = m[(3, 2)]; } @@ -1176,6 +1217,7 @@ mod tests { #[should_panic(expected = "col 7 is out of bounds (max: 7)")] fn test_index_panics_col() { let m = Matrix::::new(0, 3, 7); + assert!(m.try_get(2, 7).is_none()); let _ = m[(2, 7)]; } @@ -1373,6 +1415,7 @@ mod tests { assert_eq!(m.nrows(), 1); assert_eq!(m.ncols(), 1); assert_eq!(m[(0, 0)], 42); + assert_eq!(*m.try_get(0, 0).unwrap(), 42); // Test single row matrix let m = Matrix::new(7, 1, 5); @@ -1398,6 +1441,8 @@ mod tests { assert_eq!(m.ncols(), 1); assert_eq!(m[(0, 0)], 10); assert_eq!(m[(1, 0)], 20); + assert_eq!(*m.try_get(0, 0).unwrap(), 10); + assert_eq!(*m.try_get(1, 0).unwrap(), 20); assert_eq!(m.row(0), &[10]); assert_eq!(m.row(1), &[20]); @@ -1407,9 +1452,84 @@ mod tests { assert_eq!(m.ncols(), 2); assert_eq!(m[(0, 0)], 10); assert_eq!(m[(0, 1)], 20); + assert_eq!(*m.try_get(0, 0).unwrap(), 10); + assert_eq!(*m.try_get(0, 1).unwrap(), 20); assert_eq!(m.row(0), &[10, 20]); } + #[test] + fn test_row_vector() { + let data = vec![1, 2, 3]; + let m = MatrixView::row_vector(data.as_slice()); + assert_eq!(m.nrows(), 1); + assert_eq!(m.ncols(), 3); + assert_eq!(m.as_slice(), &[1, 2, 3]); + assert_eq!(m.row(0), &[1, 2, 3]); + + // Empty + let empty: &[i32] = &[]; + let m = MatrixView::row_vector(empty); + assert_eq!(m.nrows(), 1); + assert_eq!(m.ncols(), 0); + + // Owned + let m = Matrix::row_vector(vec![10u64, 20].into_boxed_slice()); + assert_eq!(m.nrows(), 1); + assert_eq!(m.ncols(), 2); + assert_eq!(m[(0, 0)], 10); + assert_eq!(m[(0, 1)], 20); + } + + #[test] + fn test_column_vector() { + let data = vec![1, 2, 3]; + let m = MatrixView::column_vector(data.as_slice()); + assert_eq!(m.nrows(), 3); + assert_eq!(m.ncols(), 1); + assert_eq!(m.as_slice(), &[1, 2, 3]); + assert_eq!(m[(0, 0)], 1); + assert_eq!(m[(1, 0)], 2); + assert_eq!(m[(2, 0)], 3); + assert_eq!(m.row(0), &[1]); + assert_eq!(m.row(1), &[2]); + assert_eq!(m.row(2), &[3]); + + // Empty + let empty: &[i32] = &[]; + let m = MatrixView::column_vector(empty); + assert_eq!(m.nrows(), 0); + assert_eq!(m.ncols(), 1); + + // Owned + let m = Matrix::column_vector(vec![10u64, 20].into_boxed_slice()); + assert_eq!(m.nrows(), 2); + assert_eq!(m.ncols(), 1); + assert_eq!(m[(0, 0)], 10); + assert_eq!(m[(1, 0)], 20); + } + + #[test] + fn test_map() { + let m = Matrix::try_from(vec![1u32, 2, 3, 4].into(), 2, 2).unwrap(); + let doubled = m.map(|&x| x * 2); + assert_eq!(doubled.as_slice(), &[2, 4, 6, 8]); + assert_eq!(doubled.nrows(), 2); + assert_eq!(doubled.ncols(), 2); + + // Type-changing map + let as_f64 = m.map(|&x| x as f64); + assert_eq!(as_f64.as_slice(), &[1.0, 2.0, 3.0, 4.0]); + } + + #[test] + fn test_try_get() { + let m = Matrix::try_from(vec![1, 2, 3, 4, 5, 6].into(), 2, 3).unwrap(); + assert_eq!(m.try_get(0, 0), Some(&1)); + assert_eq!(m.try_get(1, 2), Some(&6)); + assert_eq!(m.try_get(2, 0), None); + assert_eq!(m.try_get(0, 3), None); + } + #[test] fn test_subview() { let data = make_test_matrix(); diff --git a/diskann/src/error/ann_error.rs b/diskann/src/error/ann_error.rs index 040e06037..d6f1f7318 100644 --- a/diskann/src/error/ann_error.rs +++ b/diskann/src/error/ann_error.rs @@ -576,6 +576,48 @@ impl From for ANNError { } } +impl From for ANNError { + #[track_caller] + fn from(err: diskann_utils::io::ReadBinError) -> Self { + ANNError::new(ANNErrorKind::IOError, err) + } +} + +impl From for ANNError { + #[track_caller] + fn from(err: diskann_utils::io::SaveBinError) -> Self { + ANNError::new(ANNErrorKind::IOError, err) + } +} + +impl From> for ANNError +where + T: std::error::Error + Send + Sync + 'static, + U: std::error::Error + Send + Sync + 'static, +{ + #[track_caller] + fn from(err: diskann_utils::io::MetadataError) -> Self { + ANNError::new(ANNErrorKind::IOError, err) + } +} + +impl From for ANNError { + #[track_caller] + fn from(err: diskann_utils::views::TryFromErrorLight) -> Self { + ANNError::new(ANNErrorKind::DimensionMismatchError, err) + } +} + +impl From> for ANNError +where + T: diskann_utils::views::DenseData, +{ + #[track_caller] + fn from(err: diskann_utils::views::TryFromError) -> Self { + Self::from(err.as_static()) + } +} + /// An internal wrapper for error types that also tracks the file and line information /// for where the error was first converted and where context was propagated. #[derive(Debug)] @@ -1501,4 +1543,49 @@ Caused by: let ann_err = ANNError::log_build_interrupted(message); assert_eq!(ann_err.kind(), ANNErrorKind::BuildInterrupted); } + + #[test] + fn from_read_bin_error() { + let err = diskann_utils::io::ReadBinError::SizeMismatch { + expected: 100, + available: 50, + npoints: 10, + ndims: 5, + type_size: 2, + }; + let ann_err = ANNError::from(err); + assert_eq!(ann_err.kind(), ANNErrorKind::IOError); + } + + #[test] + fn from_save_bin_error() { + let err = diskann_utils::io::SaveBinError::DimensionOverflow { nrows: 1, ncols: 1 }; + let ann_err = ANNError::from(err); + assert_eq!(ann_err.kind(), ANNErrorKind::IOError); + } + + #[test] + fn from_metadata_error() { + let err = diskann_utils::io::Metadata::new(u64::MAX, 1u32).unwrap_err(); + let ann_err = ANNError::from(err); + assert_eq!(ann_err.kind(), ANNErrorKind::IOError); + } + + #[test] + fn from_try_from_error_light() { + let data: &[f32] = &[1.0, 2.0, 3.0]; + let light = diskann_utils::views::MatrixView::try_from(data, 2, 2) + .unwrap_err() + .as_static(); + let ann_err = ANNError::from(light); + assert_eq!(ann_err.kind(), ANNErrorKind::DimensionMismatchError); + } + + #[test] + fn from_try_from_error() { + let data: &[f32] = &[1.0, 2.0, 3.0]; + let err = diskann_utils::views::MatrixView::try_from(data, 2, 2).unwrap_err(); + let ann_err = ANNError::from(err); + assert_eq!(ann_err.kind(), ANNErrorKind::DimensionMismatchError); + } }