diff --git a/lib/src/async_db.rs b/lib/src/async_db.rs index 17ee043..e3d1f81 100644 --- a/lib/src/async_db.rs +++ b/lib/src/async_db.rs @@ -5,8 +5,12 @@ use async_compression::tokio::write::ZstdEncoder; use async_compression::Level; use futures::stream::StreamExt; use futures_core::stream::Stream; +use futures_core::Future; use futures_util::pin_mut; +use std::pin::Pin; +use std::task::{Context, Poll}; + use tokio::{ fs, io::{self, AsyncReadExt, AsyncWriteExt}, @@ -219,12 +223,80 @@ where Ok(item.0) } + + pub fn stream(&self) -> ReaderStream<'_, T> { + ReaderStream::new(self) + } +} + +pub struct ReaderStream<'a, T> +where + T: bincode::Decode, +{ + reader: &'a Reader, + index: Option, +} + +impl<'a, T> ReaderStream<'a, T> +where + T: bincode::Decode, +{ + fn new(reader: &'a Reader) -> Self { + ReaderStream { + reader, + index: None, + } + } +} + +impl<'a, T> Stream for ReaderStream<'a, T> +where + T: bincode::Decode, +{ + type Item = T; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + if self.index.is_none() && !self.reader.is_empty() { + self.index = Some(0); + } + + if self.index.unwrap() == self.reader.len() { + return Poll::Ready(None); + } + + let future = self.reader.get(self.index.unwrap()); + pin_mut!(future); + match Pin::new(&mut future).poll(cx) { + Poll::Ready(Ok(item)) => { + self.index = Some(self.index.unwrap() + 1); + Poll::Ready(Some(item)) + } + Poll::Ready(Err(_)) => Poll::Ready(None), + Poll::Pending => Poll::Pending, + } + } + + fn size_hint(&self) -> (usize, Option) { + let len = self.reader.len(); + if self.index.is_none() { + return (len, Some(len)); + } + + let index = self.index.unwrap(); + let rem = if len > index + 1 { + len - (index + 1) + } else { + 0 + }; + (rem, Some(rem)) + } } #[cfg(test)] mod test { use super::*; use async_stream::stream; + use core::fmt::Debug; use tempfile::tempdir; #[derive(bincode::Encode, bincode::Decode, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] @@ -298,4 +370,36 @@ mod test { assert_eq!(*item, ritem); } } + + #[tokio::test] + async fn test_write_read_stream() { + let dir = tempdir().expect("tempdir"); + let tmpfile = dir.path().join("test.tmp"); + let opts = WriterOpts { + compress_lvl: Level::Default, + data_buf_size: 10 * 1024 * 1024, + out_buf_size: 10 * 1024 * 1024, + }; + let mut writer: Writer = Writer::new(&tmpfile, opts).await.expect("new writer"); + + let items_iter = gen_data(5); + let items: Vec = items_iter.collect(); + + let src = stream_iter(items.clone().into_iter()); + pin_mut!(src); + writer.load(src).await.expect("load"); + writer.finish().await.expect("finish write"); + + let reader: Reader = Reader::new(&tmpfile, 2048).await.expect("new reader"); + assert_eq!(items.len(), reader.len()); + + let dst_stream = reader.stream(); + let src_stream = stream_iter(items.iter()); + + async fn test_values((x, y): (&TestData, TestData)) { + assert_eq!(*x, y); + } + + src_stream.zip(dst_stream).for_each(test_values).await; + } }