Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create Cache class for exact, fuzzy, and semantic deduplication #384

Open
wants to merge 34 commits into
base: main
Choose a base branch
from
Open
Changes from 30 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
769e2ea
add global cache variable and use it for exact dedup
sarahyurick Nov 19, 2024
b77139c
global cache for semdedup
sarahyurick Nov 19, 2024
337cec8
run black and modify pytest
sarahyurick Nov 19, 2024
6d55d8c
update image notebook
sarahyurick Nov 20, 2024
622912b
Merge branch 'main' into global_cache_dir
sarahyurick Nov 20, 2024
4cb26d5
save fuzzy dedup progress
sarahyurick Nov 20, 2024
b001622
save progress
sarahyurick Nov 20, 2024
0c14626
update remaining docs
sarahyurick Nov 20, 2024
7486459
run black
sarahyurick Nov 20, 2024
053f312
Merge branch 'main' into global_cache_dir
sarahyurick Dec 6, 2024
1b1ba30
Merge branch 'main' into global_cache_dir
sarahyurick Dec 11, 2024
4b12651
Merge branch 'main' into global_cache_dir
sarahyurick Dec 17, 2024
4160471
Merge branch 'main' into global_cache_dir
sarahyurick Dec 20, 2024
8a22ace
Merge branch 'main' into global_cache_dir
sarahyurick Dec 23, 2024
5e9bef1
Merge branch 'main' into global_cache_dir
sarahyurick Jan 3, 2025
d823a0b
Merge remote-tracking branch 'upstream/main' into global_cache_dir
sarahyurick Jan 21, 2025
0890fb0
re-add get_cache_directory changes
sarahyurick Jan 21, 2025
8fd79fb
create Cache singleton class
sarahyurick Jan 21, 2025
0d7b969
update exact_dedup
sarahyurick Jan 22, 2025
2c1a435
add semdedup functionality with Cache
sarahyurick Jan 22, 2025
f0ff2ce
add semdedup_example script
sarahyurick Jan 22, 2025
a379893
Cache singleton option for fuzzy dedup
sarahyurick Jan 23, 2025
67f609c
run black
sarahyurick Jan 23, 2025
8693177
fix tutorials
sarahyurick Jan 23, 2025
c296cc7
Merge branch 'main' into global_cache_dir
sarahyurick Jan 29, 2025
510347c
Merge branch 'main' into global_cache_dir
sarahyurick Feb 18, 2025
0635ebf
run black
sarahyurick Feb 18, 2025
a229857
import assert_eq
sarahyurick Feb 18, 2025
30ec409
fix semdedup test
sarahyurick Feb 19, 2025
1a63468
Merge branch 'main' into global_cache_dir
sarahyurick Feb 20, 2025
2075588
Merge branch 'main' into global_cache_dir
sarahyurick Feb 25, 2025
a6c5de3
remove repeating param
sarahyurick Feb 25, 2025
b805ce9
Merge remote-tracking branch 'upstream/main' into global_cache_dir
sarahyurick Feb 28, 2025
2ee3547
fix semdedup test
sarahyurick Feb 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions config/fuzzy_dedup_config.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
cache_dir: "./fuzzy_dedup_cache"

# Optional Params below with default values
# profile_dir: null
# id_field: "id"
2 changes: 0 additions & 2 deletions config/sem_dedup_config.yaml
Original file line number Diff line number Diff line change
@@ -11,13 +11,11 @@ write_embeddings_to_disk: true
# Clustering configuration
clustering_save_loc: "clustering_results"
n_clusters: 1000
seed: 1234
max_iter: 100
kmeans_with_cos_dist: false

# Semdedup configuration
which_to_keep: "hard"
largest_cluster_size_to_process: 100000
sim_metric: "cosine"

# Extract dedup configuration
13 changes: 7 additions & 6 deletions docs/user-guide/semdedup.rst
Original file line number Diff line number Diff line change
@@ -50,13 +50,11 @@ Semantic deduplication in NeMo Curator can be configured using a YAML file. Here
# Clustering configuration
clustering_save_loc: "clustering_results"
n_clusters: 1000
seed: 1234
max_iter: 100
kmeans_with_cos_dist: false
# Semdedup configuration
which_to_keep: "hard"
largest_cluster_size_to_process: 100000
sim_metric: "cosine"
# Extract dedup configuration
@@ -170,7 +168,8 @@ Use Individual Components
embedding_creator = EmbeddingCreator(
embedding_model_name_or_path="path/to/pretrained/model",
embedding_batch_size=128,
embedding_output_dir="path/to/output/embeddings",
cache_dir="path/to/output",
embeddings_save_loc="embeddings",
input_column="text",
logger="path/to/log/dir",
)
@@ -188,7 +187,8 @@ Use Individual Components
id_column="doc_id",
max_iter=100,
n_clusters=50000,
clustering_output_dir="path/to/output/clusters",
cache_dir="path/to/output",
clustering_save_loc="clustering_results",
logger="path/to/log/dir"
)
clustered_dataset = clustering_model(embeddings_dataset)
@@ -202,12 +202,13 @@ Use Individual Components
# Step 3: Semantic Deduplication
semantic_dedup = SemanticClusterLevelDedup(
n_clusters=50000,
emb_by_clust_dir="path/to/embeddings/by/cluster",
sorted_clusters_dir="path/to/sorted/clusters",
id_column="doc_id",
id_column_type="str",
which_to_keep="hard",
output_dir="path/to/output/deduped",
# cache_dir and clustering_save_loc should match ClusteringModel
cache_dir="path/to/output",
clustering_save_loc="clustering_results",
logger="path/to/log/dir"
)
semantic_dedup.compute_semantic_match_dfs()
1 change: 0 additions & 1 deletion examples/exact_deduplication.py
Original file line number Diff line number Diff line change
@@ -64,7 +64,6 @@ def main(args):
if isinstance(duplicates, str):
duplicates = DocumentDataset.read_parquet(duplicates, backend=backend)

# It's easy to apply dataframe operations to the dataset by using the underlying df.
result = exact_dup.remove(input_dataset, duplicates)
write_to_disk(result, output_dir, output_type="parquet")
print(time.time() - t0)
6 changes: 3 additions & 3 deletions examples/fuzzy_deduplication.py
Original file line number Diff line number Diff line change
@@ -38,7 +38,7 @@ def main(args):

filetype = "parquet"

# Fuzzy dup calculation only supports the cuDF/GPU backend
# Fuzzy deduplication only supports the cuDF/GPU backend
backend = "cudf"
assert args.device == "gpu"

@@ -89,12 +89,12 @@ def main(args):

if duplicates is None:
print("No duplicates found")
print(f"Time taken:{time.time() - t0}s")
print(f"Time taken: {time.time() - t0}s")
return

result = fuzzy_dup.remove(input_dataset, duplicates)
write_to_disk(result, output_dir, output_type=filetype)
print(f"Time taken:{time.time() - t0}s")
print(f"Time taken: {time.time() - t0}s")


def attach_args(
7 changes: 7 additions & 0 deletions examples/semdedup_example.py
Original file line number Diff line number Diff line change
@@ -49,23 +49,30 @@ def main(args):
log_level=logging.INFO,
stdout=True,
)

st = time.time()

input_files = get_all_files_paths_under(
root=args.input_data_dir,
)

if semdedup_config.num_files > 0:
input_files = input_files[: semdedup_config.num_files]

logger.info(f"Processing {len(input_files)} files")

ddf = read_data(
input_files=input_files,
file_type=args.input_file_type,
add_filename=False,
backend="cudf",
)
dataset = DocumentDataset(ddf)

semdup = SemDedup(semdedup_config, logger=logger)
dedup_ids = semdup(dataset)
print(dedup_ids.df.head())

logger.info(f"Time taken: {time.time() - st}")
client.cancel(client.futures, force=True)
client.close()
48 changes: 48 additions & 0 deletions nemo_curator/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from nemo_curator.utils.file_utils import expand_outdir_and_mkdir


class Cache:
_instance = None
_cache_dir = None

def __new__(cls, cache_dir=None):
if cls._instance is None:
cls._instance = super(Cache, cls).__new__(cls)
if cache_dir is not None:
cls._cache_dir = expand_outdir_and_mkdir(cache_dir)
else:
cls._cache_dir = None
elif cache_dir is not None and cls._cache_dir is None:
cls._cache_dir = expand_outdir_and_mkdir(cache_dir)
return cls._instance

@classmethod
def get_cache_directory(cls) -> str:
"""
Retrieve the cache directory.
"""
return cls._cache_dir

@classmethod
def delete_cache_instance(cls):
"""
Reset the Cache singleton.
"""
if cls._cache_dir is not None:
cls._cache_dir = None

cls._instance = None
166 changes: 104 additions & 62 deletions nemo_curator/modules/config.py
Original file line number Diff line number Diff line change
@@ -18,6 +18,8 @@

import yaml

from nemo_curator.cache import Cache


@dataclass
class BaseConfig:
@@ -31,44 +33,49 @@ def from_yaml(cls, file_path: str):
@dataclass
class FuzzyDuplicatesConfig(BaseConfig):
"""
Configuration for MinHash based fuzzy duplicates detection.
Configuration for MinHash-based fuzzy duplicates detection.
Parameters
----------
seed: Seed for minhash permutations
char_ngrams: Size of Char ngram shingles used in minhash computation
num_buckets: Number of Bands or buckets to use during Locality Sensitive Hashing
hashes_per_bucket: Number of hashes per bucket/band.
cache_dir: If specified, directory to store deduplication intermediates, such as
minhashes, buckets, etc. If None, we check if a cache_dir has been initialized
with Cache().get_cache_directory(). Default is None.
profile_dir: If specified, directory to write Dask profile. Default is None.
id_field: Column in the dataset denoting document ID. Default is "id".
text_field: Column in the dataset denoting document content. Default is "text".
perform_removal: Boolean value to specify whether calling the module should remove
the duplicates from the original dataset, or return the list of IDs denoating
duplicates. Default is False.
seed: Seed for minhash permutations. Default is 42.
char_ngrams: Size of character n-gram shingles used in minhash computation.
Default is 5.
num_buckets: Number of bands or buckets to use during Locality Sensitive Hashing.
Default is 20.
hashes_per_bucket: Number of hashes per bucket/band. Default is 13.
use_64_bit_hash: Whether to use a 32bit or 64bit hash function for minhashing.
buckets_per_shuffle: Number of bands/buckets to shuffle concurrently.
Larger values process larger batches by processing multiple bands
but might lead to memory pressures and related errors.
id_field: Column in the Dataset denoting document ID.
text_field: Column in the Dataset denoting document content.
perform_removal: Boolean value to specify whether calling the module should remove the duplicates from
the original dataset, or return the list of IDs denoting duplicates.
profile_dir: str, Default None
If specified directory to write dask profile
cache_dir: str, Default None
Location to store deduplcation intermediates such as minhashes/buckets etc.
false_positive_check: bool,
Whether to run a check to look for false positives within buckets.
Note: This is a computationally expensive step.
num_anchors: int
Number of documents per bucket to use as reference for computing jaccard
pairs within that bucket to identify false positives.
jaccard_threshold: float
The Jaccard similariy threshold to consider a document a near duplicate
during false positive evaluations.
Default is False.
buckets_per_shuffle: Number of bands/buckets to shuffle concurrently. Larger values
process larger batches by processing multiple bands but might lead to memory
pressures and related errors. Default is 1.
false_positive_check: Whether to run a check to look for false positives within
buckets. Note: This is a computationally expensive step. Default is False.
num_anchors: Number of documents per bucket to use as reference for computing
Jaccard pairs within that bucket to identify false positives. Default is 2.
jaccard_threshold: The Jaccard similariy threshold to consider a document a near
duplicate during false positive evaluations. Default is 0.8.
bucket_mapping_blocksize: Default is 256.
parts_per_worker: Default is 1.
bucket_parts_per_worker: Default is 8.
"""

# General config
cache_dir: str
cache_dir: Optional[str] = None
profile_dir: Optional[str] = None
id_field: str = "id"
text_field: str = "text"
perform_removal: bool = False

# Minhash + LSH Config
# Minhash + LSH config
seed: int = 42
char_ngrams: int = 24
num_buckets: int = 20
@@ -86,53 +93,72 @@ class FuzzyDuplicatesConfig(BaseConfig):

def __post_init__(self):
self.num_hashes = self.num_buckets * self.hashes_per_bucket

false_positive_defaults = {
"num_anchors": 2,
"jaccard_threshold": 0.8,
"bucket_mapping_blocksize": 256,
"parts_per_worker": 1,
"bucket_parts_per_worker": 8,
}

if self.false_positive_check:
warnings.warn(
"Identifying false positives during the Minhash deduplication is computationally expensive."
" For improved performance consider setting this to False"
"Identifying false positives during Minhash deduplication is "
"computationally expensive. For improved performance consider setting "
"this to False."
)

for arg, default in false_positive_defaults.items():
if getattr(self, arg) is None:
setattr(self, arg, default)

if self.num_anchors <= 0:
raise ValueError("Number of anchors must be greater than 0")
raise ValueError("Number of anchors must be greater than 0.")

if self.num_anchors > 2:
warnings.warn(
"Using a higher number of anchor docs might lead to higher memory footprint and might impact performance",
"Using a higher number of anchor documents might lead to higher memory "
"footprint and might impact performance.",
category=UserWarning,
)

if not 0 <= self.jaccard_threshold <= 1:
raise ValueError("Jaccard Threshold must be between [0,1]")
raise ValueError("Jaccard threshold must be between [0, 1].")

else:
if self.char_ngrams < 20:
warnings.warn(
"Using a small char_ngrams value might lead to a large number (~5%) of false positives during deduplication."
" Using a value of at least 20 for char_ngrams is recommended."
)

unused_false_positive_args = [
arg
for arg in false_positive_defaults.keys()
if getattr(self, arg) is not None
]

if unused_false_positive_args:
warnings.warn(
f"False positive check is disabled. Unused arguments {unused_false_positive_args} will be ignored",
f"False positive check is disabled. Unused arguments {unused_false_positive_args} will be ignored.",
category=UserWarning,
)

if self.cache_dir is None:
raise ValueError(
"Finding fuzzy duplicates requires a cache directory accessible via all workers to store intermediates"
)
if not 1 <= self.buckets_per_shuffle <= self.num_buckets:
raise ValueError("Buckets per shuffle must be between [1, num_buckets]")
raise ValueError("Buckets per shuffle must be between [1, num_buckets].")

if self.cache_dir is None:
cache_dir = Cache().get_cache_directory()
if cache_dir is None:
raise ValueError(
"Finding fuzzy duplicates requires a cache directory accessible via "
"all workers to store intermediates. Please use "
"Cache(cache_dir=...) or FuzzyDuplicatesConfig(cache_dir=...) to "
"set the cache directory."
)
else:
self.cache_dir = cache_dir

