Skip to content

Commit

Permalink
Fix overlapping IDs being written when ID map writes to multiple files (
Browse files Browse the repository at this point in the history
#1178)

*Issue #, if available:*

*Description of changes:*

* Fixes writing behavior and adds corresponding unit test
* GenAI used to extend new unit test

Fixes #1177 

By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.
  • Loading branch information
thvasilo authored Feb 24, 2025
1 parent 6758c8d commit 21ae728
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 32 deletions.
7 changes: 7 additions & 0 deletions python/graphstorm/data/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,10 @@

REGRESSION_TASK = "regression"
CLASSIFICATION_TASK = "classification"

# Names to give ID mapping columns
MAPPING_INPUT_ID = "orig"
MAPPING_OUTPUT_ID = "new"
# The names that GSProcessing uses
GSP_MAPPING_INPUT_ID = "node_str_id"
GSP_MAPPING_OUTPUT_ID = "node_int_id"
41 changes: 25 additions & 16 deletions python/graphstorm/gconstruct/id_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
See the License for the specific language governing permissions and
limitations under the License.
Generate example graph data using built-in datasets for node classifcation,
Generate example graph data using built-in datasets for node classification,
node regression, edge classification and edge regression.
"""
import os
Expand All @@ -22,6 +22,12 @@
import pyarrow.parquet as pq
import numpy as np

from graphstorm.data.constants import (
GSP_MAPPING_INPUT_ID,
GSP_MAPPING_OUTPUT_ID,
MAPPING_INPUT_ID,
MAPPING_OUTPUT_ID,
)
from .file_io import read_data_parquet
from .utils import ExtMemArrayWrapper

Expand Down Expand Up @@ -82,17 +88,17 @@ def __init__(self, id_map_prefix):
assert os.path.exists(id_map_prefix), \
f"{id_map_prefix} does not exist."
try:
data = read_data_parquet(id_map_prefix, ["orig", "new"])
data = read_data_parquet(id_map_prefix, [MAPPING_INPUT_ID, MAPPING_OUTPUT_ID])
except AssertionError:
# To maintain backwards compatibility with GraphStorm v0.2.1
data = read_data_parquet(id_map_prefix, ["node_str_id", "node_int_id"])
data["new"] = data["node_int_id"]
data["orig"] = data["node_str_id"]
data.pop("node_int_id")
data.pop("node_str_id")
data = read_data_parquet(id_map_prefix, [GSP_MAPPING_INPUT_ID, GSP_MAPPING_OUTPUT_ID])
data[MAPPING_OUTPUT_ID] = data[GSP_MAPPING_OUTPUT_ID]
data[MAPPING_INPUT_ID] = data[GSP_MAPPING_INPUT_ID]
data.pop(GSP_MAPPING_INPUT_ID)
data.pop(GSP_MAPPING_OUTPUT_ID)

sort_idx = np.argsort(data['new'])
self._ids = data['orig'][sort_idx]
sort_idx = np.argsort(data[MAPPING_OUTPUT_ID])
self._ids = data[MAPPING_INPUT_ID][sort_idx]

def __len__(self):
return len(self._ids)
Expand Down Expand Up @@ -226,18 +232,21 @@ def save(self, file_prefix):
"""
os.makedirs(file_prefix, exist_ok=True)
table = pa.Table.from_arrays([pa.array(self._ids.keys()), self._ids.values()],
names=["orig", "new"])
names=[MAPPING_INPUT_ID, MAPPING_OUTPUT_ID])
bytes_per_row = table.nbytes // table.num_rows
# Split table in parts, such that the max expected file size is ~1GB
max_rows_per_file = GIB_BYTES // bytes_per_row
rows_written = 0
total_rows_written = 0
file_idx = 0
while rows_written < table.num_rows:
start = rows_written
end = min(rows_written + max_rows_per_file, table.num_rows)
while total_rows_written < table.num_rows:
start = total_rows_written
filename = f"part-{str(file_idx).zfill(5)}.parquet"
pq.write_table(table.slice(start, end), os.path.join(file_prefix, filename))
rows_written = end

pq.write_table(
table.slice(offset=start, length=max_rows_per_file),
os.path.join(file_prefix, filename)
)
total_rows_written = min(start + max_rows_per_file, table.num_rows)
file_idx += 1

