scdiffeq | 建模多时间点单细胞数据的发育轨迹

单细胞测序技术通过捕获单个细胞的瞬时分子图谱,为重建动态生物学过程提供了可能。在发育和疾病过程中,细胞命运决定是由确定性调控事件和随机性调控事件之间复杂的平衡所协调的。漂移-扩散方程在从高维单细胞测量数据建模单细胞动态方面具有显著效果。虽然现有的解决方案能够在细胞状态水平上描述这些方程中漂移项所关联的确定性动态,但扩散项在所有细胞状态中都被建模为常数。为了充分理解发育和疾病中的动态调控逻辑,需要建立能够明确关注确定性与随机性生物学之间平衡的模型。

为解决这些局限性,作者开发了scDiffEq,一个生成式框架,用于学习能够近似生物学确定性和随机性动态的神经随机微分方程。通过对多能祖细胞进行计算机模拟扰动,我们发现scDiffEq能够准确重现CRISP 扰动造血的动态过程。这一方法可以从单一时间点的单细胞数据的动态变化进行建模,推广到谱系示踪或多时间点数据集。利用scDiffEq,可以模拟高分辨率的发育细胞轨迹,并对其漂移和扩散进行建模,从而帮助研究这些轨迹的时间依赖性基因水平动态。

import scanpy as sc
import scdiffeq as sdq
import pandas as pd
import umap

adata_ref = sc.read_h5ad('scdiffeq_data/larry/larry.h5ad')
adata_ref
AnnData object with n_obs × n_vars = 130887 × 2492
    obs: 'Library', 'Cell barcode', 'Time point', 'Starting population', 'Cell type annotation', 'Well', 'SPRING-x', 'SPRING-y', 'clone_idx', 'fate_observed', 't0_fated', 'train', 'ct_score', 'ct_pseudotime', 'ct_num_exp_genes'
    var: 'gene_ids', 'hv_gene', 'must_include', 'exclude', 'use_genes', 'ct_gene_corr', 'ct_correlates'
    uns: 'fate_counts', 'h5ad_path', 'time_occupance'
    obsm: 'X_clone', 'X_pca', 'X_scaled', 'X_umap', 'cell_fate_df'
    layers: 'X_scaled'

UMAP = umap.UMAP(n_components=2)
adata_ref.obsm["X_umap"] = UMAP.fit_transform(adata_ref.obsm["X_pca"])

nm_clones = adata_ref.uns["fate_counts"][["Monocyte", "Neutrophil"]].dropna().index
adata_ref.obs['nm_clones'] = adata_ref.obs["clone_idx"].isin(nm_clones)
MASK = adata_ref.obs["Cell type annotation"].isin(["Monocyte", "Neutrophil", "Undifferentiated"]) & adata_ref.obs['nm_clones']

adata = adata_ref[MASK].copy()
del adata.obsm['X_clone']
del adata.obsm["cell_fate_df"]
adata.obs.index = adata.obs.reset_index(drop=True).index.astype(str)

准备好数据,建模过程下面几行代码即可,虽然建模可以在cpu环境运行,但速度太慢,还是用gpu来加速吧:

model = sdq.scDiffEq(adata)
model.fit(train_epochs=1500)
model.drift()
model.diffusion()

不过,train_epochs参数会影响模型的拟合效果,可以用下面的方式来查看模型的损失情况来确定:

import matplotlib.pyplot as plt

train_loss = model.metrics[['epoch', 'epoch_train_loss']].dropna().reset_index(drop=True)
val_loss = model.metrics[['epoch', 'epoch_validation_loss']].dropna().reset_index(drop=True)

# Plot raw values as scatter points
plt.scatter(train_loss['epoch'], train_loss['epoch_train_loss'], color='blue', s=8, alpha=0.15)
plt.scatter(val_loss['epoch'], val_loss['epoch_validation_loss'], color='orange', s=8, alpha=0.15)

# Compute moving averages
train_loss_ma = train_loss['epoch_train_loss'].rolling(window=10, min_periods=1, center=True).mean()
val_loss_ma = val_loss['epoch_validation_loss'].rolling(window=10, min_periods=1, center=True).mean()

