From feb2303db9318c917eca8e68b334e1fafd9ac5c0 Mon Sep 17 00:00:00 2001 From: Dmitry Date: Wed, 16 Aug 2023 11:44:03 +0300 Subject: [PATCH] async_db: add BufReader, BufReaderStream for using as mutable reader --- lib/src/async_db.rs | 152 ++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 146 insertions(+), 6 deletions(-) diff --git a/lib/src/async_db.rs b/lib/src/async_db.rs index 712f50b..45ec1a7 100644 --- a/lib/src/async_db.rs +++ b/lib/src/async_db.rs @@ -1,4 +1,5 @@ use std::marker::PhantomData; +use std::ops::Deref; use std::vec; use std::{path::Path, sync::Arc}; @@ -224,7 +225,7 @@ impl Reader where T: bincode::Decode, { - pub async fn new>(path: P, _buf_size: usize) -> Result { + pub async fn new>(path: P) -> Result { let mmap = AsyncOptions::new() .read(true) .open_mmap_file(path) @@ -372,6 +373,145 @@ where } } +pub struct BufReader +where + T: bincode::Decode, +{ + inner: Reader, + buf: Vec, +} + +impl BufReader +where + T: bincode::Decode, +{ + pub async fn new>(path: P, buf_size: usize) -> Result { + match Reader::::new(path).await { + Ok(inner) => Ok(Self { + inner, + buf: Vec::with_capacity(buf_size), + }), + Err(e) => Err(e), + } + } + + pub async fn get(&mut self, index: usize) -> Result { + self.inner.get_with_buf(index, &mut self.buf).await + } + + pub fn into_inner(self) -> Reader { + self.inner + } + + pub fn stream(self) -> BufReaderStream { + BufReaderStream::new(self) + } +} + +impl From> for BufReader +where + T: bincode::Decode, +{ + fn from(inner: Reader) -> Self { + Self { + inner, + buf: Vec::new(), + } + } +} + +impl From> for Reader +where + T: bincode::Decode, +{ + fn from(value: BufReader) -> Self { + value.into_inner() + } +} + +impl Deref for BufReader +where + T: bincode::Decode, +{ + type Target = Reader; + fn deref(&self) -> &Self::Target { + &self.inner + } +} + +pub struct BufReaderStream +where + T: bincode::Decode, +{ + reader: BufReader, + index: Option, +} + +impl BufReaderStream +where + T: bincode::Decode, +{ + fn new(reader: BufReader) -> Self { + BufReaderStream { + reader, + index: None, + } + } + + async fn get_next(&mut self) -> Result { + match self.index { + None => Err("index is None".into()), + Some(index) => { + let res = self.reader.get(index).await; + self.index = Some(index + 1); + + res + } + } + } +} + +impl Stream for BufReaderStream +where + T: bincode::Decode, +{ + type Item = T; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + if self.index.is_none() && !self.reader.is_empty() { + self.index = Some(0); + } + + if self.index.unwrap() == self.reader.len() { + return Poll::Ready(None); + } + + // FIXME: mayby work only if reader.get().poll() return Ready immediately + let future = self.get_next(); + pin_mut!(future); + match Pin::new(&mut future).poll(cx) { + Poll::Ready(Ok(item)) => Poll::Ready(Some(item)), + Poll::Ready(Err(_)) => Poll::Ready(None), + Poll::Pending => Poll::Pending, + } + } + + fn size_hint(&self) -> (usize, Option) { + let len = self.reader.len(); + if self.index.is_none() { + return (len, Some(len)); + } + + let index = self.index.unwrap(); + let rem = if len > index + 1 { + len - (index + 1) + } else { + 0 + }; + (rem, Some(rem)) + } +} + #[cfg(test)] mod test { use super::*; @@ -438,7 +578,7 @@ mod test { writer.load(src).await.expect("load"); writer.finish().await.expect("finish write"); - let reader: Reader = Reader::new(&tmpfile, 2048).await.expect("new reader"); + let reader: Reader = Reader::new(&tmpfile).await.expect("new reader"); assert_eq!(items.len(), reader.len()); for (idx, item) in items.iter().enumerate() { @@ -466,7 +606,7 @@ mod test { src.forward(writer.sink()).await.expect("forward"); writer.finish().await.expect("finish write"); - let reader: Reader = Reader::new(&tmpfile, 2048).await.expect("new reader"); + let reader: Reader = Reader::new(&tmpfile).await.expect("new reader"); assert_eq!(items.len(), reader.len()); for (idx, item) in items.iter().enumerate() { @@ -494,7 +634,7 @@ mod test { writer.load(src).await.expect("load"); writer.finish().await.expect("finish write"); - let reader: Reader = Reader::new(&tmpfile, 2048).await.expect("new reader"); + let reader: Reader = Reader::new(&tmpfile).await.expect("new reader"); assert_eq!(items.len(), reader.len()); for (idx, item) in items.iter().enumerate() { @@ -523,7 +663,7 @@ mod test { writer.load(src).await.expect("load"); writer.finish().await.expect("finish write"); - let reader: Reader = Reader::new(&tmpfile, 2048).await.expect("new reader"); + let reader: Reader = Reader::new(&tmpfile).await.expect("new reader"); assert_eq!(items.len(), reader.len()); let dst_stream = reader.stream(); @@ -560,7 +700,7 @@ mod test { writer.load(src).await.expect("load"); writer.finish().await.expect("finish write"); - let reader: Reader = Reader::new(&tmpfile, 2048).await.expect("new reader"); + let reader: Reader = Reader::new(&tmpfile).await.expect("new reader"); assert_eq!(items.len(), reader.len()); let reader = Arc::new(reader);