Skip to content

Commit

Permalink
Add Shrinkage Loss to handle regression label imbalance problem. (#1157)
Browse files Browse the repository at this point in the history
*Issue #, if available:*

*Description of changes:*
This PR follows the paper:
https://openaccess.thecvf.com/content_ECCV_2018/html/Xiankai_Lu_Deep_Regression_Tracking_ECCV_2018_paper.html
to implement the Shrinkage Loss for imbalanced regression tasks.


By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.

---------

Co-authored-by: Xiang Song <[email protected]>
Co-authored-by: Jian Zhang (James) <[email protected]>
  • Loading branch information
3 people authored Feb 8, 2025
1 parent deac15b commit 72f11f3
Show file tree
Hide file tree
Showing 16 changed files with 294 additions and 18 deletions.
1 change: 1 addition & 0 deletions .github/workflow_scripts/e2e_check.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ sh ./tests/end2end-tests/data_process/homogeneous_test.sh
bash ./tests/end2end-tests/data_process/gpartition_test.sh
sh ./tests/end2end-tests/custom-gnn/run_test.sh
bash ./tests/end2end-tests/graphstorm-nc/test.sh
bash ./tests/end2end-tests/graphstorm-nr/test.sh
bash ./tests/end2end-tests/graphstorm-lp/test.sh
bash ./tests/end2end-tests/graphstorm-ec/test.sh
bash ./tests/end2end-tests/graphstorm-er/test.sh
1 change: 1 addition & 0 deletions docs/source/api/references/graphstorm.model.rst
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ Loss Function
ClassifyLossFunc
FocalLossFunc
RegressionLossFunc
ShrinkageLossFunc
LinkPredictBCELossFunc
WeightedLinkPredictBCELossFunc
LinkPredictAdvBCELossFunc
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -528,12 +528,12 @@ Link Prediction Task
- Yaml: ``lp_decoder_type: dot_product``
- Argument: ``--lp-decoder-type dot_product``
- Default value: ``distmult``
- **gamma**: Set the value of the hyperparameter denoted by the symbol gamma. Gamma is used in the following cases: i/ focal loss for binary classification ii/ DistMult score function for link prediction, iii/ TransE score function for link prediction, and iv/ RotatE score function for link prediction.
- **gamma**: Set the value of the hyperparameter denoted by the symbol gamma. Gamma is used in the following cases: i/ focal loss for binary classification ii/ DistMult score function for link prediction, iii/ TransE score function for link prediction, iv/ RotatE score function for link prediction, v/ shrinkage loss for regression.

- Yaml: ``gamma: 10.0``
- Argument: ``--gamma 10.0``
- Default value: None
- **alpha**: Set the value of the hyperparameter denoted by the symbol alpha. Alpha is used in focal loss for binary classification.
- **alpha**: Set the value of the hyperparameter denoted by the symbol alpha. Alpha is used in the following cases: i/ focal loss for binary classification and ii/ shrinkage loss for regression.

- Yaml: ``alpha: 10.0``
- Argument: ``--alpha 10.0``
Expand All @@ -543,6 +543,11 @@ Link Prediction Task
- Yaml: ``class_loss_func: cross_entropy``
- Argument: ``--class-loss-func focal``
- Default value: ``cross_entropy``
- **regression_loss_func**: Node/Edge regression loss function. Builtin loss functions include ``mse`` and ``shrinkage``.

- Yaml: ``regression_loss_func: mse``
- Argument: ``--regression-loss-func shrinkage``
- Default value: ``mse``
- **lp_loss_func**: Link prediction loss function. Builtin loss functions include ``cross_entropy`` and ``contrastive``.

- Yaml: ``lp_loss_func: cross_entropy``
Expand Down
5 changes: 4 additions & 1 deletion python/graphstorm/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,10 @@
BUILTIN_LP_LOSS_CONTRASTIVELOSS,
BUILTIN_CLASS_LOSS_CROSS_ENTROPY,
BUILTIN_CLASS_LOSS_FOCAL,
BUILTIN_CLASS_LOSS_FUNCTION)
BUILTIN_CLASS_LOSS_FUNCTION,
BUILTIN_REGRESSION_LOSS_MSE,
BUILTIN_REGRESSION_LOSS_SHRINKAGE,
BUILTIN_REGRESSION_LOSS_FUNCTION)
from .config import (GRAPHSTORM_LP_EMB_L2_NORMALIZATION,
GRAPHSTORM_LP_EMB_NORMALIZATION_METHODS)
from .config import (GRAPHSTORM_SAGEMAKER_TASK_TRACKER,
Expand Down
23 changes: 22 additions & 1 deletion python/graphstorm/config/argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@
BUILTIN_LP_LOSS_CROSS_ENTROPY,
BUILTIN_LP_LOSS_CONTRASTIVELOSS,
BUILTIN_CLASS_LOSS_CROSS_ENTROPY,
BUILTIN_CLASS_LOSS_FUNCTION)
BUILTIN_CLASS_LOSS_FUNCTION,
BUILTIN_REGRESSION_LOSS_MSE,
BUILTIN_REGRESSION_LOSS_FUNCTION)

