Source code for SIDISH.VAE

from SIDISH.VAE_ARCHITECTURE import ARCHITECTURE as architecture
import numpy as np
import torch
from torch.utils.data import TensorDataset, DataLoader
import pyro
import torch.optim as optim
from torch_geometric.loader import NeighborLoader
from torch_geometric.data import Data
import os

[docs] class VAE(): def __init__(self, epochs, adata, z_dim, layer_dims, lr, dropout, device, seed, gcn_dims=None): super(VAE, self).__init__() self.seed = seed torch.manual_seed(self.seed) torch.cuda.manual_seed(self.seed) torch.backends.cudnn.deterministic = True np.random.seed(self.seed) os.environ["PYTHONHASHSEED"] = str(self.seed) ## Model parameters self.epochs = epochs self.imput_dim = adata.X.shape[1] self.lr = lr self.z_dim = z_dim self.layer_dims = layer_dims self.dropout = dropout self.device = device ## Spatial GCN parameters if gcn_dims is not None: self.gcn_dims = gcn_dims self.use_spatial = True else: self.use_spatial = False
[docs] def initialize(self, adata, W=None, batch_size=1024, type="Normal", num_workers=8, spatial_graph=None, num_neighbors=5): ## Initialise model self.adata = adata if self.use_spatial and spatial_graph is not None: # Spatial model = architecture(self.imput_dim, self.z_dim, self.layer_dims, self.seed, self.dropout, self.gcn_dims) else: # Non-spatial model = architecture(self.imput_dim, self.z_dim, self.layer_dims,self.seed, self.dropout) self.model = model.to(self.device) self.optimizer = optim.Adam(lr=self.lr, params=self.model.parameters()) ## Initialise weights if W is None: W = np.ones(self.adata_train.X.shape) else: W = W self.W = W ## Get the cells by genes matrix X from the adata variable if type == 'Dense': data_list = [np.array(self.adata.X.todense()), self.W] data_list_total = [np.array(self.adata.X.todense()), self.W] else: data_list = [np.array(self.adata.X), self.W] data_list_total = [np.array(self.adata.X), self.W] ## Prep single cell data for training in VAE -- cell by genes matrix X with gene weight matrix W data_list = [torch.from_numpy(np.array(d)).type(torch.float) for d in data_list] dataset = TensorDataset(data_list[0].float(), data_list[1].float()) kwargs = {'num_workers': num_workers, 'pin_memory':True} self.train_loader = DataLoader(dataset, batch_size=batch_size, **kwargs, drop_last=True) ## Setup mini-batching for the full dataset (prediction) data_list_total = [torch.from_numpy(np.array(d)).type(torch.float) for d in data_list_total] total_dataset = TensorDataset(data_list_total[0].float(), data_list_total[1].float()) kwargs = {'num_workers': num_workers, 'pin_memory':True} self.total_loader = DataLoader(total_dataset, batch_size=self.adata.X.shape[0], **kwargs) ## Setup spatial graph and neighbor sampling self.use_neighbor_sampling = False if self.use_spatial and spatial_graph is not None: self.edge_index, self.edge_weight = spatial_graph if self.edge_weight is None: self.edge_weight = torch.ones(self.edge_index.size(1)) # Create PyTorch Geometric Data object self.graph_data = Data( x=self.X_tensor, edge_index=self.edge_index, edge_attr=self.edge_weight, y=self.W_tensor) # Create neighbor sampler for mini-batch training self.use_neighbor_sampling = True self.train_loader_spatial = NeighborLoader(self.graph_data, num_neighbors=[num_neighbors], batch_size=batch_size, shuffle=True, num_workers=num_workers) if self.use_spatial: return self.model, self.train_loader_spatial else: return self.model, self.train_loader
[docs] def train(self): # training loop # here y is the weight self.loss = [] for epoch in range(self.epochs): epoch_loss = 0. samples_processed = 0 # Choose training mode based on spatial data availability if self.use_spatial and self.use_neighbor_sampling: # Train with spatial neighborhood sampling for batch in self.train_loader_spatial: batch = batch.to(self.device) batch_size = batch.x.size(0) self.optimizer.zero_grad() # Forward pass with sampled neighborhood mu_decoder, dropout_logits, mu_encoder, logvar = self.model( batch.x, batch.edge_index, batch.edge_attr) loss = self.model.loss_function(batch.x, batch.y, mu_decoder, dropout_logits, mu_encoder, logvar) loss.backward() epoch_loss += loss.item() * batch_size samples_processed += batch_size self.optimizer.step() else: # Standard training without spatial data for x, y in self.train_loader: x = x.to(self.device, non_blocking=True) y_ = y.to(self.device, non_blocking=True) batch_size = x.size(0) self.optimizer.zero_grad() mu_decoder, dropout_logits, mu_encoder, logvar = self.model(x) loss = self.model.loss_function(x, y_, mu_decoder, dropout_logits, mu_encoder, logvar) loss.backward() epoch_loss += loss.item() * batch_size samples_processed += batch_size self.optimizer.step() # Calculate average loss for the epoch total_epoch_loss_train = epoch_loss / samples_processed self.epoch_loss = total_epoch_loss_train self.loss.append(-total_epoch_loss_train) print("[epoch %03d] average training loss: %.4f" %(epoch, self.epoch_loss))
[docs] def getBaseEncoder(self): return self.model.base_encoder
[docs] def getLoss(self): self.error = np.array([i for i in self.loss]) return self.error
[docs] def getEmbedding(self, clustering=True): """ Get latent embeddings for the entire dataset """ self.model.eval() if self.use_spatial and self.use_neighbor_sampling: # For spatial data, process in chunks to avoid OOM all_embeddings = [] batch_size = 256 # Process in chunks num_cells = self.X_tensor.shape[0] for i in range(0, num_cells, batch_size): end_idx = min(i + batch_size, num_cells) batch_indices = list(range(i, end_idx)) # Create subgraph for this batch and its neighbors node_idx = torch.tensor(batch_indices, dtype=torch.long) subgraph_loader = NeighborLoader( self.graph_data, num_neighbors=[5], batch_size=len(batch_indices), input_nodes=node_idx, shuffle=False, num_workers=4 ) # Process each subgraph for batch in subgraph_loader: batch = batch.to(self.device) with torch.no_grad(): # Get embeddings z = self.model.get_latent_representation( batch.x, batch.edge_index, batch.edge_attr ) # Extract only the embeddings for the central nodes (not neighbors) # The central nodes are always the first n nodes in the batch central_node_embeddings = z[:len(batch_indices)].cpu().numpy() all_embeddings.append(central_node_embeddings) # Combine all embeddings self.TZ = np.vstack(all_embeddings).tolist() else: # For non-spatial data, use the standard approach self.TZ = [] with torch.no_grad(): for x, y in self.total_loader_nonspatial: pyro.clear_param_store() x = x.to(self.device, non_blocking=True) z = self.model.get_latent_representation(x) zz = z.cpu().detach().numpy().tolist() self.TZ += zz if clustering: self.adata.obsm['latent'] = np.array(self.TZ).astype(np.float32) return self.adata
[docs] def getBaseEmbedding(self, clustering=True): """ Get embeddings using only the base encoder (for Cox regression compatibility) """ self.base_TZ = [] self.model.eval() with torch.no_grad(): for x, y in self.total_loader_nonspatial: pyro.clear_param_store() x = x.to(self.device, non_blocking=True) # Get latent representation with base encoder only (no spatial) z = self.model.get_base_latent_representation(x) zz = z.cpu().detach().numpy().tolist() self.base_TZ += zz if clustering: self.adata.obsm['base_latent'] = np.array(self.base_TZ).astype(np.float32) return self.adata