单细胞测序技术通过捕获单个细胞的瞬时分子图谱,为重建动态生物学过程提供了可能。在发育和疾病过程中,细胞命运决定是由确定性调控事件和随机性调控事件之间复杂的平衡所协调的。漂移-扩散方程在从高维单细胞测量数据建模单细胞动态方面具有显著效果。虽然现有的解决方案能够在细胞状态水平上描述这些方程中漂移项所关联的确定性动态,但扩散项在所有细胞状态中都被建模为常数。为了充分理解发育和疾病中的动态调控逻辑,需要建立能够明确关注确定性与随机性生物学之间平衡的模型。
为解决这些局限性,作者开发了scDiffEq,一个生成式框架,用于学习能够近似生物学确定性和随机性动态的神经随机微分方程。通过对多能祖细胞进行计算机模拟扰动,我们发现scDiffEq能够准确重现CRISP 扰动造血的动态过程。这一方法可以从单一时间点的单细胞数据的动态变化进行建模,推广到谱系示踪或多时间点数据集。利用scDiffEq,可以模拟高分辨率的发育细胞轨迹,并对其漂移和扩散进行建模,从而帮助研究这些轨迹的时间依赖性基因水平动态。
import scanpy as sc
import scdiffeq as sdq
import pandas as pd
import umap
adata_ref = sc.read_h5ad('scdiffeq_data/larry/larry.h5ad')
adata_ref
AnnData object with n_obs × n_vars = 130887 × 2492
obs: 'Library', 'Cell barcode', 'Time point', 'Starting population', 'Cell type annotation', 'Well', 'SPRING-x', 'SPRING-y', 'clone_idx', 'fate_observed', 't0_fated', 'train', 'ct_score', 'ct_pseudotime', 'ct_num_exp_genes'
var: 'gene_ids', 'hv_gene', 'must_include', 'exclude', 'use_genes', 'ct_gene_corr', 'ct_correlates'
uns: 'fate_counts', 'h5ad_path', 'time_occupance'
obsm: 'X_clone', 'X_pca', 'X_scaled', 'X_umap', 'cell_fate_df'
layers: 'X_scaled'
UMAP = umap.UMAP(n_components=2)
adata_ref.obsm["X_umap"] = UMAP.fit_transform(adata_ref.obsm["X_pca"])
nm_clones = adata_ref.uns["fate_counts"][["Monocyte", "Neutrophil"]].dropna().index
adata_ref.obs['nm_clones'] = adata_ref.obs["clone_idx"].isin(nm_clones)
MASK = adata_ref.obs["Cell type annotation"].isin(["Monocyte", "Neutrophil", "Undifferentiated"]) & adata_ref.obs['nm_clones']
adata = adata_ref[MASK].copy()
del adata.obsm['X_clone']
del adata.obsm["cell_fate_df"]
adata.obs.index = adata.obs.reset_index(drop=True).index.astype(str)
准备好数据,建模过程下面几行代码即可,虽然建模可以在cpu环境运行,但速度太慢,还是用gpu来加速吧:
model = sdq.scDiffEq(adata)
model.fit(train_epochs=1500)
model.drift()
model.diffusion()
不过,train_epochs参数会影响模型的拟合效果,可以用下面的方式来查看模型的损失情况来确定:
import matplotlib.pyplot as plt
train_loss = model.metrics[['epoch', 'epoch_train_loss']].dropna().reset_index(drop=True)
val_loss = model.metrics[['epoch', 'epoch_validation_loss']].dropna().reset_index(drop=True)
# Plot raw values as scatter points
plt.scatter(train_loss['epoch'], train_loss['epoch_train_loss'], color='blue', s=8, alpha=0.15)
plt.scatter(val_loss['epoch'], val_loss['epoch_validation_loss'], color='orange', s=8, alpha=0.15)
# Compute moving averages
train_loss_ma = train_loss['epoch_train_loss'].rolling(window=10, min_periods=1, center=True).mean()
val_loss_ma = val_loss['epoch_validation_loss'].rolling(window=10, min_periods=1, center=True).mean()
# Plot moving averages as lines
plt.plot(train_loss['epoch'], train_loss_ma, color='blue', linewidth=2, label='Train Loss')
plt.plot(val_loss['epoch'], val_loss_ma, color='orange', linewidth=2, label='Val Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()

建模完成,接下来看看细胞的拟时序:
sdq.tl.velocity_graph(model.adata)
cmap = {"Undifferentiated": "dimgrey", "Neutrophil": "#023047", "Monocyte": "#F08700"}
sdq.pl.velocity_stream(model.adata, c="diffusion", scatter_kwargs={"vmax": 3})

做细胞发育轨迹推断需要先选择起始细胞:
progenitor = (model.adata.obs.loc[model.adata.obs["Time point"] == model.adata.obs["Time point"].min()].loc[model.adata.obs["Cell type annotation"] == "Undifferentiated"].sample(3))
progenitor
Library Cell barcode Time point Starting population Cell type annotation Well SPRING-x SPRING-y clone_idx ... ct_pseudotime ct_num_exp_genes nm_clones KEGG test fit_train fit_val drift diffusion
6135 LK_d2 AATACATC-ATCCGCTA 2.0 Lin-Kit+Sca1- Undifferentiated 0 549.977 144.213 4897.0 ... 0.8786757914387843 352 True 1 False True False 43.883648 0.990604
4951 LSK_d2_3 GGGCATCA-TTTATCAC 2.0 Lin-Kit+Sca1+ Undifferentiated 0 298.903 336.970 4445.0 ... 0.9406133190993315 148 True 1 False True False 39.128658 1.169227
2407 d2_2 AGTCACAA-TAGAAATG 2.0 Lin-Kit+Sca1- Undifferentiated 0 353.290 719.613 4061.0 ... 0.9542395878430592 273 True 1 False True False 49.687836 1.136787
[3 rows x 22 columns]
然后,利用模型模拟细胞的发育过程,其中N这个参数比较特别,决定了每个起始细胞在各个时间生成的细胞数量。
adata_sim = sdq.tl.simulate(adata, idx=progenitor.index, N=512, diffeq=model.DiffEq, time_key="Time point")
sdq.tl.annotate_cell_state(adata_sim, kNN=model.kNN, obs_key="Cell type annotation")
sdq.tl.annotate_cell_fate(adata_sim, state_key="Cell type annotation")
adata_sim.obs
t z0_idx sim_i sim Cell type annotation fate
0 2.0 6135 0 61350 Undifferentiated Neutrophil
1 2.0 4951 0 49510 Undifferentiated Monocyte
2 2.0 2407 0 24070 Undifferentiated Undifferentiated
3 2.0 6135 1 61351 Undifferentiated Neutrophil
4 2.0 4951 1 49511 Undifferentiated Undifferentiated
... ... ... ... ... ... ...
62971 6.0 4951 510 4951510 Undifferentiated Undifferentiated
62972 6.0 2407 510 2407510 Undifferentiated Undifferentiated
62973 6.0 6135 511 6135511 Undifferentiated Undifferentiated
62974 6.0 4951 511 4951511 Undifferentiated Undifferentiated
62975 6.0 2407 511 2407511 Neutrophil Neutrophil
[62976 rows x 6 columns]
模拟完成后,就可以清晰地看到随着时间推移过程中细胞的发育轨迹:
pd.crosstab(adata_sim.obs.t, adata_sim.obs['Cell type annotation'])
t Monocyte Neutrophil Undifferentiated
2.0 0 0 1536
2.1 0 0 1536
2.2 0 0 1536
2.3 0 0 1536
2.4 0 0 1536
2.5 0 0 1536
2.6 0 0 1536
2.7 0 0 1536
2.8 0 0 1536
2.9 0 1 1535
3.0 0 1 1535
3.1 0 5 1531
3.2 0 14 1522
3.3 0 54 1482
3.4 0 80 1456
3.5 1 142 1393
3.6 1 170 1365
3.7 2 235 1299
3.8 3 278 1255
3.9 5 336 1195
4.0 4 369 1163
4.1 4 403 1129
4.2 4 430 1102
4.3 5 466 1065
4.4 6 506 1024
4.5 8 530 998
4.6 14 555 967
4.7 17 571 948
4.8 22 618 896
4.9 32 641 863
5.0 35 666 835
5.1 47 687 802
5.2 49 709 778
5.3 53 731 752
5.4 59 746 731
5.5 67 764 705
5.6 74 781 681
5.7 81 800 655
5.8 88 814 634
5.9 100 826 610
6.0 108 846 582
接着,可以研究这些模拟细胞的基因表达情况。由于模拟过程使用的是PCA数据,所以需要将其逆转为基因表达值,这个过程需要提供StandardScaler和PCA模型,可以用sklearn生成。
scaler_model = sdq.io.read_pickle("scdiffeq_data/larry/scaler.pkl")
PCA = sdq.io.read_pickle("scdiffeq_data/larry/pca.pkl")
sdq.tl.annotate_gene_features(adata_sim, adata, PCA=PCA, gene_id_key="gene_ids")
sdq.tl.invert_scaled_gex(adata_sim, scaler_model = scaler_model)
adata_sim.obsm["X_gene_inv"]
gene_ids 1110002J07Rik 1110032F04Rik 1500002F19Rik 1500026H17Rik 1600010M07Rik 1700001C19Rik 1700001O22Rik 1700011I03Rik 1700016F12Rik ... Zfp963 Zfpm2 Zkscan2 Zmiz1os1 Zmpste24 Zmynd15 Znfx1 Zscan2 Zyx
0 0.000935 0.000000 0.000922 0.001447 0.003434 0.000488 0.000000 0.004353 0.000000 ... 0.022585 0.001477 0.000240 0.000000 0.079937 0.003646 0.133242 0.001753 0.118178
1 0.000524 0.000925 0.000347 0.000000 0.008591 0.003485 0.001161 0.003041 0.000000 ... 0.012729 0.000346 0.000589 0.000000 0.236309 0.000000 0.033923 0.001078 0.095812
2 0.000000 0.000000 0.000982 0.004394 0.009979 0.005414 0.000640 0.000962 0.000000 ... 0.012827 0.000000 0.001514 0.000268 0.235095 0.002141 0.040394 0.002379 0.100079
3 0.000935 0.000000 0.000922 0.001447 0.003434 0.000488 0.000000 0.004353 0.000000 ... 0.022585 0.001477 0.000240 0.000000 0.079937 0.003646 0.133242 0.001753 0.118178
4 0.000524 0.000925 0.000347 0.000000 0.008591 0.003485 0.001161 0.003041 0.000000 ... 0.012729 0.000346 0.000589 0.000000 0.236309 0.000000 0.033923 0.001078 0.095812
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
62971 0.001237 0.000000 0.001550 0.001806 0.001409 0.000161 0.002135 0.000567 0.000000 ... 0.011184 0.000052 0.000041 0.000352 0.317569 0.000000 0.092694 0.001254 0.113291
62972 0.000011 0.000874 0.000582 0.000515 0.002702 0.001036 0.000380 0.001900 0.000000 ... 0.012742 0.000551 0.000453 0.001684 0.238807 0.002792 0.123798 0.000524 0.114049
62973 0.000006 0.004426 0.008713 0.047629 0.000000 0.000000 0.000000 0.002204 0.142455 ... 0.023694 0.000000 0.000000 0.004168 0.000000 0.000000 0.452125 0.014248 0.387427
62974 0.001503 0.004170 0.000572 0.000000 0.004613 0.000000 0.002392 0.000000 0.001085 ... 0.019704 0.000207 0.002183 0.000000 0.269295 0.000000 0.002888 0.000000 0.069778
62975 0.000383 0.000039 0.001099 0.003847 0.003830 0.000936 0.000366 0.001295 0.015110 ... 0.013344 0.000756 0.000180 0.000000 0.693576 0.010397 0.160073 0.001989 0.230127
[62976 rows x 2492 columns]
下面就可以看基因的表达情况了:
import cellplots
def mean_and_std_expr(df, adata_sim, gene):
x = adata_sim[df.index].obsm["X_gene_inv"][gene]
return pd.Series({'mean': x.mean(), 'std': x.std()})
genes = ["Gfi1", "Elane", "Mpo", "Gstm1", "Mmp8", "Gata2"]
means = []
stds = []
for gene in genes:
res = adata_sim.obs.groupby(["t", "fate"]).apply(mean_and_std_expr, adata_sim=adata_sim, gene=gene)
mean_df = res['mean'].unstack()
std_df = res['std'].unstack()
means.append(mean_df)
stds.append(std_df)
fig, axes = cellplots.plot(6, 3, height=0.65, width=0.8, wspace=0.4, hspace=0.4, x_label=["t (d)"], y_label=["norm. expr"], title=genes, delete=[["top", "right"]] * 3)
for en, (mean_df, std_df) in enumerate(zip(means, stds)):
for col in mean_df:
if col != "Undifferentiated":
color = cmap[col]
# Plot mean with line
axes[en].plot(mean_df.index, mean_df[col], label=col, c=color)
lower = mean_df[col] - std_df[col]
upper = mean_df[col] + std_df[col]
axes[en].fill_between(mean_df.index, lower, upper, color=color, alpha=0.25)
axes[en].legend(facecolor="None", edgecolor="None")

虽然用了软件文档里面的数据,但结果还是差别有点大,可见有时候重现别人的结果也不是一件容易的事。