from .config import BUILTIN_TASK_NODE_CLASSIFICATION
from .config import BUILTIN_TASK_NODE_REGRESSION
Expand Down Expand Up @@ -688,6 +690,7 @@ def verify_node_regression_arguments(self):
_ = self.batch_size
_ = self.eval_metric
_ = self.label_field
_ = self.regression_loss_func

def verify_edge_class_arguments(self):
""" Verify the correctness of arguments for edge classification tasks.
Expand Down Expand Up @@ -715,6 +718,7 @@ def verify_edge_regression_arguments(self):
_ = self.decoder_type
_ = self.num_decoder_basis
_ = self.decoder_edge_feat
_ = self.regression_loss_func

def verify_link_prediction_arguments(self):
""" Verify the correctness of arguments for link prediction tasks.
Expand Down Expand Up @@ -2873,6 +2877,21 @@ def class_loss_func(self):

return BUILTIN_CLASS_LOSS_CROSS_ENTROPY

@property
def regression_loss_func(self):
""" Regression loss function. Builtin loss functions include
``mse`` and ``shrinkage``. Default is ``mse``.
"""
# pylint: disable=no-member
if hasattr(self, "_regression_loss_func"):
assert self._regression_loss_func in BUILTIN_REGRESSION_LOSS_FUNCTION, \
f"Only support {BUILTIN_REGRESSION_LOSS_FUNCTION} " \
"loss functions for regression tasks"
return self._regression_loss_func

return BUILTIN_REGRESSION_LOSS_MSE