if not self.perform_removal:
warnings.warn(
@@ -146,30 +172,40 @@ class SemDedupConfig(BaseConfig):
Configuration for Semantic Deduplication.
Attributes:
cache_dir (str): Directory to store cache.
profile_dir (Optional[str]): If specified directory to write dask profile. Default is None.
cache_dir (str): Directory to store cache.
cache_dir (Optional[str]): If specified, directory to store cache.
If None, we check if a cache_dir has been initialized with Cache().get_cache_directory().
Default is None.
profile_dir (Optional[str]): If specified, directory to write Dask profile.
Default is None.
num_files (int): Number of files. Default is -1, meaning all files.
embeddings_save_loc (str): Location to save embeddings.
Default is "embeddings".
embedding_model_name_or_path (str): Model name or path for embeddings.
embedding_batch_size (int): Inital Batch size for processing embeddings.
embedding_pooling_strategy (str): Strategy for pooling embeddings, either "mean_pooling" or "last_token". Defaults to "mean_pooling".
write_embeddings_to_disk (bool): If True, saves the embeddings to disk, defaults to True.
Default is "sentence-transformers/all-MiniLM-L6-v2".
embedding_batch_size (int): Initial batch size for processing embeddings.
Default is 128.
embedding_pooling_strategy (str): Strategy for pooling embeddings, either
"mean_pooling" or "last_token". Default is "mean_pooling".
write_embeddings_to_disk (bool): If True, saves the embeddings to disk.
We recommend setting this to False when you have a delayed pipeline.
Setting it to False can lead to more memory overhead.
Setting it to False can lead to more memory overhead. Default is True.
clustering_save_loc (str): Location to save clustering results.
n_clusters (int): Number of clusters.
seed (int): Seed for clustering.
max_iter (int): Maximum iterations for clustering.
kmeans_with_cos_dist (bool): Use KMeans with cosine distance.
which_to_keep (str): Which duplicates to keep.
largest_cluster_size_to_process (int): Largest cluster size to process.
Default is "clustering_results".
n_clusters (int): Number of clusters. Default is 1000.
max_iter (int): Maximum iterations for clustering. Default is 100.
kmeans_with_cos_dist (bool): Whether or not to use KMeans with cosine distance.
Default is False.
which_to_keep (str): Method to determine which duplicates to keep.
Default is "hard".
sim_metric (str): Similarity metric for deduplication.
eps_thresholds (List[float]): Epsilon thresholds to calculate if semantically similar or not.
Default is "cosine".
eps_thresholds (List[float]): Epsilon thresholds to calculate if semantically
similar or not.
eps_to_extract (float): Epsilon value to extract deduplicated data.
Default is 0.1.
"""

cache_dir: str
cache_dir: str = None
profile_dir: Optional[str] = None
num_files: int = -1

@@ -181,29 +217,35 @@ class SemDedupConfig(BaseConfig):
embedding_pooling_strategy: str = "mean_pooling"
write_embeddings_to_disk: bool = True

# Clustering config
# ClusteringModel
clustering_save_loc: str = "clustering_results"
n_clusters: int = 1000
seed: int = 1234
max_iter: int = 100
kmeans_with_cos_dist: bool = False

# Semdedup config
# SemanticClusterLevelDedup
which_to_keep: str = "hard"
largest_cluster_size_to_process: int = 100000
sim_metric: str = "cosine"

# Extract dedup config
# SemDedup
eps_thresholds: List[float] = field(default_factory=lambda: [0.01, 0.001])
eps_to_extract: float = 0.01

def __post_init__(self):
if self.cache_dir is None:
raise ValueError(
"Finding sem-dedup requires a cache directory accessible via all workers to store intermediates"
)
cache_dir = Cache().get_cache_directory()
if cache_dir is None:
raise ValueError(
"Finding semantic duplicates requires a cache directory accessible "
"via all workers to store intermediates. Please use "
"Cache(cache_dir=...) or SemDedupConfig(cache_dir=...) to "
"set the cache directory."
)
else:
self.cache_dir = cache_dir

if self.eps_to_extract not in self.eps_thresholds:
raise ValueError(
f"Epsilon to extract {self.eps_to_extract} must be in eps_thresholds {self.eps_thresholds}"
f"Epsilon to extract {self.eps_to_extract} must be in eps_thresholds "
f"{self.eps_thresholds}."
)
73 changes: 51 additions & 22 deletions nemo_curator/modules/exact_dedup.py
Original file line number Diff line number Diff line change
@@ -26,6 +26,7 @@
from dask import dataframe as dd

from nemo_curator._compat import DASK_P2P_ERROR
from nemo_curator.cache import Cache
from nemo_curator.datasets import DocumentDataset
from nemo_curator.log import create_logger
from nemo_curator.modules.base import BaseModule
@@ -53,13 +54,17 @@ def __init__(
Parameters
----------
logger: Existing logger to log to, or a path to a log directory.
id_field: Column in the Dataset denoting document ID.
text_field: Column in the Dataset denoting document content.
hash_method: The hashing algorithm used for identifying exact duplicates. Currently supports {"md5"}
profile_dir: str, Default None
If specified directory to write dask profile
cache_dir: str, Default None
If specified, will compute & write duplicate id's to cache directory.
id_field: Column in the dataset denoting document ID.
text_field: Column in the dataset denoting document content.
hash_method: The hashing algorithm used for identifying exact duplicates.
Currently only supports "md5".
perform_removal: Boolean value to specify whether calling the module should
remove the duplicates from the original dataset, or return the list of IDs
denoting duplicates.
profile_dir: If specified, directory to write Dask profile. Default is None.
cache_dir: If specified, will compute and write duplicate IDs to cache directory.
If None, we check if a cache_dir has been initialized with Cache().get_cache_directory().
Default is None.
"""
super().__init__(input_backend="any")

@@ -71,19 +76,29 @@ def __init__(
self.hash_method = hash_method
self.id_field = id_field
self.text_field = text_field

if cache_dir is None:
self.cache_dir = Cache().get_cache_directory()
else:
self.cache_dir = cache_dir

if self.cache_dir is None and profile_dir is not None:
warnings.warn(
"cache_dir for intermediate outputs is required to generate profiles. "
"Please initialize with Cache(cache_dir=...) or ExactDuplicates(cache_dir=...)"
)
self.profile_dir = profile_dir

self.perform_removal = perform_removal
if not self.perform_removal:
warnings.warn(
"In future releases (starting with 0.8.0) the default will be True."
"In future NeMo Curator releases, the default value for perform_removal will be True."
)
if self.perform_removal and cache_dir is None:
warnings.warn("cache_dir is recommended to remove duplicates.")
if cache_dir is None and profile_dir is not None:
if self.perform_removal and self.cache_dir is None:
warnings.warn(
"cache_dir for intermediate outputs is required to generate profiles"
"cache_dir is recommended to remove duplicates. "
"Please initialize with Cache(cache_dir=...) or ExactDuplicates(cache_dir=...)"
)
self.cache_dir = cache_dir
self.profile_dir = profile_dir

if isinstance(logger, str):
self._logger = create_logger(
@@ -96,7 +111,8 @@ def __init__(

def _exact_dup_ids(self, df: dd.DataFrame):
"""
Get the id's for text/documents that are exact duplicates
Get the IDs for text/documents that are exact duplicates.
Parameters
----------
df: dask.dataframe.DataFrame
@@ -105,33 +121,39 @@ def _exact_dup_ids(self, df: dd.DataFrame):
* A unique ID column for each document
"""
hash_df = self._compute_hashes(df)

shuffle_context = (
config.set({"dataframe.shuffle.method": "tasks"})
if DASK_P2P_ERROR
else nullcontext()
)

with shuffle_context:
dup_ids = hash_df.shuffle(
on=["_hashes"],
ignore_index=True,
npartitions=max(1, (hash_df.npartitions // 3)),
).map_partitions(lambda x: x[x["_hashes"].duplicated(keep=False)])

return dup_ids

def _compute_hashes(
self,
df: dd.DataFrame,
) -> dd.DataFrame:
"""
Computes the hash of the text_column provided and returns a dataframe
containing the id_column and relevant hashes in the _hashes column.
Computes the hash of the text field provided and returns a DataFrame
containing the ID field and relevant hashes in the _hashes column.
"""
self._logger.info("Starting lazy hash generation")

res = df[[self.id_field]]
res["_hashes"] = df[self.text_field].map_partitions(self.hash_documents)

self._logger.info(
f"Lazy hash generation complete for {res.npartitions} partitions"
)

return res

def hash_documents(
@@ -142,41 +164,48 @@ def hash_documents(
"""
if is_cudf_type(df):
return df.hash_values(method=self.hash_method)

elif isinstance(df, pd.Series):
# TODO: Generalize ty using self.hash_method
# TODO: Generalize by using self.hash_method
return df.apply(lambda x: md5(x.encode()).hexdigest())

def identify_duplicates(self, dataset: DocumentDataset) -> DocumentDataset:
"""
Find document ID's for exact duplicates in a given DocumentDataset
Find document IDs for exact duplicates in a given DocumentDataset.
Parameters
----------
dataset: DocumentDataset
The input datset to find exact duplicates
Returns
-------
DocumentDataset containing ID's and hashes of all duplicate documents
DocumentDataset containing IDs and hashes of all duplicate documents
"""
result = self._exact_dup_ids(df=dataset.df)

if self.cache_dir is None:
return DocumentDataset(result)

t0 = time.time()
self._logger.info("Starting execution for ExactDedup")
self._logger.info("Starting execution for ExactDuplicates")
write_path = os.path.join(self.cache_dir, "_exact_duplicates.parquet")

if os.path.exists(write_path):
warnings.warn(
f"Output path f{write_path} already exists and will be overwritten"
)

with performance_report_if_with_ts_suffix(
self.profile_dir,
"exact-dedup-profile",
):
result.to_parquet(write_path, write_index=False, overwrite=True)

self._logger.info(
f"Time taken for Exact Dedup Computation = {time.time() - t0}s and output written at {write_path}"
f"Time taken for ExactDuplicates computation = {time.time() - t0}s \n"
f"Output written at {write_path}"
)

backend = "cudf" if is_cudf_type(result) else "pandas"
return DocumentDataset.read_parquet(
write_path,
70 changes: 42 additions & 28 deletions nemo_curator/modules/fuzzy_dedup/_mapbuckets.py
Original file line number Diff line number Diff line change
@@ -32,12 +32,13 @@

class _MapBuckets:
"""
buckets to a logical partition by using a modified bin packing algorithm.
Buckets to a logical partition by using a modified bin packing algorithm.
Combines buckets generated from LSH (typically high cardinality)
to more coarse lower cardinality bucket groups by mapping multiple buckets
to a logical partition using document length information and a modified bin
packing algorithm.
Only needed if running False Postive check to remove false positives.
Only needed if running false positive check to remove false positives.
"""

def __init__(
@@ -49,17 +50,18 @@ def __init__(
logger: Union[logging.LoggerAdapter, str] = "./",
):
"""
id_fields: list or str
id fields of df
text_field: str = "text",
bucket_column: str = "bucket_column",
num_anchors: int = 2,
logger: Union[logging.LoggerAdapter, str] = "./",
id_fields (list or str): ID fields of DataFrame. Default is "id".
text_field (str): text field of DataFrame. Default is "text".
bucket_field (str): Default is "_bucket_id".
num_anchors (int): Default is 2.
logger (Union[logging.LoggerAdapter, str]): Default is "./".
"""

self.id_fields = [id_fields] if isinstance(id_fields, str) else id_fields
self.text_field = text_field
self.num_anchors = num_anchors
self.bucket_field = bucket_field

if isinstance(logger, str):
self._logger = create_logger(
rank=0,
@@ -78,12 +80,13 @@ def _get_output_part_ids_with_approx_equal_sum(
output_partition_column: str,
) -> cudf.DataFrame:
"""
Create a output_series that maps the ser.index into `nparts`
Create an output_series that maps the ser.index into `nparts`
so that the total sum of bucket_val_counts_df
for each output id are all most equal and
less than max_text_bytes_per_part
This is used downstream for creating equal output_ids
for each output ID are almost equal and
less than max_text_bytes_per_part.
This is used downstream for creating equal output_ids.
"""

sizes = bucket_text_bytes_df[bytes_column].values
bucket_output_ar = build_partition(
sizes=sizes.get(), max_size=max_text_bytes_per_part
@@ -104,8 +107,8 @@ def _get_output_map_from_text_bytes_per_bucket(
max_text_bytes_per_part = int(np.iinfo(np.int32).max * 3)

self._logger.info(f"max_text_bytes_per_part = {max_text_bytes_per_part}")
# Increasing in an attempt to prevent hitting
# ulimits

# Increasing in an attempt to prevent hitting ulimits
output_map_df_meta = cudf.DataFrame(
{self.bucket_field: [0], output_partition_column: [1]}
)
@@ -122,9 +125,11 @@ def _get_output_map_from_text_bytes_per_bucket(
meta=output_map_df_meta,
)
output_map_df = output_map_df.persist()

self._logger.info(
f"Step 1 of output_map_df of len: {len(output_map_df)} computed"
f"Step 1 of output_map_df of length {len(output_map_df)} computed"
)

lower_bounds = (
output_map_df[output_partition_column]
.map_partitions(lambda s: (s.max() + 1))
@@ -145,9 +150,11 @@ def update_id(df, lower_bound):
updated_parts.append(output_map_df.get_partition(0))
output_map_df = dask_cudf.concat(updated_parts)
output_map_df = output_map_df.persist()

self._logger.info(
f"All steps of output_map_df of len: {len(output_map_df)} computed"
f"All steps of output_map_df of length {len(output_map_df)} computed"
)

return output_map_df

def _get_output_map_based_on_str_bytes(
@@ -156,6 +163,7 @@ def _get_output_map_based_on_str_bytes(
"""
Add output_partition_id to buckets_ddf
"""

documents_df = documents_df.copy()
documents_df[bytes_column] = documents_df[self.text_field].map_partitions(
lambda s: s.str.byte_count()
@@ -168,14 +176,17 @@ def _get_output_map_based_on_str_bytes(
npartitions=n_partitions
)
del documents_df

ddf_bk_text_bytes, agg_df_len = get_agg_text_bytes_df(
df=buckets_df,
agg_column=self.bucket_field,
bytes_column=bytes_column,
n_partitions=n_partitions,
shuffle=True,
)
self._logger.info(f"Agg_df computed of length = {agg_df_len}")

self._logger.info(f"agg_df of length {agg_df_len} computed")

del buckets_df
output_map_df = self._get_output_map_from_text_bytes_per_bucket(
ddf_bk_text_bytes=ddf_bk_text_bytes,
@@ -185,17 +196,20 @@ def _get_output_map_based_on_str_bytes(

def _random_select_anchor(self, buckets_df, n=2):
"""
Randomly select `n` anchors from each bucket.
Randomly select n anchors from each bucket.
"""

buckets_df = buckets_df.copy()
buckets_df["_id_hash"] = buckets_df[self.id_fields].hash_values()
buckets_df = buckets_df.sort_values([self.bucket_field, "_id_hash"])
buckets_df["_order_in_bucket"] = buckets_df.groupby(
self.bucket_field
).cumcount()
buckets_df["is_anchor"] = buckets_df["_order_in_bucket"] < n

for i in range(0, n):
buckets_df[f"is_anchor_id_{i}"] = buckets_df["_order_in_bucket"] == i

buckets_df = buckets_df.drop(columns=["_id_hash", "_order_in_bucket"], axis=1)
buckets_df = buckets_df.reset_index(drop=True)
buckets_df = buckets_df[buckets_df.is_anchor]
@@ -205,14 +219,17 @@ def _add_anchor_docs(self, buckets_df, num_anchors):
"""
Get anchor documents for each bucket.
"""

df_anchor_bk = self._random_select_anchor(buckets_df=buckets_df, n=num_anchors)
df_anchor_docs = None

for i in range(num_anchors):
df_anchor_bk_i = df_anchor_bk[df_anchor_bk[f"is_anchor_id_{i}"]][
[self.bucket_field] + self.id_fields
].reset_index(drop=True)
column_mapping = {id: f"anchor_{i}_{id}" for id in self.id_fields}
df_anchor_bk_i = df_anchor_bk_i.rename(columns=column_mapping)

if i == 0:
df_anchor_docs = df_anchor_bk_i
else:
@@ -232,17 +249,10 @@ def map_buckets_with_anchors(
shuffle_type: Union[str, bool, None] = "tasks",
) -> dask_cudf.DataFrame:
"""
Get anchor docs with bucket info
Args:
input_data_paths: list of paths to input data
input_bucket_path: path to input buckets
text_ddf_blocksize: blocksize for text ddf
num_files: number of files to read
num_workers: number of workers
shuffle_type: type of shuffle to use
Returns:
ddf_anchor_docs_with_bk
Get anchor documents with bucket information.
"""

output_map_df = self._get_output_map_based_on_str_bytes(
buckets_df=buckets_df, documents_df=documents_df
)
@@ -253,10 +263,12 @@ def map_buckets_with_anchors(
ddf_anchor_docs_with_bk = ddf_anchor_docs_with_bk.merge(
output_map_df, on=self.bucket_field
)

# Bucket is no longer needed
ddf_anchor_docs_with_bk = ddf_anchor_docs_with_bk.drop(
columns=[self.bucket_field]
)

# Below removes any duplicates lying around after dropping buckets
ddf_anchor_docs_with_bk = ddf_anchor_docs_with_bk.map_partitions(
M.drop_duplicates,
@@ -265,6 +277,7 @@ def map_buckets_with_anchors(
transform_divisions=False,
align_dataframes=False,
)

ddf_anchor_docs_with_bk = ddf_anchor_docs_with_bk.shuffle(
self.id_fields,
ignore_index=True,
@@ -276,5 +289,6 @@ def map_buckets_with_anchors(
transform_divisions=False,
align_dataframes=False,
)

del output_map_df
return ddf_anchor_docs_with_bk
37 changes: 22 additions & 15 deletions nemo_curator/modules/fuzzy_dedup/_shuffle.py
Original file line number Diff line number Diff line change
@@ -76,21 +76,24 @@ def shuffle_docs_on_buckets(
bucket_parts_per_worker: int = 8,
partition_on: str = "_output_partition_id",
):

ddf_anchor_docs_with_bk, bk_mapping = aggregated_anchor_docs_with_bk_read(
path=bucket_w_anchors_path,
blocksize=bucket_mapping_df_blocksize,
)
self._logger.info("Getting ddf_anchor_docs_with_bk completed")

self._logger.info("Computing ddf_anchor_docs_with_bk completed")
self._logger.debug(
f"ddf_anchor_docs_with_bk.npartitions = {ddf_anchor_docs_with_bk.npartitions}"
)

st = time.time()

num_workers = get_num_workers(get_current_client())
parts_per_batch = num_workers * parts_per_worker
self._logger.debug(f"parts_per_batch = {parts_per_batch}")

self._logger.debug(f"parts_per_batch = {parts_per_batch}")
parts_per_bucket_batch = num_workers * bucket_parts_per_worker
self._logger.debug(f"parts_per_bucket_batch = {parts_per_bucket_batch}")
self._logger.debug(f"parts_per_bucket_batch = {parts_per_bucket_batch}")

dask_profile_name = (
"suffle_docs"
@@ -111,8 +114,9 @@ def shuffle_docs_on_buckets(
bk_mapping=bk_mapping,
num_workers=num_workers,
)

self._logger.info(
f"Time taken for Shuffle = {time.time()-st}s and output written at {output_shuffled_docs_path}"
f"Time taken for _Shuffle: {time.time()-st}s and output written at {output_shuffled_docs_path}"
)

def _batched_merge_and_write(
@@ -145,10 +149,9 @@ def _batched_merge_and_write(
)

# Set end offsets
# NOTE: These end offsets are always set to the end
# of the data. However, we may want to be able to set
# both the start and end offsets from the command line
# in the future.
# NOTE: These end offsets are always set to the end of the data.
# However, we may want to be able to set both the start and end offsets from
# the command line in the future.
bucket_part_end_offset = total_bucket_partitions
text_part_end_offset = total_text_partitions

@@ -158,7 +161,6 @@ def _batched_merge_and_write(
assert text_part_end_offset > text_part_start_offset

# Initialize "retry" variables
#
# - retry_count: The number of successive batches that
# we have already performed at a reduced batch size.
# - retry_threshold: The number of successive batches
@@ -179,11 +181,11 @@ def _batched_merge_and_write(
bucket_part_start_offset, bucket_part_end_offset, parts_per_bucket_batch
)
):

# Outer loop over batches of "bucket-map" partitions
end_bucket_offset = min(
bucket_part_offset + parts_per_bucket_batch, bucket_part_end_offset
)

print(
f"\nStarted processing bucket-map partitions {bucket_part_offset} "
f"through {end_bucket_offset} of {bucket_part_end_offset}",
@@ -207,13 +209,13 @@ def _batched_merge_and_write(

text_part_offset = text_part_start_offset
while text_part_offset < text_part_end_offset:

# Check if we are "retrying" with a smaller "parts_per_text_batch"
if parts_per_text_batch_retry:
parts_per_text_batch_use = parts_per_text_batch_retry
else:
st_text = time.time()
parts_per_text_batch_use = parts_per_text_batch

print(f"Using {parts_per_text_batch_use} text partitions.", flush=True)

# Select partitions for our text batch
@@ -234,7 +236,9 @@ def _batched_merge_and_write(
output_df = output_df.map_partitions(
int_ids_to_str, id_column=self.int_to_str_id
)

batch_label = f"{end_bucket_offset}_{end_text_offset}"

if output_df is not None:
written_files = output_df.map_partitions(
write_partitioned_file,
@@ -244,13 +248,14 @@ def _batched_merge_and_write(
meta=cudf.Series([True]),
)
written_files = written_files.compute()

update_restart_offsets(output_path, bucket_part_offset, end_text_offset)
del output_df

print(
"Text-df partition ",
"text-df partition ",
f"{end_text_offset}/{text_part_end_offset} "
f"completed in {time.time()-st_text}",
f"completed in {time.time()-st_text}s",
flush=True,
)

@@ -268,13 +273,15 @@ def _batched_merge_and_write(
# case we fail again
parts_per_text_batch_retry = None
retry_count, retry_threshold = 0, min(retry_threshold * 2, 16)

text_part_offset += parts_per_text_batch_use

update_restart_offsets(output_path, end_bucket_offset, end_text_offset)

print(
"Bucket partition ",
f"{end_bucket_offset}/{bucket_part_end_offset} "
f"completed in {time.time()-st_bucket}",
f"completed in {time.time()-st_bucket}s",
flush=True,
)

52 changes: 33 additions & 19 deletions nemo_curator/modules/fuzzy_dedup/bucketstoedges.py
Original file line number Diff line number Diff line change
@@ -27,21 +27,21 @@
import pandas as pd
import pyarrow as pa

from nemo_curator.cache import Cache
from nemo_curator.datasets import DocumentDataset
from nemo_curator.log import create_logger
from nemo_curator.utils.distributed_utils import performance_report_if_with_ts_suffix


class BucketsToEdges:
"""
Maps buckets generated from LSH into an edgelist that
can be processed further by Connected Components to find duplicate
documents
Maps buckets generated from LSH into an edgelist that can be processed further by
Connected Components to find duplicate documents.
"""

def __init__(
self,
cache_dir: str = None,
cache_dir: Optional[str] = None,
id_fields: Union[list, str] = "id",
str_id_name: str = "id",
bucket_field: str = "_bucket_id",
@@ -51,24 +51,29 @@ def __init__(
"""
Parameters
----------
cache_dir: str or None
If specified, will compute & write the edgelist to a file
id_fields: list or str
id fields of documents in buckets_df
str_id_name: str
Ignored if there is a single id field. Multiple id fields
will be combined into a single id field with the given name.
bucket_field: str
Column denoting bucket ID
num_buckets: Number of bands/buckets to create from the minhash signature.
Hashes_per_signature = num_hashes / num_buckets
cache_dir: Directory to compute and write edgelist. Can also be set with
Cache(cache_dir=...). Default is None.
id_fields: List or string representing column(s) in buckets_df denoting
document ID. Default is "id".
str_id_name: Ignored if there is a single ID field. Multiple ID fields will be
combined into a single ID field with the given name. Default is "id".
bucket_field: Column denoting bucket ID. Default is "_bucket_id".
logger: Existing logger to log to, or a path to a log directory.
Default is "./".
profile_dir: If specified, directory to write Dask profile. Default is None.
"""
self.cache_dir = cache_dir

if cache_dir is None:
self.cache_dir = Cache().get_cache_directory()
else:
self.cache_dir = cache_dir

self.id_fields = [id_fields] if isinstance(id_fields, str) else id_fields
self.str_id_name = str_id_name if len(self.id_fields) > 1 else self.id_fields[0]
self.output_ids = [f"{self.str_id_name}_x", f"{self.str_id_name}_y"]
self.bucket_field = bucket_field
self.profile_dir = profile_dir

if isinstance(logger, str):
self._logger = create_logger(
rank=0,
@@ -84,12 +89,13 @@ def _combine_multiple_ids(
) -> cudf.DataFrame:
if output_id_field in input_df.columns:
raise ValueError(
f"Input df already contains column named: {output_id_field}"
f"Input DataFrame already contains column named {output_id_field}"
)

output_df = input_df.copy()[input_df.columns.difference(input_id_fields)]

output_df[output_id_field] = input_df[input_id_fields[0]].astype(str)

for input_field in input_id_fields[1:]:
output_df[output_id_field] = output_df[output_id_field] = (
input_df[input_id_fields[0]].astype(str)
@@ -109,23 +115,29 @@ def buckets_to_edges(
.agg(list)
.list.sort_values()
)

bucket_docs = grouped_buckets.to_arrow().to_pylist()
edges = []

# Create pairs of all documents within a bucket since they are near duplicates
# Effectively create a edge list of all near duplicate documents
for bucket_doc in bucket_docs:
edges.extend(pairwise(bucket_doc))

edges = pd.DataFrame(edges, columns=self.output_ids)
edges = pa.Table.from_pandas(edges)
result_df = cudf.DataFrame.from_arrow(edges)
del edges

result_df = result_df.drop_duplicates(self.output_ids).reset_index(drop=True)
result_df["jaccard"] = np.float32(1.0)
return result_df

def __call__(self, dataset: DocumentDataset) -> DocumentDataset:
buckets_df = dataset.df
self._logger.info(f"Starting conversion of LSH Buckets to Graph Edgelist")

self._logger.info(f"Starting conversion of LSH buckets to graph edgelist")

if len(self.id_fields) > 1:
buckets_df = buckets_df.map_partitions(
BucketsToEdges._combine_multiple_ids,
@@ -145,14 +157,16 @@ def __call__(self, dataset: DocumentDataset) -> DocumentDataset:
warnings.warn(
f"Output path {write_path} already exists and will be overwritten"
)

t0 = time.time()
with performance_report_if_with_ts_suffix(
self.profile_dir,
"bucket-to-edges",
):
edges_df.to_parquet(write_path, write_index=False, overwrite=True)

self._logger.info(
f"Time taken for Converted Buckets To Edgelist = {time.time() - t0}s and output written at {write_path}"
f"Time taken for converted buckets to edgelist: {time.time() - t0}s and output written at {write_path}"
)

return DocumentDataset(
67 changes: 51 additions & 16 deletions nemo_curator/modules/fuzzy_dedup/connectedcomponents.py
Original file line number Diff line number Diff line change
@@ -28,27 +28,39 @@
from cugraph import MultiGraph
from dask.utils import M

from nemo_curator.cache import Cache
from nemo_curator.log import create_logger
from nemo_curator.utils.distributed_utils import performance_report_if_with_ts_suffix


class ConnectedComponents:
def __init__(
self,
cache_dir: str,
jaccard_pairs_path: str,
cache_dir: Optional[str] = None,
id_column="id",
jaccard_threshold: float = 0.8,
logger: Union[logging.LoggerAdapter, str] = "./",
profile_dir: Optional[str] = None,
):
self.cache_dir = cache_dir
self.jaccard_pairs_path = jaccard_pairs_path

if cache_dir is None:
self.cache_dir = Cache().get_cache_directory()
else:
self.cache_dir = cache_dir
if self.cache_dir is None:
raise ValueError(
"cache_dir is required for Connected Components. Please initialize with "
"Cache(cache_dir=...) or ConnectedComponents(cache_dir=...)"
)

self.id_column = id_column
self.left_id = f"{id_column}_x"
self.right_id = f"{id_column}_y"
self.jaccard_threshold = jaccard_threshold
self.profile_dir = profile_dir

if isinstance(logger, str):
self._logger = create_logger(
rank=0,
@@ -60,12 +72,15 @@ def __init__(

def cc_workflow(self, output_path):
deduped_parsed_id_path = self._write_dedup_parsed_id()

encoded_jaccard_pair_path = self._write_encoded_jaccard_pair(
deduped_parsed_id_path
)

deduped_encoded_jaccard_path = self._write_dedup_encoded_jaccard_pair(
encoded_jaccard_pair_path
)

cc_path = self._run_connected_components(
deduped_encoded_jaccard_path, deduped_parsed_id_path, output_path
)
@@ -81,7 +96,6 @@ def _run_connected_components(
with performance_report_if_with_ts_suffix(
self.profile_dir, "connected-components-run"
):

Comms.initialize(p2p=False)
df = dask_cudf.read_parquet(
deduped_encoded_jaccard_path, blocksize="1GB", aggregate_files=True
@@ -102,6 +116,7 @@ def _run_connected_components(
)
result = dcg.weakly_connected_components(G)
del G

max_partitions = min(32, result.npartitions)
n_components = len(
result[["labels"]].drop_duplicates(split_out=max_partitions)
@@ -110,6 +125,7 @@ def _run_connected_components(
labels_df = labels_df.merge(
result, left_on=["uid"], right_on=["vertex"], how="inner"
)

id_columns = [self.id_column]
labels_df = labels_df[id_columns + ["labels"]]
labels_df = labels_df.rename(columns={"labels": "group"})
@@ -119,27 +135,31 @@ def _run_connected_components(

self._logger.info(
"Result of connected compoinents are "
f"# of groups : {n_components}, "
f"# of docs removed : {num_labels - n_components}, "
f"# nodes = {num_nodes}, "
f"# rows in labels_df = {len(labels_df)}"
f"# of groups: {n_components}, "
f"# of documents removed: {num_labels - n_components}, "
f"# nodes: {num_nodes}, "
f"# rows in labels_df: {len(labels_df)}"
)
assert num_nodes == len(labels_df)
# Ensure all docs in the same group are in the same partition

# Ensure all documents in the same group are in the same partition
labels_df = labels_df.shuffle(on=["group"], ignore_index=True)
labels_df.to_parquet(output_path, write_index=False, overwrite=True)
Comms.destroy()

self._logger.info(
f"Time taken for Connected Components Run = {time.time() - t0}s and output written at {output_path}"
f"Time taken for connected components: {time.time() - t0}s and output written at {output_path}"
)

@staticmethod
def _sort_ids(df, id_columns):
x = df[id_columns].values
x = cp.sort(x, axis=1)

for i, id_column in enumerate(id_columns):
df[id_column] = x[:, i]
df[id_column] = df[id_column].astype("uint64")

return df

@staticmethod
@@ -152,10 +172,10 @@ def thresholding(df, threshold, column_to_threshold):
def _write_dedup_encoded_jaccard_pair(self, encoded_jaccard_pair_path):
output_path = f"{self.cache_dir}/final_dedup_encoded_jaccard_pair.parquet"
t0 = time.time()

with performance_report_if_with_ts_suffix(
self.profile_dir, "connected-components-dedup-encoded-jaccard-pair"
):

ddf = dask_cudf.read_parquet(
encoded_jaccard_pair_path, blocksize="512MB", aggregate_files=True
)
@@ -196,14 +216,17 @@ def _write_dedup_encoded_jaccard_pair(self, encoded_jaccard_pair_path):
align_dataframes=False,
)
ddf.to_parquet(output_path, write_index=False, overwrite=True)

self._logger.info(
f"Time taken for Dedup Encoding Jaccard Pairs = {time.time() - t0}s and output written at {output_path}"
f"Time taken for dedupe encoding Jaccard pairs: {time.time() - t0}s and output written at {output_path}"
)

return output_path

def _write_dedup_parsed_id(self):
dedup_parsed_id_path = f"{self.cache_dir}/dedup_parsed_id.parquet"
t0 = time.time()

with performance_report_if_with_ts_suffix(
self.profile_dir, "connected-components-dedup-parsed-id"
):
@@ -221,20 +244,24 @@ def _write_dedup_parsed_id(self):
# Dask does not guard against split_out=0
split_out=max(ddf.npartitions // 4, 1)
)

unique_docs["uid"] = np.uint64(1)
unique_docs["uid"] = unique_docs["uid"].cumsum()
unique_docs["uid"] = unique_docs["uid"] - 1
unique_docs.to_parquet(
dedup_parsed_id_path, write_index=False, overwrite=True
)

self._logger.info(
f"Time taken for Dedup Parsed Id = {time.time() - t0}s and output written at {dedup_parsed_id_path}"
f"Time taken for dedupe parsed ID: {time.time() - t0}s and output written at {dedup_parsed_id_path}"
)

return dedup_parsed_id_path

def _write_encoded_jaccard_pair(self, dedup_parsed_id_path):
output_path = f"{self.cache_dir}/encoded_jaccard_pair/"
t0 = time.time()

with performance_report_if_with_ts_suffix(
self.profile_dir, "connected-components-encoded-jaccard-pair"
):
@@ -252,9 +279,11 @@ def _write_encoded_jaccard_pair(self, dedup_parsed_id_path):
output_path=output_path,
id_column=self.id_column,
)

self._logger.info(
f"Time taken for Encoding Jaccard Pairs = {time.time() - t0}s and output written at {output_path}"
f"Time taken for encoding Jaccard pairs: {time.time() - t0}s and output written at {output_path}"
)

return output_path

def _merge_and_write(
@@ -265,11 +294,13 @@ def _merge_and_write(
id_column: str,
) -> None:
st = time.time()
# Ensure 'id_columns' is a list

# Ensure id_columns is a list
ddf_id = ddf_id.set_index(id_column)

for tag in ["x", "y"]:
pair_id = f"{id_column}_{tag}"
# Merge 'ddf' with 'ddf_id' to map ids to uids
# Merge ddf with ddf_id to map IDs to UIDs
ddf = ddf.merge(
ddf_id,
left_on=pair_id,
@@ -279,19 +310,22 @@ def _merge_and_write(
)
ddf = ddf.drop(columns=pair_id)
ddf = ddf.rename(columns={"uid": f"{self.id_column}_{tag}"})

ddf = ddf[[self.left_id, self.right_id, "jaccard"]]
ddf.to_parquet(output_path, write_index=False, overwrite=True)

et = time.time()
self._logger.info(
f"Time taken for merge and write = {et - st}s and output written at {output_path}"
f"Time taken for merge and write: {et - st}s and output written at {output_path}"
)

@staticmethod
def _get_unique_ids_per_partition(df, id_columns):
unique_df_ls = []

for tag in ["x", "y"]:
cols_to_drop = []

for id_col in id_columns:
cols_to_drop.append(f"{id_col}_{tag}")

@@ -300,6 +334,7 @@ def _get_unique_ids_per_partition(df, id_columns):
columns={f"{id_col}_{tag}": f"{id_col}" for id_col in id_columns}
)
unique_df_ls.append(subset_df)

unique_df = cudf.concat(unique_df_ls, ignore_index=True)
unique_df = unique_df.drop_duplicates(ignore_index=True)
return unique_df
67 changes: 42 additions & 25 deletions nemo_curator/modules/fuzzy_dedup/fuzzyduplicates.py
Original file line number Diff line number Diff line change
@@ -19,6 +19,7 @@
import time
from typing import Optional, Union

from nemo_curator.cache import Cache
from nemo_curator.datasets import DocumentDataset
from nemo_curator.log import create_logger
from nemo_curator.modules.base import BaseModule
@@ -44,15 +45,16 @@ def __init__(
"""
Parameters
----------
config: FuzzyDuplicatesConfig,
Config options for finding FuzzyDuplicates
config (FuzzyDuplicatesConfig): Config options for finding fuzzy duplicates.
logger: Existing logger to log to, or a path to a log directory.
Default is "./".
Returns
-------
DocumentDataset containing IDs of all documents and the corresponding duplicate group
they belong to. Documents in the same group are near duplicates.
"""

super().__init__(input_backend="cudf")
if isinstance(logger, str):
self._logger = create_logger(
@@ -65,6 +67,16 @@ def __init__(

self.config = config

if self.config.cache_dir is not None:
self.cache_dir = self.config.cache_dir
elif Cache().get_cache_directory() is not None:
self.cache_dir = Cache().get_cache_directory()
else:
raise RuntimeError(
"No cache directory specified. Please initialize with Cache(cache_dir=...) "
"or specify a cache_dir in your YAML file."
)

self.minhash = MinHash(
seed=self.config.seed,
num_hashes=self.config.num_hashes,
@@ -74,10 +86,11 @@ def __init__(
id_field=self.config.id_field,
text_field=self.config.text_field,
profile_dir=self.config.profile_dir,
cache_dir=self.config.cache_dir,
cache_dir=self.cache_dir,
)

self.lsh = LSH(
cache_dir=self.config.cache_dir,
cache_dir=self.cache_dir,
num_hashes=self.config.num_hashes,
num_buckets=self.config.num_buckets,
buckets_per_shuffle=self.config.buckets_per_shuffle,
@@ -109,9 +122,10 @@ def __init__(
for i in range(self.config.num_anchors)
],
)

else:
self.buckets_to_edges = BucketsToEdges(
cache_dir=self.config.cache_dir,
cache_dir=self.cache_dir,
id_fields=self.config.id_field,
logger=self._logger,
profile_dir=self.config.profile_dir,
@@ -123,8 +137,8 @@ def __init__(
else "_edges.parquet"
)
self.connected_components = ConnectedComponents(
cache_dir=self.config.cache_dir,
jaccard_pairs_path=os.path.join(self.config.cache_dir, jaccard_pairs_fname),
cache_dir=self.cache_dir,
jaccard_pairs_path=os.path.join(self.cache_dir, jaccard_pairs_fname),
id_column=self.config.id_field,
jaccard_threshold=self.config.jaccard_threshold,
logger=self._logger,
@@ -137,8 +151,8 @@ def identify_duplicates(
"""
Parameters
----------
dataset: DocumentDataset
The input datset to compute FuzzyDuplicates. Must contain a text and unique id field.
dataset (DocumentDataset): The input dataset on which to compute fuzzy deduplication.
Must contain a text field and unique ID field.
Returns
-------
@@ -152,20 +166,22 @@ def identify_duplicates(
minhashLSH = Sequential([self.minhash, self.lsh])
buckets_df = minhashLSH(dataset)
print(f"Stage {stage_num}: Minhash + LSH complete!")

if buckets_df is None:
print(
f"Stage {stage_num}: No potential duplicate documents found during LSH"
)
return None
stage_num += 1

stage_num += 1
if self.config.false_positive_check:
# Map buckets to lower cardinality distribution
print(f"Stage {stage_num} (False Positive Check): Starting Map_Buckets")
t0 = time.time()
mapped_buckets_w_anchors_path = os.path.join(
self.config.cache_dir, "anchor_docs_with_bk.parquet"
self.cache_dir, "anchor_docs_with_bk.parquet"
)

with performance_report_if_with_ts_suffix(
self.config.profile_dir,
"map_buckets",
@@ -178,18 +194,17 @@ def identify_duplicates(
ddf_mapped_buckets_w_anchors.to_parquet(
mapped_buckets_w_anchors_path, write_index=False, overwrite=True
)

self._logger.info(
f"Time taken for Map_buckets : {time.time() - t0}s and output written at {mapped_buckets_w_anchors_path}"
f"Time taken for Map_Buckets: {time.time() - t0}s and output written at {mapped_buckets_w_anchors_path}"
)

print(f"Stage {stage_num} (False Postive Check): Map_Buckets Complete!")
print(f"Stage {stage_num} (False Positive Check): Map_Buckets complete!")
stage_num += 1

# Shuffle documents based on mapped buckets
print(f"Stage {stage_num} (False Postive Check): Shuffle docs")
shuffled_docs_path = os.path.join(
self.config.cache_dir, "shuffled_docs.parquet"
)
print(f"Stage {stage_num} (False Positive Check): Shuffle documents")
shuffled_docs_path = os.path.join(self.cache_dir, "shuffled_docs.parquet")
self.jaccard_shuffle.shuffle_docs_on_buckets(
documents_df=dataset.df,
bucket_w_anchors_path=mapped_buckets_w_anchors_path,
@@ -198,15 +213,17 @@ def identify_duplicates(
parts_per_worker=self.config.parts_per_worker,
bucket_parts_per_worker=self.config.bucket_parts_per_worker,
)
print(f"Stage {stage_num} (False Postive Check): Shuffle docs complete!")
print(
f"Stage {stage_num} (False Positive Check): Shuffle documents complete!"
)
stage_num += 1

# jaccard comparision within buckets
# Jaccard comparision within buckets
print(
f"Stage {stage_num} (False Postive Check): Jaccard Similarity in Buckets"
f"Stage {stage_num} (False Positive Check): Jaccard similarity in buckets"
)
jaccard_pairs_path = os.path.join(
self.config.cache_dir, "jaccard_similarity_results.parquet"
self.cache_dir, "jaccard_similarity_results.parquet"
)
t0 = time.time()
with performance_report_if_with_ts_suffix(
@@ -223,11 +240,11 @@ def identify_duplicates(
overwrite=True,
)
self._logger.info(
f"Time taken for Jaccard Similarity = {time.time()-t0}s and output written at {jaccard_pairs_path}"
f"Time taken for Jaccard similarity: {time.time()-t0}s and output written at {jaccard_pairs_path}"
)

print(
f"Stage {stage_num} (False Postive Check): Jaccard Similarity in Buckets Complete!"
f"Stage {stage_num} (False Positive Check): Jaccard similarity in buckets complete!"
)
stage_num += 1

@@ -236,13 +253,13 @@ def identify_duplicates(
print(f"Stage {stage_num}: Starting LSH Buckets to Graph Edgelist")
self.buckets_to_edges(buckets_df)
print(
f"Stage {stage_num}: Starting LSH Buckets to Graph Edgelist Complete!"
f"Stage {stage_num}: Starting LSH Buckets to Graph Edgelist complete!"
)
stage_num += 1

# Connected components across buckets
print(f"Stage {stage_num}: Connected Components across buckets")
cc_path = os.path.join(self.config.cache_dir, "connected_components.parquet")
cc_path = os.path.join(self.cache_dir, "connected_components.parquet")
self.connected_components.cc_workflow(cc_path)
print(f"Stage {stage_num}: Connected Components across buckets complete!")
stage_num += 1
33 changes: 31 additions & 2 deletions nemo_curator/modules/fuzzy_dedup/jaccardsimilarity.py
Original file line number Diff line number Diff line change
@@ -46,16 +46,19 @@ def jaccard_compute(self, shuffled_docs_path):
for entry in os.scandir(shuffled_docs_path)
if not entry.path.endswith(".txt")
]

meta_df = cudf.DataFrame(
{
self.left_id: ["x"],
self.right_id: ["y"],
"jaccard": np.float32([0.0]),
}
)

result_df = dd.from_map(
self._compute_jaccard_on_1_partition, paths, meta=meta_df
).reset_index(drop=True)

return result_df

def _compute_jaccard_on_1_partition(self, path):
@@ -64,60 +67,75 @@ def _compute_jaccard_on_1_partition(self, path):
pair_df = self._compute_jaccard_and_create_pair_df(df)
except OverflowError:
paths = [entry.path for entry in os.scandir(os.path.join(path))]

anchor_df_str_size_ls = [
self._get_anchor_docs_and_string_size(path) for path in paths
]

anchor_df = cudf.concat(
[anchor_doc for anchor_doc, _ in anchor_df_str_size_ls],
ignore_index=True,
).drop_duplicates()

df_str_size = [str_size for _, str_size in anchor_df_str_size_ls]

paths = JaccardSimilarity._create_bins(
df_str_size, np.iinfo(np.int32).max // 10
)

pair_dfs = []
for path in paths:
print(path)
df = cudf.read_parquet(path).reset_index(drop=True)
df = cudf.concat([df, anchor_df], ignore_index=True)
pair_df = self._compute_jaccard_and_create_pair_df(df)
pair_dfs.append(pair_df)

pair_df = cudf.concat(pair_dfs, ignore_index=True)

return pair_df

def _get_anchor_docs_and_string_size(self, path):
df = cudf.read_parquet(path)
str_bytes = df[self.text_field].str.byte_count().sum()
is_anchor_flag = df[self.id_field] == df[self.anchor_id_fields[0]]

for anchor_id in self.anchor_id_fields[1:]:
is_anchor_flag = is_anchor_flag | (df[self.id_field] == df[anchor_id])

anchor_df = df[is_anchor_flag].reset_index(drop=True)
return anchor_df, {"path": path, "str_bytes": str_bytes}

@staticmethod
def _create_bins(path_dicts, max_size):
path_dicts.sort(key=lambda x: x["str_bytes"], reverse=True)
bins, bin_sizes = [], []

for path_d in path_dicts:
new_path, new_size = path_d["path"], path_d["str_bytes"]

for i, bin_size in enumerate(bin_sizes):
if bin_size + new_size <= max_size:
bins[i].append(new_path)
bin_sizes[i] += new_size
new_size = 0
break

if new_size:
bins.append([new_path])
bin_sizes.append(new_size)

return bins

def _compute_jaccard_and_create_pair_df(self, df):
df = df.drop_duplicates(
subset=[self.id_field] + self.anchor_id_fields, ignore_index=True
)

anchor_columns = self.anchor_id_fields
id_field = self.id_field
result_ls = []

try:
for anchor_col in anchor_columns:
doc_df = df[[id_field, self.text_field, anchor_col]]
@@ -128,15 +146,17 @@ def _compute_jaccard_and_create_pair_df(self, df):
result_ls.append(result_df)

return cudf.concat(result_ls)

except OverflowError as e:
print(
"Failed with OverflowError in compute_jaccard_and_create_pair_df",
"Failed with OverflowError in compute_jaccard_and_create_pair_df",
flush=True,
)
print(df, flush=True)
print("--" * 30)
print("Error")
print("---" * 30)

raise e

def _get_anchor_df(self, df, anchor_col):
@@ -150,22 +170,29 @@ def _compute_jaccard_pair(self, docs_df, anchor_df):
nrows_at_once = JaccardSimilarity._get_max_num_rows_to_process_once(
df=docs_df, text_field=self.text_field
)

result_ls = []
for i in range(0, docs_df.shape[0], nrows_at_once):
pair_df = docs_df[i : i + nrows_at_once]
pair_df = pair_df.merge(anchor_df, on=self.anchor_id)

pair_df = pair_df.rename(
columns={self.id_field: self.left_id, self.anchor_id: self.right_id}
)

mask = pair_df[self.left_id] != pair_df[self.right_id]
pair_df = pair_df[mask].reset_index(drop=True)

if len(pair_df) == 0:
result_df = self._create_empty_jaccard_result()
else:
result_df = self._compute_jaccard_partition(pair_df)

result_ls.append(result_df)

if len(result_ls) == 0:
return self._create_empty_jaccard_result()

df_pair = cudf.concat(result_ls)
return df_pair

@@ -186,10 +213,12 @@ def _compute_jaccard_partition(self, df):
@staticmethod
def _get_max_num_rows_to_process_once(df, text_field):
nbytes = df[text_field].str.byte_count().sum()
# Number of exmploded bytes

# Number of exploded bytes
exploded_bytes = nbytes * 5 * 2
max_chars_allowed = 2_147_483_647
byte_ratio = int(exploded_bytes) // max_chars_allowed

if byte_ratio > 1:
nrows_at_once = len(df) // byte_ratio
else:
91 changes: 63 additions & 28 deletions nemo_curator/modules/fuzzy_dedup/lsh.py
Original file line number Diff line number Diff line change
@@ -25,6 +25,7 @@
import dask_cudf
import numpy as np

from nemo_curator.cache import Cache
from nemo_curator.datasets import DocumentDataset
from nemo_curator.log import create_logger
from nemo_curator.utils.distributed_utils import performance_report_if_with_ts_suffix
@@ -33,14 +34,14 @@

class LSH:
"""
Performs LSH on a MinhashSignatures
Performs LSH on a Minhash signatures
"""

def __init__(
self,
cache_dir: str,
num_hashes: int,
num_buckets: int,
cache_dir: Optional[str] = None,
buckets_per_shuffle: int = 1,
false_positive_check: bool = False,
logger: Union[logging.LoggerAdapter, str] = "./",
@@ -51,37 +52,48 @@ def __init__(
"""
Parameters
----------
cache_dir: str
Needs to be specified, will compute & write duplicate id, bucket pairs to cache directory.
num_hashes: Length of minhash signature
num_hashes: Length of minhash signature.
num_buckets: Number of bands/buckets to create from the minhash signature.
Hashes_per_signature = num_hashes / num_buckets
buckets_per_shuffle: Number of bands/buckets to shuffle concurrently.
but might lead to memory pressures and related errors.
false_positive_check: bool
If True, writes out buckets in a format compatible with downstream false positive check.
hashes_per_signature = num_hashes / num_buckets.
cache_dir: Directory to compute and write duplicate ID, bucket pairs.
This field is required via LSH(cache_dir=...) or Cache(cache_dir=...).
buckets_per_shuffle: Number of bands/buckets to shuffle concurrently. Larger
values process larger batches by processing multiple bands but might lead
to memory pressures and related errors. Default is 1.
false_positive_check: If True, writes out buckets in a format compatible with
downstream false positive check. Default is False.
logger: Existing logger to log to, or a path to a log directory.
id_field: Columns in the Dataset denoting document ID.
minhash_field: Column in the Dataset denoting minhash signature.
profile_dir: str, Default None
If specified directory to write dask profile
Default is "./".
id_fields: List or string representing column(s) in the dataset denoting
document ID. Default is "id".
minhash_field: Column in the dataset denoting minhash signature.
Default is "_minhash_signature".
profile_dir: If specified, directory to write Dask profile. Default is None.
"""

self.num_hashes = num_hashes
self.num_buckets = num_buckets
self.id_fields = [id_fields] if isinstance(id_fields, str) else id_fields
self.minhash_field = minhash_field
self.buckets_per_shuffle = buckets_per_shuffle

self.bucket_ranges = self._generate_bucket_ranges(
self.num_buckets, self.num_hashes
)

self.buckets_as_int = false_positive_check

if cache_dir is None:
self.cache_dir = Cache().get_cache_directory()
else:
self.cache_dir = cache_dir
self.profile_dir = profile_dir

if self.cache_dir is None:
raise ValueError(
"cache_dir for intermediate outputs is required for this stage"
"cache_dir for intermediate outputs is required for this stage. "
"Please initialize with Cache(cache_dir=...) or LSH(cache_dir=...)"
)
self.cache_dir = cache_dir
self.profile_dir = profile_dir

if isinstance(logger, str):
self._logger = create_logger(
@@ -96,11 +108,12 @@ def _generate_bucket_ranges(
self, num_buckets: int, num_hashes: int
) -> List[List[int]]:
"""
Generates a list of indices for the minhash ranges given num_bands &
num_hashes.
eg: num_bands=3, num_hashes=6
Generates a list of indices for the minhash ranges, given num_bands and num_hashes.
For example: num_bands=3, num_hashes=6
[[0, 1], [2, 3], [4, 5]]
"""

minhashes_per_bucket = num_hashes // num_buckets

bucket_ranges = [
@@ -111,6 +124,7 @@ def _generate_bucket_ranges(
)
for bucket in range(num_buckets)
]

return bucket_ranges

def minhash_to_buckets(
@@ -119,11 +133,13 @@ def minhash_to_buckets(
bucket_ranges: List[List[int]],
) -> cudf.DataFrame:
df2 = df[self.id_fields]

for i, h in enumerate(bucket_ranges):
indices = cudf.Series([h]).repeat(len(df2))
df2[f"_bucket_{i}"] = f"b{i}_" + df[self.minhash_field].list.take(
indices
).hash_values(method="md5")

return df2

def bucket_id_to_int(
@@ -133,23 +149,28 @@ def bucket_id_to_int(
start_id: int = 0,
) -> Tuple[dask_cudf.DataFrame, int]:
"""
Maps bucket ids to a contigious integer range from starting from start_id.
Maps bucket IDs to a contigious integer range from starting from start_id.
"""

unique_bucket_df = (
bucket_ddf[[bucket_col_name]]
.map_partitions(lambda x: x.drop_duplicates(ignore_index=True))
.persist()
)

end_bucket_id = len(unique_bucket_df) - 1 + start_id
unique_bucket_df["bucket_int_id"] = np.uint64(1)
unique_bucket_df["bucket_int_id"] = unique_bucket_df["bucket_int_id"].cumsum()

unique_bucket_df["bucket_int_id"] = (
unique_bucket_df["bucket_int_id"] - 1 + start_id
)

bucket_ddf = bucket_ddf.merge(unique_bucket_df, on=[bucket_col_name])
bucket_ddf = bucket_ddf.drop(columns=[bucket_col_name])
bucket_ddf = bucket_ddf.rename(columns={"bucket_int_id": "_bucket_id"})
bucket_ddf["_bucket_id"] = bucket_ddf["_bucket_id"].astype(np.uint64)

return (bucket_ddf, end_bucket_id)

def _minhash_to_bucket_meta(
@@ -165,29 +186,33 @@ def lsh(
df: dask_cudf.DataFrame,
) -> bool:
"""
Computes hash buckets for the DataFrame and writes them as parquet files to the specified path.
Computes hash buckets for the DataFrame and writes them as Parquet files to the specified path.
Parameters:
- write_path (str): The directory path to write parquet files.
- write_path (str): The directory path to write Parquet files.
- df (dask_cudf.DataFrame): The input DataFrame with minhashes to be bucketed.
Returns:
are_buckets_empty: True if buckets were empty (no duplicates found), False otherwise.
"""

wrote_buckets = False
are_buckets_empty = True

meta = self._minhash_to_bucket_meta(df)

df = df.map_partitions(
self.minhash_to_buckets,
bucket_ranges=self.bucket_ranges,
meta=meta,
)

bucket_start_id = 0
for i in range(0, self.num_buckets, self.buckets_per_shuffle):
bucket_columns = [
f"_bucket_{i}"
for i in range(i, min(self.num_buckets, i + self.buckets_per_shuffle))
]

df2 = df.melt(
id_vars=self.id_fields,
value_name="_bucket_id",
@@ -201,17 +226,20 @@ def lsh(
).map_partitions(lambda x: x[x["_bucket_id"].duplicated(keep=False)])

df2 = df2.reset_index(drop=True)
# Buckets to Int

# Buckets to int
if self.buckets_as_int:
df2, end_id = self.bucket_id_to_int(
df2, bucket_col_name="_bucket_id", start_id=bucket_start_id
)
# If bucketing return empty dataframe

# If bucketing returns empty DataFrame
if end_id < bucket_start_id:
self._logger.info(
f"No duplicate documents found for buckets: {bucket_columns}"
)
continue

bucket_start_id = end_id + 1
are_buckets_empty = False

@@ -241,15 +269,17 @@ def _write_bucket_parquet(
buckets_to_write: List[str],
) -> tuple[bool, bool]:
"""
Utility function to write the bucketed data to parquet
Utility function to write the bucketed data to Parquet,
handling cases of overwriting and appending as needed.
"""

if not wrote_buckets:
if os.path.exists(write_path):
warnings.warn(
f"Output path {write_path} already exists and will be overwritten"
)
df.to_parquet(write_path, write_index=False, overwrite=True)

else:
df.to_parquet(
write_path,
@@ -258,9 +288,11 @@ def _write_bucket_parquet(
append=not are_buckets_empty,
ignore_divisions=True,
)

# Only check if buckets written so far are empty
if are_buckets_empty:
are_buckets_empty = check_empty_buckets(write_path)

wrote_buckets = True

if are_buckets_empty:
@@ -269,21 +301,24 @@ def _write_bucket_parquet(
)
else:
self._logger.info(f"Wrote data for buckets: {buckets_to_write}")

return wrote_buckets, are_buckets_empty

def __call__(self, dataset: DocumentDataset) -> DocumentDataset:
df = dataset.df

write_path = os.path.join(self.cache_dir, "_buckets.parquet")

t0 = time.time()
with performance_report_if_with_ts_suffix(self.profile_dir, "lsh-profile"):
empty_result = self.lsh(write_path=write_path, df=df)

self._logger.info(
f"Time taken for LSH = {time.time() - t0}s and output written at {write_path}"
f"Time taken for LSH: {time.time() - t0}s and output written at {write_path}"
)

if empty_result:
return None

buckets_df = dask_cudf.read_parquet(write_path, split_row_groups=False)

return DocumentDataset(buckets_df)
65 changes: 44 additions & 21 deletions nemo_curator/modules/fuzzy_dedup/minhash.py
Original file line number Diff line number Diff line change
@@ -25,6 +25,7 @@
import numpy as np

from nemo_curator._compat import MINHASH_DEPRECATED_API, MINHASH_PERMUTED_AVAILABLE
from nemo_curator.cache import Cache
from nemo_curator.datasets import DocumentDataset
from nemo_curator.log import create_logger
from nemo_curator.utils.distributed_utils import performance_report_if_with_ts_suffix
@@ -50,17 +51,19 @@ def __init__(
"""
Parameters
----------
seed: Seed for minhash permutations
num_hashes: Length of minhash signature (No. of minhash permutations)
seed: Seed for minhash permutations. Default is 42.
num_hashes: Length of minhash signature (number of minhash permutations).
Default is 260.
char_ngrams: Width of text window (in characters) while computing minhashes.
use_64bit_hash: Whether to use a 64 bit hash function.
Default is 5.
use_64bit_hash: Whether to use a 64 bit hash function. Default is False.
logger: Existing logger to log to, or a path to a log directory.
id_field: Column in the Dataset denoting document ID.
text_field: Column in the Dataset denoting document content.
profile_dir: str, Default None
If specified directory to write dask profile
cache_dir: str, Default None
If specified, will compute & write id, minhash pairs to directory
Default is "./".
id_field: Column in the dataset denoting document ID. Default is "id".
text_field: Column in the dataset denoting document content. Default is "text".
profile_dir: If specified, directory to write Dask profile. Default is None.
cache_dir: If specified, will compute and write "ID, minhash pairs" to
directory. Can also be set with Cache(cache_dir=...). Default is None.
"""
self.num_hashes = num_hashes
self.char_ngram = char_ngrams
@@ -78,12 +81,16 @@ def __init__(
self.id_field = id_field
self.text_field = text_field

if cache_dir is None and profile_dir is not None:
if cache_dir is None:
self.cache_dir = Cache().get_cache_directory()
else:
self.cache_dir = cache_dir
self.profile_dir = profile_dir
if self.cache_dir is None and profile_dir is not None:
warnings.warn(
"cache_dir for intermediate outputs is required to generate profiles"
"cache_dir for intermediate outputs is required to generate profiles. "
"Please initialize with Cache(cache_dir=...) or MinHash(cache_dir=...)"
)
self.cache_dir = cache_dir
self.profile_dir = profile_dir

if isinstance(logger, str):
self._logger = create_logger(
@@ -98,6 +105,7 @@ def generate_seeds(self, n_seeds: int = 260, seed: int = 0) -> np.ndarray:
"""
Generate seeds for all minhash permutations based on the given seed.
"""

gen = np.random.RandomState(seed)
return gen.randint(0, 1e6, size=n_seeds)

@@ -107,6 +115,7 @@ def generate_hash_permutation_seeds(
"""
Generate seeds for all minhash permutations based on the given seed.
"""

gen = np.random.RandomState(seed)

if bit_width == 32:
@@ -117,7 +126,7 @@ def generate_hash_permutation_seeds(
MERSENNE_PRIME = np.uint64((1 << 61) - 1)
dtype = np.uint64
else:
raise ValueError("Unsupported bit width. Use either 32 or 64.")
raise ValueError("Unsupported bit width. Please use either 32 or 64.")

return np.array(
[
@@ -134,8 +143,9 @@ def minhash32(
self, ser: cudf.Series, seeds: np.ndarray, char_ngram: int
) -> cudf.Series:
"""
Compute 32bit minhashes based on the MurmurHash3 algorithm
Compute 32-bit minhashes based on the MurmurHash3 algorithm.
"""

if not isinstance(ser, cudf.Series):
raise TypeError("Expected data of type cudf.Series")

@@ -146,8 +156,10 @@ def minhash32(
"Install the latest version of cuDF using `pip install curator[cuda12x_nightly]`",
category=FutureWarning,
)

seeds = cudf.Series(seeds, dtype="uint32")
return ser.str.minhash(seeds=seeds, width=char_ngram)

else:
seeds_a = cudf.Series(seeds[:, 0], dtype="uint32")
seeds_b = cudf.Series(seeds[:, 1], dtype="uint32")
@@ -165,19 +177,23 @@ def minhash64(
self, ser: cudf.Series, seeds: np.ndarray, char_ngram: int
) -> cudf.Series:
"""
Compute 64bit minhashes based on the MurmurHash3 algorithm
Compute 64-bit minhashes based on the MurmurHash3 algorithm.
"""

if not isinstance(ser, cudf.Series):
raise TypeError("Expected data of type cudf.Series")

if MINHASH_DEPRECATED_API:
warnings.warn(
"Using an outdated minhash implementation, please update to cuDF version 24.12 "
"or later for improved performance. "
"Install the latest version of cuDF using `pip install curator[cuda12x_nightly]`",
category=FutureWarning,
)

seeds = cudf.Series(seeds, dtype="uint64")
return ser.str.minhash64(seeds=seeds, width=char_ngram)

else:
seeds_a = cudf.Series(seeds[:, 0], dtype="uint64")
seeds_b = cudf.Series(seeds[:, 1], dtype="uint64")
@@ -193,16 +209,19 @@ def minhash64(

def __call__(self, dataset: DocumentDataset) -> Union[str, DocumentDataset]:
"""
Computes the MinHash Signatures for a given dataset.
Computes the MinHash signatures for a given dataset.
Parameters
----------
dataset: DocumentDataset
The input datset to compute MinHashes.
dataset (DocumentDataset): The input dataset on which to compute MinHashes.
Returns
-------
DocumentDataset containing IDs of all documents and the corresponding MinHash Signature
DocumentDataset containing IDs of all documents and the corresponding MinHash signature
"""

result = dataset.df[[self.id_field]]

result["_minhash_signature"] = dataset.df[self.text_field].map_partitions(
self.minhash_method,
seeds=self.seeds,
@@ -214,16 +233,20 @@ def __call__(self, dataset: DocumentDataset) -> Union[str, DocumentDataset]:

t0 = time.time()
self._logger.info("Starting execution for Minhashes")

write_path = os.path.join(self.cache_dir, "_minhashes.parquet")
if os.path.exists(write_path):
warnings.warn(
f"Output path {write_path} already exists and will be overwritten"
)

with performance_report_if_with_ts_suffix(self.profile_dir, "minhash-profile"):
result.to_parquet(write_path, write_index=False, overwrite=True)

self._logger.info(
f"Time taken for Minhash signature computation = {time.time() - t0}s and output written at {write_path}"
f"Time taken for Minhash signature computation: {time.time() - t0}s and output written at {write_path}"
)

return DocumentDataset(
dask_cudf.read_parquet(write_path, blocksize="2GB", aggregate_files=True)
)
61 changes: 42 additions & 19 deletions nemo_curator/modules/semantic_dedup/clusteringmodel.py
Original file line number Diff line number Diff line change
@@ -25,14 +25,14 @@
import numpy as np
from cuml.dask.cluster import KMeans

from nemo_curator.cache import Cache
from nemo_curator.datasets import DocumentDataset
from nemo_curator.log import create_logger
from nemo_curator.utils.distributed_utils import performance_report_if_with_ts_suffix
from nemo_curator.utils.file_utils import expand_outdir_and_mkdir
from nemo_curator.utils.semdedup_utils import assign_and_sort_clusters


### Clustering Module
def get_embedding_ar(df: "cudf.DataFrame", embedding_col: str) -> cp.ndarray:
return df[embedding_col].list.leaves.values.reshape(len(df), -1)

@@ -47,13 +47,15 @@ def add_dist_to_cents(
return df


# Clustering module
class ClusteringModel:
def __init__(
self,
id_column: str,
max_iter: int,
n_clusters: int,
clustering_output_dir: str,
cache_dir: Optional[str] = None,
clustering_save_loc: str = "clustering_results",
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We deprecate clustering_output_dir in favor of cache_dir and clustering_save_loc for ClusteringModel, which matches the logic in the SemDedup class.

embedding_col: str = "embeddings",
sim_metric: str = "cosine",
which_to_keep: str = "hard",
@@ -70,22 +72,29 @@ def __init__(
id_column (str): Column name used as the identifier in the dataset.
max_iter (int): Maximum number of iterations for the clustering algorithm.
n_clusters (int): The number of clusters to form.
clustering_output_dir (str): Directory path where clustering results will be saved.
cache_dir (str, optional): Directory path where clustering results will be saved.
clustering_save_loc (str): Location within cache_dir to save clustering results.
Default is "clustering_results".
embedding_col (str): Column name where the embeddings are stored.
sim_metric (str): Similarity metric to use for clustering, default is "cosine".
which_to_keep (str): Strategy to decide which duplicates to keep; default is "hard".
sort_clusters (bool): Whether to sort clusters, default is True.
kmeans_with_cos_dist (bool): Whether to use KMeans with cosine distance, default is False.
partition_size (str): The size of data partition to run kmeans with, default is "2gb".
logger (Union[logging.Logger, str]): Logger object or directory path to save logs; default is "./".
profile_dir (str): If specified directory to write dask profile. Default is None.
sim_metric (str): Similarity metric to use for clustering.
Default is "cosine".
which_to_keep (str): Strategy to decide which duplicates to keep.
Default is "hard".
sort_clusters (bool): Whether to sort clusters. Default is True.
kmeans_with_cos_dist (bool): Whether to use KMeans with cosine distance.
Default is False.
partition_size (str): The size of data partition with which to run KMeans.
Default is "2gb".
logger (Union[logging.Logger, str]): Logger object or directory path to save logs.
Default is "./".
profile_dir (str, optional): If specified, directory to write Dask profile.
Default is None.
This constructor sets up the parameters required for clustering operations.
"""
self.id_col = id_column
self.max_iter = max_iter
self.n_clusters = n_clusters
self.clustering_output_dir = clustering_output_dir
self.embedding_col = embedding_col
self.sim_metric = sim_metric
self.keep_hard = which_to_keep == "hard"
@@ -95,11 +104,24 @@ def __init__(
self.logger = self._setup_logger(logger)
self.profile_dir = profile_dir

if cache_dir is not None:
self.clustering_output_dir = os.path.join(cache_dir, clustering_save_loc)
elif Cache().get_cache_directory() is not None:
self.clustering_output_dir = os.path.join(
Cache().get_cache_directory(), clustering_save_loc
)
else:
raise RuntimeError(
"No cache directory specified. Please initialize with Cache(cache_dir=...) "
"or ClusteringModel(cache_dir=...)"
)

if not os.path.exists(self.clustering_output_dir):
expand_outdir_and_mkdir(self.clustering_output_dir)
else:
self.logger.warning(
f"Clustering output directory {self.clustering_output_dir} already exists and will be overwritten"
f"Clustering output directory {self.clustering_output_dir} already exists"
" and will be overwritten"
)

def _setup_logger(self, logger):
@@ -119,7 +141,7 @@ def __call__(self, embeddings_dataset: DocumentDataset):

if self.embedding_col not in embeddings_df.columns:
raise ValueError(
f"Expected embedding column '{self.embedding_col}'"
f'Expected embedding column "{self.embedding_col}"'
f" to be in dataset. Only found columns {embeddings_df.columns}"
)

@@ -140,14 +162,14 @@ def __call__(self, embeddings_dataset: DocumentDataset):
self.logger.info("KMeans starting fit")
kmeans.fit(cupy_darr)
self.logger.info("KMeans fit complete")
self.logger.info(f"Time taken for KMeans Fit: {time.time() - t0}")
self.logger.info(f"Time taken for KMeans fit: {time.time() - t0}")

self.logger.info(
"Computing nearest centroids + distance to centers using kmeans.predict"
"Computing nearest centroids and distance to centers using kmeans.predict"
)
t0 = time.time()
nearest_cents = kmeans.predict(cupy_darr)
self.logger.info(f"Time taken for KMeans Predict: {time.time() - t0}")
self.logger.info(f"Time taken for KMeans predict: {time.time() - t0}")

t0 = time.time()
embeddings_df["nearest_cent"] = nearest_cents.astype(np.int32)
@@ -174,7 +196,8 @@ def __call__(self, embeddings_dataset: DocumentDataset):
)
if os.path.exists(clustering_output_dir):
self.logger.warning(
f"Output directory {clustering_output_dir} already exists and will be overwritten"
f"Output directory {clustering_output_dir} already exists and will"
" be overwritten."
)
shutil.rmtree(clustering_output_dir)

@@ -184,8 +207,8 @@ def __call__(self, embeddings_dataset: DocumentDataset):
partition_on="nearest_cent",
)
self.logger.info(
f"Time taken for Assigning distance to each embedding : {time.time() - t0} "
f"and output written at {clustering_output_dir}"
f"Time taken for assigning distance to each embedding: {time.time() - t0}s"
f" and output written at {clustering_output_dir}"
)

del embeddings_df
61 changes: 37 additions & 24 deletions nemo_curator/modules/semantic_dedup/embeddings.py
Original file line number Diff line number Diff line change
@@ -36,7 +36,7 @@
)


# Embedding Creation Module
# Embedding creation module
@dataclass
class EmbeddingConfig:
model_name_or_path: str
@@ -47,7 +47,7 @@ def __post_init__(self):
self.max_seq_length = AutoTokenizer.from_pretrained(
self.model_name_or_path
).model_max_length
# Gaurd against the HF bug
# Guard against Hugging Face bug
# which sets max_seq_length to max(int) for some models
if self.max_seq_length > 1e5:
self.max_seq_length = AutoConfig.from_pretrained(
@@ -135,7 +135,8 @@ def __init__(
self,
embedding_model_name_or_path: str,
embedding_batch_size: int,
embedding_output_dir: str,
cache_dir: Optional[str] = None,
embeddings_save_loc: str = "embeddings",
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We deprecate embedding_output_dir in favor of cache_dir and embeddings_save_loc for EmbeddingCreator, which matches the logic in the SemDedup class.

embedding_max_mem_gb: Optional[int] = None,
embedding_pooling_strategy: str = "mean_pooling",
input_column: str = "text",
@@ -151,26 +152,27 @@ def __init__(
Args:
embedding_model_name_or_path (str): The path or identifier for the model used to generate embeddings.
embedding_batch_size (int): Number of samples to process in each batch.
embedding_output_dir (str): Directory path where embeddings will be saved.
embedding_max_mem_gb (int): Maximum memory usage in GB for the embedding process.
If None, it defaults to the available GPU memory minus 4 GB.
embedding_pooling_strategy (str): Strategy for pooling embeddings, either "mean_pooling" or "last_token". Defaults to "mean_pooling".
input_column (str): Column name from the data to be used for embedding generation, defaults to "text".
write_embeddings_to_disk (bool, optional): If True, saves the embeddings to disk, defaults to True.
We recommend setting this to False when you have a delayed pipeline.
Setting it to False can lead to more memory overhead.
write_to_filename (bool): If True, saves the embeddings to the same filename as input files, defaults to False.
logger (Union[logging.Logger, str]): Logger object or path to store logs, defaults to "./".
profile_dir (str): If specified directory to write dask profile. Default is None.
Attributes:
embeddings_config (EmbeddingConfig): Configuration for embeddings.
batch_size (int): Batch size for embedding generation.
logger (logging.Logger): Logger instance for the class.
embedding_output_dir (str): Output directory for embeddings.
input_column (str): Input column for data processing.
model (EmbeddingCrossFitModel): Model instance for embedding generation.
write_to_filename (bool): If True, saves the embeddings to the same filename as input files, defaults to False.
cache_dir (str, optional): Directory path where embeddings will be saved.
If None, we check if a cache_dir has been initialized with Cache().get_cache_directory().
Default is None.
embeddings_save_loc (str): Location within cache_dir to save embeddings.
Default is "embeddings".
embedding_max_mem_gb (int, optional): Maximum memory usage in GB for the embedding process.
If None, it defaults to the available GPU memory minus 4 GB.
embedding_pooling_strategy: Strategy for pooling embeddings, either "mean_pooling" or "last_token".
Default is "mean_pooling".
input_column (str): Column name from the data to be used for embedding generation.
Default is "text".
embedding_column (str): The column name that stores the embeddings. Default is "embeddings".
write_embeddings_to_disk (bool): If True, saves the embeddings to disk.
We recommend setting this to False when you have a delayed pipeline.
Setting it to False can lead to more memory overhead. Default is True.
write_to_filename (bool): If True, saves the embeddings to the same filename as input files.
Default False.
logger (Union[logging.Logger, str]): Logger object or path to store logs.
Default is "./".
profile_dir (str, optional): If specified, directory to write Dask profile.
Default is None.
"""

self.embeddings_config = EmbeddingConfig(
@@ -179,7 +181,6 @@ def __init__(
)
self.batch_size = embedding_batch_size
self.logger = self._setup_logger(logger)
self.embedding_output_dir = embedding_output_dir
self.input_column = input_column
self.embedding_column = embedding_column
self.model = EmbeddingCrossFitModel(
@@ -189,6 +190,18 @@ def __init__(
self.write_to_filename = write_to_filename
self.profile_dir = profile_dir

if cache_dir is not None:
self.embedding_output_dir = os.path.join(cache_dir, embeddings_save_loc)
elif Cache().get_cache_directory() is not None:
self.embedding_output_dir = os.path.join(
Cache().get_cache_directory(), embeddings_save_loc
)
else:
raise RuntimeError(
"No cache directory specified. Please initialize with Cache(cache_dir=...) "
"or EmbeddingCreator(cache_dir=...)"
)

def _setup_logger(self, logger):
if isinstance(logger, str):
return create_logger(
49 changes: 36 additions & 13 deletions nemo_curator/modules/semantic_dedup/semanticclusterleveldedup.py
Original file line number Diff line number Diff line change
@@ -21,6 +21,7 @@

import dask.bag as db

from nemo_curator.cache import Cache
from nemo_curator.datasets import DocumentDataset
from nemo_curator.log import create_logger
from nemo_curator.modules.config import SemDedupConfig
@@ -36,13 +37,13 @@ class SemanticClusterLevelDedup:
def __init__(
self,
n_clusters: int,
emb_by_clust_dir: str,
sorted_clusters_dir: str,
id_column: str,
id_column_type: str,
which_to_keep: str,
output_dir: str,
output_dir: Optional[str] = None,
cache_dir: Optional[str] = None,
embedding_col: str = "embeddings",
clustering_save_loc: str = "clustering_results",
logger: Union[logging.Logger, str] = "./",
profile_dir: Optional[str] = None,
) -> None:
@@ -51,31 +52,52 @@ def __init__(
Args:
n_clusters (int): Number of clusters.
emb_by_clust_dir (str): Directory containing embeddings by cluster.
sorted_clusters_dir (str): Directory containing sorted clusters.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We deprecate emb_by_clust_dir and sorted_clusters_dir in favor of cache_dir and embeddings_save_loc for SemanticClusterLevelDedup, which matches the logic in the SemDedup class.

id_column (str): Column name for IDs.
id_column_type (str): Data type of the ID column.
which_to_keep (str): Strategy for which duplicate to keep.
output_dir (str): Directory to save output files.
output_dir (str, optional): Directory to save output files.
If None, it will be saved to cache_dir/clustering_save_loc.
Default is None.
cache_dir (str, optional): Should be the same as specified in ClusteringModel.
embedding_col (str): Column where the embeddings are stored.
clustering_save_loc (str): Should be the same as specified in ClusteringModel.
logger (Union[logging.Logger, str]): Logger instance or path to the log file directory.
profile_dir (str): If specified directory to write dask profile. Default is None.
Default is "./".
profile_dir (str, optional): If specified, directory to write Dask profile.
Default is None.
"""
self.n_clusters = n_clusters
self.emb_by_clust_dir = emb_by_clust_dir
self.sorted_clusters_dir = sorted_clusters_dir
self.id_col = id_column
self.id_col_type = id_column_type
self.which_to_keep = which_to_keep
self.output_dir = output_dir
self.semdedup_pruning_tables_dir = os.path.join(
output_dir, "semdedup_pruning_tables"
)
self.computed_semantic_match_dfs = False
self.embedding_col = embedding_col
self.logger = self._setup_logger(logger)
self.profile_dir = profile_dir

if cache_dir is None:
if Cache().get_cache_directory() is None:
raise RuntimeError(
"No cache directory specified. Please initialize with Cache(cache_dir=...) "
"or SemanticClusterLevelDedup(cache_dir=...)"
)
else:
cache_dir = Cache().get_cache_directory()
self.emb_by_clust_dir = os.path.join(
cache_dir, clustering_save_loc, "embs_by_nearest_center"
)
self.sorted_clusters_dir = os.path.join(
cache_dir, clustering_save_loc, "sorted"
)

if output_dir is None:
self.output_dir = os.path.join(cache_dir, clustering_save_loc)
else:
self.output_dir = output_dir
self.semdedup_pruning_tables_dir = os.path.join(
output_dir, "semdedup_pruning_tables"
)

def _setup_logger(self, logger: Union[logging.Logger, str]) -> logging.Logger:
"""
Set up the logger.
@@ -117,6 +139,7 @@ def compute_semantic_match_dfs(
)
shutil.rmtree(self.semdedup_pruning_tables_dir)
expand_outdir_and_mkdir(self.semdedup_pruning_tables_dir)

t0 = time.time()
with performance_report_if_with_ts_suffix(
self.profile_dir, "semantic-match-compute"
50 changes: 37 additions & 13 deletions nemo_curator/modules/semantic_dedup/semdedup.py
Original file line number Diff line number Diff line change
@@ -17,6 +17,7 @@
import os
from typing import Union

from nemo_curator.cache import Cache
from nemo_curator.datasets import DocumentDataset
from nemo_curator.modules.base import BaseModule
from nemo_curator.modules.config import SemDedupConfig
@@ -43,42 +44,65 @@ def __init__(
config (SemDedupConfig): Configuration for SemDedup.
logger (Union[logging.Logger, str]): Logger instance or path to the log file directory.
"""

super().__init__(input_backend="cudf")
self.config = config
self.logger = logger
cache_dir = config.cache_dir
if config.cache_dir is not None:
cache_dir = config.cache_dir
elif Cache().get_cache_directory() is not None:
cache_dir = Cache().get_cache_directory()
else:
raise RuntimeError(
"No cache directory specified. Please initialize with Cache(cache_dir=...) "
"or specify a cache_dir in your YAML file."
)
profile_dir = self.config.profile_dir
clustering_save_loc = config.clustering_save_loc

self.embedding_creator = EmbeddingCreator(
embedding_model_name_or_path=config.embedding_model_name_or_path,
embedding_batch_size=config.embedding_batch_size,
cache_dir=cache_dir,
embeddings_save_loc=config.embeddings_save_loc,
embedding_pooling_strategy=config.embedding_pooling_strategy,
input_column=input_column,
embedding_output_dir=os.path.join(cache_dir, config.embeddings_save_loc),
write_embeddings_to_disk=config.write_embeddings_to_disk,
logger=logger,
profile_dir=self.config.profile_dir,
profile_dir=profile_dir,
# Hardcoded as recommended values
embedding_max_mem_gb=None,
embedding_column="embeddings",
write_to_filename=False,
)
self.clustering_model = ClusteringModel(
id_column=id_column,
max_iter=config.max_iter,
n_clusters=config.n_clusters,
clustering_output_dir=os.path.join(cache_dir, config.clustering_save_loc),
cache_dir=cache_dir,
clustering_save_loc=clustering_save_loc,
sim_metric=config.sim_metric,
which_to_keep=config.which_to_keep,
kmeans_with_cos_dist=config.kmeans_with_cos_dist,
logger=logger,
profile_dir=self.config.profile_dir,
profile_dir=profile_dir,
# Hardcoded as recommended values
embedding_col="embeddings",
sort_clusters=True,
partition_size="2gb",
)
self.semantic_cluster_dedup = SemanticClusterLevelDedup(
n_clusters=config.n_clusters,
emb_by_clust_dir=os.path.join(
cache_dir, config.clustering_save_loc, "embs_by_nearest_center"
),
sorted_clusters_dir=os.path.join(
cache_dir, config.clustering_save_loc, "sorted"
),
id_column=id_column,
id_column_type=id_column_type,
which_to_keep=config.which_to_keep,
output_dir=os.path.join(cache_dir, config.clustering_save_loc),
cache_dir=cache_dir,
clustering_save_loc=clustering_save_loc,
logger=logger,
profile_dir=self.config.profile_dir,
profile_dir=profile_dir,
# Hardcoded as recommended values
output_dir=os.path.join(cache_dir, clustering_save_loc),
embedding_col="embeddings",
)
self.eps_thresholds = config.eps_thresholds
self.eps_to_extract = config.eps_to_extract
16 changes: 14 additions & 2 deletions nemo_curator/scripts/find_exact_duplicates.py
Original file line number Diff line number Diff line change
@@ -38,8 +38,10 @@ def main(args):
logger.info(f"Starting workflow with args:\n {args}")

assert args.hash_method == "md5", "Currently only md5 hash is supported"

client = get_client(**ArgumentHelper.parse_client_args(args))
logger.info(f"Client Created {client}")
logger.info(f"Client created: {client}")

if args.device == "gpu":
client.run(pre_imports)
logger.info("Pre imports complete")
@@ -48,29 +50,37 @@ def main(args):
id_field = args.input_json_id_field
text_field = args.input_json_text_field
num_files = args.num_files

t0 = time.time()

dfs = []
for data_path in data_paths:
data_path = strip_trailing_sep(data_path)

if num_files is not None and num_files <= 0:
logger.info(f"Processed {num_files}... quitting")
break

files = get_all_files_paths_under(
root=data_path, recurse_subdirectories=False, keep_extensions="jsonl"
)

df = read_data(
files[:num_files] if num_files else files,
file_type="jsonl",
backend="pandas" if args.device != "gpu" else "cudf",
files_per_partition=args.files_per_partition,
add_filename=False,
)[[id_field, text_field]]

if num_files is not None:
num_files -= len(files)

dfs.append(df)
logger.info(f"Lazy read complete for {dfs[-1].npartitions} partitions")

input_df = dask_cudf.concat(dfs, ignore_unknown_divisions=True)

exact_dups = ExactDuplicates(
logger=logger,
id_field=id_field,
@@ -80,8 +90,10 @@ def main(args):
cache_dir=args.output_dir,
)
exact_dups(dataset=DocumentDataset(input_df))

logger.info(
f"Exact deduplication computation across datasets took {time.time() - t0}s complete at {args.output_dir}" # noqa:E501
f"Exact deduplication computation across datasets took {time.time() - t0}s \n"
f"Output written at {args.output_dir}" # noqa:E501
)


2 changes: 1 addition & 1 deletion nemo_curator/scripts/fuzzy_deduplication/README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
## Fuzzy Deduplication Steps
This directory consists of scripts that can be invoked directly via the command line for finding fuzzy duplicates from a group of Jsonl files consisting of text & unique ID's that are specifically formatted using the `add_id` script included as a part of NeMo-Curator.
This directory consists of scripts that can be invoked directly via the command line for finding fuzzy duplicates from a group of JSONL files consisting of text and unique IDs that are specifically formatted using the `add_id` script included as a part of NeMo Curator.

> [!IMPORTANT]
> The up to date documentation on running the fuzzy deduplication scripts can be found in the [NeMo Curator User Guide](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/gpudeduplication.html#id4). It is recommended to use the Python API directly rather than the CLI scripts for most cases.
5 changes: 4 additions & 1 deletion nemo_curator/scripts/fuzzy_deduplication/buckets_to_edges.py
Original file line number Diff line number Diff line change
@@ -68,6 +68,7 @@ def main(args):
OUTPUT_PATH = args.output_dir

client = get_client(**ArgumentHelper.parse_client_args(args))

logger.info(f"Client Created {client}")
logger.info(f"Num Workers = {get_num_workers(client)}")
logger.info(
@@ -81,13 +82,15 @@ def main(args):
bucket_field=args.input_bucket_field,
logger=logger,
)

st = time.time()
buckets_df = DocumentDataset(
dask_cudf.read_parquet(input_bucket_path, split_row_groups=False)
)
_ = buckets_to_edges(buckets_df)

et = time.time()
logger.info(f"Bucket to Edges conversion took = {et-st} s")
logger.info(f"Bucket to edges conversion took {et-st} seconds")


def console_script():
12 changes: 9 additions & 3 deletions nemo_curator/scripts/fuzzy_deduplication/compute_minhashes.py
Original file line number Diff line number Diff line change
@@ -39,11 +39,11 @@ def main(args):
)
logger.info(f"Starting workflow with args:\n {args}")

assert args.hash_bytes in {4, 8}, "Currently only 32bit/64bit hashes are supported"
assert args.hash_bytes in {4, 8}, "Currently only 32bit/64bit hashes are supported."
assert args.device == "gpu"

client = get_client(**ArgumentHelper.parse_client_args(args))
logger.info(f"Client Created {client}")
logger.info(f"Client created {client}")
client.run(pre_imports)
logger.info("Pre imports complete")

@@ -65,14 +65,17 @@ def main(args):
t0 = time.time()
for data_path in data_paths:
print(f"Computing minhashes for {data_path}", flush=True)

data_path = strip_trailing_sep(data_path)

if num_files is not None and num_files <= 0:
print(f"Processed {args.num_files}... quitting")
break

files = get_all_files_paths_under(
root=data_path, recurse_subdirectories=False, keep_extensions="jsonl"
)

df = read_data(
files[:num_files] if num_files else files,
file_type="jsonl",
@@ -86,10 +89,12 @@ def main(args):
num_files -= len(files)

res = minhasher(DocumentDataset(df)).df

logger.info(
f"Lazy minhash generation complete for {res.npartitions} partitions"
)
logger.info(f"Starting execution for {data_path}")

write_path = os.path.join(
args.output_minhash_dir, os.path.basename(data_path), "minhashes.parquet"
)
@@ -99,8 +104,9 @@ def main(args):
args.profile_path, f"{os.path.basename(data_path)}-minhash-profile.html"
):
res.to_parquet(write_path, write_index=False)

logger.info(
f"Minhash computation for f{data_path} took {time.time() - t1}s complete at {write_path}" # noqa:E501
f"Minhash computation for {data_path} took {time.time() - t1}s complete at {write_path}" # noqa:E501
)
logger.info(
f"Minhash computation across datasets took {time.time() - t0}s complete at {args.output_minhash_dir}" # noqa:E501
Original file line number Diff line number Diff line change
@@ -44,6 +44,7 @@ def main(args):
profile_dir=args.profile_path,
)
components_stage.cc_workflow(output_path=output_path)

print(f"All done in {time.time()-st:.1f} seconds")
print(f"Results written to {output_path}")

13 changes: 9 additions & 4 deletions nemo_curator/scripts/fuzzy_deduplication/jaccard_compute.py
Original file line number Diff line number Diff line change
@@ -26,24 +26,28 @@ def main(args):
from a partitioned Parquet dataset. Result is a Parquet dataset consiting of
document and ID pairs, along with their Jaccard similarity scores.
"""

OUTPUT_PATH = args.output_dir
shuffled_docs_path = args.shuffled_docs_path
output_final_results_path = os.path.join(
OUTPUT_PATH, "jaccard_similarity_results.parquet"
)

args.enable_spilling = True
client = get_client(**ArgumentHelper.parse_client_args(args))

print(f"Num Workers = {get_num_workers(client)}", flush=True)
print("Connected to dask cluster", flush=True)
print("Running jaccard compute script", flush=True)
print(f"Number of workers: {get_num_workers(client)}", flush=True)
print("Connected to Dask cluster", flush=True)
print("Running Jaccard compute script", flush=True)
st = time.time()

jaccard = JaccardSimilarity(
id_field=args.input_json_id_field,
text_field=args.input_json_text_field,
anchor_id_fields=[f"anchor_{i}_{args.input_json_id_field}" for i in range(2)],
ngram_width=args.ngram_size,
)

# Run actual computation
result_df = jaccard.jaccard_compute(shuffled_docs_path)

@@ -52,7 +56,8 @@ def main(args):
write_index=False,
write_metadata_file=False,
)
print(f"Jaccard Computing+Writing time: {time.time() - st:.1f} seconds")

print(f"Jaccard computing and writing time: {time.time() - st:.1f}s")


def attach_args():
16 changes: 11 additions & 5 deletions nemo_curator/scripts/fuzzy_deduplication/jaccard_shuffle.py
Original file line number Diff line number Diff line change
@@ -38,11 +38,13 @@ def main(args):

client = get_client(**ArgumentHelper.parse_client_args(args))
client.run(func)
print(f"Num Workers = {get_num_workers(client)}", flush=True)
print("Connected to dask cluster", flush=True)
print("Running jaccard shuffle script", flush=True)

print(f"Number of workers: {get_num_workers(client)}", flush=True)
print("Connected to Dask cluster", flush=True)
print("Running Jaccard shuffle script", flush=True)
print(f"Args = {args}")
st = time.time()

text_ddf = get_text_ddf_from_json_path_with_blocksize(
input_data_paths=input_data_paths,
num_files=args.num_files,
@@ -51,17 +53,20 @@ def main(args):
text_column=args.input_json_text_field,
input_meta=args.input_meta,
)

print(
"Graph creation for get_text_ddf_from_json_path_with_blocksize complete.",
flush=True,
)
print(f"text_ddf.npartitions = {text_ddf.npartitions}", flush=True)
print(f"text_ddf.npartitions = {text_ddf.npartitions}", flush=True)

shuffle = _Shuffle(
id_fields=["dataset_id", "doc_id"],
text_field=args.input_json_text_field,
profile_dir=args.profile_path,
int_to_str_id=args.input_json_id_field,
)

shuffle.shuffle_docs_on_buckets(
documents_df=text_ddf,
bucket_w_anchors_path=input_anchor_docs_with_bk_dir,
@@ -71,8 +76,9 @@ def main(args):
bucket_parts_per_worker=args.bucket_parts_per_worker,
partition_on="_output_partition_id",
)

et = time.time()
print(f"Jaccard Shuffle E2E time taken = {et-st} s")
print(f"Jaccard shuffle E2E time taken: {et-st}s")


def attach_args():
32 changes: 23 additions & 9 deletions nemo_curator/scripts/fuzzy_deduplication/map_buckets.py
Original file line number Diff line number Diff line change
@@ -38,7 +38,8 @@ def get_anchor_and_output_map_info(
input_meta,
):
"""
Get anchor docs with bucket info
Get anchor documents with bucket information.
Args:
input_data_paths: list of paths to input data
input_bucket_path: path to input buckets
@@ -49,6 +50,7 @@ def get_anchor_and_output_map_info(
Returns:
ddf_anchor_docs_with_bk
"""

ddf_text = get_text_ddf_from_json_path_with_blocksize(
input_data_paths=input_data_paths,
num_files=num_files,
@@ -57,17 +59,21 @@ def get_anchor_and_output_map_info(
text_column=input_text_field,
input_meta=input_meta,
)

ddf_bk = get_bucket_ddf_from_parquet_path(
input_bucket_path=input_bucket_path, num_workers=num_workers
)

map_buckets = _MapBuckets(
id_fields=["dataset_id", "doc_id"],
bucket_field=input_bucket_field,
text_field=input_text_field,
)

ddf_anchor_docs_with_bk = map_buckets.map_buckets_with_anchors(
documents_df=ddf_text, buckets_df=ddf_bk, shuffle_type=shuffle_type
)

return ddf_anchor_docs_with_bk


@@ -123,18 +129,21 @@ def jaccard_get_output_map_workflow(
input_meta,
):
"""
Workflow for jaccard shuffle
Workflow for Jaccard shuffle.
Args:
client: dask client
client: Dask client
input_data_paths: list of paths to input data
input_bucket_path: path to input buckets
output_anchor_docs_with_bk_path: path to save anchor docs with bucket info
output_anchor_docs_with_bk_path: path to save anchor documents with bucket
information
text_ddf_blocksize: blocksize for text ddf
num_files: number of files to read
parts_per_worker: number of parts per worker
shuffle_type: type of shuffle to use before writing to parquet
"""

num_workers = get_num_workers(client)

ddf_anchor_docs_with_bk = get_anchor_and_output_map_info(
input_data_paths,
input_bucket_path,
@@ -147,6 +156,7 @@ def jaccard_get_output_map_workflow(
input_text_field,
input_meta=input_meta,
)

ddf_anchor_docs_with_bk.to_parquet(
output_anchor_docs_with_bk_path,
write_index=False,
@@ -160,12 +170,15 @@ def main(args):
output_anchor_docs_with_bk_path = os.path.join(
OUTPUT_PATH, "anchor_docs_with_bk.parquet"
)

client = get_client(**ArgumentHelper.parse_client_args(args))
print(f"Num Workers = {get_num_workers(client)}", flush=True)
print("Connected to dask cluster", flush=True)
print("Running jaccard map buckets script", flush=True)

print(f"Number of workers: {get_num_workers(client)}", flush=True)
print("Connected to Dask cluster", flush=True)
print("Running Jaccard map buckets script", flush=True)
print(f"Args = {args}")
st = time.time()

jaccard_get_output_map_workflow(
client,
input_data_paths,
@@ -179,8 +192,9 @@ def main(args):
args.input_json_text_field,
args.input_meta,
)

et = time.time()
print(f"Bucket Mapping time taken = {et-st} s")
print(f"Bucket mapping time taken: {et-st}s")


def console_script():
7 changes: 5 additions & 2 deletions nemo_curator/scripts/fuzzy_deduplication/minhash_lsh.py
Original file line number Diff line number Diff line change
@@ -40,7 +40,8 @@ def main(args):

assert args.device == "gpu"
client = get_client(**ArgumentHelper.parse_client_args(args))
logger.info(f"Client Created {client}")

logger.info(f"Client created {client}")
client.run(pre_imports)
logger.info("Pre imports complete")

@@ -53,8 +54,10 @@ def main(args):
dfs.append(
dask_cudf.read_parquet(data_path, blocksize="2GB", aggregate_files=True)
)

df = dask_cudf.concat(dfs, ignore_unknown_divisions=True)
df = df[~df[id_field].isna()]

df = df.map_partitions(
convert_str_id_to_int,
id_column=id_field,
@@ -77,7 +80,7 @@ def main(args):

t1 = time.time()
_ = lsh(DocumentDataset(df))
logger.info(f"Computing and writing buckets took {time.time() - t1} s")
logger.info(f"Computing and writing buckets took {time.time() - t1}s")


def attach_args():
15 changes: 10 additions & 5 deletions nemo_curator/scripts/semdedup/clustering.py
Original file line number Diff line number Diff line change
@@ -52,9 +52,6 @@ def main(args):
embedding_fp = os.path.join(
semdedup_config.cache_dir, semdedup_config.embeddings_save_loc
)
clustering_output_dir = os.path.join(
semdedup_config.cache_dir, semdedup_config.clustering_save_loc
)

# Switch to https://github.com/NVIDIA/NeMo-Curator/issues/50
# When we fix that
@@ -65,8 +62,17 @@ def main(args):
id_column=args.id_column,
max_iter=semdedup_config.max_iter,
n_clusters=semdedup_config.n_clusters,
clustering_output_dir=clustering_output_dir,
cache_dir=semdedup_config.cache_dir,
clustering_save_loc=semdedup_config.clustering_save_loc,
sim_metric=semdedup_config.sim_metric,
which_to_keep=semdedup_config.which_to_keep,
kmeans_with_cos_dist=semdedup_config.kmeans_with_cos_dist,
logger=logger,
# Hardcoded as recommended values
embedding_col="embeddings",
sort_clusters=True,
partition_size="2gb",
profile_dir=None,
)

clustered_embeddings = clustering_model(embedding_dataset)
@@ -93,7 +99,6 @@ def attach_args():
" cache_dir for the directory to store cache,"
" clustering_save_loc for the location to save clustering results,"
" n_clusters for the number of clusters,"
" seed for the seed for clustering,"
" max_iter for the maximum iterations for clustering,"
" kmeans_with_cos_dist for using K-Means with cosine distance."
),
12 changes: 8 additions & 4 deletions nemo_curator/scripts/semdedup/compute_embeddings.py
Original file line number Diff line number Diff line change
@@ -41,7 +41,7 @@ def main(args):
semdedup_config.cache_dir, semdedup_config.embeddings_save_loc
)

# Some time jsonl files are stored as .json
# Sometimes JSONL files are stored as .json
# So to handle that case we can pass the input_file_extension
if args.input_file_extension is not None:
input_file_extension = args.input_file_extension
@@ -76,13 +76,17 @@ def main(args):
embedding_creator = EmbeddingCreator(
embedding_model_name_or_path=semdedup_config.embedding_model_name_or_path,
embedding_batch_size=semdedup_config.embedding_batch_size,
embedding_output_dir=os.path.join(
semdedup_config.cache_dir, semdedup_config.embeddings_save_loc
),
cache_dir=semdedup_config.cache_dir,
embeddings_save_loc=semdedup_config.embeddings_save_loc,
input_column=args.input_text_field,
write_embeddings_to_disk=semdedup_config.write_embeddings_to_disk,
logger=logger,
# Hardcoded as recommended values
embedding_max_mem_gb=None,
embedding_column="embeddings",
write_embeddings_to_disk=True,
write_to_filename=True,
profile_dir=None,
)

embedding_dataset = embedding_creator(dataset=dataset)
15 changes: 7 additions & 8 deletions nemo_curator/scripts/semdedup/extract_dedup_data.py
100755 → 100644
Original file line number Diff line number Diff line change
@@ -28,21 +28,21 @@ def main(args):
dt1 = datetime.now()
logger.info(f"Start: {dt1}")
cache_dir = semdedup_config.cache_dir

semantic_dedup = SemanticClusterLevelDedup(
n_clusters=semdedup_config.n_clusters,
emb_by_clust_dir=os.path.join(
cache_dir, semdedup_config.clustering_save_loc, "embs_by_nearest_center"
),
sorted_clusters_dir=os.path.join(
cache_dir, semdedup_config.clustering_save_loc, "sorted"
),
id_column=args.id_column,
id_column_type=args.id_column_type,
which_to_keep=semdedup_config.which_to_keep,
cache_dir=semdedup_config.cache_dir,
clustering_save_loc=semdedup_config.clustering_save_loc,
logger=logger,
# Hardcoded as recommended values
output_dir=os.path.join(
semdedup_config.cache_dir, semdedup_config.clustering_save_loc
),
logger=logger,
embedding_col="embeddings",
profile_dir=None,
)

semantic_dedup.compute_semantic_match_dfs(semdedup_config.eps_thresholds)
@@ -72,7 +72,6 @@ def attach_args():
"Important configuration parameters include:"
" cache_dir for the directory to store cache"
" which_to_keep for specifying which duplicates to keep,"
" largest_cluster_size_to_process for the largest cluster size to process,"
" sim_metric for the similarity metric for deduplication,"
" eps_thresholds for epsilon thresholds to calculate if semantically similar or not"
" and eps_to_extract for the epsilon value to extract deduplicated data."
45 changes: 43 additions & 2 deletions tests/test_exact_dedup.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,12 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from hashlib import md5

import pandas as pd
import pytest
from dask import dataframe as dd
from dask.dataframe.utils import assert_eq

from nemo_curator.cache import Cache
from nemo_curator.datasets import DocumentDataset
from nemo_curator.modules import ExactDuplicates

@@ -40,20 +43,58 @@ def test_unsupported_hash(self):
with pytest.raises(ValueError):
ExactDuplicates(hash_method="sha256")

@pytest.mark.parametrize("cache_method", [None, "Cache", "ExactDuplicates"])
def test_exact_dedup_cache_method(self, exact_dedup_data, cache_method, tmpdir):

Cache().delete_cache_instance() # Fresh start for new PyTest
if cache_method == "Cache":
Cache(cache_dir=tmpdir)
cache_dir = None
elif cache_method == "ExactDuplicates":
cache_dir = tmpdir
else:
cache_dir = None

exact_dups = ExactDuplicates(
id_field="id",
text_field="text",
hash_method="md5",
cache_dir=cache_dir,
)

result = exact_dups(exact_dedup_data)
result = result.df.compute()
expected_df = exact_dedup_data.df.compute()
expected_df = expected_df[expected_df.text.duplicated(keep=False)]

assert_eq(result.id, expected_df.id, check_index=False)

# Check that the output is written when either:
# (1) Cache(cache_dir=...) is initialized, or
# (2) ExactDuplicates(cache_dir=...) is initialized.
# If there is no Cache and ExactDuplicates(cache_dir=None),
# then there should be no output file.
if cache_method in ["Cache", "ExactDuplicates"]:
assert os.path.exists(str(tmpdir / "_exact_duplicates.parquet"))
else:
assert not os.path.exists(str(tmpdir / "_exact_duplicates.parquet"))

@pytest.mark.parametrize("cache_result", [False, True])
def test_dup(self, exact_dedup_data, cache_result, tmpdir):
def test_exact_dedup(self, exact_dedup_data, cache_result, tmpdir):
exact_dups = ExactDuplicates(
id_field="id",
text_field="text",
hash_method="md5",
cache_dir=tmpdir if cache_result else None,
)

duplicates = exact_dups.identify_duplicates(exact_dedup_data)
deduplicated_ds = exact_dups.remove(exact_dedup_data, duplicates)
deduplicated_ids_series = deduplicated_ds.df.to_backend("pandas").compute()[
"id"
]
output_deduplicated_ids = set(deduplicated_ids_series.tolist())

assert (
len(output_deduplicated_ids) == 3
and 300 in output_deduplicated_ids
156 changes: 126 additions & 30 deletions tests/test_fuzzy_dedup.py

Large diffs are not rendered by default.

59 changes: 55 additions & 4 deletions tests/test_semdedup.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -22,6 +22,7 @@
from transformers import AutoConfig, AutoModel, AutoTokenizer

from nemo_curator import SemDedup, SemDedupConfig
from nemo_curator.cache import Cache
from nemo_curator.datasets import DocumentDataset
from nemo_curator.utils.import_utils import gpu_only_import, gpu_only_import_from

@@ -54,17 +55,25 @@ def dedup_data():

@pytest.mark.gpu
class TestSemDuplicates:
@pytest.mark.parametrize("cache_method", ["Cache", "SemDedupConfig"])
def test_sem_dedup(
self,
dedup_data,
tmpdir,
cache_method,
gpu_client,
):
print("client", gpu_client)
cache_dir = os.path.join(tmpdir, "test_sem_dedup_cache")

Cache().delete_cache_instance() # Fresh start for new PyTest
if cache_method == "Cache":
Cache(cache_dir=os.path.join(tmpdir, "test_sem_dedup_cache"))
cache_dir = None
else:
cache_dir = os.path.join(tmpdir, "test_sem_dedup_cache")

config = SemDedupConfig(
cache_dir=cache_dir,
seed=42,
n_clusters=3,
eps_thresholds=[0.10],
eps_to_extract=0.10,
@@ -81,28 +90,70 @@ def test_sem_dedup(
expected_df = cudf.Series(duplicate_docs, name="id")
assert_eq(result_df["id"].sort_values(), expected_df, check_index=False)

# Check that the output is written when either:
# (1) Cache(cache_dir=...) is initialized, or
# (2) SemDedupConfig(cache_dir=...) is initialized.
# Either way, their output files should be identical.
cache_dir = os.path.join(tmpdir, "test_sem_dedup_cache")

assert os.path.exists(cache_dir)
assert os.path.exists(cache_dir + "/embeddings/part.0.parquet")
assert os.path.exists(cache_dir + "/embeddings/part.1.parquet")
assert os.path.exists(cache_dir + "/clustering_results/dedup_summary_0.1.csv")
assert os.path.exists(cache_dir + "/clustering_results/kmeans_centroids.npy")
assert os.path.exists(cache_dir + "/clustering_results/sorted/cluster_0.npy")
assert os.path.exists(cache_dir + "/clustering_results/sorted/cluster_1.npy")
assert os.path.exists(cache_dir + "/clustering_results/sorted/cluster_2.npy")
assert os.path.exists(
cache_dir
+ "/clustering_results/embs_by_nearest_center/nearest_cent=0/part.0.parquet"
)
assert os.path.exists(
cache_dir
+ "/clustering_results/embs_by_nearest_center/nearest_cent=1/part.0.parquet"
)
assert os.path.exists(
cache_dir
+ "/clustering_results/embs_by_nearest_center/nearest_cent=2/part.0.parquet"
)
assert os.path.exists(
cache_dir + "/clustering_results/semdedup_pruning_tables/cluster_0.parquet"
)
assert os.path.exists(
cache_dir + "/clustering_results/semdedup_pruning_tables/cluster_1.parquet"
)
assert os.path.exists(
cache_dir + "/clustering_results/semdedup_pruning_tables/cluster_2.parquet"
)
assert os.path.exists(cache_dir + "/clustering_results/unique_ids_0.1.parquet")

@pytest.mark.parametrize("pooling_strategy", ["last_token", "mean_pooling"])
def test_embedding_creator_pooling_strategies(self, tmpdir, pooling_strategy):
test_text_1 = "The quick brown fox jumps over the lazy dog"
test_text_2 = "The brown fox jumps over the dog"
test_texts = [test_text_1, test_text_2] * 32
df = cudf.DataFrame({"text": test_texts})
ddf = dask_cudf.from_cudf(df, 1)

cache_dir = os.path.join(tmpdir, "test_embeddings_cache")

embedding_creator = EmbeddingCreator(
embedding_model_name_or_path="sentence-transformers/all-MiniLM-L6-v2",
embedding_batch_size=32,
cache_dir=cache_dir,
embeddings_save_loc="mean_embeddings",
embedding_pooling_strategy=pooling_strategy,
input_column="text",
embedding_output_dir=os.path.join(cache_dir, "mean_embeddings"),
)

embeddings = embedding_creator.create_embeddings(ddf).compute()
embeddings = embeddings["embeddings"].to_arrow().to_pylist()
embeddings = np.array(embeddings)

reference_embeddings = get_reference_embeddings(
test_texts, pooling_strategy=pooling_strategy
)

assert np.allclose(
embeddings, reference_embeddings, atol=1e-3
), "Embeddings should match reference embeddings"
9 changes: 4 additions & 5 deletions tutorials/image-curation/image-curation.ipynb
Original file line number Diff line number Diff line change
@@ -641,7 +641,8 @@
" embedding_col=\"image_embedding\",\n",
" max_iter=10,\n",
" n_clusters=1,\n",
" clustering_output_dir=clustering_output,\n",
" cache_dir=semantic_dedup_outputs,\n",
" clustering_save_loc=\"cluster_output\",\n",
")\n",
"clustered_dataset = clustering_model(embeddings_dataset)"
]
@@ -669,19 +670,17 @@
],
"source": [
"# Run cluster-level dedup\n",
"emb_by_cluster_output = os.path.join(clustering_output, \"embs_by_nearest_center\")\n",
"sorted_cluster_output = os.path.join(clustering_output, \"sorted\")\n",
"duplicate_output = os.path.join(semantic_dedup_outputs, \"duplicates\")\n",
"\n",
"semantic_dedup = SemanticClusterLevelDedup(\n",
" n_clusters=1,\n",
" emb_by_clust_dir=emb_by_cluster_output,\n",
" sorted_clusters_dir=sorted_cluster_output,\n",
" id_column=id_col,\n",
" id_column_type=\"str\",\n",
" embedding_col=\"image_embedding\",\n",
" which_to_keep=\"hard\",\n",
" output_dir=duplicate_output,\n",
" cache_dir=semantic_dedup_outputs,\n",
" clustering_save_loc=\"cluster_output\",\n",
")\n",
"semantic_dedup.compute_semantic_match_dfs([0.01, 0.001])\n",
"deduplicated_dataset_ids = semantic_dedup.extract_dedup_data(eps_to_extract=0.01)"
2 changes: 2 additions & 0 deletions tutorials/peft-curation-with-sdg/main.py
Original file line number Diff line number Diff line change
@@ -131,13 +131,15 @@ def semantic_dedupe(dataset):
os.path.join(CONFIG_DIR, "sem_dedup_config.yaml")
)
expand_outdir_and_mkdir(semdedup_config.cache_dir)

semdup = SemDedup(
config=semdedup_config,
input_column="text",
id_column="id",
id_column_type="str",
)
dedup_ids = semdup(dataset)

# When there are few duplicates we can compute the results to a list and use `isin`.
result = dataset.df[dataset.df["id"].isin(dedup_ids.df["id"].compute())]
return DocumentDataset(result)
Original file line number Diff line number Diff line change
@@ -806,7 +806,7 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": null,
"metadata": {},
"outputs": [
{
@@ -832,7 +832,7 @@
" id_field=exact_dedup_dataset_id_field,\n",
" text_field=exact_dedup_dataset_text_field,\n",
" hash_method=\"md5\",\n",
" cache_dir=exact_dedup_output_dir\n",
" cache_dir=exact_dedup_output_dir,\n",
")\n",
"duplicates = exact_dup(dataset=input_dataset)\n",
"\n",
134 changes: 63 additions & 71 deletions tutorials/single_node_tutorial/single_gpu_tutorial.ipynb
Original file line number Diff line number Diff line change
@@ -768,15 +768,15 @@
"id": "1baf027e",
"metadata": {},
"source": [
"## 4.Exact Deduplication\n",
"## 4. Exact Deduplication\n",
"\n",
"In exact deduplication, the document text is hashed into unique string using certain hashing algorithm, such as 'md5'. The documents with exact hashed values are having identical text. We will output the `ID` of duplicated documents for removal later. The function used is `ExactDuplicates()`. Arguments for this function include:\n",
"- `id_field`: Key in input file for identifying document ID\n",
"- `text_field`: Key in input file which contains document text.\n",
"- `hash_method`: Hashing algorithm used. Default is `md5`\n",
"- `cache_dir`: If specified, the duplicated document IDs will be output to the `cache_dir`. Otherwise, the IDs will not be saved\n",
"In exact deduplication, the document text is hashed into a unique string by using a hashing algorithm such as md5. The documents with exact hashed values are identified as having identical text. We will output the ID of duplicated documents for removal later. The class used for exact deduplication in NeMo Curator is called `ExactDuplicates`. Fields for this class include:\n",
"- `id_field`: Column in input file which contains a unique ID.\n",
"- `text_field`: Column in input file which contains document text.\n",
"- `hash_method`: Hashing algorithm used. Default is \"md5\".\n",
"- `cache_dir`: If specified via `ExactDuplicates(cache_dir=...)` or `Cache(cache_dir=...)`, the duplicated document IDs will be output to the cache directory. Otherwise, the IDs will not be saved.\n",
"\n",
"Also, we are going to use GPU dask cluster to accelerate computation for deduplication (both exact and fuzzy)\n"
"We are going to use a GPU-based Dask cluster to accelerate computation for deduplication (both exact and fuzzy deduplication).\n"
]
},
{
@@ -901,19 +901,19 @@
"# Read input dataset\n",
"input_dataset = DocumentDataset.read_json(exact_dedup_input_dataset_dir, backend='cudf')\n",
"\n",
"#Run exact deduplication to the input\n",
"# Run exact deduplication to the input\n",
"exact_dup = ExactDuplicates(\n",
" logger=exact_dedup_log_dir,\n",
" id_field=exact_dedup_dataset_id_field,\n",
" text_field=exact_dedup_dataset_text_field,\n",
" hash_method=\"md5\",\n",
" cache_dir=exact_dedup_output_dir #Duplicated document ID list is output to the cache_dir\n",
" cache_dir=exact_dedup_output_dir, # Duplicated document ID list is output to the cache_dir\n",
")\n",
"duplicates = exact_dup(dataset=input_dataset)\n",
"\n",
"print(f\"Number of exact duplicated file:{len(duplicates)}\")\n",
"print(f\"Number of exact duplicated files: {len(duplicates)}\")\n",
"\n",
"print(f\"Time taken for exact duplicate:{time.time()-t0}\")"
"print(f\"Time taken for exact deduplication: {time.time()-t0}s\")"
]
},
{
@@ -1248,23 +1248,21 @@
"id": "5df73743",
"metadata": {},
"source": [
"### 5.2 Minhash\n",
"### 5.2 MinHash\n",
"\n",
"Run `MinHash()` for this section. The output of a minhash is a parquet file which contains document ID and hashed value which is an array contains 260 32-bit integer data. To obtain such hashed values we need to go through the following steps:\n",
"1. Generate a set of n-gram components of a document. For example, doc = `Nemo Curator is a data curation tool`, a 3-gram set of this document will be `['Nemo Curator is','Curator is a','is a data','a data curation','data curation tool']`\n",
"2. Hashed each n-gram into numerical values\n",
"2. Hash each n-gram into numerical values\n",
"3. Generate a random hash function $H_1()$ which will hash each numeric n-gram into a 32-bit integer and take the minimum integer to use as minhash value for $H_1()$\n",
"4. Repeat step 2 and 3 with hash function $H_x()$ until desired minhash length is reached. Minhash value of each iteration will be append together to form the final minhash array. \n",
"4. Repeat step 2 and 3 with hash function $H_x()$ until the desired minhash length is reached. The minhash value of each iteration will be append together to form the final minhash array. \n",
"\n",
"Arguments include:\n",
"- `seed`:Random seed used for initializing the hash functions used to compute the MinHashes. It's advised to keep this value the same for different experiment for reproducibility\n",
"- `num_hashes`:Length of each minhash array. Default is 260. Longer minhash length will have better estimate of actual Jaccard similarity, but require more computational power\n",
"- `char_ngrams`:n-gram length. Assuming an average of 4.5 chars per word it's recommended to use `char_ngrams>=24` to use ~5 word ngrams or greater.\n",
"- `use_64bit_hash`:Whether to use 64bit or 32bit hash function\n",
"- `id_field`: Key in input file for identifying document ID\n",
"- `text_field`: Key in input file which contains document text.\n",
"- `cache_dir`: If specified, the intermediate result will be output to the `cache_dir`. \n",
"\n"
"- `seed`: Random seed used for initializing the hash functions used to compute the minhashes. It is advised to keep this value the same for different experiments for reproducibility.\n",
"- `num_hashes`: Length of each minhash array. Default is 260. A longer minhash length will have better estimate of actual Jaccard similarity, but require more computational power.\n",
"- `char_ngrams`: n-gram length. Assuming an average of 4.5 characters per word, it is recommended to use `char_ngrams>=24` to use ~5 word n-grams or greater.\n",
"- `use_64bit_hash`: Whether to use 64-bit or 32-bit hash function.\n",
"- `id_field`: Column in input file which contains a unique ID.\n",
"- `text_field`: Column in input file which contains document text.\n",
"- `cache_dir`: If specified via `MinHash(cache_dir=...)` or `Cache(cache_dir=...)`, the intermediate result will be output to the cache directory."
]
},
{
@@ -1296,18 +1294,12 @@
},
"outputs": [],
"source": [
"#Input\n",
"# Input\n",
"minhash_data_path = added_id_output_path\n",
"#Output\n",
"minhash_base_output_path = os.path.join(data_dir,\"fuzzy/minhash\")\n",
"minhash_log_dir = os.path.join(minhash_base_output_path,'log')\n",
"minhash_output_dir = os.path.join(minhash_base_output_path,'data')\n",
"#Specify dataset name\n",
"dataset_name = 'TH_wikipedia'\n",
"\n",
"#Relevant parameters\n",
"minhash_id_field = 'id'\n",
"minhash_text_field = 'text'\n",
"# Relevant parameters\n",
"minhash_id_field = \"id\"\n",
"minhash_text_field = \"text\"\n",
"seed = 10 # Using the same value as the wrapper above for consistency\n",
"minhash_length = 260\n",
"char_ngram = 24\n",
@@ -1338,7 +1330,8 @@
"t0 = time.time()\n",
"print(f\"Computing minhashes for {minhash_data_path}\")\n",
"\n",
"# Load data. Only the [minhash_id_field, text_field] columns are needed\n",
"# Load data\n",
"# Only the [minhash_id_field, minhash_text_field] columns are needed\n",
"files = get_all_files_paths_under(\n",
" root=minhash_data_path, recurse_subdirectories=False, keep_extensions=\"jsonl\"\n",
")\n",
@@ -1350,7 +1343,7 @@
" add_filename=False,\n",
")[[minhash_id_field, minhash_text_field]]\n",
"\n",
"# Run MinHash() on input data\n",
"# Run MinHash on input data\n",
"minhasher = MinHash(\n",
" seed=seed,\n",
" num_hashes=minhash_length,\n",
@@ -1359,11 +1352,11 @@
" logger=minhash_log_dir,\n",
" id_field=minhash_id_field,\n",
" text_field=minhash_text_field,\n",
" cache_dir=minhash_output_dir\n",
" cache_dir=minhash_output_dir,\n",
")\n",
"res = minhasher(DocumentDataset(df)).df\n",
"\n",
"print(f\"Time taken for MinHash:{time.time()-t0}\")"
"print(f\"Time taken for MinHash: {time.time()-t0}\")"
]
},
{
@@ -1393,19 +1386,18 @@
"metadata": {},
"source": [
"### 5.3 LSH\n",
"`LSH()` implements LSH algorithm which includes the following steps:\n",
"`LSH` implements Locality Sensitive Hashing algorithm which includes the following steps:\n",
"1. Divide the minhash array into `X` different portions. \n",
"2. For each portions, hash the minhash values into buckets. One document will be assigned to `X` buckets.\n",
"3. Documents within the same bucket will be deemed similar. Since every document will be assigned `X` buckets and as long as two documents share 1 or more buckets they are deemed similar.\n",
"\n",
"Arguments include:\n",
"- `minhash_length`:Length of minhash signature. Must be consistent with `MinHash()`\n",
"- `num_buckets`: Number of buckets\n",
"- `buckets_per_shuffle`: Number of buckets to shuffle concurrently\n",
"- `id_field`: Key in input file for identifying document ID\n",
"- `minhash_field`: Key in input file for identifying document MinHash signature \n",
"- `cache_dir`:If specified, the intermediate result will be output to the `cache_dir`.\n",
"\n"
"- `minhash_length`: Length of minhash signature. Must be consistent with `MinHash`.\n",
"- `num_buckets`: Number of buckets.\n",
"- `buckets_per_shuffle`: Number of buckets to shuffle concurrently.\n",
"- `id_field`: Column in input file which contains a unique ID.\n",
"- `minhash_field`: Column in input file for identifying a document's MinHash signature.\n",
"- `cache_dir`: If specified via `LSH(cache_dir=...)` or `Cache(cache_dir=...)`, the intermediate result will be output to the cache directory.\""
]
},
{
@@ -1437,20 +1429,20 @@
},
"outputs": [],
"source": [
"#Input\n",
"# Input\n",
"lsh_input_data_path = minhash_output_dir\n",
"\n",
"#Output\n",
"lsh_base_output_path = os.path.join(data_dir,\"fuzzy/lsh\")\n",
"lsh_log_dir = os.path.join(lsh_base_output_path,'log')\n",
"lsh_output_dir = os.path.join(lsh_base_output_path,'data')\n",
"# Output\n",
"lsh_base_output_path = os.path.join(data_dir, \"fuzzy/lsh\")\n",
"lsh_log_dir = os.path.join(lsh_base_output_path, \"log\")\n",
"lsh_output_dir = os.path.join(lsh_base_output_path, \"data\")\n",
"\n",
"#Relevant parameters\n",
"lsh_id_field = 'id'\n",
"minhash_field = '_minhash_signature'\n",
"minhash_length=260\n",
"num_bands=20\n",
"buckets_per_shuffle=1\n",
"# Relevant parameters\n",
"lsh_id_field = \"id\"\n",
"minhash_field = \"_minhash_signature\"\n",
"minhash_length = 260\n",
"num_bands = 20\n",
"buckets_per_shuffle = 1\n",
"\n",
"!mkdir -p {lsh_log_dir}\n",
"!mkdir -p {lsh_output_dir}"
@@ -1475,10 +1467,10 @@
"source": [
"t0 = time.time()\n",
"\n",
"#Load MinHash output\n",
"# Load MinHash output\n",
"df = dask_cudf.read_parquet(lsh_input_data_path, blocksize=\"2GB\", aggregate_files=True, backend = \"cudf\")\n",
"\n",
"#Run LSH()\n",
"# Run LSH\n",
"lsh = LSH(\n",
" cache_dir=lsh_output_dir,\n",
" num_hashes=minhash_length,\n",
@@ -1491,7 +1483,7 @@
"res = lsh(DocumentDataset(df))\n",
"\n",
"t1 = time.time()\n",
"print(f\"Time taken for LSH:{time.time()-t0}\")"
"print(f\"Time taken for LSH: {time.time()-t0}s\")"
]
},
{
@@ -1631,13 +1623,13 @@
"metadata": {},
"source": [
"### 5.5 Connected Components\n",
"This section uses `ConnectedComponents()`.This section takes a dataset consisting of document pairs and their corresponding jaccard similarity to construct a non-directed graph. A edge will be formed between documents whose Jaccard similarity is higher than the threshold. It will then identify the connected components in this graph. Documents within the same connected components are deemed duplicated.\n",
"This section uses the `ConnectedComponents` class. This section takes a dataset consisting of document pairs and their corresponding Jaccard similarity scores to construct a non-directed graph. A edge will be formed between documents whose Jaccard similarity is higher than a given threshold (0.8 in this example). It will then identify the connected components in this graph. Documents within the same connected components are deemed duplicates.\n",
"\n",
"Arguments include:\n",
"- `cache_dir`: Output path for intermediate results\n",
"- `jaccard_pairs_path`: Input path for `jaccard_similarity_results.parquet`\n",
"- `id_column`: prefix of ID column in `jaccard_similarity_results.parquet`\n",
"- `jaccard_threshold`: Threshold to determine if an edge exists between two documents"
"- `cache_dir`: If specified via `ConnectedComponents(cache_dir=...)` or `Cache(cache_dir=...)`, the intermediate results will be output to the cache directory.\n",
"- `jaccard_pairs_path`: Input path for `jaccard_similarity_results.parquet`.\n",
"- `id_column`: Prefix of ID column in `jaccard_similarity_results.parquet`.\n",
"- `jaccard_threshold`: Threshold to determine if an edge exists between two documents."
]
},
{
@@ -1669,16 +1661,16 @@
},
"outputs": [],
"source": [
"#Input\n",
"# Input\n",
"jaccard_pairs_path = edgelist_output_dir\n",
"\n",
"#Output\n",
"connected_component_base_output_path = os.path.join(data_dir,\"fuzzy/cc\")\n",
"# Output\n",
"connected_component_base_output_path = os.path.join(data_dir, \"fuzzy/cc\")\n",
"connected_component_output_path = os.path.join(connected_component_base_output_path, \"connected_components.parquet\")\n",
"connected_component_cache_dir = os.path.join(connected_component_base_output_path, \"cache\")\n",
"\n",
"#Relevant parameters\n",
"input_id_field = 'id'\n",
"# Relevant parameters\n",
"input_id_field = \"id\"\n",
"\n",
"!mkdir -p {connected_component_base_output_path}"
]
@@ -1708,9 +1700,9 @@
" id_column=input_id_field,\n",
")\n",
"\n",
"#Load and run connected component\n",
"# Load and run connected components\n",
"components_stage.cc_workflow(output_path=connected_component_output_path)\n",
"print(f\"Time taken for Connected Component: {time.time()-t0} s\")"
"print(f\"Time taken for Connected Components: {time.time()-t0}s\")"
]
},
{
Original file line number Diff line number Diff line change
@@ -39,7 +39,6 @@
cache_dir=buckets_to_edges_out,
id_fields=["dataset_id", "doc_id"],
)

ddf_b2e = buckets_to_edges(DocumentDataset(ddf_bk))

logging.info(f"Time taken for Buckets to Edges: {time.time() - t0} s")