Skip to content

Commit

Permalink
fix long maches
Browse files Browse the repository at this point in the history
  • Loading branch information
mcroomp committed Nov 22, 2023
1 parent 009dbbf commit e44b2fe
Show file tree
Hide file tree
Showing 9 changed files with 144 additions and 27 deletions.
Binary file added samples/compressed_flate2_level1_longmatch.bin
Binary file not shown.
Binary file added samples/sample2.bin
Binary file not shown.
Binary file added samples/sample3.bin
Binary file not shown.
4 changes: 2 additions & 2 deletions src/complevel_estimator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,8 @@ impl<'a> CompLevelEstimatorState<'a> {
self.info.recommended_compression_level = 9;
self.info.very_far_matches = self.info.longest_dist_at_hop_0
> self.window_size() - preflate_constants::MIN_LOOKAHEAD
&& self.info.longest_dist_at_hop_1_plus
< self.window_size() - preflate_constants::MIN_LOOKAHEAD;
|| self.info.longest_dist_at_hop_1_plus
>= self.window_size() - preflate_constants::MIN_LOOKAHEAD;
self.info.far_len_3_matches = self.info.longest_len_3_dist > 4096;

self.info.zlib_compatible = self.info.possible_compression_levels > 1
Expand Down
93 changes: 81 additions & 12 deletions src/hash_chain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ pub struct HashChain<'a> {
hash_shift: u32,
running_hash: RotatingHash,
hash_mask: u16,
total_shift: u32,
total_shift: i32,
}

