Source code for SIDISH.DEEP_COX

import torch
import numpy as np
from sklearn.model_selection import train_test_split
from torch.utils.data import TensorDataset, DataLoader
import torch.optim as optim
from tqdm import tqdm
from lifelines.utils import concordance_index
import os
from SIDISH.DEEP_COX_ARCHITECTURE import DEEPCOX_ARCHITECTURE as deepCox


[docs] def loss_DeepCox(pred, events, durations, weight=None, train=True): """ Compute the negative log-likelihood for the Deep Cox model in Phase 2 of SIDISH. Parameters ---------- pred : torch.Tensor Predicted risk scores. events : torch.Tensor Event indicators (1 if event occurred, 0 otherwise or censored). durations : torch.Tensor Time durations. weight : torch.Tensor Patient weights. train : bool, optional Whether the model is in training mode. Defaults to True. Returns ------- torch.Tensor Negative log-likelihood. Notes ----- This method is based on the implementation from DeepSurv: https://github.com/jaredleekatzman/DeepSurv/blob/master/deepsurv/deep_surv.py """ if train: idx = durations.sort(descending=True)[1] # Sort by durations in descending order events = events[idx] # Sort events by durations pred = pred[idx] # Sort risk predictions by durations weight = weight[idx] # Sort patient weight by durations hazard_ratio = torch.exp(pred) / weight log_risk = torch.log(torch.cumsum(hazard_ratio, dim=0)) uncensored_likelihood = pred.t() - log_risk censored_likelihood = uncensored_likelihood * events num_observed_events = torch.sum(events) neg_likelihood = -torch.sum(censored_likelihood) / num_observed_events elif train == False: idx = durations.sort(descending=True)[1] events = events[idx] pred = pred[idx] hazard_ratio = torch.exp(pred) log_risk = torch.log(torch.cumsum(hazard_ratio, dim=0)) uncensored_likelihood = pred.t() - log_risk censored_likelihood = uncensored_likelihood * events num_observed_events = torch.sum(events) neg_likelihood = -torch.sum(censored_likelihood) / num_observed_events return neg_likelihood
[docs] class DEEPCOX(): def __init__(self, X_train, Y_train, weights, hidden, encoder, device, batch_size,seed, lr=0.000001, dropout=0): self.device = device self.X_train = X_train self.Y_train = Y_train self.weights = weights self.hidden = hidden self.encoder = encoder self.dropout = dropout self.lr = lr self.seed = seed torch.manual_seed(self.seed) torch.cuda.manual_seed(self.seed) torch.backends.cudnn.deterministic = True np.random.seed(self.seed) os.environ["PYTHONHASHSEED"] = str(self.seed) # Initializing the Deep Cox regression model of phase 2 of SIDISH self.model = deepCox(hidden=self.hidden, encoder=self.encoder, dropout=self.dropout) # Transfer Learning, unfrezzing the weights of the encoder from the VAE from phase 1. Continuing the training from where it left off. for name, para in self.model.encoder_layer.named_parameters(): para.requires_grad = True self.non_frozen_parameters = [p for p in self.model.parameters() if p.requires_grad] self.model = self.model.to(self.device) self.opt = optim.Adam(self.non_frozen_parameters, lr=self.lr) self.X_train[:, -1] = self.weights train_dataset = TensorDataset(self.X_train.float(), self.Y_train.float()) self.train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
[docs] def train(self, epochs): # Training the Deep Cox Regression model for epoch in tqdm(range(epochs)): for x, y in self.train_loader: self.model.train() x = x.to(self.device, non_blocking=True) y = y.to(self.device, non_blocking=True) x_ = x[:, :-1] w = x[:, -1:].flatten() days = y[:, :-1].flatten() events = y[:, -1:].flatten() pred = self.model(x_).flatten() loss = loss_DeepCox(pred, events, days, w).view(1,) self.opt.zero_grad(set_to_none=True) loss.backward() self.opt.step() # Loss and Concordance index calclation out = -np.exp(pred.detach().cpu().numpy().reshape((1, -1))[0]) self.ci_train = concordance_index(days.cpu().numpy(), out, events.cpu().numpy()) self.loss_train = loss.item()
[docs] def get_train_ci(self): #Return training concordance index return self.ci_train
[docs] def get_train_loss(self): #Return final training loss. return self.loss_train
[docs] def get_test_loss(self, test_loader): #Compute loss on test data. self.model.eval() with torch.no_grad(): genes, surv = next(iter(test_loader)) genes = genes.to(self.device, non_blocking=True) surv = surv.to(self.device, non_blocking=True) genes_ = genes[:, :-1] w_test = genes[:, -1:] days_test = surv[:, :-1].flatten() events_test = surv[:, -1:].flatten() pred_test = self.model(genes_).flatten() loss_test = loss_DeepCox(pred_test, events_test, days_test, train=False).view(1,) return loss_test
[docs] def get_test_ci(self, test_loader): #Compute concordance index on test data. self.model.eval() with torch.no_grad(): genes, surv = next(iter(test_loader)) genes = genes.to(self.device, non_blocking=True) surv = surv.to(self.device, non_blocking=True) genes_ = genes[:, :-1] w_test = genes[:, -1:] days_test = surv[:, :-1].flatten() events_test = surv[:, -1:].flatten() pred_test = self.model(genes_).flatten() out = -np.exp(pred_test.detach().cpu().numpy().reshape((1, -1))[0]) ci_test = concordance_index(days_test.cpu().numpy(), out, events_test.cpu().numpy()) return ci_test