Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add map and set extract_if #308

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 48 additions & 2 deletions src/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ mod tests;
pub use self::core::raw_entry_v1::{self, RawEntryApiV1};
pub use self::core::{Entry, IndexedEntry, OccupiedEntry, VacantEntry};
pub use self::iter::{
Drain, IntoIter, IntoKeys, IntoValues, Iter, IterMut, IterMut2, Keys, Splice, Values, ValuesMut,
Drain, ExtractIf, IntoIter, IntoKeys, IntoValues, Iter, IterMut, IterMut2, Keys, Splice,
Values, ValuesMut,
};
pub use self::mutable::MutableEntryKey;
pub use self::mutable::MutableKeys;
Expand All @@ -36,7 +37,7 @@ use alloc::vec::Vec;
#[cfg(feature = "std")]
use std::collections::hash_map::RandomState;

use self::core::IndexMapCore;
pub(crate) use self::core::{ExtractCore, IndexMapCore};
use crate::util::{third, try_simplify_range};
use crate::{Bucket, Entries, Equivalent, HashValue, TryReserveError};

Expand Down Expand Up @@ -307,6 +308,51 @@ impl<K, V, S> IndexMap<K, V, S> {
Drain::new(self.core.drain(range))
}

/// Creates an iterator which uses a closure to determine if an element should be removed,
/// for all elements in the given range.
///
/// If the closure returns true, the element is removed from the map and yielded.
/// If the closure returns false, or panics, the element remains in the map and will not be
/// yielded.
///
/// Note that `extract_if` lets you mutate every value in the filter closure, regardless of
/// whether you choose to keep or remove it.
///
/// The range may be any type that implements [`RangeBounds<usize>`],
/// including all of the `std::ops::Range*` types, or even a tuple pair of
/// `Bound` start and end values. To check the entire map, use `RangeFull`
/// like `map.extract_if(.., predicate)`.
///
/// If the returned `ExtractIf` is not exhausted, e.g. because it is dropped without iterating
/// or the iteration short-circuits, then the remaining elements will be retained.
/// Use [`retain`] with a negated predicate if you do not need the returned iterator.
///
/// [`retain`]: IndexMap::retain
///
/// # Examples
///
/// Splitting a map into even and odd keys, reusing the original map:
///
/// ```
/// use indexmap::IndexMap;
///
/// let mut map: IndexMap<i32, i32> = (0..8).map(|x| (x, x)).collect();
/// let extracted: IndexMap<i32, i32> = map.extract_if(.., |k, _v| k % 2 == 0).collect();
///
/// let evens = extracted.keys().copied().collect::<Vec<_>>();
/// let odds = map.keys().copied().collect::<Vec<_>>();
///
/// assert_eq!(evens, vec![0, 2, 4, 6]);
/// assert_eq!(odds, vec![1, 3, 5, 7]);
/// ```
pub fn extract_if<F, R>(&mut self, range: R, pred: F) -> ExtractIf<'_, K, V, F>
where
F: FnMut(&K, &mut V) -> bool,
R: RangeBounds<usize>,
{
ExtractIf::new(&mut self.core, range, pred)
}

/// Splits the collection into two at the given index.
///
/// Returns a newly allocated map containing the elements in the range
Expand Down
3 changes: 3 additions & 0 deletions src/map/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
//! However, we should probably not let this show in the public API or docs.

mod entry;
mod extract;

pub mod raw_entry_v1;

Expand All @@ -25,6 +26,7 @@ type Indices = hash_table::HashTable<usize>;
type Entries<K, V> = Vec<Bucket<K, V>>;

pub use entry::{Entry, IndexedEntry, OccupiedEntry, VacantEntry};
pub(crate) use extract::ExtractCore;

