extern crate priority_queue;
use priority_queue::PriorityQueue;
use std::collections::HashMap;
use std::collections::LinkedList;
use std::hash::Hash;
const CHUNK_SIZE:usize = 32;
#[derive(Hash, PartialEq, Eq, Debug)]
enum Node<T:Sized + Eq + Hash> {
	Terminal { value: T },
	Nonterminal { left: Box<Node<T>>, right: Box<Node<T>> },
}
impl<T:Sized + Eq + Hash> Node<T> {
	fn terminal(val: T) -> Self {
		Node::Terminal { value: val }
	}
	fn noneterminal(left: Self, right: Self) -> Self {
		Node::Nonterminal {
			left: Box::new(left),
			right: Box::new(right),
		}
	}
}
fn get_huffmanencoding<'a, T:Sized + Eq + Hash>(
	node: &'a Node<T>,
	code: u32,
    depth:u32,
	mut map: HashMap<&'a T, (u32, u32)>,
) -> HashMap<&'a T, (u32, u32)> {
	match node {
		Node::Nonterminal { left, right } => get_huffmanencoding(
			right.as_ref(),
			1 | code << 1,
            depth + 1,
			get_huffmanencoding(left.as_ref(), 0 | code << 1, depth + 1, map),
		),
		Node::Terminal { value } => {
			map.insert(value, (depth, code));
			map
		}
	}
}
use std::io::{Read, Write};
fn decode(writer:&mut Write, reader:&mut Read,origin_size:usize, table:&HashMap<(u32, u32), [u8; CHUNK_SIZE]>){
	let mut t:Vec< _ > = table.iter().map(|(key, value)|{
		(key.0, key.1, value)
	}).collect();
	t.sort_by_key(|it| it.0);
	let mut byte = [0u8;1];
	let mut buffer:u32 = 0;
	let mut bit_count = 0;
	let mut decompressed_size:usize = 0;
	while let Ok( _ ) = reader.read_exact(&mut byte){
		buffer = buffer << 8 | (byte[0] as u32);
		bit_count += 8;
		//println!("bit count {}", bit_count);
		for (len, bit, data) in &t{
			if *len > bit_count{
				continue;
			}
			if *bit == buffer >> (bit_count - *len){

				writer.write_all(*data);
				decompressed_size += data.len();
				bit_count -= *len;
				buffer = buffer & !(0xFFFFFFFF << bit_count);
			}
		}
	}
	println!("complete {} {}", decompressed_size, origin_size);
	
	while bit_count != 0{
		for (len, bit, data) in &t{
			if *len > bit_count{
				continue;
			}
			if *bit == buffer >> (bit_count - *len){
				decompressed_size += data.len();
				writer.write_all(*data);
				bit_count -= *len;
				buffer = buffer & !(0xFFFFFFFF << bit_count);
			}
		}
		if decompressed_size == origin_size{
			return;
		}
	}
}
use std::ops::Index;
use std::slice::SliceIndex;
struct SliceReader<'a>{
	slice:&'a [u8],
	offset:usize
}
impl<'a> SliceReader<'a>{
	fn new(slice:&'a [u8])->Self{
		Self{
			slice:slice,
			offset:0
		}
	}
}
impl<'a> Read for SliceReader<'a>{
	fn read(&mut self,buffer:&mut [u8])->Result<usize, std::io::Error>{
		//println!("len {}, offset {}, buffer len {}", self.slice.len(), self.offset, buffer.len());
		if self.slice.len() == self.offset{
			return Err(std::io::Error::from(std::io::ErrorKind::Other));
		}
		let read_byte = if self.slice.len() -  self.offset < buffer.len(){
			self.slice.len() -  self.offset
		}else{
			buffer.len()
		};
		let t = &mut buffer[0..read_byte];
		t.copy_from_slice(&self.slice[self.offset .. self.offset + read_byte]);
		self.offset += read_byte;
		return Ok(read_byte);
	}
}
fn main() {
	let mut text: Vec<u8> = include_bytes!("./sample.txt").to_vec();
	
    let origin_size = text.len();
    for _ in 0..(CHUNK_SIZE - text.len() % CHUNK_SIZE){
        text.push(0u8);
    }
    
	let hash_map = {
		let mut map = HashMap::<[u8;CHUNK_SIZE], i32>::new();
		for it in text.chunks_exact(CHUNK_SIZE) {
            let mut chunk = [0u8; CHUNK_SIZE];
            chunk.copy_from_slice(it);
			if !map.contains_key(&chunk) {
                
				map.insert(chunk.clone(), 0);
			}
			*map.get_mut(&chunk).unwrap() += 1;
		}
		map
	};
	let mut queue = PriorityQueue::new();
	for (key, count) in hash_map.into_iter() {
		queue.push(Node::terminal(key), -1 * count);
	}
	let mut rl = true;
	let res = loop {
		let node1 = queue.pop().unwrap();
		let node2 = match queue.pop() {
			None => break node1.0,
			Some(node) => node,
		};
		let new_node = match rl {
			true => (Node::noneterminal(node1.0, node2.0), (node1.1 + node2.1)),
			false => (Node::noneterminal(node2.0, node1.0), (node1.1 + node2.1)),
		};
		queue.push(new_node.0, new_node.1);
		rl = !rl;
	};
	let table = get_huffmanencoding(&res, 0, 0, HashMap::new());
	// for (key, code) in &table {
	// 	print!("{:?} ", key);
    //     for i in 1..=code.0{
    //         print!("{}", (code.1 >> (code.0 - i)) & 1);
    //     }
    //     println!(" len:{}", code.0);
	// }
    
	let mut array = Vec::new();
	let mut byte = 0u8;
	let mut reft_bit = 8;
	let encoded = text.chunks_exact(CHUNK_SIZE)
		.map(|ch|{
            let mut temp = [0u8; CHUNK_SIZE];
            temp.copy_from_slice(ch);
            temp
        }).map(|ch| table.get(&ch).unwrap());
    /*
    let mut bit_count = 0; 
    for it in text.iter()
		.map(|ch| table.get(ch).unwrap()){
        for i in 1..=it.0{
            print!("{}", (it.1 >> (it.0 - i)) & 1);
            bit_count += 1;
            if bit_count == 8{
                print!("\n");
                bit_count = 0;
            }
        }
    }
    
	println!("\n\n");
    */
    let mut compressed_size = 0;
	for it in encoded{
		
		let mut bit_count = it.0;
		//println!("{:b} bit count: {}",it, bit_count);
        compressed_size += it.0;
		while bit_count != 0{
			
			if bit_count >= reft_bit{
				//println!("\nif {}, {}", bit_count, reft_bit);
				byte = byte | (it.1 >> (bit_count - reft_bit)) as u8;
				bit_count -= reft_bit;
				reft_bit = 0;
			}
			else{
				//("\nelse {}, {}", bit_count, reft_bit);
				byte = byte | ((it.1 & (0xFF >>(8 - bit_count))) << (reft_bit - bit_count)) as u8;
				reft_bit -= bit_count;
				bit_count = 0;
			}
			if reft_bit == 0{
				//print!("{:b}", byte);
				array.push(byte);
				byte = 0;
				reft_bit = 8;
				//읽는다.

			}
		}
	}
	if reft_bit != 0{
		array.push(byte);
	}
	println!("origin {} compressed {}", origin_size, compressed_size / 8);
    let mut extract_result = Vec::new();
	let mut decode_table = HashMap::new();
	for (data, (len, key) ) in &table{
		decode_table.insert((*len, *key), *data.clone());
	}
	decode(&mut extract_result, &mut SliceReader::new(array.as_slice()),text.len(),  &decode_table);
	println!("{}", String::from_utf8_lossy(&extract_result[0..origin_size]));

	//array.iter().for_each(|it| println!("{:08b}", it));
	
    let mut res = std::fs::File::create("./hff.compress").unwrap();
    use std::io::Write;
    res.write_all(&unsafe{std::mem::transmute::<_, [u8;4]>(compressed_size)});
    let b= unsafe{std::mem::transmute::<_, [u8;4]>(table.len() as u32)};
    res.write_all(&b);
    for (key, (len, bits)) in table{
        let key =unsafe{ std::mem::transmute::<_, [u8;CHUNK_SIZE]>(key.clone())};
        let len =unsafe{  std::mem::transmute::<_, [u8;1]>(len as u8)};
        let bits =unsafe{  std::mem::transmute::<_, [u8;4]>(bits)};
        res.write_all(&key);
        res.write_all(&len);
        res.write_all(&bits);
    }
    res.write_all(array.as_slice());
    res.flush();
	
}