# Plot moving averages as lines
plt.plot(train_loss['epoch'], train_loss_ma, color='blue', linewidth=2, label='Train Loss')
plt.plot(val_loss['epoch'], val_loss_ma, color='orange', linewidth=2, label='Val Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()

建模完成,接下来看看细胞的拟时序:

sdq.tl.velocity_graph(model.adata)
cmap = {"Undifferentiated": "dimgrey", "Neutrophil": "#023047", "Monocyte": "#F08700"}
sdq.pl.velocity_stream(model.adata, c="diffusion", scatter_kwargs={"vmax": 3})

做细胞发育轨迹推断需要先选择起始细胞:

progenitor = (model.adata.obs.loc[model.adata.obs["Time point"] == model.adata.obs["Time point"].min()].loc[model.adata.obs["Cell type annotation"] == "Undifferentiated"].sample(3))
progenitor
       Library       Cell barcode  Time point Starting population Cell type annotation  Well  SPRING-x  SPRING-y  clone_idx  ...       ct_pseudotime  ct_num_exp_genes  nm_clones KEGG   test fit_train  fit_val      drift  diffusion
6135     LK_d2  AATACATC-ATCCGCTA         2.0       Lin-Kit+Sca1-     Undifferentiated     0   549.977   144.213     4897.0  ...  0.8786757914387843               352       True    1  False      True    False  43.883648   0.990604
4951  LSK_d2_3  GGGCATCA-TTTATCAC         2.0       Lin-Kit+Sca1+     Undifferentiated     0   298.903   336.970     4445.0  ...  0.9406133190993315               148       True    1  False      True    False  39.128658   1.169227
2407      d2_2  AGTCACAA-TAGAAATG         2.0       Lin-Kit+Sca1-     Undifferentiated     0   353.290   719.613     4061.0  ...  0.9542395878430592               273       True    1  False      True    False  49.687836   1.136787

[3 rows x 22 columns]

然后,利用模型模拟细胞的发育过程,其中N这个参数比较特别,决定了每个起始细胞在各个时间生成的细胞数量。

adata_sim = sdq.tl.simulate(adata, idx=progenitor.index, N=512, diffeq=model.DiffEq, time_key="Time point")
sdq.tl.annotate_cell_state(adata_sim, kNN=model.kNN, obs_key="Cell type annotation")
sdq.tl.annotate_cell_fate(adata_sim, state_key="Cell type annotation")
adata_sim.obs
         t z0_idx  sim_i      sim Cell type annotation              fate
0      2.0   6135      0    61350     Undifferentiated        Neutrophil
1      2.0   4951      0    49510     Undifferentiated          Monocyte
2      2.0   2407      0    24070     Undifferentiated  Undifferentiated
3      2.0   6135      1    61351     Undifferentiated        Neutrophil
4      2.0   4951      1    49511     Undifferentiated  Undifferentiated
...    ...    ...    ...      ...                  ...               ...
62971  6.0   4951    510  4951510     Undifferentiated  Undifferentiated
62972  6.0   2407    510  2407510     Undifferentiated  Undifferentiated
62973  6.0   6135    511  6135511     Undifferentiated  Undifferentiated
62974  6.0   4951    511  4951511     Undifferentiated  Undifferentiated
62975  6.0   2407    511  2407511           Neutrophil        Neutrophil

[62976 rows x 6 columns]

模拟完成后,就可以清晰地看到随着时间推移过程中细胞的发育轨迹:

pd.crosstab(adata_sim.obs.t, adata_sim.obs['Cell type annotation'])
t    Monocyte  Neutrophil  Undifferentiated
2.0         0           0              1536
2.1         0           0              1536
2.2         0           0              1536
2.3         0           0              1536
2.4         0           0              1536
2.5         0           0              1536
2.6         0           0              1536
2.7         0           0              1536
2.8         0           0              1536
2.9         0           1              1535
3.0         0           1              1535
3.1         0           5              1531
3.2         0          14              1522
3.3         0          54              1482
3.4         0          80              1456
3.5         1         142              1393
3.6         1         170              1365
3.7         2         235              1299
3.8         3         278              1255
3.9         5         336              1195
4.0         4         369              1163
4.1         4         403              1129
4.2         4         430              1102
4.3         5         466              1065
4.4         6         506              1024
4.5         8         530               998
4.6        14         555               967
4.7        17         571               948
4.8        22         618               896
4.9        32         641               863
5.0        35         666               835
5.1        47         687               802
5.2        49         709               778
5.3        53         731               752
5.4        59         746               731
5.5        67         764               705
5.6        74         781               681
5.7        81         800               655
5.8        88         814               634
5.9       100         826               610
6.0       108         846               582

接着,可以研究这些模拟细胞的基因表达情况。由于模拟过程使用的是PCA数据,所以需要将其逆转为基因表达值,这个过程需要提供StandardScaler和PCA模型,可以用sklearn生成。

scaler_model = sdq.io.read_pickle("scdiffeq_data/larry/scaler.pkl")
PCA = sdq.io.read_pickle("scdiffeq_data/larry/pca.pkl")
sdq.tl.annotate_gene_features(adata_sim, adata, PCA=PCA, gene_id_key="gene_ids")
sdq.tl.invert_scaled_gex(adata_sim, scaler_model = scaler_model)
adata_sim.obsm["X_gene_inv"]
gene_ids  1110002J07Rik  1110032F04Rik  1500002F19Rik  1500026H17Rik  1600010M07Rik  1700001C19Rik  1700001O22Rik  1700011I03Rik  1700016F12Rik  ...    Zfp963     Zfpm2   Zkscan2  Zmiz1os1  Zmpste24   Zmynd15     Znfx1    Zscan2       Zyx
0              0.000935       0.000000       0.000922       0.001447       0.003434       0.000488       0.000000       0.004353       0.000000  ...  0.022585  0.001477  0.000240  0.000000  0.079937  0.003646  0.133242  0.001753  0.118178
1              0.000524       0.000925       0.000347       0.000000       0.008591       0.003485       0.001161       0.003041       0.000000  ...  0.012729  0.000346  0.000589  0.000000  0.236309  0.000000  0.033923  0.001078  0.095812
2              0.000000       0.000000       0.000982       0.004394       0.009979       0.005414       0.000640       0.000962       0.000000  ...  0.012827  0.000000  0.001514  0.000268  0.235095  0.002141  0.040394  0.002379  0.100079
3              0.000935       0.000000       0.000922       0.001447       0.003434       0.000488       0.000000       0.004353       0.000000  ...  0.022585  0.001477  0.000240  0.000000  0.079937  0.003646  0.133242  0.001753  0.118178
4              0.000524       0.000925       0.000347       0.000000       0.008591       0.003485       0.001161       0.003041       0.000000  ...  0.012729  0.000346  0.000589  0.000000  0.236309  0.000000  0.033923  0.001078  0.095812
...                 ...            ...            ...            ...            ...            ...            ...            ...            ...  ...       ...       ...       ...       ...       ...       ...       ...       ...       ...
62971          0.001237       0.000000       0.001550       0.001806       0.001409       0.000161       0.002135       0.000567       0.000000  ...  0.011184  0.000052  0.000041  0.000352  0.317569  0.000000  0.092694  0.001254  0.113291
62972          0.000011       0.000874       0.000582       0.000515       0.002702       0.001036       0.000380       0.001900       0.000000  ...  0.012742  0.000551  0.000453  0.001684  0.238807  0.002792  0.123798  0.000524  0.114049
62973          0.000006       0.004426       0.008713       0.047629       0.000000       0.000000       0.000000       0.002204       0.142455  ...  0.023694  0.000000  0.000000  0.004168  0.000000  0.000000  0.452125  0.014248  0.387427
62974          0.001503       0.004170       0.000572       0.000000       0.004613       0.000000       0.002392       0.000000       0.001085  ...  0.019704  0.000207  0.002183  0.000000  0.269295  0.000000  0.002888  0.000000  0.069778
62975          0.000383       0.000039       0.001099       0.003847       0.003830       0.000936       0.000366       0.001295       0.015110  ...  0.013344  0.000756  0.000180  0.000000  0.693576  0.010397  0.160073  0.001989  0.230127

[62976 rows x 2492 columns]

下面就可以看基因的表达情况了:

import cellplots

def mean_and_std_expr(df, adata_sim, gene):
    x = adata_sim[df.index].obsm["X_gene_inv"][gene]
    return pd.Series({'mean': x.mean(), 'std': x.std()})

genes = ["Gfi1", "Elane", "Mpo", "Gstm1", "Mmp8", "Gata2"]
means = []
stds = []
for gene in genes:
    res = adata_sim.obs.groupby(["t", "fate"]).apply(mean_and_std_expr, adata_sim=adata_sim, gene=gene)
    mean_df = res['mean'].unstack()
    std_df = res['std'].unstack()
    means.append(mean_df)
    stds.append(std_df)

fig, axes = cellplots.plot(6, 3, height=0.65, width=0.8, wspace=0.4, hspace=0.4, x_label=["t (d)"], y_label=["norm. expr"], title=genes, delete=[["top", "right"]] * 3)

for en, (mean_df, std_df) in enumerate(zip(means, stds)):
    for col in mean_df:
        if col != "Undifferentiated":
            color = cmap[col]
            # Plot mean with line
            axes[en].plot(mean_df.index, mean_df[col], label=col, c=color)
            lower = mean_df[col] - std_df[col]
            upper = mean_df[col] + std_df[col]
            axes[en].fill_between(mean_df.index, lower, upper, color=color, alpha=0.25)
    axes[en].legend(facecolor="None", edgecolor="None")

虽然用了软件文档里面的数据,但结果还是差别有点大,可见有时候重现别人的结果也不是一件容易的事。

©著作权归作者所有,转载或内容合作请联系作者
【社区内容提示】社区部分内容疑似由AI辅助生成,浏览时请结合常识与多方信息审慎甄别。
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

相关阅读更多精彩内容

友情链接更多精彩内容