/// Core of the map that does not depend on S
#[derive(Debug)]
Expand Down Expand Up @@ -163,6 +165,7 @@ impl<K, V> IndexMapCore<K, V> {

#[inline]
pub(crate) fn len(&self) -> usize {
debug_assert_eq!(self.entries.len(), self.indices.len());
self.indices.len()
}

Expand Down
107 changes: 107 additions & 0 deletions src/map/core/extract.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
#![allow(unsafe_code)]

use super::{Bucket, IndexMapCore};
use crate::util::simplify_range;

use core::ops::RangeBounds;

impl<K, V> IndexMapCore<K, V> {
pub(crate) fn extract<R>(&mut self, range: R) -> ExtractCore<'_, K, V>
where
R: RangeBounds<usize>,
{
let range = simplify_range(range, self.entries.len());

// SAFETY: We must have consistent lengths to start, so that's a hard assertion.
// Then the worst `set_len` can do is leak items if `ExtractCore` doesn't drop.
assert_eq!(self.entries.len(), self.indices.len());
unsafe {
self.entries.set_len(range.start);
}
ExtractCore {
map: self,
new_len: range.start,
current: range.start,
end: range.end,
}
}
}

pub(crate) struct ExtractCore<'a, K, V> {
map: &'a mut IndexMapCore<K, V>,
new_len: usize,
current: usize,
end: usize,
}

impl<K, V> Drop for ExtractCore<'_, K, V> {
fn drop(&mut self) {
let old_len = self.map.indices.len();
let mut new_len = self.new_len;

debug_assert!(new_len <= self.current);
debug_assert!(self.current <= self.end);
debug_assert!(self.current <= old_len);
debug_assert!(old_len <= self.map.entries.capacity());

// SAFETY: We assume `new_len` and `current` were correctly maintained by the iterator.
// So `entries[new_len..current]` were extracted, but the rest before and after are valid.
unsafe {
if new_len == self.current {
// Nothing was extracted, so any remaining items can be left in place.
new_len = old_len;
} else if self.current < old_len {
// Need to shift the remaining items down.
let tail_len = old_len - self.current;
let base = self.map.entries.as_mut_ptr();
let src = base.add(self.current);
let dest = base.add(new_len);
src.copy_to(dest, tail_len);
new_len += tail_len;
}
self.map.entries.set_len(new_len);
}

if new_len != old_len {
// We don't keep track of *which* items were extracted, so reindex everything.
self.map.rebuild_hash_table();
}
}
}

impl<K, V> ExtractCore<'_, K, V> {
pub(crate) fn extract_if<F>(&mut self, mut pred: F) -> Option<Bucket<K, V>>
where
F: FnMut(&mut Bucket<K, V>) -> bool,
{
debug_assert!(self.end <= self.map.entries.capacity());

let base = self.map.entries.as_mut_ptr();
while self.current < self.end {
// SAFETY: We're maintaining both indices within bounds of the original entries, so
// 0..new_len and current..indices.len() are always valid items for our Drop to keep.
unsafe {
let item = base.add(self.current);
if pred(&mut *item) {
// Extract it!
self.current += 1;
return Some(item.read());
} else {
// Keep it, shifting it down if needed.
if self.new_len != self.current {
debug_assert!(self.new_len < self.current);
let dest = base.add(self.new_len);
item.copy_to_nonoverlapping(dest, 1);
}
self.current += 1;
self.new_len += 1;
}
}
}
None
}

pub(crate) fn remaining(&self) -> usize {
self.end - self.current
}
}
51 changes: 49 additions & 2 deletions src/map/iter.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use super::core::IndexMapCore;
use super::{Bucket, Entries, IndexMap, Slice};
use super::{Bucket, Entries, ExtractCore, IndexMap, IndexMapCore, Slice};

use alloc::vec::{self, Vec};
use core::fmt;
Expand Down Expand Up @@ -774,3 +773,51 @@ where
.finish()
}
}

/// An extracting iterator for `IndexMap`.
///
/// This `struct` is created by [`IndexMap::extract_if()`].
/// See its documentation for more.
pub struct ExtractIf<'a, K, V, F> {
inner: ExtractCore<'a, K, V>,
pred: F,
}

impl<K, V, F> ExtractIf<'_, K, V, F> {
pub(super) fn new<R>(core: &mut IndexMapCore<K, V>, range: R, pred: F) -> ExtractIf<'_, K, V, F>
where
R: RangeBounds<usize>,
F: FnMut(&K, &mut V) -> bool,
{
ExtractIf {
inner: core.extract(range),
pred,
}
}
}

impl<K, V, F> Iterator for ExtractIf<'_, K, V, F>
where
F: FnMut(&K, &mut V) -> bool,
{
type Item = (K, V);

fn next(&mut self) -> Option<Self::Item> {
self.inner
.extract_if(|bucket| {
let (key, value) = bucket.ref_mut();
(self.pred)(key, value)
})
.map(Bucket::key_value)
}

fn size_hint(&self) -> (usize, Option<usize>) {
(0, Some(self.inner.remaining()))
}
}

