From 31c02ae8efb2066e7b95e3ee6cc225937ac1ac15 Mon Sep 17 00:00:00 2001 From: Dmitry Date: Sun, 13 Aug 2023 18:16:49 +0300 Subject: [PATCH] async_db: add WriterSink --- app_async/src/main.rs | 12 +++---- lib/src/async_db.rs | 83 ++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 83 insertions(+), 12 deletions(-) diff --git a/app_async/src/main.rs b/app_async/src/main.rs index 17a0316..9049829 100644 --- a/app_async/src/main.rs +++ b/app_async/src/main.rs @@ -178,15 +178,11 @@ async fn db_writer_task(rx: UnboundedReceiver) { let mut writer: async_db::Writer = async_db::Writer::new(NEW_DB_FILENAME, writer_opts) .await - .expect("new db writer"); - - let mut stream: UnboundedReceiverStream<_> = rx.into(); - - writer - .load(&mut stream) - .await - .unwrap_or_else(|e| panic!("db writer load, {e:#?}")); + .unwrap_or_else(|e| panic!("db writer load, {e:#?}")); + let stream: UnboundedReceiverStream<_> = rx.into(); + let stream = stream.map(Ok); + stream.forward(writer.sink()).await.expect("forward"); writer.finish().await.expect("db writer finish"); println!("write done"); diff --git a/lib/src/async_db.rs b/lib/src/async_db.rs index c74dbf3..ca9c09b 100644 --- a/lib/src/async_db.rs +++ b/lib/src/async_db.rs @@ -4,6 +4,7 @@ use std::{path::Path, sync::Arc}; use async_compression::tokio::bufread::ZstdDecoder; use async_compression::tokio::bufread::ZstdEncoder; use async_compression::Level; +use futures::sink::Sink; use futures::stream::StreamExt; use futures_core::stream::Stream; use futures_core::Future; @@ -124,6 +125,80 @@ where self.out.flush().await.str_err()?; Ok(()) } + + pub fn sink(&mut self) -> WriterSink<'_, T> { + WriterSink { + writer: self, + item: None, + } + } +} + +use pin_project::pin_project; + +#[pin_project] +pub struct WriterSink<'a, T> +where + T: bincode::Encode, +{ + #[pin] + writer: &'a mut Writer, + item: Option, +} + +impl<'a, T> Sink for WriterSink<'a, T> +where + T: bincode::Encode, +{ + type Error = String; + + fn poll_ready( + self: std::pin::Pin<&mut Self>, + ctx: &mut std::task::Context<'_>, + ) -> Poll> { + let mut this = self.project(); + + if this.item.is_none() { + return Poll::Ready(Ok(())); + } + + let item = this.item.take().unwrap(); + + let fut = this.writer.push(item); + pin_mut!(fut); + match fut.poll(ctx) { + Poll::Ready(Ok(_)) => { + *this.item = None; + Poll::Ready(Ok(())) + } + Poll::Ready(Err(e)) => { + *this.item = None; + Poll::Ready(Err(e)) + } + Poll::Pending => Poll::Pending, + } + } + + fn start_send(self: std::pin::Pin<&mut Self>, item: T) -> Result<(), Self::Error> { + let this = self.project(); + *this.item = Some(item); + Ok(()) + } + + fn poll_flush( + self: std::pin::Pin<&mut Self>, + ctx: &mut std::task::Context<'_>, + ) -> Poll> { + self.poll_ready(ctx) + } + + fn poll_close( + mut self: std::pin::Pin<&mut Self>, + ctx: &mut std::task::Context<'_>, + ) -> Poll> { + futures::ready!(self.as_mut().poll_ready(ctx))?; + Poll::Ready(Ok(())) + } } pub struct Reader @@ -369,9 +444,9 @@ mod test { let items_iter = gen_data(5); let items: Vec = items_iter.collect(); - let src = futures::stream::iter(items.clone()); + let src = futures::stream::iter(items.clone()).map(Ok); pin_mut!(src); - writer.load(src).await.expect("load"); + src.forward(writer.sink()).await.expect("forward"); writer.finish().await.expect("finish write"); let reader: Reader = Reader::new(&tmpfile, 2048).await.expect("new reader"); @@ -406,9 +481,9 @@ mod test { let items_iter = gen_data(5); let items: Vec = items_iter.collect(); - let src = futures::stream::iter(items.clone()); + let src = futures::stream::iter(items.clone()).map(Ok); pin_mut!(src); - writer.load(src).await.expect("load"); + src.forward(writer.sink()).await.expect("forward"); writer.finish().await.expect("finish write"); let reader: Reader = Reader::new(&tmpfile, 2048).await.expect("new reader");