def map_node_ids(src_ids, dst_ids, edge_type, node_id_map, skip_nonexist_edges):
Expand Down
91 changes: 75 additions & 16 deletions tests/unit-tests/gconstruct/test_remap_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,26 +13,31 @@
See the License for the specific language governing permissions and
limitations under the License.
"""
import glob
import os
import argparse
import math
import tempfile
import pytest
from functools import partial
from pathlib import Path

import pytest
import pyarrow as pa
import pandas as pd
import torch as th
import numpy as np
from pyarrow import parquet as pq
from numpy.testing import assert_equal, assert_almost_equal

from graphstorm.config import GSConfig
from graphstorm.config.config import get_mttask_id
from graphstorm.config import (BUILTIN_TASK_NODE_CLASSIFICATION,
BUILTIN_TASK_NODE_REGRESSION,
BUILTIN_TASK_EDGE_CLASSIFICATION,
BUILTIN_TASK_EDGE_REGRESSION,
BUILTIN_TASK_LINK_PREDICTION,
BUILTIN_TASK_RECONSTRUCT_NODE_FEAT)
from graphstorm.config import (
BUILTIN_TASK_NODE_CLASSIFICATION,
BUILTIN_TASK_EDGE_CLASSIFICATION,
)
from graphstorm.data.constants import (
MAPPING_INPUT_ID,
MAPPING_OUTPUT_ID,
)
from graphstorm.gconstruct import remap_result
from graphstorm.gconstruct.file_io import read_data_parquet
from graphstorm.gconstruct.id_map import IdMap, IdReverseMap
Expand Down Expand Up @@ -502,7 +507,6 @@ def test_parse_config():
assert len(pred_etypes) == 2
assert pred_etypes[0] == ['n0', 'r0', 'r1']
assert pred_etypes[1] == ['n0', 'r0', 'r2']
print(task_emb_dirs)
assert len(task_emb_dirs) == 1
assert task_emb_dirs[0] == get_mttask_id(
task_type="link_prediction",
Expand Down Expand Up @@ -535,11 +539,66 @@ def test_parse_config():
assert predict_dir is None
assert emb_dir is None

if __name__ == '__main__':
test_parse_config()
@pytest.mark.parametrize("num_rows", [1000, 10001])
def test_idmap_save_no_duplicates(num_rows, monkeypatch):
# Mock GIB_BYTES to 1kib to force multiple partitions
mock_gib_bytes = 1024
monkeypatch.setattr("graphstorm.gconstruct.id_map.GIB_BYTES", mock_gib_bytes)

test_write_data_csv_file()
test_write_data_parquet_file()
test__get_file_range()
test_worker_remap_edge_pred()
test_worker_remap_node_data("pred")
with tempfile.TemporaryDirectory() as tmpdirname:
# Create a table with known IDs to make verification easier
# string ids will be sequential id_<idx>
string_ids = [f"id_{i}" for i in range(num_rows)]

# Estimate the expected number of files
table = pa.Table.from_arrays([pa.array(string_ids), list(range(num_rows))],
names=[MAPPING_INPUT_ID, MAPPING_OUTPUT_ID])
bytes_per_row = table.nbytes // table.num_rows
max_rows_per_file = mock_gib_bytes // bytes_per_row
expected_num_files = math.ceil(table.num_rows / max_rows_per_file)

# Create the IdMap
id_map = IdMap(np.array(string_ids))

# Save the id map - will create multiple partitions due to small GIB_BYTES
mapping_prefix = os.path.join(tmpdirname, "raw_id_mappings")
id_map.save(mapping_prefix)

# Read back all partition files and check for duplicates
all_rows = []
file_idx = 0
for filename in glob.glob(os.path.join(mapping_prefix, "*.parquet")):
# Read this partition
table = pq.read_table(filename)
partition_rows = table.to_pydict()
all_rows.extend(zip(
partition_rows[MAPPING_INPUT_ID],
partition_rows[MAPPING_OUTPUT_ID])
)
file_idx += 1

# There should be multiple partition files due to small GIB_BYTES
assert file_idx == expected_num_files, \
f"Expected {expected_num_files} partition files, got {file_idx}"

# Check total number of rows
assert len(all_rows) == num_rows, \
f"Expected {num_rows} total rows, got {len(all_rows)}"

# Check for duplicates
seen_ids = set()
duplicate_ids = set()
for orig_id, _ in all_rows:
if orig_id in seen_ids:
duplicate_ids.add(orig_id)
seen_ids.add(orig_id)

assert len(duplicate_ids) == 0, \
f"Found duplicate IDs across partitions: {duplicate_ids}"

# Verify correct mapping is maintained
# We expect mapped ids to be: id_<idx> -> <idx>
for orig_id, mapped_id in all_rows:
expected_mapped_id = int(orig_id.split('_')[1])
assert mapped_id == expected_mapped_id, \
f"Incorrect mapping for {orig_id}: expected {expected_mapped_id}, got {mapped_id}"

0 comments on commit 21ae728

Please sign in to comment.