impl<'a, K, V, F> fmt::Debug for ExtractIf<'a, K, V, F> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ExtractIf").finish_non_exhaustive()
}
}
44 changes: 43 additions & 1 deletion src/set.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ mod slice;
mod tests;

pub use self::iter::{
Difference, Drain, Intersection, IntoIter, Iter, Splice, SymmetricDifference, Union,
Difference, Drain, ExtractIf, Intersection, IntoIter, Iter, Splice, SymmetricDifference, Union,
};
pub use self::mutable::MutableValues;
pub use self::slice::Slice;
Expand Down Expand Up @@ -258,6 +258,48 @@ impl<T, S> IndexSet<T, S> {
Drain::new(self.map.core.drain(range))
}

/// Creates an iterator which uses a closure to determine if a value should be removed,
/// for all values in the given range.
///
/// If the closure returns true, then the value is removed and yielded.
/// If the closure returns false, the value will remain in the list and will not be yielded
/// by the iterator.
///
/// The range may be any type that implements [`RangeBounds<usize>`],
/// including all of the `std::ops::Range*` types, or even a tuple pair of
/// `Bound` start and end values. To check the entire set, use `RangeFull`
/// like `set.extract_if(.., predicate)`.
///
/// If the returned `ExtractIf` is not exhausted, e.g. because it is dropped without iterating
/// or the iteration short-circuits, then the remaining elements will be retained.
/// Use [`retain`] with a negated predicate if you do not need the returned iterator.
///
/// [`retain`]: IndexSet::retain
///
/// # Examples
///
/// Splitting a set into even and odd values, reusing the original set:
///
/// ```
/// use indexmap::IndexSet;
///
/// let mut set: IndexSet<i32> = (0..8).collect();
/// let extracted: IndexSet<i32> = set.extract_if(.., |v| v % 2 == 0).collect();
///
/// let evens = extracted.into_iter().collect::<Vec<_>>();
/// let odds = set.into_iter().collect::<Vec<_>>();
///
/// assert_eq!(evens, vec![0, 2, 4, 6]);
/// assert_eq!(odds, vec![1, 3, 5, 7]);
/// ```
pub fn extract_if<F, R>(&mut self, range: R, pred: F) -> ExtractIf<'_, T, F>
where
F: FnMut(&T) -> bool,
R: RangeBounds<usize>,
{
ExtractIf::new(&mut self.map.core, range, pred)
}

/// Splits the collection into two at the given index.
///
/// Returns a newly allocated set containing the elements in the range
Expand Down
47 changes: 47 additions & 0 deletions src/set/iter.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use crate::map::{ExtractCore, IndexMapCore};

use super::{Bucket, Entries, IndexSet, Slice};

use alloc::vec::{self, Vec};
Expand Down Expand Up @@ -626,3 +628,48 @@ impl<I: fmt::Debug> fmt::Debug for UnitValue<I> {
fmt::Debug::fmt(&self.0, f)
}
}

/// An extracting iterator for `IndexSet`.
///
/// This `struct` is created by [`IndexSet::extract_if()`].
/// See its documentation for more.
pub struct ExtractIf<'a, T, F> {
inner: ExtractCore<'a, T, ()>,
pred: F,
}

impl<T, F> ExtractIf<'_, T, F> {
pub(super) fn new<R>(core: &mut IndexMapCore<T, ()>, range: R, pred: F) -> ExtractIf<'_, T, F>
where
R: RangeBounds<usize>,
F: FnMut(&T) -> bool,
{
ExtractIf {
inner: core.extract(range),
pred,
}
}
}

impl<T, F> Iterator for ExtractIf<'_, T, F>
where
F: FnMut(&T) -> bool,
{
type Item = T;

fn next(&mut self) -> Option<Self::Item> {
self.inner
.extract_if(|bucket| (self.pred)(bucket.key_ref()))
.map(Bucket::key)
}

fn size_hint(&self) -> (usize, Option<usize>) {
(0, Some(self.inner.remaining()))
}
}

impl<'a, T, F> fmt::Debug for ExtractIf<'a, T, F> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ExtractIf").finish_non_exhaustive()
}
}
Loading