use std::{
    fs,
    io::{self, Cursor, Read, Seek, SeekFrom, Write},
    marker::PhantomData,
    path::Path,
};
use zstd::stream::{raw::Operation, zio};

type LSize = u32;
const LEN_SIZE: usize = std::mem::size_of::<LSize>();
const BINCODE_CFG: bincode::config::Configuration = bincode::config::standard();

trait ErrorToString {
    type Output;
    fn str_err(self) -> std::result::Result<Self::Output, String>;
}

impl<T, E> ErrorToString for std::result::Result<T, E>
where
    E: std::error::Error,
{
    type Output = T;
    fn str_err(self) -> std::result::Result<Self::Output, String> {
        self.map_err(|e| e.to_string())
    }
}

pub struct WriterOpts {
    pub compress_lvl: i32,
    pub data_buf_size: usize,
    pub out_buf_size: usize,
    pub current_buf_size: usize,
}

impl Default for WriterOpts {
    fn default() -> Self {
        Self {
            compress_lvl: 1,
            data_buf_size: 500 * 1024 * 1024,
            out_buf_size: 200 * 1024 * 1024,
            current_buf_size: 20 * 1024,
        }
    }
}

pub struct Writer<'a, T>
where
    T: bincode::Encode,
{
    out: io::BufWriter<fs::File>,
    data_buf: Cursor<Vec<u8>>,
    cur_buf_z: zio::Writer<Cursor<Vec<u8>>, zstd::stream::raw::Encoder<'a>>,
    table: Vec<LSize>,
    _t: PhantomData<*const T>,
}

impl<T> Writer<'_, T>
where
    T: bincode::Encode,
{
    pub fn new<P: AsRef<Path>>(path: P, opts: WriterOpts) -> Result<Self, String> {
        let out = fs::File::create(path).str_err()?;
        let out = io::BufWriter::with_capacity(opts.out_buf_size, out);
        let data_buf: Vec<u8> = Vec::with_capacity(opts.data_buf_size);
        let cur_buf_raw: Vec<u8> = Vec::with_capacity(opts.current_buf_size);
        let data_buf = Cursor::new(data_buf);
        let cur_buf_raw = Cursor::new(cur_buf_raw);

        let zencoder = zstd::stream::raw::Encoder::new(opts.compress_lvl).str_err()?;
        let cur_buf_z = zstd::stream::zio::Writer::new(cur_buf_raw, zencoder);

        let table: Vec<LSize> = vec![];

        Ok(Self {
            out,
            data_buf,
            cur_buf_z,
            table,
            _t: PhantomData,
        })
    }

    pub fn push(&mut self, item: T) -> Result<(), String> {
        let pos: LSize = self.data_buf.position() as LSize;

        let item_data = bincode::encode_to_vec(item, BINCODE_CFG).str_err()?;

        let zencoder = self.cur_buf_z.operation_mut();
        zencoder.reinit().str_err()?;
        zencoder
            .set_pledged_src_size(item_data.len() as u64)
            .str_err()?;

        self.cur_buf_z.write_all(&item_data).str_err()?;
        self.cur_buf_z.finish().str_err()?;
        self.cur_buf_z.flush().str_err()?;

        self.table.push(pos);
        let cur_buf_raw = self.cur_buf_z.writer_mut();
        let size = cur_buf_raw.position();
        cur_buf_raw.set_position(0);
        io::copy(&mut cur_buf_raw.take(size), &mut self.data_buf).str_err()?;

        Ok(())
    }

    pub fn load<I>(&mut self, iter: &mut I) -> Result<(), String>
    where
        I: Iterator<Item = T>,
    {
        for item in iter {
            self.push(item)?;
        }

        Ok(())
    }

    pub fn finish(mut self) -> Result<(), String> {
        // finish tab
        let pos: LSize = self.data_buf.position() as LSize;
        self.table.push(pos);

        // write tab
        let tab_size = (self.table.len() * LEN_SIZE) as LSize;
        for pos in self.table {
            let pos_data = (pos + tab_size).to_le_bytes();
            self.out.write_all(&pos_data).str_err()?;
        }

        // copy data
        let data_size = self.data_buf.position();
        self.data_buf.set_position(0);
        let mut data = self.data_buf.take(data_size);
        io::copy(&mut data, &mut self.out).str_err()?;

        self.out.flush().str_err()?;
        Ok(())
    }
}

pub struct Reader<T>
where
    T: bincode::Decode,
{
    input: io::BufReader<fs::File>,
    count: usize,
    first_pos: LSize,
    _t: PhantomData<*const T>,
}

impl<T> Reader<T>
where
    T: bincode::Decode,
{
    pub fn new<P: AsRef<Path>>(path: P, buf_size: usize) -> Result<Self, String> {
        let input = fs::File::open(path).str_err()?;
        let mut input = io::BufReader::with_capacity(buf_size, input);

        // read first pos and records count
        let mut first_data: [u8; LEN_SIZE] = [0; LEN_SIZE];
        input.read_exact(&mut first_data).str_err()?;
        let first_pos = LSize::from_le_bytes(first_data);
        let tab_len = (first_pos as usize) / LEN_SIZE;
        let count = tab_len - 1;

        Ok(Self {
            input,
            count,
            first_pos,
            _t: PhantomData,
        })
    }

    pub fn len(&self) -> usize {
        self.count
    }

    pub fn get(&mut self, index: usize) -> Result<T, String> {
        if index >= self.len() {
            return Err("index out of range".into());
        }

        // read item data pos
        let data_pos = if 0 == index {
            self.first_pos
        } else {
            let tab_pos: u64 = (index * LEN_SIZE).try_into().str_err()?;
            let mut pos_curr_data: [u8; LEN_SIZE] = [0; LEN_SIZE];
            self.input.seek(SeekFrom::Start(tab_pos)).str_err()?;
            self.input.read_exact(&mut pos_curr_data).str_err()?;
            LSize::from_le_bytes(pos_curr_data)
        };

        // read next item pos
        let mut pos_next_data: [u8; LEN_SIZE] = [0; LEN_SIZE];
        self.input.read_exact(&mut pos_next_data).str_err()?;
        let data_pos_next = LSize::from_le_bytes(pos_next_data);
        // calc item data length
        let data_len = data_pos_next - data_pos;

        // read & unpack item data
        self.input
            .seek(SeekFrom::Start(data_pos as u64))
            .str_err()?;
        let reader = self.input.by_ref().take(data_len as u64);
        let data = zstd::decode_all(reader).str_err()?;

        // decode item
        let item: (T, usize) = bincode::decode_from_slice(&data, BINCODE_CFG).str_err()?;

        Ok(item.0)
    }
}