Skip to content

Commit

Permalink
[GSProcessing] Make label transformation use Spark for output instead…
Browse files Browse the repository at this point in the history
… of Pyarrow when using custom masks. (#1150)

*Issue #, if available:*

*Description of changes:*

* Previously, during classification labels would be collected to the
driver and written using pyarrow regardless of the type of mask that
will be generated.
* To improve scaling, since we observed that the collect is not
necessary when using custom masks, we revert to writing the label files
using Spark, avoiding collecting to the driver and improving
scalability.

By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.

---------

Co-authored-by: jalencato <[email protected]>
  • Loading branch information
thvasilo and jalencato authored Jan 30, 2025
1 parent 721e561 commit f03a0b8
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1342,14 +1342,18 @@ def _process_node_labels(
order_col = None

self.graph_info["is_multilabel"] = label_conf.multilabel

# Create loader object. If doing classification order_col!=None will enforce re-order
node_label_loader = DistLabelLoader(label_conf, self.spark, order_col)

logging.info(
"Processing label data for node type %s, label col: %s...",
node_type,
label_conf.label_column,
)

transformed_label = node_label_loader.process_label(nodes_df)

self.graph_info["label_map"] = node_label_loader.label_map

label_output_path = (
Expand All @@ -1361,37 +1365,45 @@ def _process_node_labels(
# The presence of order_col ensures transformed_label DF comes in ordered
# but do we want to double-check before writing?
# Get number of original partitions
input_num_parts = nodes_df.rdd.getNumPartitions()
# If num parts is different for original and transformed, log a warning
transformed_num_parts = transformed_label.rdd.getNumPartitions()
if input_num_parts != transformed_num_parts:
logging.warning(
"Number of partitions for original (%d) and transformed label data "
"(%d) differ. This may cause issues with the label split files.",
input_num_parts,
transformed_num_parts,
)
# For classification we need to order the DF, collect to Pandas
# and write to storage directly
logging.info(
"Collecting label data for node type '%s', label col: '%s' to leader...",
node_type,
label_conf.label_column,
)
transformed_label_pd = transformed_label.select(
label_conf.label_column, order_col
).toPandas()

# Write to parquet using zero-copy column values from Pandas DF
path_list = self._write_pyarrow_table(
pa.Table.from_arrays(
[transformed_label_pd[label_conf.label_column].values],
names=[label_conf.label_column],
),
label_output_path,
num_files=input_num_parts,
)
if label_conf.custom_split_filenames:
# When using custom splits we can rely on order being preserved by Spark
path_list = self._write_df(
transformed_label.select(label_conf.label_column), label_output_path
)
else:
input_num_parts = nodes_df.rdd.getNumPartitions()
# If num parts is different for original and transformed, log a warning
transformed_num_parts = transformed_label.rdd.getNumPartitions()
if input_num_parts != transformed_num_parts:
logging.warning(
"Number of partitions for original (%d) and transformed label data "
"(%d) differ. This may cause issues with the label split files.",
input_num_parts,
transformed_num_parts,
)
# For random splits we need to collect the ordered DF to Pandas
# and write to storage directly
logging.info(
"Collecting label data for node type '%s', label col: '%s' to leader...",
node_type,
label_conf.label_column,
)
transformed_label_pd = transformed_label.select(
label_conf.label_column, order_col
).toPandas()

# Write to parquet using zero-copy column values from Pandas DF
path_list = self._write_pyarrow_table(
pa.Table.from_arrays(
[transformed_label_pd[label_conf.label_column].values],
names=[label_conf.label_column],
),
label_output_path,
num_files=input_num_parts,
)
else:
# Regression and LP tasks will preserve input order, no need to re-order
path_list = self._write_df(
transformed_label.select(label_conf.label_column), label_output_path
)
Expand Down
13 changes: 5 additions & 8 deletions graphstorm-processing/graphstorm_processing/repartition_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -916,9 +916,8 @@ def modify_flat_array_metadata(
repartitioner.modify_metadata_for_flat_arrays(edge_mask_files)

if node_data_meta:
node_types_with_labels = graph_info["ntype_label"] # type: List[str]
if node_types_with_labels:
ntype_label_property = graph_info["ntype_label_property"][0] # type: str
node_types_with_labels: list[str] = graph_info["ntype_label"]

for type_idx, (type_name, ntype_data_dict) in enumerate(node_data_meta.items()):
logging.info(
"Modifying Parquet metadata for node type '%s', %d/%d:",
Expand All @@ -927,7 +926,8 @@ def modify_flat_array_metadata(
len(node_data_meta),
)
if type_name in node_types_with_labels:
node_label_files = ntype_data_dict[ntype_label_property]["data"] # type: List[str]
ntype_label_property: str = graph_info["ntype_label_property"][0]
node_label_files: list[str] = ntype_data_dict[ntype_label_property]["data"]
logging.info(
"Modifying Parquet metadata for %d files of label '%s' of node type '%s'",
len(node_label_files),
Expand Down Expand Up @@ -1196,14 +1196,11 @@ def main():

if repartition_config.input_prefix.startswith("s3://"):
filesystem_type = FilesystemType.S3
input_prefix = s3_utils.s3_path_remove_trailing(repartition_config.input_prefix)
else:
input_prefix = str(Path(repartition_config.input_prefix).resolve(strict=True))
filesystem_type = FilesystemType.LOCAL

# Trim trailing '/' from S3 URI
if filesystem_type == FilesystemType.S3:
input_prefix = s3_utils.s3_path_remove_trailing(repartition_config.input_prefix)

logging.info(
"Re-partitioning files under %s to ensure all files that belong to the same "
"edge/node type have the same number of rows per part-file.",
Expand Down

0 comments on commit f03a0b8

Please sign in to comment.