diff --git a/src/db.rs b/src/db.rs index a2d92cd..bc5b0b2 100644 --- a/src/db.rs +++ b/src/db.rs @@ -1,10 +1,9 @@ use std::{ fs, - io::{self, Cursor, Read, Seek, SeekFrom, Write}, + io::{self, Cursor, Read, Seek, Write}, marker::PhantomData, path::Path, }; -use zstd::stream::{raw::Operation, zio}; type LSize = u32; const LEN_SIZE: usize = std::mem::size_of::(); @@ -38,44 +37,44 @@ impl Default for WriterOpts { compress_lvl: 1, data_buf_size: 500 * 1024 * 1024, out_buf_size: 200 * 1024 * 1024, - current_buf_size: 20 * 1024, + current_buf_size: 20 * 1024 * 1024, } } } -pub struct Writer<'a, T> +pub struct Writer where T: bincode::Encode, { out: io::BufWriter, data_buf: Cursor>, - cur_buf_z: zio::Writer>, zstd::stream::raw::Encoder<'a>>, table: Vec, + compress_lvl: i32, + current_buf_size: usize, _t: PhantomData<*const T>, } -impl Writer<'_, T> +impl Writer where - T: bincode::Encode, + T: bincode::Encode + std::fmt::Debug, { pub fn new>(path: P, opts: WriterOpts) -> Result { let out = fs::File::create(path).str_err()?; let out = io::BufWriter::with_capacity(opts.out_buf_size, out); let data_buf: Vec = Vec::with_capacity(opts.data_buf_size); - let cur_buf_raw: Vec = Vec::with_capacity(opts.current_buf_size); let data_buf = Cursor::new(data_buf); - let cur_buf_raw = Cursor::new(cur_buf_raw); - - let zencoder = zstd::stream::raw::Encoder::new(opts.compress_lvl).str_err()?; - let cur_buf_z = zstd::stream::zio::Writer::new(cur_buf_raw, zencoder); + let compress_lvl = opts.compress_lvl; + let current_buf_size = opts.current_buf_size; + let table: Vec = vec![]; Ok(Self { out, data_buf, - cur_buf_z, table, + compress_lvl, + current_buf_size, _t: PhantomData, }) } @@ -84,22 +83,26 @@ where let pos: LSize = self.data_buf.position() as LSize; let item_data = bincode::encode_to_vec(item, BINCODE_CFG).str_err()?; - - let zencoder = self.cur_buf_z.operation_mut(); - zencoder.reinit().str_err()?; + + let cur_buf_raw: Vec = Vec::with_capacity(self.current_buf_size); + let cur_buf_raw = Cursor::new(cur_buf_raw); + let mut zencoder = zstd::stream::raw::Encoder::new(self.compress_lvl).str_err()?; zencoder .set_pledged_src_size(item_data.len() as u64) .str_err()?; - self.cur_buf_z.write_all(&item_data).str_err()?; - self.cur_buf_z.finish().str_err()?; - self.cur_buf_z.flush().str_err()?; + let mut cur_buf_z = zstd::stream::zio::Writer::new(cur_buf_raw, zencoder); + cur_buf_z.write_all(&item_data).str_err()?; + cur_buf_z.finish().str_err()?; + cur_buf_z.flush().str_err()?; self.table.push(pos); - let cur_buf_raw = self.cur_buf_z.writer_mut(); + let (mut cur_buf_raw, _) = cur_buf_z.into_inner(); let size = cur_buf_raw.position(); + cur_buf_raw.set_position(0); - io::copy(&mut cur_buf_raw.take(size), &mut self.data_buf).str_err()?; + let mut chunk_reader = cur_buf_raw.take(size); + io::copy(&mut chunk_reader, &mut self.data_buf).str_err()?; Ok(()) } @@ -186,7 +189,9 @@ where } else { let tab_pos: u64 = (index * LEN_SIZE).try_into().str_err()?; let mut pos_curr_data: [u8; LEN_SIZE] = [0; LEN_SIZE]; - self.input.seek(SeekFrom::Start(tab_pos)).str_err()?; + let cur_pos = self.input.stream_position().str_err()? as i64; + self.input.seek_relative((tab_pos as i64) - cur_pos).str_err()?; + self.input.read_exact(&mut pos_curr_data).str_err()?; LSize::from_le_bytes(pos_curr_data) }; @@ -199,8 +204,9 @@ where let data_len = data_pos_next - data_pos; // read & unpack item data + let cur_pos = self.input.stream_position().str_err()? as i64; self.input - .seek(SeekFrom::Start(data_pos as u64)) + .seek_relative((data_pos as i64) - cur_pos) .str_err()?; let reader = self.input.by_ref().take(data_len as u64); let data = zstd::decode_all(reader).str_err()?; @@ -211,3 +217,44 @@ where Ok(item.0) } } + +#[cfg(test)] +mod test { + use rand::RngCore; + + use super::*; + + #[derive(bincode::Encode, bincode::Decode, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] + struct TestData { + num: u64, + } + + fn gen_data(count: usize) -> impl Iterator { + (0..count).into_iter().map(|i| TestData{ num: i as u64}) + } + + #[test] + fn test_write_read() { + let mut rng = rand::thread_rng(); + let tempfile = std::env::temp_dir().with_file_name(format!("test-{}.tmp", rng.next_u32())); + let opts = WriterOpts { compress_lvl: 1, data_buf_size: 10 * 1024 * 1024, out_buf_size: 10 * 1024 * 1024, current_buf_size: 4096 }; + let mut writer: Writer = Writer::new(&tempfile, opts).expect("new writer"); + + let items_iter = gen_data(5); + let items: Vec = items_iter.collect(); + + writer.load(&mut items.clone().into_iter()).expect("load"); + writer.finish().expect("finish write"); + + let mut reader: Reader = Reader::new(&tempfile, 2048).expect("new reader"); + assert_eq!(items.len(), reader.len()); + + for (idx,item) in items.iter().enumerate() { + let ritem = reader.get(idx).expect("get"); + assert_eq!(*item, ritem); + } + + let r = fs::remove_file(tempfile); + drop(r); + } +}