Source code for scfocus.focus

from .environment import Env, ReplayBuffer, train_off_policy
from .model import SAC
import numpy as np
import torch
import tqdm
import time
from scipy.stats import multivariate_normal
from sklearn.preprocessing import minmax_scale
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')  

[docs] class focus: """ Focus class for performing advanced reinforcement learning-based analysis on single-cell data. This class utilizes the Soft Actor-Critic (SAC) reinforcement learning framework to enhance cell subtype discrimination and identify distinct lineage branches within single-cell data. It manages the environment, memory buffers, and the ensemble of SAC agents to train and evaluate the model over multiple episodes. Parameters ---------- f : array-like Latent space of the original data, with shape (num_samples, num_features). hidden_dim : int, optional Number of hidden units in the neural networks, by default 128. n : int, optional Number of agents or parallel environments, by default 8. max_steps : int, optional Maximum number of steps per episode, by default 5. pct_samples : float, optional Percentage of samples to use for each state, by default 0.125. n_states : int, optional Number of state variables, by default 2. err_scale : float, optional Error scaling factor for reward calculation, by default 1. bins : int, optional Number of bins for histogram-based state discretization, by default 15. capacity : float, optional Capacity of the replay buffer, by default 1e4. actor_lr : float, optional Learning rate for the actor network, by default 1e-4. critic_lr : float, optional Learning rate for the critic network, by default 1e-3. alpha_lr : float, optional Learning rate for the entropy coefficient, by default 1e-4. target_entropy : float, optional Target entropy for the SAC algorithm, by default -1. tau : float, optional Soft update coefficient for target networks, by default 5e-3. gamma : float, optional Discount factor for future rewards, by default 0.99. num_episodes : int, optional Number of training episodes, by default 1000. batch_size : int, optional Batch size for training, by default 64. res : float, optional Resolution parameter for merging focus patterns, by default 0.05. device : torch.device, optional Device to run the computations on (e.g., CPU or GPU), by default torch.device('cpu'). """
[docs] def __init__(self, f, hidden_dim=128, n=8, max_steps=5, pct_samples=.125, n_states=2, err_scale=1, bins=15, capacity=1e4, actor_lr=1e-4, critic_lr=1e-3, alpha_lr=1e-4, target_entropy=-1, tau=5e-3, gamma=.99, num_episodes=1e3, batch_size=64, res=.05, device=torch.device('cpu')): """ Initialize the Focus class. Parameters ---------- f : array-like Latent space of the original data, with shape (num_samples, num_features). hidden_dim : int, optional Number of hidden units in the neural networks, by default 128. n : int, optional Number of agents or parallel environments, by default 8. max_steps : int, optional Maximum number of steps per episode, by default 5. pct_samples : float, optional Percentage of samples to use for each state, by default 0.125. n_states : int, optional Number of state variables, by default 2. err_scale : float, optional Error scaling factor for reward calculation, by default 1. bins : int, optional Number of bins for histogram-based state discretization, by default 15. capacity : float, optional Capacity of the replay buffer, by default 1e4. actor_lr : float, optional Learning rate for the actor network, by default 1e-4. critic_lr : float, optional Learning rate for the critic network, by default 1e-3. alpha_lr : float, optional Learning rate for the entropy coefficient, by default 1e-4. target_entropy : float, optional Target entropy for the SAC algorithm, by default -1. tau : float, optional Soft update coefficient for target networks, by default 5e-3. gamma : float, optional Discount factor for future rewards, by default 0.99. num_episodes : int, optional Number of training episodes, by default 1000. batch_size : int, optional Batch size for training, by default 64. res : float, optional Resolution parameter for merging focus patterns, by default 0.05. device : torch.device, optional Device to run the computations on (e.g., CPU or GPU), by default torch.device('cpu'). """ self.state_d = (2 + bins) * n_states * n self.hidden_dim = hidden_dim self.action_d = 2 * n_states * n self.action_space = (f[:, :n_states].min().item(), f[:, :n_states].max().item()) self.actor_lr = actor_lr self.critic_lr = critic_lr self.alpha_lr = alpha_lr self.target_entropy = target_entropy self.tau = tau self.gamma = gamma self.device = device self.capacity = capacity self.ensemble = [] self.env = Env(n, f, max_steps, pct_samples, n_states, err_scale, bins) self.memory = [] self.max_steps = max_steps self.num_episodes = num_episodes self.minimal_size = num_episodes / 10 * max_steps self.batch_size = batch_size self.res = res self.fp = [] self.r = [] self.e = []
def meta_focusing(self, n): """ Perform meta focusing by iteratively fitting the ensemble and refining focus. Parameters ---------- n : int Number of meta focusing iterations to perform. Returns ------- self : Focus Returns the instance itself for method chaining. """ start = time.time() for i in range(n): self.meta_fit() self.focus_fit(10) end = time.time() print(f'Meta focusing time used: {(end - start):.2f} seconds') return self def meta_fit(self): """ Perform a single meta fitting step by training a new SAC agent and updating memory. Returns ------- self : Focus Returns the instance itself for method chaining. """ start = time.time() self.ensemble.append(SAC( self.state_d, self.hidden_dim, self.action_d, self.action_space, self.actor_lr, self.critic_lr, self.alpha_lr, self.target_entropy, self.tau, self.gamma, self.device )) self.memory.append(ReplayBuffer(self.capacity)) r, e = train_off_policy( self.env, self.ensemble[-1], self.memory[-1], self.num_episodes, self.minimal_size, self.batch_size ) self.r.append(np.vstack(r).ravel()) self.e.append(np.vstack(e).ravel()) end = time.time() print(f'Meta fitting time used: {(end - start):.2f} seconds') return self def focus_fit(self, episodes): """ Fit the focus model over a specified number of episodes. This method iteratively updates the focus weights based on the actions taken by the SAC agent within the environment. It monitors the convergence by checking the change in weights and stops early if the change is below a threshold. Parameters ---------- episodes : int Number of episodes to train the focus model. Returns ------- self : Focus Returns the instance itself for method chaining. """ start = time.time() episode_weight = [] with tqdm.tqdm(total=int(episodes), desc='Focus fitting...') as pbar: self.weights = None for i_episode in range(int(episodes)): ls_weights = [] state = self.env.reset() for i in range(self.max_steps): with torch.no_grad(): action = self.ensemble[-1].take_action(state) action = action.ravel() mus = action[:int(action.shape[-1]/2)] logstds = action[int(action.shape[-1]/2):] L = self.env.n_states bra_weights = [] for j in range(self.env.n): mu = mus[L*j:L*(j+1)] logstd = logstds[L*j:L*(j+1)] std = np.log1p(np.exp(logstd)) mn = multivariate_normal(mu, np.diag(self.env.sigma / (1 + np.exp(-std)))) weights = minmax_scale(mn.logpdf(self.env.f[:, :self.env.n_states])) bra_weights.append(weights) ls_weights.append(bra_weights) next_state, _, _ = self.env.step(action) state = next_state weights = np.array(ls_weights) if self.weights is not None: err = np.linalg.norm(weights - self.weights) if err < 3 and i_episode > 2: break self.weights = np.array(ls_weights) episode_weight.append(ls_weights) pbar.update(1) fp = np.array(episode_weight) self.fp.append(fp.T.mean(axis=-1).mean(axis=-1)) end = time.time() print(f'Focus fitting time used: {(end - start):.2f} seconds') return self def merge_fp2(self): """ Merge focus patterns by performing two levels of merging. This method first calls `merge_fp` to perform initial merging of focus patterns, then concatenates all merged focus patterns into a single array for further processing. Returns ------- self : Focus Returns the instance itself for method chaining. """ self.merge_fp() self.fp = [np.hstack(self.mfp)] self.merge_fp() return self def merge_fp(self): """ Merge focus patterns based on similarity thresholds. This method groups focus patterns that have significant overlap based on the specified resolution parameter. It computes the mean of grouped focus patterns to create merged focus patterns. Returns ------- self : Focus Returns the instance itself after merging focus patterns. """ self.mfp = [] for fp in self.fp: n = int(fp.shape[0] * self.res) ord_indices = np.argsort(fp, axis=0)[-n:, :] groups = [] for i in range(fp.shape[1]): if any([i in g for g in groups]): continue g_ = [i] if i != fp.shape[1] - 1: for j in range(i + 1, fp.shape[1]): if len(set(ord_indices[:, i]).intersection(set(ord_indices[:, j]))) > 0.25 * n: g_.append(j) groups.append(g_) else: groups.append(g_) mfp = [] for g in groups: if len(g) > 1: mfp.append(fp[:, g].mean(axis=1)[:, np.newaxis]) else: mfp.append(fp[:, g]) mfp = np.hstack(mfp) self.mfp.append(mfp) return self def focus_diff(self): """ Calculate entropy and pseudotime based on merged focus patterns. This method computes the entropy of the merged focus patterns and scales it to derive pseudotime values, which can be used for further analysis of cell differentiation trajectories. Returns ------- self : Focus Returns the instance itself with updated entropy and pseudotime attributes. """ self.entropy = (self.mfp * -np.log(self.mfp)).sum(axis=1) self.pseudotime = 1 - minmax_scale(self.entropy) return self