# Core SIDISH components
from SIDISH.DEEP_COX import DEEPCOX as DeepCox # Deep Cox model wrapper (Phase 2)
from SIDISH.VAE import VAE as VAE # VAE encoder (Phase 1)
from SIDISH.Utils import Utils as utils # Helper utilities used across phases
from SIDISH.Utils import get_spatial_graph_from_adata # Build spatial graph (for spatial-VAE)
from SIDISH.in_silico_perturbation import InSilicoPerturbation
from SIDISH.ppi_network_handler import PPINetworkHandler # Loads/queries PPI network
from SIDISH.gene_perturbation_utils import GenePerturbationUtils # Gene KO / network-based adj.
from statsmodels.stats.multitest import multipletests
import seaborn as sns
import pandas as pd
import numpy as np
import torch
import scanpy as sc
import matplotlib.pyplot as plt
from torch.utils.data import TensorDataset, DataLoader
from sklearn.model_selection import train_test_split
import os
from tqdm import tqdm
from typing import Literal
from lifelines import CoxPHFitter
from lifelines.statistics import logrank_test
from lifelines import KaplanMeierFitter
import math
import itertools
from scipy.stats import binomtest
from scipy.stats import wilcoxon
import pyro
[docs]
def process_Data(X: np.ndarray, Y: np.ndarray, test_size: float, batch_size: int, seed: int) -> tuple:
"""
Splits bulk RNA-seq data into training and testing datasets, converts them to tensors,
and creates DataLoaders.
Parameters
----------
X : np.ndarray
Bulk gene expression data.
Y : np.ndarray
Survival data: [survival days, event, weight].
test_size : float
Proportion of dataset allocated to the test split.
batch_size : int
Number of patients per batch.
seed : int
Random seed for reproducibility.
Returns
-------
tuple
- torch.Tensor: X_train (Training feature matrix)
- torch.Tensor: X_test (Testing feature matrix)
- torch.Tensor: y_train (Training labels)
- torch.Tensor: y_test (Testing labels)
"""
# Split data into train, val, test dataset
X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=test_size, random_state=seed, stratify=Y[:, 1])
# Turn to tensor fromat
data_list = [X_train, y_train, X_test, y_test]
data_list = [torch.from_numpy(np.array(d)).type(torch.float) for d in data_list]
X_train = data_list[0]
y_train = data_list[1]
X_test = data_list[2]
y_test = data_list[3]
# Train_dataset - X_train, y_train
train_dataset = TensorDataset(data_list[0].float(), data_list[1].float())
# Test_dataset - X_test, y_test
test_dataset = TensorDataset(data_list[2].float(), data_list[3].float())
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=X_test.shape[0], shuffle=False)
return X_train, X_test, y_train, y_test
[docs]
def plot_umap(ax, umap_combined, palette, percentage_change_):
"""
Scatter UMAP of High-Risk/Background status after perturbation with a
custom palette and an extra legend line indicating the perturbation percentage change.
"""
sns.scatterplot(
x="UMAP1",
y="UMAP2",
hue='status',
data=umap_combined,
palette=palette,
edgecolor='none',
alpha=1,
s=15,
ax=ax
)
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.set_xlabel("UMAP1", fontsize=12)
ax.set_ylabel("UMAP2", fontsize=12)
ax.set_xticks([])
ax.set_yticks([])
# Update legend to include perturbation score
handles, labels = ax.get_legend_handles_labels()
handles.append(plt.Line2D([0], [0]))
labels.append(f"{percentage_change_:.2f}%")
ax.legend(handles, labels, loc="upper left", fontsize=12, bbox_to_anchor=(0.95, 0.75), frameon=False)
[docs]
def plot_umap_differential(ax, umap_combined):
"""
Scatter UMAP colored by continuous risk delta after perturbation.
"""
sns.scatterplot(
x="UMAP1",
y="UMAP2",
hue='risk',
data=umap_combined,
palette="rocket",
edgecolor='none',
alpha=1,
s=15,
ax=ax,
legend=False
)
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.set_xlabel("UMAP1", fontsize=12)
ax.set_ylabel("UMAP2", fontsize=12)
ax.set_xticks([])
ax.set_yticks([])
# Update legend to include perturbation score
handles, labels = ax.get_legend_handles_labels()
handles.append(plt.Line2D([0], [0]))
ax.legend(handles, labels, loc="upper left", fontsize=12, bbox_to_anchor=(0.95, 0.75), frameon=False)
[docs]
def map_event_column(val):
"""
Strictly maps survival event status to binary integers (0 or 1).
This function enforces a strict schema for event data to prevent silent errors during survival analysis. It accepts case-insensitive string labels or numeric binary values.
Parameters
----------
val : str, int, or float
The value representing the event status.
Accepted values:
- Strings: Dead (maps to 1), Alive (maps to 0).
- Numbers: 1 and 0
Returns
-------
int
Returns 1 if the event occurred (Dead) and 0 if censored (Alive).
Raises
------
ValueError
If val is anything other than the accepted string or numeric inputs (e.g., 'censored', 'unknown', 2, NaN).
"""
s = str(val).strip().lower()
# 1. Check explicit strings
if s == "dead": return 1
if s == "alive": return 0
# 2. Check numeric 0/1 (handles 1, 1.0, "1", "0.0")
try:
num = float(s)
if num == 1.0: return 1
if num == 0.0: return 0
except ValueError:
pass
# 3. Raise Error for anything else
raise ValueError(f"Invalid survival event found: '{val}'. " f"The status column must only contain 'Dead', 'Alive', 1, or 0.")
[docs]
def preprocess(
adata,
bulk,
survival_df,
patient_id,
celltype_name,
processed = True,
n_genes_by_counts = 5000,
pct_counts_mt = 10,
batch_correction=False,
survival_ = "Overall_survival_days",
status = "Sample_Status"
):
"""
Harmonize scRNA-seq (AnnData) and bulk tables:
- QC (if raw), HVG selection, intersection of genes across modalities
- Optionally apply Harmony (neighbors/umap/visual check) + ComBat
- Merge survival metadata -> bulk with columns: duration, event
Returns
-------
(AnnData, pd.DataFrame)
(sc object restricted to intersecting genes, bulk with survival columns)
"""
subset = None
## Single-cell data preprocessing
if processed == False:
adata.var_names_make_unique()
sc.pp.filter_cells(adata, min_genes=3)
sc.pp.filter_genes(adata, min_cells=400)
# annotate mito genes
adata.var['mt'] = adata.var_names.str.startswith(('MT-','mt-'))
sc.pp.calculate_qc_metrics(adata, qc_vars=['mt'], percent_top=None, log1p=False, inplace=True)
# QC plots
sc.pl.violin(adata, ['n_genes_by_counts', 'total_counts', 'pct_counts_mt'], jitter=0.4, multi_panel=True)
sc.pl.scatter(adata, x='total_counts', y='pct_counts_mt')
sc.pl.scatter(adata, x='total_counts', y='n_genes_by_counts')
# cell-level thresholds
adata = adata[adata.obs.n_genes_by_counts < n_genes_by_counts, :].copy()
adata = adata[adata.obs.pct_counts_mt < pct_counts_mt, :].copy()
# normalize/log + HVG
adata.raw = adata.copy()
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)
sc.pp.highly_variable_genes(adata, flavor='seurat')
# count cells restricted to HVGs
subset = adata.raw.to_adata()
subset = subset[:, adata.var.highly_variable].copy()
adata = adata[:, adata.var.highly_variable].copy()
data = bulk.copy() # assumes first two columns are metadata (survival days, event)
bulk_genes = data.columns.to_numpy()
if 'gene_ids' in adata.var.columns: sc_gene = adata.var['gene_ids'].astype(str).values
else: sc_gene = adata.var_names.astype(str).values
inter = np.intersect1d(bulk_genes, sc_gene)
if inter.size == 0:
raise ValueError("No overlapping genes between bulk and scRNA-seq.")
## Keeping intersection between scRNA-seq and bulk
if 'gene_ids' in adata.var.columns:
adata = adata[:, adata.var['gene_ids'].astype(str).isin(inter)].copy()
subset = subset[:, subset.var['gene_ids'].astype(str).isin(inter)].copy()
else:
adata = adata[:, adata.var_names.isin(inter)].copy()
subset = subset[:, subset.var_names.isin(inter)].copy()
if 'cells' in subset.obs.columns:
subset.obs_names = subset.obs['cells'].astype(str).values
if celltype_name in subset.obs.columns:
subset.obs.rename(columns={celltype_name: "celltype_major"}, inplace=True)
data = data.filter(items=adata.to_df().columns.values)
data = data[adata.to_df().columns.values]
bulk = pd.concat([survival_df, data], axis=1)
bulk.rename(columns={survival_: "duration", status: "event"}, inplace=True)
bulk["event"] = bulk["event"].apply(map_event_column)
## Optional UMAP and batch correction using Harmony
if batch_correction:
sc.pp.scale(adata, max_value=10)
sc.tl.pca(adata, svd_solver='arpack')
sc.pp.neighbors(adata, n_neighbors=10, n_pcs=40)
sc.external.pp.harmony_integrate(adata, key=patient_id)
sc.tl.umap(adata)
sc.pl.umap(adata, color=[patient_id])
## Data batch correction using ComBat
if patient_id not in subset.obs.columns:
subset.obs[patient_id] = subset.obs_names.str.split('_').str[-1]
sc.pp.combat(subset, key=patient_id)
return subset, bulk
else:
# processed == True
if 'cells' in adata.obs.columns:
adata.obs_names = adata.obs['cells'].astype(str).values
if celltype_name in adata.obs.columns:
adata.obs.rename(columns={celltype_name: "celltype_major"}, inplace=True)
data = bulk.copy()
bulk_genes = data.columns.to_numpy()
# Restrict sc to intersection
if 'gene_ids' in adata.var.columns: sc_gene = adata.var['gene_ids'].astype(str).values
else: sc_gene = adata.var_names.astype(str).values
inter = np.intersect1d(bulk_genes, sc_gene)
if inter.size == 0:
raise ValueError("No overlapping genes between bulk and scRNA-seq.")
## Keeping intersection between scRNA-seq and bulk
if 'gene_ids' in adata.var.columns:
adata = adata[:, adata.var['gene_ids'].astype(str).isin(inter)].copy()
else:
adata = adata[:, adata.var_names.isin(inter)].copy()
# Reorder bulk to sc gene order
data = data.filter(items=adata.to_df().columns.values)
data = data[adata.to_df().columns.values]
# Merge survival metadata; map event to {0,1}
bulk = pd.concat([survival_df, data], axis=1)
bulk.rename(columns={survival_: "duration", status: "event"}, inplace=True)
bulk["event"] = bulk["event"].apply(map_event_column)
return adata, bulk
[docs]
class SIDISH:
"""
SIDISH (Semi-Supervised Iterative Deep Learning for Identifying High-Risk Cells).
This framework integrates single-cell and bulk RNA-seq data to identify
High-Risk cancer cells and potential biomarkers.
Parameters
----------
adata : AnnData
Single-cell RNA-seq data.
bulk : pd.DataFrame
Bulk RNA-seq data.
use_spatial_graph : bool, optional
Whether to use spatial graph information (default=False).
k_neighbors : int, optional
Number of neighbors to use for constructing the spatial graph (default=5).
device : str
Computation device ('cpu' or 'cuda').
seed : int, optional
Random seed for reproducibility (default=1234).
"""
[docs]
def __init__(self, adata, bulk, device: str = "cpu", seed: int = 1234, use_spatial_graph: bool = False, k_neighbors: int = None) -> None:
self.adata = adata
self.bulk = bulk
self.device = device
self.seed = seed
self.use_spatial_graph = use_spatial_graph
self.k_neighbors = k_neighbors
if self.use_spatial_graph and "spatial" in self.adata.obsm:
edge_index, edge_weight = get_spatial_graph_from_adata(self.adata, spatial_key="spatial", method="knn", k=self.k_neighbors)
self.spatial_graph = (edge_index, edge_weight)
print("SIDISH Spatial graph constructed using k-NN with k =", self.k_neighbors)
else:
print("SIDISH No spatial graph used. Proceeding with dense VAE.")
[docs]
def init_Phase1(self, epochs: int, i_epochs: int, latent_size: int, layer_dims: list, batch_size: int, optimizer: str, lr: float, lr_3: float, dropout: float, type: str = 'Normal') -> None:
"""
Initializes Phase 1: training a Variational Autoencoder (VAE) on single-cell RNA-seq data.
Parameters
----------
epochs : int
Number of epochs for initial VAE training.
i_epochs : int
Number of iterations for retraining VAE.
latent_size : int
Latent dimension size.
layer_dims : list
List of hidden layer dimensions.
batch_size : int
Batch size.
optimizer : str
Optimizer for VAE training.
lr : float
Learning rate.
lr_3 : float
Learning rate for later iterations.
dropout : float
Dropout rate.
type : str, optional
Specifies dense or normal representation (default="Normal").
Returns
-------
None
"""
self.epochs_1 = epochs
self.epochs_3 = i_epochs
self.latent_size = latent_size
self.layer_dims = layer_dims
self.optimizer = optimizer
self.lr_1 = lr
self.lr_3 = lr_3
self.dropout_1 = dropout
self.batch_size = batch_size
self.type = type
# Initialise the weight matrix of phase 1
self.W_matrix = np.ones(self.adata.X.shape)
[docs]
def init_Phase2(self, epochs: int, hidden: int, lr: float, dropout: float, test_size: float, batch_size_bulk: int) -> None:
"""
Initializes Phase 2: training a Deep Cox model for survival analysis using bulk RNA-seq data.
Parameters
----------
epochs : int
Number of training epochs for Deep Cox model.
hidden : int
Number of neurons in the hidden layer.
lr : float
Learning rate for Deep Cox model.
dropout : float
Dropout rate for training.
test_size : float
Proportion of dataset allocated to the test split.
batch_size_bulk : int
Number of samples per batch for bulk data.
Returns
-------
None
"""
self.epochs_2 = epochs
self.hidden = hidden
# self.iterations = iterations
self.lr_2 = lr
self.dropout_2 = dropout
self.batch_size_bulk = batch_size_bulk
# Initialise the weight vector of phase 2
self.W_vector = np.ones(self.bulk.iloc[:,2:].shape[0])
self.bulk['weight'] = self.W_vector
self.X = self.bulk.iloc[:, 2:].values
self.Y = self.bulk.iloc[:, :2].values
self.X_train, self.X_test, self.y_train, self.y_test = process_Data(self.X, self.Y, test_size, batch_size_bulk, self.seed)
[docs]
def train(self, iterations: int, percentile: float, steepness: float, path: str, num_workers: int = 0, show: bool = True, distribution_fit:Literal["default", "fitted"] = "default") -> sc.AnnData:
"""
Trains the SIDISH framework iteratively, refining the identification of High-Risk cells.
This function iteratively updates High-Risk cell classifications by integrating
single-cell and bulk RNA-seq data. Each iteration includes:
- Training the VAE model on single-cell data.
- Training the Deep Cox model on bulk RNA-seq survival data.
- Updating weight matrices to improve High-Risk cell identification.
Parameters
----------
iterations : int
Number of training iterations.
percentile : float
Threshold percentile for defining High-Risk cells.
steepness : float
Scaling factor for updating weights.
path : str
Directory for saving model checkpoints.
num_workers : int, optional
Number of parallel workers (default=8).
show : bool, optional
If True, displays training progress (default=True).
Returns
-------
sc.AnnData
Updated AnnData object containing the refined High-Risk cell classifications.
"""
os.makedirs(path, exist_ok=True)
self.path = path
self.num_workers = num_workers
self.percentile = percentile
self.steepness = steepness
# Re-Initialise the weight vector of phase 2
test_dataset = TensorDataset(self.X_test.float(), self.y_test.float())
self.test_loader = DataLoader(test_dataset, batch_size=self.X_test.shape[0], shuffle=False)
self.W_vector = self.X_train[:, -1]
# Initialise the VAE of phase 1
if self.use_spatial_graph and self.spatial_graph is not None:
print("########################################## Using Spatial Graph in VAE ##########################################")
self.vae = VAE(epochs=self.epochs_1, adata=self.adata, z_dim=self.latent_size, layer_dims=self.layer_dims, lr=self.lr_1, dropout=self.dropout_1, device=self.device, seed=self.seed, gcn_dims=[32, self.latent_size])
self.vae.initialize(self.adata, W=self.W_matrix, batch_size=self.batch_size, num_workers=self.num_workers, spatial_graph=self.spatial_graph, num_neighbors=self.k_neighbors)
else:
print("########################################## Using Dense VAE ##########################################")
self.vae = VAE(self.epochs_1,self.adata,self.latent_size, self.layer_dims, self.lr_1, self.dropout_1,self.device, self.seed)
self.vae.initialize(self.adata, self.W_matrix, self.batch_size, self.type, self.num_workers)
# Initial training of VAE in iteration 1
print("########################################## ITERATION 1 OUT OF {} ##########################################".format(iterations))
self.vae.train()
# Save VAE for iterative process
torch.save(self.vae.model.state_dict(), "{}vae_transfer".format(self.path))
self.train_loss_Cox = []
self.train_ci_Cox = []
self.test_loss_Cox = []
self.test_ci_Cox = []
self.percentile_list = []
for i in range(iterations):
self.encoder = self.vae
self.deepCox_model = DeepCox(self.X_train, self.y_train, self.W_vector, self.hidden, self.encoder, self.device,self.batch_size,self.seed, self.lr_2, self.dropout_2)
self.deepCox_model.train(self.epochs_2)
self.train_loss_Cox.append(self.deepCox_model.get_train_loss())
self.train_ci_Cox.append(self.deepCox_model.get_train_ci())
self.test_loss_Cox.append(self.deepCox_model.get_test_loss(self.test_loader))
self.test_ci_Cox.append(self.deepCox_model.get_test_ci(test_loader=self.test_loader))
patients_data = self.X_train[:, :-1].to(self.device)
print("########################################## Calculating Patients Weight Vector ##########################################")
if i == 0:
self.scores, self.adata_, self.percentile_cells, self.cells_max, self.cells_min, returned_dist = utils.getWeightVector(patients_data, self.vae.adata, self.deepCox_model.model, self.percentile, self.device, distribution_fit)
self.dist = returned_dist
else:
if distribution_fit == 'fitted':
self.scores, self.adata_, self.percentile_cells, self.cells_max, self.cells_min , returned_dist = utils.getWeightVector(patients_data, self.vae.adata, self.deepCox_model.model, self.percentile, self.device, distribution_fit, self.dist)
else:
self.scores, self.adata_, self.percentile_cells, self.cells_max, self.cells_min , returned_dist = utils.getWeightVector(patients_data, self.vae.adata, self.deepCox_model.model, self.percentile, self.device, "default")
self.percentile_list.append(self.percentile_cells)
self.W_vector += self.scores
print("########################################## Calculating Cells Weight Matrix ##########################################")
self.W_temp = utils.getWeightMatrix(self.adata_, self.seed, self.steepness, self.type)
self.W_matrix += self.W_temp
self.W_matrix[self.W_matrix >= 2] = 2
self.adata = self.adata_.copy()
print("########################################## Saving Weight Matrix at Iteration {} ##########################################".format(i))
pd.DataFrame(self.W_matrix).to_csv("{}W_matrix_{}.csv".format(self.path,i))
if i == (iterations - 1):
print("########################################## SIDISH TRAINING DONE ##########################################")
break
else:
print("########################################## ITERATION {} OUT OF {} ##########################################".format(i+2, iterations))
if self.use_spatial_graph and self.spatial_graph is not None:
self.vae = VAE(epochs=self.epochs_3, adata=self.adata, z_dim=self.latent_size, layer_dims=self.layer_dims, lr=self.lr_3, dropout=self.dropout_1, device=self.device, seed=self.seed, gcn_dims=[32, self.latent_size])
self.vae.initialize(self.adata, spatial_graph=self.spatial_graph, W=self.W_matrix, batch_size=self.batch_size, num_neighbors=self.k_neighbors, num_workers=self.num_workers)
self.vae.model.load_state_dict(torch.load("{}vae_transfer".format(self.path)))
else:
self.vae = VAE(self.epochs_3,self.adata,self.latent_size, self.layer_dims,self.lr_3, self.dropout_1,self.device, self.seed)
self.vae.initialize(self.adata, self.W_matrix, self.batch_size, self.type,self.num_workers)
self.vae.model.load_state_dict(torch.load("{}vae_transfer".format(self.path)))
self.vae.train()
# Save VAE for iterative process
torch.save(self.vae.model.state_dict(), "{}vae_transfer".format(self.path))
torch.save(self.deepCox_model.model.state_dict(), "{}deepCox".format(self.path))
print("########################################## Saving Final AnnData Object ##########################################")
fn = "{}adata_SIDISH.h5ad".format(self.path)
self.adata.write_h5ad(fn, compression="gzip")
return self.adata
[docs]
def getEmbedding_adata(self) -> sc.AnnData:
"""
Extracts latent representations from the trained VAE.
Returns
-------
AnnData
Updated AnnData object with embeddings stored in `obsm['latent']`.
"""
self.TZ = []
self.vae.model.eval()
with torch.no_grad():
for x, y in self.vae.total_loader:
# if on GPU put mini-batch into CUDA memory
pyro.clear_param_store()
x = x.to(self.device, non_blocking=True)
z = self.vae.model.get_latent_representation(x)
zz = z.cpu().detach().numpy().tolist()
self.TZ += zz
self.adata.obsm['latent'] = np.array(self.TZ).astype(np.float32)
return self.adata
[docs]
def plotUMAP(self, resolution: float, figure_size: tuple = (8, 6), fontsize: int = 12, cell_size: int = 20) -> None:
"""
Performs UMAP dimensionality reduction and Leiden clustering on the latent space.
Parameters
----------
resolution : float
The resolution parameter for Leiden clustering.
figure_size : tuple, optional
Size of the generated UMAP plot (default=(8, 6)).
fontsize : int, optional
Font size for labels and legends (default=12).
cell_size : int, optional
Size of points in the scatter plot (default=20).
Returns
-------
None
"""
print("################### Calculating Neighbors #################")
sc.pp.neighbors(self.adata, n_neighbors=30, use_rep="latent", random_state=self.seed)
print("################### Calculating UMAP coordinated #################")
sc.tl.umap(self.adata, random_state=self.seed)
print("################### Leiden Clustering #################")
sc.tl.leiden(self.adata, resolution=resolution, random_state=self.seed)
print("################## Annotating Anndata #################")
h = self.adata[self.adata.obs.SIDISH == "h"].shape[0]
b = self.adata[self.adata.obs.SIDISH == "b"].shape[0]
self.adata.obs["SIDISH_"] = ["High-Risk Cells ({})".format(h) if i == "h" else "Background Cells ({})".format(b) for i in self.adata.obs.SIDISH]
self.adata.uns["SIDISH__colors"] = np.array(["grey", "red"], dtype="object")
self.adata.obs["SIDISH__"] = ["High-Risk" if i == "h" else "Background" for i in self.adata.obs.SIDISH]
self.adata.uns["SIDISH___colors"] = np.array(["grey", "red"], dtype="object")
print("################## Plotting SIDISH identified High-Risk cells #################")
plt.figure(figsize=figure_size)
plt.rcParams["font.family"] = "sans-serif"
plt.rcParams["font.size"] = fontsize
plt.rcParams["legend.fontsize"] = fontsize
ax = sc.pl.umap(self.adata, color=["SIDISH_"], title="", show=False, size=cell_size,edgecolor="none")
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
# Adjust legend properties: location at the top with more space
ax.legend(loc="upper center", frameon=False, markerscale=1.5, bbox_to_anchor=(0.5, 1.15))
plt.tight_layout()
plt.show()
[docs]
def annotateCells(self, test_adata, percentile_cells, mode, perturbation=False):
adata_new = utils.annotateCells(test_adata, self.deepCox_model.model, percentile_cells, self.device, self.percentile,mode =mode,perturbation=perturbation)
return adata_new
[docs]
def reload(self, path, num_workers = 0):
self.path = path
self.num_workers = num_workers
if self.use_spatial_graph and self.spatial_graph is not None:
print("########################################## Using Spatial Graph in VAE ##########################################")
self.vae = VAE(epochs=self.epochs_1, adata=self.adata, z_dim=self.latent_size, layer_dims=self.layer_dims, lr=self.lr_1, dropout=self.dropout_1, device=self.device, seed=self.seed, gcn_dims=[32, self.latent_size])
self.vae.initialize(self.adata, W=self.W_matrix, batch_size=self.batch_size, num_workers=self.num_workers, spatial_graph=self.spatial_graph, num_neighbors=self.k_neighbors)
self.vae.model.load_state_dict(torch.load("{}vae_transfer".format(self.path)))
else:
print("########################################## Using Dense VAE ##########################################")
self.vae = VAE(self.epochs_1,self.adata,self.latent_size, self.layer_dims, self.lr_1, self.dropout_1,self.device, self.seed)
self.vae.initialize(self.adata, self.W_matrix, self.batch_size, self.type, self.num_workers)
self.vae.model.load_state_dict(torch.load("{}vae_transfer".format(self.path)))
self.encoder = self.vae
self.W_vector = self.X_train[:, -1]
self.deepCox_model = DeepCox(self.X_train, self.y_train, self.W_vector, self.hidden, self.encoder, self.device,self.batch_size, self.seed, self.lr_2, self.dropout_2)
self.deepCox_model.model.load_state_dict(torch.load("{}deepCox".format(self.path)))
print("✅ Reload complete – VAE and DeepCox restored")
[docs]
def get_percentille(self, percentile):
self.percentile = percentile
self.percentile_cells = utils.get_threshold(self.adata, self.deepCox_model.model, self.percentile, self.device)
return self.percentile_cells
[docs]
def get_embedding(self, n_neighbors=30, resolution=None, celltype=True):
if celltype and resolution is not None:
raise ValueError("Resolution should not be provided when celltype=True.")
self.adata = self.getEmbedding_adata()
sc.pp.neighbors(self.adata, n_neighbors=n_neighbors, use_rep="latent", random_state=self.seed)
sc.tl.umap(self.adata, random_state=self.seed)
if not celltype:
if resolution is None:
resolution = 0.8 # Default resolution if not provided
sc.tl.leiden(self.adata, resolution=resolution, random_state=self.seed)
return self.adata
[docs]
def set_adata(self):
h = self.adata[self.adata.obs.SIDISH == "h"].shape[0]
b = self.adata[self.adata.obs.SIDISH == "b"].shape[0]
self.adata.obs["SIDISH_"] = ["High-Risk Cells ({})".format(h) if i == "h" else "Background Cells ({})".format(b) for i in self.adata.obs.SIDISH]
self.adata.uns["SIDISH__colors"] = np.array(["grey", "red"], dtype="object")
self.adata.obs["SIDISH__"] = ["High-Risk" if i == "h" else "Background" for i in self.adata.obs.SIDISH]
self.adata.uns["SIDISH___colors"] = np.array(["grey", "red"], dtype="object")
fn = "{}adata_SIDISH_embedding.h5ad".format(self.path)
self.adata.write_h5ad(fn, compression="gzip")
[docs]
def plot_HighRisk_UMAP(self, size= 10, resolution=None, celltype=True):
self.adata = self.get_embedding(resolution=resolution, celltype=celltype)
self.set_adata()
plt.figure(figsize=(8, 6))
plt.rcParams["font.family"] = "sans-serif"
plt.rcParams["font.size"] = 12
plt.rcParams["legend.fontsize"] = 12
ax = sc.pl.umap(self.adata, color=["SIDISH_"], title="", show=False, size=size, edgecolor="none",)
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
# Adjust legend properties: location at the top with more space
ax.legend(loc="upper center", frameon=False, markerscale=1.5, bbox_to_anchor=(0.5, 1.15))
plt.tight_layout()
plt.show()
[docs]
def plot_CellType_UMAP(self, size = 10, resolution=None, celltype=True):
self.adata = self.get_embedding(resolution=resolution, celltype=celltype)
self.set_adata()
plt.figure(figsize=(8,6))
plt.rcParams["font.family"] = "sans-serif"
plt.rcParams["font.size"] = 12
plt.rcParams["legend.fontsize"] = 12
ax = sc.pl.umap(self.adata, color=["leiden"], title="", legend_loc="on data", legend_fontsize=12, show=False, size=size, edgecolor="none", palette='Set3')
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
# Adjust legend properties: location at the top with more space
ax.legend(loc="upper center", frameon=False, markerscale=1.5, bbox_to_anchor=(0.5, 1.15))
plt.tight_layout()
[docs]
def get_MarkerGenes(self,logfc_threshold=1.5, pval_threshold=0.05, method="wilcoxon", group="h"):
"""
Identifies marker genes for the specified group using different statistical methods.
Parameters:
logfc_threshold (float): Log fold change threshold for filtering significant genes.
pval_threshold (float): P-value threshold for statistical significance.
method (str): Method for ranking genes ('wilcoxon', 't-test', 'logreg').
group (str): The group to compare against others (default is 'h').
Returns:
upregulated_genes (list): List of upregulated marker genes.
downregulated_genes (list): List of downregulated marker genes.
"""
# Validate method
supported_methods = ["wilcoxon", "t-test", "logreg"]
if method not in supported_methods:
raise ValueError(f"Unsupported method. Choose from {supported_methods}.")
# Normalize and log-transform the data
sc.pp.normalize_total(self.adata, target_sum=1e4)
sc.pp.log1p(self.adata)
# Rank genes using the selected method
sc.tl.rank_genes_groups(self.adata, groupby="SIDISH", groups=[group], method=method, key_added=f"SIDISH_deg")
# Extract ranked genes for the specified group
ranked_genes_df = sc.get.rank_genes_groups_df(self.adata, group=group, key=f"SIDISH_deg")
# Filter for upregulated and downregulated genes
self.upregulated_genes = ranked_genes_df[(ranked_genes_df["logfoldchanges"] > logfc_threshold) & (ranked_genes_df["pvals"] < pval_threshold)]["names"].values
self.downregulated_genes = ranked_genes_df[(ranked_genes_df["logfoldchanges"] < -logfc_threshold) & (ranked_genes_df["pvals"] < pval_threshold)]["names"].values
return self.upregulated_genes, self.downregulated_genes
[docs]
def analyze_perturbation_effects(self):
self.percent_change, self.p_flip, self.p_score = {}, {}, {}
self.delta_change = {}
self.b_dict = {}
self.b_to_h_dict = {}
self.h_to_b_dict = {}
self.h_dict = {}
n_hcells = (self.adata.obs.SIDISH == "h").sum()
for gene, data in tqdm(zip(self.genes, self.optimized_results), total=len(self.genes), desc="Stats"):
adata_p = self.annotateCells(data, self.percentile_cells, mode="no", perturbation=True)
h_to_b = ((self.adata.obs.SIDISH == "h") & (adata_p.obs.SIDISH == "b")).sum()
b_to_h = ((self.adata.obs.SIDISH == "b") & (adata_p.obs.SIDISH == "h")).sum()
# % drop in high-risk cells
self.percent_change[gene] = ((h_to_b - b_to_h) / n_hcells) * 100
# --- 1 one-sided sign test on label flips ---
flips = h_to_b + b_to_h
if flips: # avoid 0-trials edge case
p_flip = binomtest(k=h_to_b, n=flips, p=0.5, alternative="greater").pvalue
else:
p_flip = 1.0
self.p_flip[gene] = p_flip
# --- 2 one-sided Wilcoxon on risk-score delta ---
delta = adata_p.obs["perturbation_score"].values
self.delta_change[gene] = delta.mean()
_, p_score = wilcoxon(delta, alternative="greater")
self.h_to_b_dict[gene] = h_to_b
self.b_to_h_dict[gene] = b_to_h
self.b_dict[gene] = ((self.adata.obs["SIDISH"] == "b") & (adata_p.obs["SIDISH"] == "b")).sum()
self.h_dict[gene] = ((self.adata.obs["SIDISH"] == "h") & (adata_p.obs["SIDISH"] == "h")).sum()
self.p_score[gene] = p_score
return self.percent_change, self.delta_change, self.p_flip, self.p_score
[docs]
def run_Perturbation(self, n_jobs: int = 4) -> tuple:
self.adata = sc.read_h5ad("{}adata_SIDISH.h5ad".format(self.path))
self.genes = list(self.adata.var.index)
self.percentage_dict = {}
self.pvalue_dict = {}
perturbation = InSilicoPerturbation(self.adata)
perturbation.setup_ppi_network(threshold=0.7)
self.optimized_results = perturbation.run_parallel_processing(self.adata, n_jobs=4)
self.percent_change, self.delta_change, self.p_flip, self.p_score = self.analyze_perturbation_effects()
return self.percent_change, self.delta_change, self.p_flip, self.p_score
[docs]
def plot_KM(self, penalizer=0.1, data_name="DATA", high_risk_label="High-Risk", background_label="Background", colors=("pink", "grey"), fontsize=12):
"""
Plot Kaplan-Meier survival curves for High-Risk and background patient groups.
Parameters:
penalizer (float): Penalizer for CoxPHFitter regularization.
data_name (str): Title label for the dataset.
high_risk_label (str): Label for the High-Risk group.
background_label (str): Label for the background group.
colors (tuple): Colors for the survival plots (High-Risk, background).
fontsize (int): Font size for plot labels and legends.
"""
# Prepare Data
DEG_genes = np.append(self.upregulated_genes, ["duration", "event"])
result = self.bulk.filter(DEG_genes)
# Fit Cox Proportional Hazards Model
cph = CoxPHFitter(penalizer=penalizer)
cph.fit(result, duration_col="duration", event_col="event")
# Risk Score Calculation
coef = cph.summary.T.filter(self.upregulated_genes).iloc[0].values.reshape(-1, 1)
expression = result.iloc[:, :-2].values
risk_scores = np.dot(expression, coef)
# Classify Patients into High/Low Risk
median_risk = np.median(risk_scores)
risk_group = np.where(risk_scores >= median_risk, high_risk_label, background_label)
result_df = result.copy()
result_df["risk"] = risk_group
result_df["scores"] = risk_scores
# Log-Rank Test
low_risk_group = result_df[result_df["risk"] == background_label]
high_risk_group = result_df[result_df["risk"] == high_risk_label]
logrank_result = logrank_test(
durations_A=low_risk_group["duration"],
durations_B=high_risk_group["duration"],
event_observed_A=low_risk_group["event"],
event_observed_B=high_risk_group["event"]
)
p_value = logrank_result.p_value
# Kaplan-Meier Plot
fig, ax = plt.subplots(figsize=(8, 6))
kmf = KaplanMeierFitter()
# Plot High-Risk Group
kmf.fit(high_risk_group["duration"], event_observed=high_risk_group["event"], label=high_risk_label)
kmf.plot_survival_function(ax=ax, color=colors[0], ci_show=False, linewidth=2.5)
# Plot Background Group
kmf.fit(low_risk_group["duration"], event_observed=low_risk_group["event"], label=background_label)
kmf.plot_survival_function(ax=ax, color=colors[1], ci_show=False, linewidth=2.5)
# Aesthetics
ax.set_title("", fontsize=fontsize)
ax.set_xlabel("Time (Days)", fontsize=fontsize)
ax.set_ylabel("Survival Probability", fontsize=fontsize)
ax.tick_params(axis='both', labelsize=fontsize)
ax.set_ylim(0, 1)
# Custom Legend
legend_labels = [data_name, high_risk_label, background_label]
legend_handles = [
plt.Line2D([0], [0], color="w", label=legend_labels[0])
] + [
plt.Line2D([0], [0], marker="s", color="w", markerfacecolor=color, markersize=8, label=label)
for color, label in zip(colors, legend_labels[1:])
]
fig.legend(handles=legend_handles, loc="upper center", bbox_to_anchor=(0.5, 1.1), ncol=3, frameon=False, fontsize=fontsize)
# P-value Formatting
p_value_formatted = f"P = {p_value:.2e}"
plt.text(ax.get_xlim()[1] * 0.5, 1.02, p_value_formatted, fontsize=fontsize, ha='center', fontstyle='italic', weight='bold')
# Final Adjustments
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
plt.tight_layout()
plt.show()
[docs]
def plot_top_perturbed_genes(self, gene_data, top_n=20):
"""
Plots a barplot of the top N genes with the highest percentage reduction
in High-Risk cells after in-silico perturbation.
Parameters:
- gene_data (dict): Dictionary of gene perturbation effects.
- top_n (int): Number of top genes to display. Default is 20.
"""
# Sort the genes by their reduction percentages
self.top_genes = dict(sorted(gene_data.items(), key=lambda x: x[1], reverse=True)[:top_n])
# Plotting
plt.figure(figsize=(8, 8))
plt.barh(list(self.top_genes.keys()), list(self.top_genes.values()), color='brown')
plt.xlabel('% Reduction of High-Risk Cells')
plt.title(f'Top {top_n} Genes by Reduction in High-Risk Cells After Perturbation')
plt.gca().invert_yaxis() # Invert y-axis to show the highest reduction on top
plt.show()
[docs]
def plot_perturbation_UMAP_default(self, genes_of_interest, resolution=None, celltype=True, threshold=0.8):
"""
Generates UMAP visualizations for specified genes after in-silico perturbation.
Parameters:
- adata: AnnData object with latent embeddings.
- sidish: SIDISH object for annotation and processing.
- ppi_df: DataFrame containing the PPI network data.
- genes_of_interest (list): List of genes to visualize.
- output_path: Filepath for saving the generated UMAP plot.
- seed: Random seed for reproducibility. Default is 42.
"""
if not isinstance(genes_of_interest, list):
raise TypeError("genes_of_interest must be a list of gene names.")
# Dynamic subplot layout
n_genes = len(genes_of_interest)
n_cols = 2
n_rows = math.ceil(n_genes / n_cols)
fig, axs = plt.subplots(n_rows, n_cols, figsize=(8 * n_cols, 4.5 * n_rows), squeeze=False)
axs = axs.flatten()
for gene, ax in zip(genes_of_interest, axs):
self.adata = sc.read_h5ad("{}adata_SIDISH.h5ad".format(self.path))
self.ppi_handler = PPINetworkHandler(self.adata)
self.ppi_handler.load_network(threshold)
direct_neighbors, indirect_neighbors = self.ppi_handler.get_neighbors(gene)
neighbors = direct_neighbors + indirect_neighbors
network_df = self.ppi_handler.ppi_df[self.ppi_handler.ppi_df["Source"].isin(neighbors) | self.ppi_handler.ppi_df["Target"].isin(neighbors)]
if not network_df.empty:
adata_p = GenePerturbationUtils.adjust_expression(self.adata, gene, network_df)
else:
adata_p.X = GenePerturbationUtils.knockout_gene(self.adata, gene).tocsr()
adata_p = self.annotateCells(adata_p, self.percentile_cells, "no")
h_to_b = ((self.adata.obs['SIDISH'] == 'h') & (adata_p.obs['SIDISH'] == 'b')).sum()
b_to_h = ((self.adata.obs['SIDISH'] == 'b') & (adata_p.obs['SIDISH'] == 'h')).sum()
# Calculate the percentage change and store in the dictionary
percentage_change = ((h_to_b - b_to_h) / (self.adata.obs.SIDISH == "h").values.sum()) * 100
h = adata_p[adata_p.obs.SIDISH == "h"].shape[0]
b = adata_p[adata_p.obs.SIDISH == "b"].shape[0]
adata_p.obs["SIDISH_"] = ["High risk cells ({})".format(h) if i == "h" else "Background cells ({})".format(b) for i in adata_p.obs.SIDISH]
self.adata = self.get_embedding(resolution=resolution, celltype=celltype)
self.set_adata()
umap_combined = pd.DataFrame(self.adata.obsm["X_umap"], columns=["UMAP1", "UMAP2"])
umap_combined.index = self.adata.obs.index.values
umap_combined['risk'] = self.adata.obs['SIDISH'].values
umap_combined['status'] = '{}'.format(((self.adata.obs.SIDISH == "b") & (adata_p.obs.SIDISH == "b")).sum())
umap_combined.loc[(umap_combined['risk'] == 'h') & (adata_p.obs['SIDISH'] == 'h'), 'status'] = '{}'.format(((self.adata.obs.SIDISH == "h") & (adata_p.obs.SIDISH == "h")).sum())
umap_combined.loc[(umap_combined['risk'] == 'h') & (adata_p.obs['SIDISH'] == 'b'), 'status'] = '{}'.format(((self.adata.obs.SIDISH == "h") & (adata_p.obs.SIDISH == "b")).sum())
umap_combined.loc[(umap_combined['risk'] == 'b') & (adata_p.obs['SIDISH'] == 'h'), 'status'] = '{}'.format(((self.adata.obs.SIDISH == "b") & (adata_p.obs.SIDISH == "h")).sum())
palette = {
'{}'.format(((self.adata.obs.SIDISH == "b") & (adata_p.obs.SIDISH == "b")).sum()): 'darkgray',
'{}'.format(((self.adata.obs.SIDISH == "h") & (adata_p.obs.SIDISH == "h")).sum()): 'red',
'{}'.format(((self.adata.obs.SIDISH == "h") & (adata_p.obs.SIDISH == "b")).sum()): '#3DB1EA',
'{}'.format(((self.adata.obs.SIDISH == "b") & (adata_p.obs.SIDISH == "h")).sum()): 'purple',
'{}'.format(percentage_change): 'white'
}
plot_umap(ax, umap_combined, palette, percentage_change)
ax.set_title("In-Silico Knockout of {}".format(gene), fontsize=12, y=0.96)
# Remove empty subplots
if len(axs) > n_genes:
for ax in axs[n_genes:]:
ax.axis('off')
plt.show()
[docs]
def plot_perturbation_UMAP_differential(self, genes_of_interest, resolution=None, celltype=True, threshold=0.8):
"""
Generates UMAP visualizations for specified genes after in-silico perturbation.
Parameters:
- adata: AnnData object with latent embeddings.
- sidish: SIDISH object for annotation and processing.
- ppi_df: DataFrame containing the PPI network data.
- genes_of_interest (list): List of genes to visualize.
- output_path: Filepath for saving the generated UMAP plot.
- seed: Random seed for reproducibility. Default is 42.
"""
if not isinstance(genes_of_interest, list):
raise TypeError("genes_of_interest must be a list of gene names.")
# Dynamic subplot layout
n_genes = len(genes_of_interest)
n_cols = 2
n_rows = math.ceil(n_genes / n_cols)
fig, axs = plt.subplots(n_rows, n_cols, figsize=(8 * n_cols, 4.5 * n_rows), squeeze=False)
axs = axs.flatten()
for gene, ax in zip(genes_of_interest, axs):
self.adata = sc.read_h5ad("{}adata_SIDISH.h5ad".format(self.path))
self.ppi_handler = PPINetworkHandler(self.adata)
self.ppi_handler.load_network(threshold)
direct_neighbors, indirect_neighbors = self.ppi_handler.get_neighbors(gene)
neighbors = direct_neighbors + indirect_neighbors
network_df = self.ppi_handler.ppi_df[self.ppi_handler.ppi_df["Source"].isin(neighbors) | self.ppi_handler.ppi_df["Target"].isin(neighbors)]
if not network_df.empty:
adata_p = GenePerturbationUtils.adjust_expression(self.adata, gene, network_df)
else:
adata_p.X = GenePerturbationUtils.knockout_gene(self.adata, gene).tocsr()
adata_p = self.annotateCells(adata_p, self.percentile_cells, mode="no", perturbation=True)
# --- 2 one-sided Wilcoxon on risk-score delta ---
delta = adata_p.obs["perturbation_score"].values
self.delta_change[gene] = delta.mean()
self.adata = self.get_embedding(resolution=resolution, celltype=celltype)
self.set_adata()
umap_combined = pd.DataFrame(self.adata.obsm["X_umap"], columns=["UMAP1", "UMAP2"])
umap_combined.index = adata_p.obs.index.values
umap_combined['risk'] = adata_p.obs['perturbation_score'].values
plot_umap_differential(ax, umap_combined)
ax.set_title("In silico Knockout of {} (Change in score)".format(gene), fontsize=12, y=0.96)
# Remove empty subplots
if len(axs) > n_genes:
for ax in axs[n_genes:]:
ax.axis('off')
plt.show()
[docs]
def run_double_Perturbation(self,genes, top_n = 20, threshold=0.8):
self.percent_change_double, self.p_flip_double = {}, {}
self.b_dict_double = {}
self.b_to_h_dict_double = {}
self.h_to_b_dict_double = {}
self.h_dict_double = {}
pvals_flip = list(self.p_flip.values())
corrected_pvals_flip = multipletests(pvals_flip, alpha=0.05, method='fdr_bh')
corrected_pvalue_dict_flip = dict(zip(self.p_flip.keys(), corrected_pvals_flip[1]))
pvalue_df_flip = pd.DataFrame(list(corrected_pvalue_dict_flip.items()), columns=["Gene", "Pvalue"])
pvalue_df_flip.sort_values(by='Pvalue',ascending=True, inplace=True)
pvalue_df_flip = pvalue_df_flip[pvalue_df_flip.Pvalue < 0.05]
self.top_genes_flip = pvalue_df_flip.Gene.values
self.percentage_df_flip = pd.DataFrame([self.percent_change], index=None).T.reset_index()
self.percentage_df_flip.columns = ["Genes", "Scores"]
self.percentage_df_flip.sort_values(by=["Scores"], ascending=False, inplace=True)
self.adata = sc.read_h5ad("{}adata_SIDISH.h5ad".format(self.path))
self.top_genes_flip = self.percentage_df_flip.Genes.values
self.top_genes = self.top_genes_flip
print(self.top_genes_flip[:top_n] == genes)
all_combinations = list(itertools.permutations(self.top_genes[:top_n], 2))
n_hcells = (self.adata.obs.SIDISH == "h").sum()
self.percentage_double_dict = {}
self.pvalue_double_dict = {}
self.ppi_handler = PPINetworkHandler(self.adata)
self.ppi_handler.load_network(threshold)
for combination in tqdm(all_combinations[:]):
adata_p = self.adata.copy()
for g in combination:
direct_neighbors, indirect_neighbors = self.ppi_handler.get_neighbors(g)
neighbors = direct_neighbors + indirect_neighbors
network_df = self.ppi_handler.ppi_df[self.ppi_handler.ppi_df["Source"].isin(neighbors) | self.ppi_handler.ppi_df["Target"].isin(neighbors)]
if not network_df.empty:
adata_p = GenePerturbationUtils.adjust_expression(self.adata, g, network_df)
else:
adata_p.X = GenePerturbationUtils.knockout_gene(self.adata, g).tocsr()
adata_p = self.annotateCells(adata_p, self.percentile_cells, mode="no", perturbation=True)
h_to_b = ((self.adata.obs.SIDISH == "h") & (adata_p.obs.SIDISH == "b")).sum()
b_to_h = ((self.adata.obs.SIDISH == "b") & (adata_p.obs.SIDISH == "h")).sum()
self.percent_change_double["{}+{}".format(combination[0], combination[1])] = ((h_to_b - b_to_h) / n_hcells) * 100
# --- 1 one-sided sign test on label flips ---
flips = h_to_b + b_to_h
if flips: # avoid 0-trials edge case
p_flip_ = binomtest(k=h_to_b, n=flips, p=0.5, alternative="greater").pvalue
else:
p_flip_ = 1.0
self.p_flip_double["{}+{}".format(combination[0], combination[1])] = p_flip_
self.h_to_b_dict_double["{}+{}".format(combination[0], combination[1])] = h_to_b
self.b_to_h_dict_double["{}+{}".format(combination[0], combination[1])] = b_to_h
self.b_dict_double["{}+{}".format(combination[0], combination[1])] = ((self.adata.obs["SIDISH"] == "b") & (adata_p.obs["SIDISH"] == "b")).sum()
self.h_dict_double["{}+{}".format(combination[0], combination[1])] = ((self.adata.obs["SIDISH"] == "h") & (adata_p.obs["SIDISH"] == "h")).sum()
return self.percent_change_double, self.p_flip_double
[docs]
def run_double_Perturbation_score(self, genes, top_n = 20, threshold=0.8):
self.p_score_double = {}
self.delta_change_double = {}
self.b_dict_double = {}
self.b_to_h_dict_double = {}
self.h_to_b_dict_double = {}
self.h_dict_double = {}
pvals_score = list(self.p_score.values())
corrected_pvals_score = multipletests(pvals_score, alpha=0.05, method='fdr_bh')
corrected_pvalue_dict_score = dict(zip(self.p_score.keys(), corrected_pvals_score[1]))
pvalue_df_score = pd.DataFrame(list(corrected_pvalue_dict_score.items()), columns=["Gene", "Pvalue"])
pvalue_df_score.sort_values(by='Pvalue',ascending=True, inplace=True)
pvalue_df_score = pvalue_df_score[pvalue_df_score.Pvalue < 0.05]
self.top_genes_score = pvalue_df_score.Gene.values
self.percentage_df_score = pd.DataFrame([self.delta_change], index=None).T.reset_index()
self.percentage_df_score.columns = ["Genes", "Scores"]
self.percentage_df_score.sort_values(by=["Scores"], ascending=False, inplace=True)
self.top_genes_score = self.percentage_df_score.Genes.values
self.adata = sc.read_h5ad("{}adata_SIDISH.h5ad".format(self.path))
self.top_genes = self.top_genes_score
print(self.top_genes_score[:top_n] == genes)
all_combinations = list(itertools.permutations(self.top_genes[:top_n], 2))
self.percentage_double_dict = {}
self.pvalue_double_dict = {}
self.ppi_handler = PPINetworkHandler(self.adata)
self.ppi_handler.load_network(threshold)
for combination in tqdm(all_combinations[:]):
adata_p = self.adata.copy()
for g in combination:
direct_neighbors, indirect_neighbors = self.ppi_handler.get_neighbors(g)
neighbors = direct_neighbors + indirect_neighbors
network_df = self.ppi_handler.ppi_df[self.ppi_handler.ppi_df["Source"].isin(neighbors) | self.ppi_handler.ppi_df["Target"].isin(neighbors)]
if not network_df.empty:
adata_p = GenePerturbationUtils.adjust_expression(self.adata, g, network_df)
else:
adata_p.X = GenePerturbationUtils.knockout_gene(self.adata, g).tocsr()
adata_p = self.annotateCells(adata_p, self.percentile_cells, mode="no", perturbation=True)
# --- 2 one-sided Wilcoxon on risk-score delta ---
delta_double = adata_p.obs["perturbation_score"].values
self.delta_change_double["{}+{}".format(combination[0], combination[1])] = delta_double.mean()
_, p_score_ = wilcoxon(delta_double, alternative="greater")
self.p_score_double["{}+{}".format(combination[0], combination[1])] = p_score_
return self.delta_change_double, self.p_score_double
[docs]
def plot_double_Perturbation_Heatmap(self, percentage_double_dict, top_n=20):
percent_change_df = pd.DataFrame(list(self.percent_change.items()), columns=["Genes", "Scores"])
percent_change_df.sort_values(by='Scores', ascending=False, inplace=True)
percent_change_df.Genes.values[:top_n]
df = percent_change_df.iloc[:top_n].sort_values(by='Genes')
df = df.sort_values(by="Scores", ascending=False)
# Filter out zero values and create a DataFrame for the gene pairs
double_dict = percentage_double_dict
double_dict = {k: v for k, v in double_dict.items() if v != 0}
heatmap_data = pd.DataFrame([{'Gene1': k.split('+')[0], 'Gene2': k.split('+')[1], 'Value': v} for k, v in double_dict.items()])
heatmap_data = heatmap_data.sort_values(by="Value", ascending=True)
# Pivot the DataFrame to create a matrix suitable for a heatmap
heatmap_matrix = heatmap_data.pivot(index='Gene1', columns='Gene2', values='Value')
sorted_genes = df.Genes.values
# Reordering rows and columns of the matrix
heatmap_matrix = heatmap_matrix.loc[sorted_genes, sorted_genes]
for gene in sorted_genes:
score = self.percent_change[gene][id]
heatmap_matrix[gene] = heatmap_matrix[gene].fillna(score)
# Create the figure and subplots with different width ratios
fig, (ax1, ax2) = plt.subplots(1, 2, gridspec_kw={'width_ratios': [1, 15]}, figsize=(10, 8)) # Reduce width for 1D heatmap
# Plot 1D heatmap on the first axis (ax1)
sns.heatmap(df[['Scores']], cmap='Reds', cbar=True, yticklabels=df['Genes'], xticklabels=False, ax=ax1)
ax1.set_title('', fontsize=12)
ax1.set_xlabel('') # No x-axis label
ax1.set_ylabel('')
# Plot 2D heatmap on the second axis (ax2)
sns.heatmap(heatmap_matrix, cmap='Reds', annot=False, ax=ax2)
ax2.set_title('')
ax2.set_xlabel('')
ax2.set_ylabel('')
plt.show()