use std::{path::Path, sync::Arc}; use async_compression::tokio::bufread::ZstdDecoder; use async_compression::tokio::write::ZstdEncoder; use async_compression::Level; use futures::stream::StreamExt; use futures_core::stream::Stream; use futures_core::Future; use futures_util::pin_mut; use std::pin::Pin; use std::task::{Context, Poll}; use tokio::{ fs, io::{self, AsyncReadExt, AsyncWriteExt}, }; use fmmap::tokio::{AsyncMmapFile, AsyncMmapFileExt, AsyncOptions}; type LSize = u32; const LEN_SIZE: usize = std::mem::size_of::(); const BINCODE_CFG: bincode::config::Configuration = bincode::config::standard(); use crate::util::ErrorToString; pub struct WriterOpts { pub compress_lvl: Level, pub data_buf_size: usize, pub out_buf_size: usize, } impl Default for WriterOpts { fn default() -> Self { Self { compress_lvl: Level::Default, data_buf_size: 500 * 1024 * 1024, out_buf_size: 200 * 1024 * 1024, } } } pub struct Writer where T: bincode::Encode, { out: io::BufWriter, data_buf: Vec, table: Vec, compress_lvl: Level, _t: Option>, // PhantomData replacement } impl Writer where T: bincode::Encode, { pub async fn new>(path: P, opts: WriterOpts) -> Result { let out = fs::File::create(path).await.str_err()?; let out = io::BufWriter::with_capacity(opts.out_buf_size, out); let data_buf: Vec = Vec::with_capacity(opts.data_buf_size); let compress_lvl = opts.compress_lvl; let table: Vec = vec![]; Ok(Self { out, data_buf, table, compress_lvl, _t: None, }) } pub async fn push(&mut self, item: T) -> Result<(), String> { let pos: LSize = self.data_buf.len() as LSize; let item_data = bincode::encode_to_vec(item, BINCODE_CFG).str_err()?; let mut zencoder = ZstdEncoder::with_quality(&mut self.data_buf, self.compress_lvl); zencoder.write_all(&item_data).await.str_err()?; zencoder.flush().await.str_err()?; self.table.push(pos); Ok(()) } pub async fn load(&mut self, source: S) -> Result<(), String> where S: Stream + std::marker::Unpin, { let hint = source.size_hint(); let hint = std::cmp::max(hint.0, hint.1.unwrap_or(0)); if hint > 0 { self.table.reserve(hint); } pin_mut!(source); while let Some(item) = source.next().await { self.push(item).await?; } Ok(()) } pub async fn finish(mut self) -> Result<(), String> { // finish tab let pos: LSize = self.data_buf.len() as LSize; self.table.push(pos); // write tab let tab_size = (self.table.len() * LEN_SIZE) as LSize; for pos in self.table { let pos_data = (pos + tab_size).to_le_bytes(); self.out.write_all(&pos_data).await.str_err()?; } // copy data self.out.write_all(&self.data_buf[..]).await.str_err()?; self.out.flush().await.str_err()?; Ok(()) } } pub struct Reader where T: bincode::Decode, { mmap: AsyncMmapFile, count: usize, first_pos: LSize, _t: Option>, // PhantomData replacement } impl Reader where T: bincode::Decode, { pub async fn new>(path: P, _buf_size: usize) -> Result { let mmap = AsyncOptions::new() .read(true) .open_mmap_file(path) .await .str_err()?; mmap.try_lock_shared().str_err()?; // read first pos and records count let first_data: [u8; LEN_SIZE] = mmap.bytes(0, LEN_SIZE).str_err()?.try_into().str_err()?; let first_pos = LSize::from_le_bytes(first_data); let tab_len = (first_pos as usize) / LEN_SIZE; let count = tab_len - 1; Ok(Self { mmap, count, first_pos, _t: None, }) } pub fn len(&self) -> usize { self.count } pub fn is_empty(&self) -> bool { 0 == self.len() } pub async fn get(&self, index: usize) -> Result { if index >= self.len() { return Err("index out of range".into()); } let next_pos: usize = (index + 1) * LEN_SIZE; // read item data pos let data_pos = if 0 == index { self.first_pos } else { let tab_pos: usize = index * LEN_SIZE; let pos_curr_data: [u8; LEN_SIZE] = self .mmap .bytes(tab_pos, LEN_SIZE) .str_err()? .try_into() .str_err()?; LSize::from_le_bytes(pos_curr_data) } as usize; // read next item pos let pos_next_data: [u8; LEN_SIZE] = self .mmap .bytes(next_pos, LEN_SIZE) .str_err()? .try_into() .str_err()?; let data_pos_next = LSize::from_le_bytes(pos_next_data) as usize; let data_len = data_pos_next - data_pos; // read & unpack item data let mut decoder = ZstdDecoder::new(self.mmap.range_reader(data_pos, data_len).str_err()?); let mut data = Vec::::new(); decoder.read_to_end(&mut data).await.str_err()?; // decode item let item: (T, usize) = bincode::decode_from_slice(&data, BINCODE_CFG).str_err()?; Ok(item.0) } pub fn stream(&self) -> ReaderStream<'_, T> { ReaderStream::new(self) } } pub struct ReaderStream<'a, T> where T: bincode::Decode, { reader: &'a Reader, index: Option, } impl<'a, T> ReaderStream<'a, T> where T: bincode::Decode, { fn new(reader: &'a Reader) -> Self { ReaderStream { reader, index: None, } } } impl<'a, T> Stream for ReaderStream<'a, T> where T: bincode::Decode, { type Item = T; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { if self.index.is_none() && !self.reader.is_empty() { self.index = Some(0); } if self.index.unwrap() == self.reader.len() { return Poll::Ready(None); } let future = self.reader.get(self.index.unwrap()); pin_mut!(future); match Pin::new(&mut future).poll(cx) { Poll::Ready(Ok(item)) => { self.index = Some(self.index.unwrap() + 1); Poll::Ready(Some(item)) } Poll::Ready(Err(_)) => Poll::Ready(None), Poll::Pending => Poll::Pending, } } fn size_hint(&self) -> (usize, Option) { let len = self.reader.len(); if self.index.is_none() { return (len, Some(len)); } let index = self.index.unwrap(); let rem = if len > index + 1 { len - (index + 1) } else { 0 }; (rem, Some(rem)) } } #[cfg(test)] mod test { use super::*; use core::fmt::Debug; use tempfile::tempdir; #[derive(bincode::Encode, bincode::Decode, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] struct TestData { num: u64, test: String, vnum: Vec, vstr: Vec, } fn gen_data(count: usize) -> impl Iterator { (0..count).into_iter().map(|i| TestData { num: i as u64, test: "test".repeat(i), vnum: (0..i * 120).map(|x| (x ^ 0x345FE34) as u64).collect(), vstr: (0..i * 111).map(|x| "test".repeat(x).to_string()).collect(), }) } async fn assert_data_eq((x, y): (&TestData, TestData)) { assert_eq!(*x, y); } #[tokio::test] async fn test_write() { let dir = tempdir().expect("tempdir"); let tmpfile = dir.path().join("test.tmp"); let opts = WriterOpts { compress_lvl: Level::Default, data_buf_size: 10 * 1024 * 1024, out_buf_size: 10 * 1024 * 1024, }; let mut writer: Writer = Writer::new(&tmpfile, opts).await.expect("new writer"); let items_iter = gen_data(5); let items: Vec = items_iter.collect(); let src = futures::stream::iter(items.clone()); pin_mut!(src); writer.load(src).await.expect("load"); writer.finish().await.expect("finish write"); } #[tokio::test] async fn test_write_read() { let dir = tempdir().expect("tempdir"); let tmpfile = dir.path().join("test.tmp"); let opts = WriterOpts { compress_lvl: Level::Default, data_buf_size: 10 * 1024 * 1024, out_buf_size: 10 * 1024 * 1024, }; let mut writer: Writer = Writer::new(&tmpfile, opts).await.expect("new writer"); let items_iter = gen_data(5); let items: Vec = items_iter.collect(); let src = futures::stream::iter(items.clone()); pin_mut!(src); writer.load(src).await.expect("load"); writer.finish().await.expect("finish write"); let reader: Reader = Reader::new(&tmpfile, 2048).await.expect("new reader"); assert_eq!(items.len(), reader.len()); for (idx, item) in items.iter().enumerate() { let ritem = reader.get(idx).await.expect("get"); assert_eq!(*item, ritem); } } #[tokio::test] async fn test_write_read_stream() { let dir = tempdir().expect("tempdir"); let tmpfile = dir.path().join("test.tmp"); let opts = WriterOpts { compress_lvl: Level::Default, data_buf_size: 10 * 1024 * 1024, out_buf_size: 10 * 1024 * 1024, }; let mut writer: Writer = Writer::new(&tmpfile, opts).await.expect("new writer"); let items_iter = gen_data(5); let items: Vec = items_iter.collect(); let src = futures::stream::iter(items.clone()); pin_mut!(src); writer.load(src).await.expect("load"); writer.finish().await.expect("finish write"); let reader: Reader = Reader::new(&tmpfile, 2048).await.expect("new reader"); assert_eq!(items.len(), reader.len()); let dst_stream = reader.stream(); let src_stream = futures::stream::iter(items.iter()); let mut count = 0; src_stream .zip(dst_stream) .map(|x| { count += 1; x }) .for_each(assert_data_eq) .await; assert_eq!(count, items.len()) } /// sharing Reader instance between threads #[tokio::test] async fn test_share_reader() { let dir = tempdir().expect("tempdir"); let tmpfile = dir.path().join("test.tmp"); let opts = WriterOpts { compress_lvl: Level::Default, data_buf_size: 10 * 1024 * 1024, out_buf_size: 10 * 1024 * 1024, }; let mut writer: Writer = Writer::new(&tmpfile, opts).await.expect("new writer"); let items_iter = gen_data(5); let items: Vec = items_iter.collect(); let src = futures::stream::iter(items.clone()); pin_mut!(src); writer.load(src).await.expect("load"); writer.finish().await.expect("finish write"); let reader: Reader = Reader::new(&tmpfile, 2048).await.expect("new reader"); assert_eq!(items.len(), reader.len()); let reader = Arc::new(reader); for _ in 0..=3 { let cur_items = items.clone(); let cur_reader = Arc::clone(&reader); tokio::spawn(async move { let dst_stream = cur_reader.stream(); let src_stream = futures::stream::iter(cur_items.iter()); src_stream.zip(dst_stream).for_each(assert_data_eq).await; }); } } }