#[derive(Default, Debug, Copy, Clone)]
Expand All @@ -108,9 +108,12 @@ impl<'a> HashChain<'a> {
let hash_bits = mem_level + 7;
let hash_mask = ((1u32 << hash_bits) - 1) as u16;

// Important: total_shift starts at -8 since 0 indicates the end of the hash chain
// so this means that all valid values will be >= 8, otherwise the very first hash
// offset would be zero and so it would get missed
let mut hash_chain_ext = HashChain {
input: PreflateInput::new(i),
total_shift: 0,
total_shift: -8,
hash_shift: (hash_bits + MIN_MATCH - 1) / MIN_MATCH,
hash_mask,
hash_table: HashTable::default_boxed(),
Expand Down Expand Up @@ -150,22 +153,88 @@ impl<'a> HashChain<'a> {
}

fn reshift_if_necessary(&mut self) {
if self.input.pos() - self.total_shift >= 0xfd00 {
if self.input.pos() as i32 - self.total_shift >= 0xfe00 {
const DELTA: usize = 0x7e00;
for i in 0..=self.hash_mask as usize {
self.hash_table.head[i] = self.hash_table.head[i].saturating_sub(DELTA as u16);
}

for i in DELTA..(1 << 16) {
for i in DELTA..=65535 {
self.hash_table.prev[i - DELTA] =
self.hash_table.prev[i].saturating_sub(DELTA as u16);
}

self.hash_table.chain_depth.copy_within(DELTA..65536, 0);
self.total_shift += DELTA as u32;
self.hash_table.chain_depth.copy_within(DELTA..=65535, 0);
self.total_shift += DELTA as i32;
}
}

/// construct a hash chain from scratch and verify that we match the existing hash chain
/// used for debugging only
#[allow(dead_code)]
pub fn verify_hash(&self, dist: Option<PreflateTokenReference>) {
let mut hash = RotatingHash::default();
let mut start_pos = self.total_shift as i32;

let mut chains: Vec<Vec<u16>> = Vec::new();
chains.resize(self.hash_mask as usize + 1, Vec::new());

let mut start_delay = 2;

while start_pos - 1 <= self.input.pos() as i32 {
hash = hash.append(
self.input.cur_char(start_pos - self.input.pos() as i32),
self.hash_shift,
);

if start_delay > 0 {
start_delay -= 1;
} else {
chains[hash.hash(self.hash_mask) as usize]
.push((start_pos - 2 - self.total_shift as i32) as u16);
}

start_pos += 1;
}

let distance = dist.map_or(0, |d| d.dist() as i32);

println!(
"MATCH t={:?} a={:?} b={:?} d={}",
dist,
&self.input.cur_chars(-distance)[0..10],
&self.input.cur_chars(0)[0..10],
self.input.pos() - self.total_shift as u32 - distance as u32
);

//println!("MATCH pos = {}, total_shift = {}", self.input.pos(), self.total_shift);
let mut mismatch = false;
for i in 0..=self.hash_mask {
let current_chain = &chains[i as usize];

let mut hash_table_chain = Vec::new();
hash_table_chain.reserve(current_chain.len());

let mut curr_pos = self.hash_table.head[i as usize];
while curr_pos != 0 {
hash_table_chain.push(curr_pos);
curr_pos = self.hash_table.prev[curr_pos as usize];
}
hash_table_chain.reverse();

if hash_table_chain[..] != current_chain[..] {
mismatch = true;
println!(
"HASH {i} MISMATCH a={:?} b={:?}",
hash_table_chain, current_chain
);
}

//assert_eq!(0, chains[i as usize].len());
}
assert!(!mismatch);
}

pub fn get_head(&self, hash: RotatingHash) -> u32 {
self.hash_table.head[hash.hash(self.hash_mask) as usize].into()
}
Expand All @@ -184,7 +253,7 @@ impl<'a> HashChain<'a> {
HashIterator::new(
&self.hash_table.prev,
&self.hash_table.chain_depth,
ref_pos - self.total_shift,
(ref_pos as i32 - self.total_shift) as u32,
max_dist,
head,
)
Expand All @@ -194,9 +263,9 @@ impl<'a> HashChain<'a> {
HashIterator::new(
&self.hash_table.prev,
&self.hash_table.chain_depth,
ref_pos - self.total_shift,
(ref_pos as i32 - self.total_shift) as u32,
max_dist,
pos - self.total_shift,
(pos as i32 - self.total_shift) as u32,
)
}

Expand Down Expand Up @@ -228,7 +297,7 @@ impl<'a> HashChain<'a> {

self.reshift_if_necessary();

let pos = (self.input.pos() - self.total_shift) as u16;
let pos = (self.input.pos() as i32 - self.total_shift) as u16;

let limit = std::cmp::min(length + 2, self.input.remaining()) as u16;

Expand All @@ -251,7 +320,7 @@ impl<'a> HashChain<'a> {
pub fn skip_hash(&mut self, l: u32) {
self.reshift_if_necessary();

let pos = self.input.pos();
let pos = self.input.pos() as i32;

let remaining = self.input.remaining();
if remaining > 2 {
Expand All @@ -268,7 +337,7 @@ impl<'a> HashChain<'a> {
// bad analysis results
// --------------------
for i in 1..l {
let p = (pos + i) - self.total_shift;
let p = (pos + i as i32) - self.total_shift;
self.hash_table.chain_depth[p as usize] = 0xffff8000;
}

Expand Down
20 changes: 14 additions & 6 deletions src/predictor_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ pub enum MatchResult {
Success(PreflateTokenReference),
DistanceLargerThanHop0(u32, u32),
NoInput,
NoMoreMatchesFound(u32),
NoMoreMatchesFound { start_len: u32, last_dist: u32 },
MaxChainExceeded,
}

Expand Down Expand Up @@ -174,10 +174,9 @@ impl<'a> PredictorState<'a> {
let mut best_match: Option<PreflateTokenReference> = None;
let input = self.hash.input().cur_chars(offset as i32);
loop {
let match_start = self
.hash
.input()
.cur_chars(offset as i32 - chain_it.dist() as i32);
let dist = chain_it.dist();

let match_start = self.hash.input().cur_chars(offset as i32 - dist as i32);

let match_length = Self::prefix_compare(match_start, input, best_len, max_len);
if match_length > best_len {
Expand All @@ -195,7 +194,10 @@ impl<'a> PredictorState<'a> {
if let Some(r) = best_match {
return MatchResult::Success(r);
} else {
return MatchResult::NoMoreMatchesFound(match_length);
return MatchResult::NoMoreMatchesFound {
start_len: match_length,
last_dist: dist,
};
}
}

Expand Down Expand Up @@ -303,4 +305,10 @@ impl<'a> PredictorState<'a> {
}
}
}

/// debugging function to verify that the hash chain is correct
#[allow(dead_code)]
pub fn verify_hash(&self, dist: Option<PreflateTokenReference>) {
self.hash.verify_hash(dist);
}
}
9 changes: 9 additions & 0 deletions src/process.rs
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,15 @@ fn do_analyze(crc: Option<u32>, compressed_data: &[u8], verify: bool) {
}
}

#[test]
fn verify_longmatch() {
do_analyze(
None,
&read_file("compressed_flate2_level1_longmatch.bin"),
false,
);
}

#[test]
fn verify_zlib_compressed() {
for i in 0..9 {
Expand Down
26 changes: 21 additions & 5 deletions src/token_predictor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -162,9 +162,13 @@ impl<'a> TokenPredictor<'a> {
CodecMisprediction::LiteralPredictionWrong,
true,
);
self.repredict_reference().with_context(|| {
format!("repredict_reference target={:?} index={}", target_ref, i)
})?
self.repredict_reference(Some(*target_ref))
.with_context(|| {
format!(
"repredict_reference target={:?} index={}",
target_ref, i
)
})?
}
PreflateToken::Reference(r) => {
// we predicted a reference correctly, so verify that the length/dist was correct
Expand Down Expand Up @@ -280,7 +284,7 @@ impl<'a> TokenPredictor<'a> {
continue;
}

predicted_ref = self.repredict_reference().with_context(|| {
predicted_ref = self.repredict_reference(None).with_context(|| {
format!(
"repredict_reference token_count={:?}",
self.current_token_count
Expand Down Expand Up @@ -443,13 +447,22 @@ impl<'a> TokenPredictor<'a> {

/// When the predicted token was a literal, but the actual token was a reference, try again
/// to find a match for the reference.
fn repredict_reference(&mut self) -> anyhow::Result<PreflateTokenReference> {
fn repredict_reference(
&mut self,
dist_match: Option<PreflateTokenReference>,
) -> anyhow::Result<PreflateTokenReference> {
if self.state.current_input_pos() == 0 || self.state.available_input_size() < MIN_MATCH {
return Err(anyhow::Error::msg(
"Not enough space left to find a reference",
));
}

if let Some(x) = dist_match {
if x.dist() == 32653 {
println!("dist_match = {:?}", dist_match);
}
}

let hash = self.state.calculate_hash();
let match_token =
self.state
Expand All @@ -462,6 +475,9 @@ impl<'a> TokenPredictor<'a> {
return Ok(m);
}
}

//self.state.verify_hash(dist_match);

Err(anyhow::Error::msg(format!(
"Didnt find a match {:?}",
match_token
Expand Down
19 changes: 17 additions & 2 deletions tests/end_to_end.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
*--------------------------------------------------------------------------------------------*/

use std::fs::File;
use std::io::{Cursor, Read};
use std::io::{Cursor, Read, Write};
use std::path::Path;

use flate2::{read::ZlibEncoder, Compression};
Expand Down Expand Up @@ -35,7 +35,17 @@ fn end_to_end_compressed() {
}

#[test]
fn test_wrong() {
fn test_matchnotfound() {
test_file("sample3.bin");
}

#[test]
fn test_nomatch() {
test_file("sample2.bin");
}

#[test]
fn test_sample1() {
test_file("sample1.bin");
}

Expand Down Expand Up @@ -86,6 +96,11 @@ fn test_file(filename: &str) {
// skip header and final crc
let minusheader = &output[2..output.len() - 4];

// write to file
let mut f =
File::create(format!("c:\\temp\\compressed_flate2_level{}.bin", level)).unwrap();
f.write_all(minusheader).unwrap();

verifyresult(minusheader);
}
}

0 comments on commit e44b2fe

Please sign in to comment.