Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: optim loss of MultivariateFailure #75

Merged
merged 1 commit into from
Jan 4, 2024

Conversation

bbayukari
Copy link
Collaborator

@bbayukari bbayukari commented Jan 3, 2024

skmodel.MultivariateFailure

  • Implement vectorized calculations to replace nested loops for efficiency.
  • Utilize logsumexp for better numerical stability in exponential calculations.

test code

import numpy as np
import jax.numpy as jnp
from jax.scipy.special import logsumexp
import time


def multivariate_failure_objective(params, X, y, delta, n, K):
    Xbeta = jnp.matmul(X, params)
    tmp = jnp.ones((n, K))
    for i in range(n):
        for k in range(K):
            tmp = tmp.at[i, k].set(Xbeta[i] - jnp.log(jnp.matmul(y[:, k] >= y[i, k], jnp.exp(Xbeta))))
    loss = -jnp.mean(tmp * delta)
    return loss


def multivariate_failure_objective_vectorized_logsumexp(params, X, y, delta, n, K):
    Xbeta_expanded = jnp.matmul(X, params)[:, None]
    sum_exp_Xbeta = logsumexp(Xbeta_expanded + jnp.log(y >= y[:, None, :]), axis=1)
    loss = -jnp.mean((Xbeta_expanded - sum_exp_Xbeta) * delta)
    return loss


def make_Clayton2_data(n, theta=15, lambda1=1, lambda2=1, c1=1, c2=1):
    u1 = np.random.uniform(0, 1, n)
    u2 = np.random.uniform(0, 1, n)
    time2 = -np.log(1 - u2) / lambda2
    time1 = (np.log(1 - np.power((1 - u2), -theta) + np.power((1 - u1), -theta / (1 + theta)) * np.power((1 - u2), -theta)) / theta / lambda1)
    ctime1 = np.random.uniform(0, c1, n)
    ctime2 = np.random.uniform(0, c2, n)
    delta1 = (time1 < ctime1) * 1
    delta2 = (time2 < ctime2) * 1
    time1 = np.minimum(time1, ctime1)
    time2 = np.minimum(time2, ctime2)
    y = np.hstack((time1.reshape((-1, 1)), time2.reshape((-1, 1))))
    delta = np.hstack((delta1.reshape((-1, 1)), delta2.reshape((-1, 1))))
    return y, delta


def test(seed):
    np.random.seed(seed)
    n, p, s, rho = 100, 100, 10, 0.5
    K = 2

    beta = np.zeros(p)
    beta[:s] = 5
    Sigma = np.power(rho, np.abs(np.linspace(1, p, p) - np.linspace(1, p, p).reshape(p, 1)))
    X = np.random.multivariate_normal(mean=np.zeros(p), cov=Sigma, size=(n,))
    lambda1 = 1 * np.exp(np.matmul(X, beta))
    lambda2 = 10 * np.exp(np.matmul(X, beta))

    y, delta = make_Clayton2_data(n, theta=50, lambda1=lambda1, lambda2=lambda2, c1=5, c2=5)

    # Convert numpy arrays to jax numpy arrays
    X_jax = jnp.array(X)
    y_jax = jnp.array(y)
    delta_jax = jnp.array(delta)

    # Generate random parameters for testing
    params = jnp.array(np.random.randn(p))

    # Calculate loss using both methods
    t1 = time.time()
    loss_original = multivariate_failure_objective(params, X_jax, y_jax, delta_jax, n, K)
    t2 = time.time()
    loss_vectorized = multivariate_failure_objective_vectorized_logsumexp(params, X_jax, y_jax, delta_jax, n, K)
    t3 = time.time()

    return loss_original, t2 - t1, loss_vectorized, t3 - t2


if __name__ == "__main__":
    for i in range(10):
        loss_original, time_original, loss_vectorized, time_vectorized = test(i)
        print("loss_original:   ", loss_original, "time_original: ", time_original)
        print("loss_vectorized: ", loss_vectorized, "time_vectorized: ", time_vectorized)

- Implement vectorized calculations to replace nested loops for efficiency.
- Utilize logsumexp for better numerical stability in exponential calculations.
Copy link

codecov bot commented Jan 3, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Comparison is base (d87fe76) 94.24% compared to head (1190f61) 94.19%.
Report is 1 commits behind head on master.

Additional details and impacted files
@@            Coverage Diff             @@
##           master      #75      +/-   ##
==========================================
- Coverage   94.24%   94.19%   -0.06%     
==========================================
  Files          19       19              
  Lines        2103     2101       -2     
  Branches      653      653              
==========================================
- Hits         1982     1979       -3     
- Misses         91       92       +1     
  Partials       30       30              

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@bbayukari bbayukari merged commit 8bb88af into abess-team:master Jan 4, 2024
11 of 13 checks passed
@bbayukari bbayukari deleted the dev branch January 4, 2024 12:52
@Mamba413 Mamba413 added the enhancement New feature or request label Jan 4, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants