use std::{ fs, io::{self, Cursor, Read, Write}, marker::PhantomData, path::Path, }; use memmap::{Mmap, MmapOptions}; type LSize = u32; const LEN_SIZE: usize = std::mem::size_of::(); const BINCODE_CFG: bincode::config::Configuration = bincode::config::standard(); trait ErrorToString { type Output; fn str_err(self) -> std::result::Result; } impl ErrorToString for std::result::Result where E: std::error::Error, { type Output = T; fn str_err(self) -> std::result::Result { self.map_err(|e| e.to_string()) } } pub struct WriterOpts { pub compress_lvl: i32, pub data_buf_size: usize, pub out_buf_size: usize, pub current_buf_size: usize, } impl Default for WriterOpts { fn default() -> Self { Self { compress_lvl: 1, data_buf_size: 500 * 1024 * 1024, out_buf_size: 200 * 1024 * 1024, current_buf_size: 100 * 1024, } } } pub struct Writer where T: bincode::Encode, { out: io::BufWriter, data_buf: Cursor>, cur_buf_raw: Cursor>, table: Vec, compress_lvl: i32, _t: PhantomData<*const T>, } impl Writer where T: bincode::Encode, { pub fn new>(path: P, opts: WriterOpts) -> Result { let out = fs::File::create(path).str_err()?; let out = io::BufWriter::with_capacity(opts.out_buf_size, out); let data_buf: Vec = Vec::with_capacity(opts.data_buf_size); let data_buf = Cursor::new(data_buf); let cur_buf_raw: Vec = Vec::with_capacity(opts.current_buf_size); let cur_buf_raw = Cursor::new(cur_buf_raw); let compress_lvl = opts.compress_lvl; let table: Vec = vec![]; Ok(Self { out, data_buf, cur_buf_raw, table, compress_lvl, _t: PhantomData, }) } pub fn push(&mut self, item: T) -> Result<(), String> { let pos: LSize = self.data_buf.position() as LSize; let item_data = bincode::encode_to_vec(item, BINCODE_CFG).str_err()?; let mut zencoder = zstd::stream::raw::Encoder::new(self.compress_lvl).str_err()?; zencoder .set_pledged_src_size(item_data.len() as u64) .str_err()?; self.cur_buf_raw.set_position(0); let mut cur_buf_z = zstd::stream::zio::Writer::new(&mut self.cur_buf_raw, zencoder); cur_buf_z.write_all(&item_data).str_err()?; cur_buf_z.finish().str_err()?; cur_buf_z.flush().str_err()?; self.table.push(pos); let (cur_buf_raw, _) = cur_buf_z.into_inner(); let size = cur_buf_raw.position(); cur_buf_raw.set_position(0); let mut chunk = cur_buf_raw.take(size); io::copy(&mut chunk, &mut self.data_buf).str_err()?; Ok(()) } pub fn load(&mut self, iter: &mut I) -> Result<(), String> where I: Iterator, { let hint = iter.size_hint(); let hint = std::cmp::max(hint.0, hint.1.unwrap_or(0)); if hint > 0 { self.table.reserve(hint); } for item in iter { self.push(item)?; } Ok(()) } pub fn finish(mut self) -> Result<(), String> { // finish tab let pos: LSize = self.data_buf.position() as LSize; self.table.push(pos); // write tab let tab_size = (self.table.len() * LEN_SIZE) as LSize; for pos in self.table { let pos_data = (pos + tab_size).to_le_bytes(); self.out.write_all(&pos_data).str_err()?; } // copy data let data_size = self.data_buf.position(); self.data_buf.set_position(0); let mut data = self.data_buf.take(data_size); io::copy(&mut data, &mut self.out).str_err()?; self.out.flush().str_err()?; Ok(()) } } pub struct Reader where T: bincode::Decode, { mmap: Mmap, count: usize, first_pos: LSize, _t: PhantomData<*const T>, } impl Reader where T: bincode::Decode, { pub fn new>(path: P, _buf_size: usize) -> Result { let file = fs::File::open(path).str_err()?; let mmap = unsafe { MmapOptions::new().map(&file).str_err()? }; // read first pos and records count let first_data: [u8; LEN_SIZE] = mmap[0..LEN_SIZE].try_into().str_err()?; let first_pos = LSize::from_le_bytes(first_data); let tab_len = (first_pos as usize) / LEN_SIZE; let count = tab_len - 1; Ok(Self { mmap, count, first_pos, _t: PhantomData, }) } pub fn len(&self) -> usize { self.count } pub fn get(&mut self, index: usize) -> Result { if index >= self.len() { return Err("index out of range".into()); } let next_pos: usize = (index + 1) * LEN_SIZE; let next_end: usize = next_pos + LEN_SIZE; // read item data pos let data_pos = if 0 == index { self.first_pos } else { let tab_pos: usize = index * LEN_SIZE; let pos_curr_data: [u8; LEN_SIZE] = self.mmap[tab_pos..next_pos].try_into().str_err()?; LSize::from_le_bytes(pos_curr_data) } as usize; // read next item pos let pos_next_data: [u8; LEN_SIZE] = self.mmap[next_pos..next_end].try_into().str_err()?; let data_pos_next = LSize::from_le_bytes(pos_next_data) as usize; // read & unpack item data let reader = io::Cursor::new(self.mmap[data_pos..data_pos_next].as_ref()); let data = zstd::decode_all(reader).str_err()?; // decode item let item: (T, usize) = bincode::decode_from_slice(&data, BINCODE_CFG).str_err()?; Ok(item.0) } pub fn iter(&mut self) -> ReaderIter<'_, T> { ReaderIter::new(self) } } pub struct ReaderIter<'a, T> where T: bincode::Decode, { reader: &'a mut Reader, index: Option, } impl<'a, T> ReaderIter<'a, T> where T: bincode::Decode, { fn new(reader: &'a mut Reader) -> Self { ReaderIter { reader, index: None, } } } impl<'a, T> Iterator for ReaderIter<'a, T> where T: bincode::Decode, { type Item = T; fn next(&mut self) -> Option { if self.index.is_none() && self.reader.len() != 0 { self.index = Some(0); } match self.index { Some(i) if i < self.reader.len() => self.nth(i), _ => None, } } fn nth(&mut self, n: usize) -> Option { if self.reader.len() <= n { return None; } self.index = Some(n + 1); let item = self.reader.get(n); match item { Ok(item) => Some(item), Err(_) => None, } } fn size_hint(&self) -> (usize, Option) { let len = self.reader.len(); if self.index.is_none() { return (len, Some(len)); } let index = self.index.unwrap(); let rem = if len > index + 1 { len - (index + 1) } else { 0 }; (rem, Some(rem)) } fn count(self) -> usize where Self: Sized, { self.reader.len() } } impl<'a, T> ExactSizeIterator for ReaderIter<'a, T> where T: bincode::Decode, { fn len(&self) -> usize { self.reader.len() } } #[cfg(test)] mod test { use super::*; use tempfile::tempdir; #[derive(bincode::Encode, bincode::Decode, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] struct TestData { num: u64, test: String, } fn gen_data(count: usize) -> impl Iterator { (0..count).into_iter().map(|i| TestData { num: i as u64, test: "test".repeat(i), }) } #[test] fn test_write_read() { let dir = tempdir().expect("tempdir"); let tmpfile = dir.path().join("test.tmp"); let opts = WriterOpts { compress_lvl: 1, data_buf_size: 10 * 1024 * 1024, out_buf_size: 10 * 1024 * 1024, current_buf_size: 4096, }; let mut writer: Writer = Writer::new(&tmpfile, opts).expect("new writer"); let items_iter = gen_data(5); let items: Vec = items_iter.collect(); writer.load(&mut items.clone().into_iter()).expect("load"); writer.finish().expect("finish write"); let mut reader: Reader = Reader::new(&tmpfile, 2048).expect("new reader"); assert_eq!(items.len(), reader.len()); for (idx, item) in items.iter().enumerate() { let ritem = reader.get(idx).expect("get"); assert_eq!(*item, ritem); } } #[test] fn test_write_read_iter() { let dir = tempdir().expect("tempdir"); let tmpfile = dir.path().join("test.tmp"); let opts = WriterOpts { compress_lvl: 1, data_buf_size: 10 * 1024 * 1024, out_buf_size: 10 * 1024 * 1024, current_buf_size: 4096, }; let mut writer: Writer = Writer::new(&tmpfile, opts).expect("new writer"); let items_iter = gen_data(10); let items: Vec = items_iter.collect(); writer.load(&mut items.clone().into_iter()).expect("load"); writer.finish().expect("finish write"); let mut reader: Reader = Reader::new(&tmpfile, 2048).expect("new reader"); assert_eq!(items.len(), reader.len()); items.into_iter().zip(reader.iter()).for_each(|pair| { assert_eq!(pair.0, pair.1); }); } }