@property
def lp_loss_func(self):
""" Link prediction loss function. Builtin loss functions include
Expand Down Expand Up @@ -3530,6 +3549,8 @@ def _add_link_prediction_args(parser):
)
group.add_argument("--class-loss-func", type=str, default=argparse.SUPPRESS,
help="Classification loss function.")
group.add_argument("--regression-loss-func", type=str, default=argparse.SUPPRESS,
help="Regression loss function.")
group.add_argument("--lp-loss-func", type=str, default=argparse.SUPPRESS,
help="Link prediction loss function.")
group.add_argument("--contrastive-loss-temperature", type=float, default=argparse.SUPPRESS,
Expand Down
5 changes: 5 additions & 0 deletions python/graphstorm/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@
BUILTIN_CLASS_LOSS_FOCAL = "focal"
BUILTIN_CLASS_LOSS_FUNCTION = [BUILTIN_CLASS_LOSS_CROSS_ENTROPY, BUILTIN_CLASS_LOSS_FOCAL]

BUILTIN_REGRESSION_LOSS_MSE = "mse"
BUILTIN_REGRESSION_LOSS_SHRINKAGE = "shrinkage"
BUILTIN_REGRESSION_LOSS_FUNCTION = [BUILTIN_REGRESSION_LOSS_MSE,
BUILTIN_REGRESSION_LOSS_SHRINKAGE]

BUILTIN_LP_LOSS_CROSS_ENTROPY = "cross_entropy"
BUILTIN_LP_LOSS_LOGSIGMOID_RANKING = "logsigmoid"
BUILTIN_LP_LOSS_CONTRASTIVELOSS = "contrastive"
Expand Down
30 changes: 27 additions & 3 deletions python/graphstorm/gsf.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@
from .config import (BUILTIN_LP_LOSS_CROSS_ENTROPY,
BUILTIN_LP_LOSS_CONTRASTIVELOSS,
BUILTIN_CLASS_LOSS_CROSS_ENTROPY,
BUILTIN_CLASS_LOSS_FOCAL)
BUILTIN_CLASS_LOSS_FOCAL,
BUILTIN_REGRESSION_LOSS_MSE,
BUILTIN_REGRESSION_LOSS_SHRINKAGE)
from .eval.eval_func import (
SUPPORTED_HIT_AT_METRICS,
SUPPORTED_LINK_PREDICTION_METRICS)
Expand All @@ -65,6 +67,7 @@
from .model.lp_gnn import GSgnnLinkPredictionModel
from .model.loss_func import (ClassifyLossFunc,
RegressionLossFunc,
ShrinkageLossFunc,
LinkPredictBCELossFunc,
WeightedLinkPredictBCELossFunc,
LinkPredictAdvBCELossFunc,
Expand Down Expand Up @@ -521,7 +524,17 @@ def create_builtin_node_decoder(g, decoder_input_dim, config, train_task):
dropout=dropout,
norm=config.decoder_norm,
use_bias=config.decoder_bias)
loss_func = RegressionLossFunc()
if config.regression_loss_func == BUILTIN_REGRESSION_LOSS_MSE:
loss_func = RegressionLossFunc()
elif config.regression_loss_func == BUILTIN_REGRESSION_LOSS_SHRINKAGE:
# set default value of alpha to 10. for shrinkage loss
# set default value of gamma to 0.2 for shrinkage loss
alpha = config.alpha if config.alpha is not None else 10.
gamma = config.gamma if config.gamma is not None else 0.2
loss_func = ShrinkageLossFunc(alpha, gamma)
else:
raise RuntimeError(
f"Unknown regression loss {config.regression_loss_func}")
else:
raise ValueError('unknown node task: {}'.format(config.task_type))

Expand Down Expand Up @@ -723,7 +736,18 @@ def create_builtin_edge_decoder(g, decoder_input_dim, config, train_task):
use_bias=config.decoder_bias)
else:
assert False, "decoder not supported"
loss_func = RegressionLossFunc()

if config.regression_loss_func == BUILTIN_REGRESSION_LOSS_MSE:
loss_func = RegressionLossFunc()
elif config.regression_loss_func == BUILTIN_REGRESSION_LOSS_SHRINKAGE:
# set default value of alpha to 10. for shrinkage loss
# set default value of gamma to 0.2 for shrinkage loss
alpha = config.alpha if config.alpha is not None else 10.
gamma = config.gamma if config.gamma is not None else 0.2
loss_func = ShrinkageLossFunc(alpha, gamma)
else:
raise RuntimeError(
f"Unknown regression loss {config.regression_loss_func}")
else:
raise ValueError('unknown node task: {}'.format(config.task_type))
return decoder, loss_func
Expand Down
1 change: 1 addition & 0 deletions python/graphstorm/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
from .loss_func import (ClassifyLossFunc,
FocalLossFunc,
RegressionLossFunc,
ShrinkageLossFunc,
LinkPredictBCELossFunc,
WeightedLinkPredictBCELossFunc,
LinkPredictAdvBCELossFunc,
Expand Down
87 changes: 87 additions & 0 deletions python/graphstorm/model/loss_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,93 @@ def out_dims(self):
"""
return None

class ShrinkageLossFunc(GSLayer):
r""" Shrinkage Loss for imbalanced regression tasks.
The shrinkage loss is defined as:
.. math::
loss = \frac{l^2}{1 + \exp \left( \alpha \cdot (\gamma - l) \right)}
where l is the absolute difference between the
predicted value and the groud truth. \alpha and
\gamma are hyper-parameters controlling the
shrinkage speed and the localization respectively.
The shrinkage loss only penalizes the importance of
easy samples (when l < 0.5) and keeps the loss of
hard samples unchanged.
# pylint: disable=line-too-long
For more details, please refer to the paper
"Deep Regression Tracking with Shrinkage Loss"
(https://openaccess.thecvf.com/content_ECCV_2018/html/Xiankai_Lu_Deep_Regression_Tracking_ECCV_2018_paper.html)
Parameters
----------
alpha: float
A hyper-parameter controls the loss shrinkage
speed when the prediction error decreases.
Default: ``10.``.
gamma: float
A hyper-parameter controls the localization
of the loss regarding to the x-axis.
Default: ``0.2``.
.. versionadded:: 0.4.1
Add shrinkage loss for regressoin tasks.
"""
def __init__(self, alpha=10, gamma=0.2):
super(ShrinkageLossFunc, self).__init__()
self.alpha = alpha
self.gamma = gamma

def forward(self, logits, labels):
""" The forward function.
Parameters
----------
logits: torch.Tensor
The prediction results.
labels: torch.Tensor
The training labels.
Returns
-------
loss: Tensor
The loss value.
"""
# Make sure the labels is a float tensor
labels = labels.float()
diff = th.abs(logits - labels)
numerator = diff ** 2
denominator = 1 + th.exp(self.alpha * (self.gamma - diff))

loss = numerator / denominator
return loss.mean()

