Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

KM time sampler #81

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -164,5 +164,9 @@ cython_debug/
doc/generated/


# Remove files auto-generated from mac
*.DS_Store

# This dataset should not be redistributed, because users have to sign an agreement.
hazardous/data/seer_cancer_cardio_raw_data.txt
hazardous/data/*.txt
131 changes: 131 additions & 0 deletions hazardous/_ipcw.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,134 @@
from scipy.interpolate import interp1d
from sklearn.base import clone
from sklearn.ensemble import HistGradientBoostingClassifier
from sklearn.preprocessing import MinMaxScaler
from sklearn.utils.validation import check_is_fitted

from .utils import check_y_survival


class KaplanMeierEstimator:
Vincent-Maladiere marked this conversation as resolved.
Show resolved Hide resolved
"""Estimate the Inverse Probability of Censoring Weight (IPCW).

This class estimates the inverse probability of 'survival' to censoring using the
Kaplan-Meier estimator applied to a binary indicator for censoring, defined as the
negation of the binary indicator for any event occurrence. This estimator assumes
that the censoring distribution is independent of the covariates X. If this
assumption is violated, the estimator may be biased, and a conditional estimator
might be more appropriate.

This approach is useful for correcting the bias introduced by right censoring in
survival analysis, particularly when computing model evaluation metrics such as
the Brier score or the concordance index.

Note that the term 'IPCW' can be somewhat misleading: IPCW values represent the
inverse of the probability of remaining censor-free (or uncensored) at a given time.
For instance, at t=0, the probability of being censored is 0, so the probability of
being uncensored is 1.0, and its inverse is also 1.0.

By construction, IPCW values are always greater than or equal to 1.0 and can only
increase over time. If no observations are censored, the IPCW values remain
uniformly at 1.0.

Note: This estimator extrapolates by maintaining a constant value equal to the last
observed IPCW value beyond the last recorded time point.

Parameters
----------
epsilon_censoring_prob : float, default=0.05
Lower limit of the predicted censoring probabilities. It helps avoiding
instabilities during the division to obtain IPCW.

Attributes
----------
min_censoring_prob_ : float
The effective minimal probability used, defined as the max between
min_censoring_prob and the minimum predicted probability.

unique_times_ : ndarray of shape (n_unique_times,)
The observed censoring durations from the training target.

censoring_survival_probs_ : ndarray of shape (n_unique_times,)
The estimated censoring survival probabilities.

censoring_survival_func_ : callable
The linear interpolation function defined with unique_times_ (x) and
censoring_survival_probs_ (y).
"""

def __init__(self):
pass

def fit(self, y, X=None):
"""Marginal estimation of the censoring survival function

In addition to running the Kaplan-Meier estimator on the negated event
labels (1 for censoring, 0 for any event), this methods also fits
interpolation function to be able to make prediction at any time.

Parameters
----------
y : array-like of shape (n_samples, 2)
The target data.

X : None
The input samples. Unused since this estimator is non-conditional.

Returns
-------
self : object
Fitted estimator.
"""
event, duration = check_y_survival(y)

km = KaplanMeierFitter()
km.fit(
durations=duration,
event_observed=event,
)

df = km.survival_function_
self.unique_times_ = df.index
self.survival_probs_ = df.values[:, 0]
scaler = MinMaxScaler()
self.survival_probs_rescaled = scaler.fit_transform(
self.survival_probs_.reshape(-1, 1)
).flatten()

self.survival_func_ = interp1d(
self.unique_times_,
self.survival_probs_,
kind="previous",
bounds_error=False,
fill_value="extrapolate",
)

self.inverse_surv_func_rescaled = interp1d(
1 - self.survival_probs_rescaled,
self.unique_times_,
kind="previous",
bounds_error=False,
fill_value="extrapolate",
)
self.survival_func_rescaled = interp1d(
self.unique_times_,
self.survival_probs_rescaled,
kind="previous",
bounds_error=False,
fill_value="extrapolate",
)

self.inverse_surv_func_ = interp1d(
1 - self.survival_probs_,
self.unique_times_,
kind="previous",
bounds_error=False,
fill_value="extrapolate",
)

return self


class KaplanMeierIPCW:
"""Estimate the Inverse Probability of Censoring Weight (IPCW).

Expand Down Expand Up @@ -108,6 +231,14 @@ def fit(self, y, X=None):
bounds_error=False,
fill_value="extrapolate",
)

self.inverse_surv_func_ = interp1d(
1 - self.censoring_survival_probs_,
self.unique_times_,
kind="previous",
bounds_error=False,
fill_value="extrapolate",
)
return self

def compute_ipcw_at(self, times, X=None, ipcw_training=False):
Expand Down
98 changes: 77 additions & 21 deletions hazardous/_survival_boost.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
import numpy as np
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.ensemble import HistGradientBoostingClassifier
from sklearn.model_selection import train_test_split
from sklearn.utils.validation import check_array, check_random_state
from tqdm import tqdm

from ._ipcw import AlternatingCensoringEstimator, KaplanMeierIPCW
from ._ipcw import AlternatingCensoringEstimator, KaplanMeierEstimator, KaplanMeierIPCW
from .metrics._brier_score import (
IncidenceScoreComputer,
integrated_brier_score_incidence,
Expand Down Expand Up @@ -61,18 +62,28 @@ def __init__(
ipcw_estimator=None,
n_iter_before_feedback=20,
random_state=None,
time_sampler="uniform",
y_censor=None,
):
self.rng = check_random_state(random_state)
self.hard_zero_fraction = hard_zero_fraction
self.n_iter_before_feedback = n_iter_before_feedback

if y_censor is None:
y_censor = y_train.copy()
super().__init__(
y_train,
event_of_interest="any",
ipcw_estimator=ipcw_estimator,
y_censor=y_censor,
)
# Precompute the censoring probabilities at the time of the events on the
# training set:
self.ipcw_train = self.ipcw_estimator.compute_ipcw_at(self.duration_train)
if time_sampler == "uniform":
self.time_sampler = None
elif time_sampler == "kaplan-meier":
self.time_sampler = KaplanMeierEstimator().fit(y_censor)

def draw(self, ipcw_training=False, X=None):
# Sample time horizons uniformly on the observed time range:
Expand All @@ -82,9 +93,16 @@ def draw(self, ipcw_training=False, X=None):
# Sample from t_min=0 event if never observed in the training set
# because we want to make sure that the model learns to predict a 0
# incidence at t=0.
t_min = 0.0
t_max = observation_durations.max()
sampled_time_horizons = self.rng.uniform(t_min, t_max, n_samples)
if self.time_sampler is None:
t_min = 0.0
t_max = observation_durations.max()
sampled_time_horizons = self.rng.uniform(t_min, t_max, n_samples)
else:
q_min, q_max = 0.0, 1.0
quantiles = self.rng.uniform(q_min, q_max, n_samples)
sampled_time_horizons = self.time_sampler.inverse_surv_func_rescaled(
quantiles
)

# Add some hard zeros to make sure that the model learns to
# predict 0 incidence at t=0.
Expand All @@ -105,16 +123,35 @@ def draw(self, ipcw_training=False, X=None):
# sampled time horizon;
# * 0 when an event has happened before the sampled time horizon.
# The sample weight is zero in that case.
n_samples_censor = self.duration_censor.shape[0]
if self.time_sampler is None:
t_min = 0.0
t_max = observation_durations.max()
sampled_time_horizons = self.rng.uniform(t_min, t_max, n_samples_censor)
else:
q_min, q_max = 0.0, 1.0
quantiles = self.rng.uniform(q_min, q_max, n_samples_censor)
sampled_time_horizons = self.time_sampler.inverse_surv_func_rescaled(
quantiles
)

# Add some hard zeros to make sure that the model learns to
# predict 0 incidence at t=0.
n_hard_zeros = max(int(self.hard_zero_fraction * n_samples_censor), 1)
hard_zero_indices = self.rng.choice(
n_samples_censor, n_hard_zeros, replace=False
)
sampled_time_horizons[hard_zero_indices] = 0.0

if not hasattr(self, "inv_any_survival_train"):
self.inv_any_survival_train = self.ipcw_estimator.compute_ipcw_at(
self.duration_train, ipcw_training=True, X=X
self.duration_censor, ipcw_training=True, X=X
)

censored_observations = self.any_event_train == 0
censored_observations = self.any_event_censor == 0
y_targets, sample_weight = self._weighted_binary_targets(
censored_observations,
observation_durations,
self.duration_censor,
sampled_time_horizons,
ipcw_y_duration=self.inv_any_survival_train,
ipcw_training=True,
Expand Down Expand Up @@ -145,11 +182,14 @@ def draw(self, ipcw_training=False, X=None):
return sampled_time_horizons.reshape(-1, 1), y_targets, sample_weight

def fit(self, X):
"""Fit the IPCW estimator."""
self.inv_any_survival_train = self.ipcw_estimator.compute_ipcw_at(
self.duration_train, ipcw_training=True, X=X
self.duration_censor, ipcw_training=True, X=X
)

for _ in range(self.n_iter_before_feedback):
for _ in range(
self.n_iter_before_feedback
): # Maybe a hyperparams here depending on the size of y_censor
sampled_time_horizons, y_targets, sample_weight = self.draw(
ipcw_training=True,
X=X,
Expand All @@ -161,12 +201,6 @@ def fit(self, X):
sample_weight=sample_weight,
)

self.ipcw_train = self.ipcw_estimator.compute_ipcw_at(
self.duration_train,
ipcw_training=False,
X=X,
)


class SurvivalBoost(BaseEstimator, ClassifierMixin):
r"""Cause-specific Cumulative Incidence Function (CIF) with GBDT [1]_.
Expand Down Expand Up @@ -333,6 +367,8 @@ def __init__(
n_iter_before_feedback=20,
random_state=None,
n_horizons_per_observation=3,
time_sampler="uniform",
split_data_censor_and_surv=False,
):
self.hard_zero_fraction = hard_zero_fraction
self.n_iter = n_iter
Expand All @@ -347,6 +383,8 @@ def __init__(
self.ipcw_strategy = ipcw_strategy
self.random_state = random_state
self.n_horizons_per_observation = n_horizons_per_observation
self.time_sampler = time_sampler
self.split_data_censor_and_surv = split_data_censor_and_surv

def fit(self, X, y, times=None):
"""Fit the model.
Expand Down Expand Up @@ -413,30 +451,41 @@ def fit(self, X, y, times=None):
"Valid values are 'alternating' and 'kaplan-meier'."
)

if self.split_data_censor_and_surv:
X_surv, X_censor, y_surv, y_censor = train_test_split(
X, y, random_state=0, test_size=max(1 - y["event"].mean(), 0.1)
)
else:
X_censor = X_surv = X
y_surv = y
y_censor = None

self.weighted_targets_ = WeightedMultiClassTargetSampler(
y,
y_surv,
hard_zero_fraction=self.hard_zero_fraction,
random_state=self.random_state,
ipcw_estimator=ipcw_estimator,
n_iter_before_feedback=self.n_iter_before_feedback,
time_sampler=self.time_sampler,
y_censor=y_censor,
)

self.X_surv = X_surv # remove
iterator = range(self.n_iter)
if self.show_progressbar:
iterator = tqdm(iterator)

for idx_iter in iterator:
X_with_time = np.empty((0, X.shape[1] + 1))
X_with_time = np.empty((0, X_surv.shape[1] + 1))
y_targets = np.empty((0,))
sample_weight = np.empty((0,))
for _ in range(self.n_horizons_per_observation):
(
sampled_times_,
y_targets_,
sample_weight_,
) = self.weighted_targets_.draw(X=X, ipcw_training=False)

X_with_time_ = np.hstack([sampled_times_, X])
) = self.weighted_targets_.draw(X=X_surv, ipcw_training=False)
X_with_time_ = np.hstack([sampled_times_, X_surv])
X_with_time = np.vstack([X_with_time, X_with_time_])
y_targets = np.hstack([y_targets, y_targets_])
sample_weight = np.hstack([sample_weight, sample_weight_])
Expand All @@ -455,7 +504,14 @@ def fit(self, X, y, times=None):
if (idx_iter % self.n_iter_before_feedback == 0) and isinstance(
ipcw_estimator, AlternatingCensoringEstimator
):
self.weighted_targets_.fit(X)
self.weighted_targets_.fit(X_censor)
self.weighted_targets_.ipcw_train = (
self.weighted_targets_.ipcw_estimator.compute_ipcw_at(
self.weighted_targets_.duration_train,
ipcw_training=False,
X=X_surv,
)
)

# XXX: implement verbose logging with a version of IBS that
# can handle competing risks.
Expand Down
Loading
Loading