@property
def in_dims(self):
""" The number of input dimensions.
Returns
-------
int : the number of input dimensions.
"""
return None

@property
def out_dims(self):
""" The number of output dimensions.
Returns
-------
int : the number of output dimensions.
"""
return None


class LinkPredictBCELossFunc(GSLayer):
r""" Loss function for link prediction tasks using binary
cross entropy loss.
Expand Down
12 changes: 12 additions & 0 deletions tests/end2end-tests/data_gen/movielens.json
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,18 @@
}
]
},
{
"node_type": "user",
"format": {"name": "parquet"},
"files": "/data/ml-100k/users.parquet",
"labels": [
{
"label_col": "age",
"task_type": "regression",
"split_pct": [0.8, 0.1, 0.1]
}
]
},
{
"node_id_col": "id",
"node_type": "movie",
Expand Down
4 changes: 3 additions & 1 deletion tests/end2end-tests/data_gen/process_movielens.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,9 @@ def write_data_parquet(data, data_file):
table = pa.Table.from_arrays(list(arr_dict.values()), names=list(arr_dict.keys()))
pq.write_table(table, data_file)

user_data = {'id': user['id'], 'feat': feat, 'occupation': user['occupation']}
user_data = {'id': user['id'], 'feat': feat,
'age': user['age'],
'occupation': user['occupation']}
write_data_parquet(user_data, '/data/ml-100k/users.parquet')

movie_data = {'id': ids,
Expand Down
5 changes: 5 additions & 0 deletions tests/end2end-tests/graphstorm-er/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ python3 -m graphstorm.run.gs_edge_regression --workspace $GS_HOME/training_scrip

error_and_exit $?

echo "**************dataset: Test edge regression, RGCN layer: 1, node feat: fixed HF BERT, BERT nodes: movie, inference: mini-batch, with shrinkage loss"
python3 -m graphstorm.run.gs_edge_regression --workspace $GS_HOME/training_scripts/gsgnn_ep/ --num-trainers $NUM_TRAINERS --num-servers 1 --num-samplers 0 --part-config /data/movielen_100k_er_1p_4t/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_er.yaml --num-epochs 1 --regression-loss-func shrinkage

error_and_exit $?

echo "**************dataset: Test edge regression, RGCN layer: 1, node feat: fixed HF BERT, BERT nodes: movie, inference: mini-batch, no test"
python3 -m graphstorm.run.gs_edge_regression --workspace $GS_HOME/training_scripts/gsgnn_ep/ --num-trainers $NUM_TRAINERS --num-servers 1 --num-samplers 0 --part-config /data/movielen_100k_er_no_test_1p_4t/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_er.yaml --num-epochs 1 --logging-file /tmp/train_log.txt

Expand Down
46 changes: 46 additions & 0 deletions tests/end2end-tests/graphstorm-nr/test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#!/bin/bash

service ssh restart

DGL_HOME=/root/dgl
GS_HOME=$(pwd)
NUM_TRAINERS=1
export PYTHONPATH=$GS_HOME/python/
cd $GS_HOME/training_scripts/gsgnn_np
echo "127.0.0.1" > ip_list.txt

cd $GS_HOME/inference_scripts/ep_infer
echo "127.0.0.1" > ip_list.txt

cat ip_list.txt

error_and_exit () {
# check exec status of launch.py
status=$1
echo $status

if test $status -ne 0
then
exit -1
fi
}

echo "Test GraphStorm node regression"

date

echo "**************dataset: Test node regression, RGCN layer: 1, node feat: fixed HF BERT, BERT nodes: movie, inference: mini-batch"
python3 -m graphstorm.run.gs_node_regression --workspace $GS_HOME/training_scripts/gsgnn_np/ --num-trainers $NUM_TRAINERS --num-servers 1 --num-samplers 0 --part-config /data/movielen_100k_train_val_1p_4t/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_nr.yaml --num-epochs 1

error_and_exit $?

echo "**************dataset: Test node regression, RGCN layer: 1, node feat: fixed HF BERT, BERT nodes: movie, inference: mini-batch, with shrinkage loss"
python3 -m graphstorm.run.gs_node_regression --workspace $GS_HOME/training_scripts/gsgnn_np/ --num-trainers $NUM_TRAINERS --num-servers 1 --num-samplers 0 --part-config /data/movielen_100k_train_val_1p_4t/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_nr.yaml --num-epochs 1 --regression-loss-func shrinkage

error_and_exit $?

rm -fr $GS_HOME/training_scripts/gsgnn_np/logs/

date

echo 'Done'
Loading

0 comments on commit 72f11f3

Please sign in to comment.