Geneformer | 细胞注释

Geneformer 是一个基于30M scRNA-seq data训练的Transformer模型,训练数据包括人类的多种组织器官。Geneformer可以用于细胞水平的分类预测和基因水平的分类预测(例如预测是否为耐药基因),这里我们先根据教程演示其在细胞类型预测上的步骤。

Genformer

Geneformer首先将细胞的基因表达量转换为rank value encoding作为输入,再传递到transformer架构中进行预训练,后续可以根据下游任务在特定数据集上微调加上最后的输出层。

Geneformer 和其训练数据集 Genecorpus-30M 都可以在hugging face上访问到。

Geneformer环境配置

我们首先配置Geneformer分析所需要的环境。

创建conda环境

conda create --envs geneformer
conda activate geneformer

为了在jupyter notebook使用该环境,我们需要安装ipykernel.

https://blog.csdn.net/mighty13/article/details/119859242

conda install -c anaconda ipykernel
python -m ipykernel install --user --name geneformer

Installation

接着,我们下载Geneformer和相关的示例数据集。

Clone project

git clone https://huggingface.co/ctheodoris/Geneformer
cd Geneformer
pip install .
git lfs install
git lfs pull

Download associated dataset

git clone https://huggingface.co/datasets/ctheodoris/Genecorpus-30M

由于clone的.dataset相关文件只有1kb,我们需要手动下载相应训练数据

cd Genecorpus-30M/example_input_files/cell_classification/cell_type_annotation/cell_type_train_data.dataset
wget https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/resolve/main/example_input_files/cell_classification/cell_type_annotation/cell_type_train_data.dataset/dataset.arrow
wget https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/resolve/main/example_input_files/cell_classification/cell_type_annotation/cell_type_train_data.dataset/dataset_info.json
wget https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/resolve/main/example_input_files/cell_classification/cell_type_annotation/cell_type_train_data.dataset/state.json

Install related modules

pip3 install seaborn
pip3 install datasets
pip3 install -U scikit-learn
pip3 install transformers
# gpu version of torch
pip3 install torch -f https://download.pytorch.org/whl/torch_stable.html
pip3 install statsmodels
pip3 install -U accelerate

Import modules

import os
GPU_NUMBER = [0]
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(s) for s in GPU_NUMBER])
os.environ["NCCL_DEBUG"] = "INFO"
# imports
from collections import Counter
import datetime
import pickle
import subprocess
import seaborn as sns; sns.set()
from datasets import load_from_disk
from sklearn.metrics import accuracy_score, f1_score
from transformers import BertForSequenceClassification
from transformers import Trainer
from transformers.training_args import TrainingArguments
from geneformer import DataCollatorForCellClassification

Prepare training and evaluation datasets

# load cell type dataset (includes all tissues)
train_dataset=load_from_disk("D:/jupyterNote/Geneformer/Genecorpus-30M/example_input_files/cell_classification/cell_type_annotation/cell_type_train_data.dataset")

我们读入文章提供的数据集Genecorpus-30M,该数据集以Apache Arrow format提供。

Data Fields

  • cell_type
  • organ_major
  • input_id: rank value encoding for an example cell
  • length: length of rank value encoding for that example cell

For rank value

  1. 计算各个检测到的基因在所有细胞中的非零中位值(nonzero median);
  2. 对每个细胞中的基因read counts除以该细胞的总read counts以校正测序深度;
  3. 对每个细胞的每个基因除以其相应的非零中位值以求得normalized expression;
  4. 基于每个细胞的normalized expression进行ranking,获得rank values。

The rank value encodings for each single cell transcriptome were then tokenized based on a total vocabulary of 25,424 protein-coding or miRNA genes detected within Geneformer-30M. The token dictionary mapping each token ID to special tokens (pad and mask) or Ensembl IDs for each gene is included within the repository as a pickle file (token_dictionary.pkl).

Why the rank values do not range from 1 to the number of genes in that cell?

# elements of train_dataset
# Includes 249,556 cells in total
print(train_dataset)

# number of cell for each celltypes
print("Celltypes:")
print(Counter(train_dataset['cell_type']))

# number of cells from different organs
print("\nOrgans:")
print(Counter(train_dataset['organ_major']))

# rank value encoding for each cell
print(len(train_dataset['input_ids'][1]))

# number of genes of each cells
print(train_dataset['length'][1])
Dataset({
    features: ['cell_type', 'input_ids', 'length', 'organ_major'],
    num_rows: 249556
})
Counter({'B cell (Plasmocyte)': 20728, 'T cell': 16695, 'Enterocyte progenitor': 15441, 'Fetal epithelial progenitor': 14580, 'Fetal neuron': 12287, 'Fetal mesenchymal progenitor': 11905, 'Erythroid progenitor cell (RP high)': 10819, 'Hepatocyte/Endodermal cell': 9781, 'Fetal enterocyte ': 9613, 'Erythroid cell': 9089, 'Loop of Henle': 8439, 'Macrophage': 7854, 'Monocyte': 7541, 'Epithelial cell': 7458, 'AT2 cell': 7333, 'Neutrophil': 7276, 'Fibroblast': 6980, 'Dendritic cell': 5960, 'Pancreas exocrine cell': 5538, 'Endothelial cell (APC)': 5431, 'M2 Macrophage': 5373, 'Endothelial cell': 4738, 'Antigen presenting cell (RPS high)': 4658, 'Intercalated cell': 4414, 'Sinusoidal endothelial cell': 3844, 'B cell': 3783, 'Endothelial cell (endothelial to mesenchymal transition)': 3647, 'Fetal acinar cell': 3214, 'Ureteric bud cell': 2472, 'Enterocyte': 1994, 'Proximal tubule progenitor': 1846, 'Smooth muscle cell': 1794, 'Fetal stromal cell': 1061, 'Stromal cell': 1029, 'Mast cell': 968, 'Fetal endocrine cell': 834, 'Neutrophil (RPS high)': 604, 'Intermediated cell': 529, 'Proliferating T cell': 528, 'CB CD34+': 423, 'Basal cell': 236, 'Primordial germ cell': 230, 'Fetal fibroblast': 125, 'Fetal Neuron': 114, 'Stratified epithelial cell': 113, 'Fetal skeletal muscle cell': 57, 'Fetal chondrocyte': 49, 'Mesothelial cell': 30, 'Goblet cell': 19, 'Chondrocyte': 19, 'hESC': 18, 'Fasciculata cell': 13, 'Gastric endocrine cell': 7, 'Myeloid cell': 7, 'Epithelial cell (intermediated)': 7, 'Astrocyte': 4, 'Kidney intercalated cell': 3, 'Ventricle cardiomyocyte': 2, 'Immature sertoli cell (Pre-Sertoli cell)': 2})
Counter({'large_intestine': 50363, 'kidney': 45059, 'lung': 33309, 'liver': 28376, 'pancreas': 28116, 'immune': 17110, 'spleen': 15614, 'brain': 13440, 'placenta': 9509, 'bone_marrow': 8660})
569
569
print(train_dataset['length'][2])
print(min(train_dataset['input_ids'][2]))
print(max(train_dataset['input_ids'][2]))
625
9
25414

接着,我们将每个组织的细胞分为80%的training set和20%的evaluation set以供后续的model fine-tune使用。

这里的细胞都带有celltype labels,并转换为数值标记,例如"B cell" -- 1。之后提供给下游fine-tune训练使用。

dataset_list = []
evalset_list = []
organ_list = []
target_dict_list = []

for organ in Counter(train_dataset["organ_major"]).keys():
    # collect list of tissues for fine-tuning (immune and bone marrow are included together)
    # for each organ
    if organ in ["bone_marrow"]:
        continue
    elif organ=="immune":
        organ_ids = ["immune", "bone_marrow"]
        organ_list += ["immune"]
    else: 
        organ_ids = [organ]
        organ_list += [organ]
        
    print(organ)
    
    # filter datasets for given organ 
    def if_organ(example, organ_ids):
        return example["organ_major"] in organ_ids
    # if_organ cannot access the outside variable, pass the organ_ids directly
    trainset_organ = train_dataset.filter(if_organ, num_proc=4, fn_kwargs={"organ_ids": organ_ids}) # `num_proc` indicates the number of processors
    
    # per scDeepsort published method, drop cell types representing < 0.5% of cells
    celltype_counter = Counter(trainset_organ["cell_type"])
    total_cells = sum(celltype_counter.values())
    # generates a new list that includes only the keys from the celltype_counter dictionary, but only if the corresponding value is greater than 0.5%
    cells_to_keep = [k for k,v in celltype_counter.items() if v > (0.005*total_cells)]
    
    def if_not_rare_celltype(example, cells_to_keep):
        return example["cell_type"] in cells_to_keep
    trainset_organ_subset = trainset_organ.filter(if_not_rare_celltype, num_proc=6, fn_kwargs={"cells_to_keep": cells_to_keep}) # filter rare cell type (cells < 0.5% )
    
    # shuffle datasets and rename columns
    trainset_organ_shuffled = trainset_organ_subset.shuffle(seed=42)
    trainset_organ_shuffled = trainset_organ_shuffled.rename_column("cell_type","label")
    trainset_organ_shuffled = trainset_organ_shuffled.remove_columns("organ_major")
    
    # create dictonary of cell types : label ids
    # Counter returns dictionary with element as key and number of occurences as value 
    target_names = list(Counter(trainset_organ_shuffled["label"]).keys()) 
    target_name_id_dict = dict(zip(target_names, [i for i in range(len(target_names))]))
    target_dict_list += [target_name_id_dict]
    
    # change labels to numerical ids
    def classes_to_ids(example, target_name_id_dict):
        example["label"] = target_name_id_dict[example["label"]]
        return example
    labeled_trainset = trainset_organ_shuffled.map(classes_to_ids, num_proc=4, fn_kwargs={"target_name_id_dict": target_name_id_dict})
    
    # create 80/20 train/eval splits
    labeled_train_split = labeled_trainset.select([i for i in range(0, round(len(labeled_trainset)*0.8))])
    labeled_eval_split = labeled_trainset.select([i for i in range(round(len(labeled_trainset)*0.8),len(labeled_trainset))])
    
    # filter dataset for cell types in corresponding training set
    trained_labels = list(Counter(labeled_train_split["label"]).keys())
    def if_trained_label(example, trained_labels):
        return example["label"] in trained_labels
    labeled_eval_split_subset = labeled_eval_split.filter(if_trained_label, num_proc=4, fn_kwargs={"trained_labels": trained_labels})
    
    dataset_list += [labeled_train_split]
    evalset_list += [labeled_eval_split_subset]
Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-d895a8dc1f433c21_*_of_00004.arrow
Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-5f89f392ef5206d4_*_of_00006.arrow
Loading cached shuffled indices for dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-ed12a33637a220d0.arrow
Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-37f89eeea757d6c5_*_of_00004.arrow
Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-74f78f156da669e5_*_of_00004.arrow


spleen


Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-f6dc5b7fe424bf2d_*_of_00004.arrow
Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-896e8ac5f576851c_*_of_00006.arrow
Loading cached shuffled indices for dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-820cdcb7f7383e21.arrow


kidney


Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-d4006d9701718093_*_of_00004.arrow
Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-3d777eac360cb136_*_of_00004.arrow
Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-0292fb0af10803a4_*_of_00004.arrow
Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-a996d2bf76adce7c_*_of_00006.arrow
Loading cached shuffled indices for dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-65960d687cc54b82.arrow


lung


Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-456daf9851000084_*_of_00004.arrow
Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-60719877e54cdb77_*_of_00004.arrow
Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-5304f89297ce82b0_*_of_00004.arrow
Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-5284c824687e6edd_*_of_00006.arrow
Loading cached shuffled indices for dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-01ed43e584533226.arrow
Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-53c0460a585f1ee6_*_of_00004.arrow
Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-3d719e36692d0047_*_of_00004.arrow
Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-f94573c555aec2c1_*_of_00004.arrow


brain
placenta


Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-c546b4750f90df75_*_of_00006.arrow
Loading cached shuffled indices for dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-f29113e9ecf5a9a6.arrow
Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-3e7c6c00c9efd043_*_of_00004.arrow
Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-f57326d9b39be686_*_of_00004.arrow
Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-d155ab4f91b9b109_*_of_00004.arrow


immune


Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-7699e47999d0e20a_*_of_00006.arrow
Loading cached shuffled indices for dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-c7fa1566fa301d6f.arrow
Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-c30b7a2b73e574a5_*_of_00004.arrow
Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-b2e2f437bfe2e62b_*_of_00004.arrow
Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-f30bce417351df8c_*_of_00004.arrow
Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-f040851bda405e47_*_of_00006.arrow


large_intestine


Loading cached shuffled indices for dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-fc4569333911ef13.arrow
Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-791285ddf886e3a6_*_of_00004.arrow
Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-06ee2b3139e1b0a8_*_of_00004.arrow
Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-2ae398bc8c532f07_*_of_00004.arrow


pancreas


Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-a47c5e32577914f3_*_of_00006.arrow
Loading cached shuffled indices for dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-0bc8012540760dc1.arrow
Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-d4dcf029b0b1437a_*_of_00004.arrow
Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-f2a3bf8b55c9560b_*_of_00004.arrow
Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-8434ffd865a76d79_*_of_00004.arrow
Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-dbd87150b1b95134_*_of_00006.arrow
Loading cached shuffled indices for dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-b56b8ddf9ca9920d.arrow


liver


Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-a31834f681c29f0f_*_of_00004.arrow
Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-c65ae92f453eb94b_*_of_00004.arrow

每个组织的celltype和label id的映射关系存储在target_dict_list这个list中

# number of cells for each organ
for i in range(0,len(organ_list)):
    print(organ_list[i])
    print(Counter(target_dict_list[i]))

spleen
Counter({'Endothelial cell (APC)': 5, 'Macrophage': 4, 'Neutrophil': 3, 'T cell': 2, 'B cell': 1, 'B cell (Plasmocyte)': 0})

kidney
Counter({'T cell': 14, 'Dendritic cell': 13, 'Fetal stromal cell': 12, 'Smooth muscle cell': 11, 'Endothelial cell': 10, 'Macrophage': 9, 'Proximal tubule progenitor': 8, 'Intermediated cell': 7, 'Fetal mesenchymal progenitor': 6, 'Intercalated cell': 5, 'Ureteric bud cell': 4, 'Loop of Henle': 3, 'Fetal epithelial progenitor': 2, 'Epithelial cell': 1, 'Endothelial cell (APC)': 0})

lung
Counter({'Basal cell': 15, 'Monocyte': 14, 'Endothelial cell': 13, 'Proliferating T cell': 12, 'Dendritic cell': 11, 'Fetal epithelial progenitor': 10, 'B cell (Plasmocyte)': 9, 'Mast cell': 8, 'Endothelial cell (APC)': 7, 'T cell': 6, 'Endothelial cell (endothelial to mesenchymal transition)': 5, 'M2 Macrophage': 4, 'Macrophage': 3, 'Fetal mesenchymal progenitor': 2, 'AT2 cell': 1, 'Smooth muscle cell': 0})

brain
Counter({'Fetal epithelial progenitor': 5, 'Fetal endocrine cell': 4, 'Erythroid cell': 3, 'Macrophage': 2, 'Fetal mesenchymal progenitor': 1, 'Fetal neuron': 0})

placenta
Counter({'Macrophage': 2, 'Epithelial cell': 1, 'Fibroblast': 0})

immune
Counter({'B cell': 9, 'Neutrophil (RPS high)': 8, 'B cell (Plasmocyte)': 7, 'Dendritic cell': 6, 'Erythroid progenitor cell (RP high)': 5, 'Erythroid cell': 4, 'Neutrophil': 3, 'T cell': 2, 'Monocyte': 1, 'Antigen presenting cell (RPS high)': 0})

large_intestine
Counter({'Endothelial cell': 15, 'Smooth muscle cell': 14, 'Fetal stromal cell': 13, 'B cell': 12, 'Stromal cell': 11, 'Enterocyte': 10, 'T cell': 9, 'Macrophage': 8, 'Dendritic cell': 7, 'Epithelial cell': 6, 'Fetal neuron': 5, 'Fetal mesenchymal progenitor': 4, 'B cell (Plasmocyte)': 3, 'Fetal enterocyte ': 2, 'Hepatocyte/Endodermal cell': 1, 'Enterocyte progenitor': 0})

pancreas
Counter({'Endothelial cell (APC)': 14, 'Smooth muscle cell': 13, 'Dendritic cell': 12, 'Fetal epithelial progenitor': 11, 'Erythroid cell': 10, 'Fetal endocrine cell': 9, 'Endothelial cell': 8, 'Enterocyte progenitor': 7, 'Fetal neuron': 6, 'Macrophage': 5, 'Fetal mesenchymal progenitor': 4, 'Pancreas exocrine cell': 3, 'Fetal acinar cell': 2, 'T cell': 1, 'B cell': 0})

liver
Counter({'CB CD34+': 11, 'B cell': 10, 'Neutrophil (RPS high)': 9, 'B cell (Plasmocyte)': 8, 'T cell': 7, 'Neutrophil': 6, 'Monocyte': 5, 'Dendritic cell': 4, 'Macrophage': 3, 'Sinusoidal endothelial cell': 2, 'Erythroid cell': 1, 'Erythroid progenitor cell (RP high)': 0})
trainset_dict = dict(zip(organ_list, dataset_list))
traintargetdict_dict = dict(zip(organ_list, target_dict_list))
evalset_dict = dict(zip(organ_list, evalset_list))

Fine-Tune With Cell Classification Learning Objective and Quantify Predictive Performance

接下来,使用预设的hyperparameters进行训练,作者建议根据下游任务调整hyperparameters。

另外,我们定义一个评估模型预测性能的函数compute_metrics.

def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    # calculate accuracy and macro f1 using sklearn's function
    acc = accuracy_score(labels, preds)
    macro_f1 = f1_score(labels, preds, average='macro')
    return {
      'accuracy': acc,
      'macro_f1': macro_f1
    }

# set model parameters
# max inpiut size
max_input_size = 2 ** 11 # 2048

# set training hyperparameters
# max learning rate
max_lr = 5e-5
# how many pretrained layers to freeze
freeze_layers = 0
# number gpus
num_gpus = 1
# number cpu cores
num_proc = 6
# batch size for training and eval
# reducing batch size for limited gpu memeory 
geneformer_batch_size = 4
# learning schedule
lr_schedule_fn = "linear"
# warmup steps
warmup_steps = 500
# number of epochs 
epochs = 10
# optimizer 
optimizer = "adamw"

接着,我们对每个组织都微调一个细胞分类的预测器,其中包括, brain, immune, kidney, large intestine, liver, lung, pancreas, placenta, and spleen.

这一步耗时很久

模型通过BertForSequenceClassification.from_pretrained读入。num_labels为模型输出的class数目,这里设置成每个组织对应的细胞类型数量即可。

随后,创建trainer并进行训练,.predict()进行预测。

for organ in organ_list:
    print(organ)
    organ_trainset = trainset_dict[organ]
    organ_evalset = evalset_dict[organ]
    organ_label_dict = traintargetdict_dict[organ]
    
    # set logging steps
    logging_steps = round(len(organ_trainset)/geneformer_batch_size/10)
    
    # reload pretrained model
    model = BertForSequenceClassification.from_pretrained("D:\\jupyterNote\\Geneformer",
                                                         num_labels = len(organ_label_dict.keys()),
                                                         output_attentions = False, 
                                                         output_hidden_states = False).to("cuda")
    
    # define output directory path
    current_date = datetime.datetime.now()
    datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}"
    output_dir = f"D:\\jupyterNote\\Geneformer\\examples\\cell_class_test\\{datestamp}_geneformer_CellClassifier_{organ}_L{max_input_size}_B{geneformer_batch_size}_LR{max_lr}_LS{lr_schedule_fn}_WU{warmup_steps}_E{epochs}_O{optimizer}_F{freeze_layers}\\"
    
    # ensure not overwriting previously saved model
    saved_model_test = os.path.join(output_dir, f"pytorch_model.bin")
    if os.path.isfile(saved_model_test) == True:
        raise Exception("Model already saved to this directory.")
    
    # make output directory
    subprocess.call(f'mkdir {output_dir}', shell=True)
    
    # set training arguments
    training_args = {
        "learning_rate": max_lr,
        "do_train": True,
        "do_eval": True,
        "evaluation_strategy": "epoch",
        "save_strategy": "epoch",
        "logging_steps": logging_steps,
        "group_by_length": True,
        "length_column_name": "length",
        "disable_tqdm": False,
        "lr_scheduler_type": lr_schedule_fn,
        "warmup_steps": warmup_steps,
        "weight_decay": 0.001,
        "per_device_train_batch_size": geneformer_batch_size,
        "per_device_eval_batch_size": geneformer_batch_size,
        "num_train_epochs": epochs,
        "load_best_model_at_end": True,
        "output_dir": output_dir,
    }
    
    training_args_init = TrainingArguments(**training_args)
    
    # create the trainer
    trainer = Trainer(
        model=model,
        args=training_args_init,
        data_collator=DataCollatorForCellClassification(),
        train_dataset=organ_trainset,
        eval_dataset=organ_evalset,
        compute_metrics=compute_metrics
    )
    
    # train the cell type classifier
    trainer.train()
    predictions = trainer.predict(organ_evalset)
    with open(f"{output_dir}predictions.pickle", "wb") as fp:
        pickle.dump(predictions, fp)
    trainer.save_metrics("eval",predictions.metrics)
    trainer.save_model(output_dir)
    Some weights of the model checkpoint at D:\jupyterNote\Geneformer were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias']
    - This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
    - This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
    Some weights of BertForSequenceClassification were not initialized from the model checkpoint at D:\jupyterNote\Geneformer and are newly initialized: ['bert.pooler.dense.weight', 'classifier.bias', 'classifier.weight', 'bert.pooler.dense.bias']
    You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


<...training message...>

训练结束后,训练的模型,及其预测结果都输出到设置的output_dir下。

每个文件夹中都包含了fine-tuned model相关文件(training_args.bin, config.json, pytorch_model.bin),以及预测结果相关文件(predictions.pickle

$ ls 230719_geneformer_CellClassifier_brain_L2048_B4_LR5e-05_LSlinear_WU500_E10_Oadamw_F0/

all_results.json  checkpoint-15984  checkpoint-23976  checkpoint-5328  eval_results.json   training_args.bin
checkpoint-10656  checkpoint-18648  checkpoint-2664   checkpoint-7992  predictions.pickle
checkpoint-13320  checkpoint-21312  checkpoint-26640  config.json      pytorch_model.bin
# clear GPU memory after pytorch training 
import torch
torch.cuda.empty_cache()
# The pretrained model
model
BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(25426, 256, padding_idx=0)
      (position_embeddings): Embedding(2048, 256)
      (token_type_embeddings): Embedding(2, 256)
      (LayerNorm): LayerNorm((256,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.02, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-5): 6 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=256, out_features=256, bias=True)
              (key): Linear(in_features=256, out_features=256, bias=True)
              (value): Linear(in_features=256, out_features=256, bias=True)
              (dropout): Dropout(p=0.02, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=256, out_features=256, bias=True)
              (LayerNorm): LayerNorm((256,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0.02, inplace=False)
            )
          )
          (intermediate): BertIntermediate(
            (dense): Linear(in_features=256, out_features=512, bias=True)
            (intermediate_act_fn): ReLU()
          )
          (output): BertOutput(
            (dense): Linear(in_features=512, out_features=256, bias=True)
            (LayerNorm): LayerNorm((256,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.02, inplace=False)
          )
        )
      )
    )
    (pooler): BertPooler(
      (dense): Linear(in_features=256, out_features=256, bias=True)
      (activation): Tanh()
    )
  )
  (dropout): Dropout(p=0.02, inplace=False)
  (classifier): Linear(in_features=256, out_features=12, bias=True)
)

接下来,我们将基于immune数据的微调模型应用到3k PBMCs scRNA-seq data上进行celltype prediction.

首先,我们需要将原始的测序counts转换为rank values encoding (tk.tokenize_data).

# applying fine-tuned model on new datasets
# using the 3k PBMCs dataset
# 1. transform scRNA-seq expression data to rank value .dataset format
from geneformer import TranscriptomeTokenizer

tk = TranscriptomeTokenizer({"cell_type": "cell_type", "organ_major": "organ_major"}, nproc=4)
tk.tokenize_data("D:/jupyterNote/pySC/output/", output_directory="token_data/", output_prefix="tk_pbmc3k")
Tokenizing D:\jupyterNote\pySC\output\pbmc3k.loom
D:\jupyterNote\pySC\output\pbmc3k.loom has no column attribute 'filter_pass'; tokenizing all cells.

读入tokenized dataset

# 2. load new dataset
new_dataset = load_from_disk("D:/jupyterNote/Geneformer/examples/token_data/tk_pbmc3k.dataset/")
new_dataset
Dataset({
    features: ['input_ids', 'cell_type', 'organ_major', 'length'],
    num_rows: 2638
})
import pandas as pd

# input_ids represent rank encodings
pd.DataFrame(new_dataset)

由于模型要求input tensors(每个细胞的rank encoding)长度一致,这里将其padding到统一长度(所有细胞中最多的基因数)。

from geneformer.pretrainer import token_dictionary

def preprocess_classifier_batch(cell_batch, max_len):
    if max_len == None:
        max_len = max([len(i) for i in cell_batch["input_ids"]])
    def pad_label_example(example):
        #example["labels"] = np.pad(example["labels"], 
        #                           (0, max_len-len(example["input_ids"])), 
        #                           mode='constant', constant_values=-100)
        example["input_ids"] = np.pad(example["input_ids"], 
                                      (0, max_len-len(example["input_ids"])), 
                                      mode='constant', constant_values=token_dictionary.get("<pad>"))
        example["attention_mask"] = (example["input_ids"] != token_dictionary.get("<pad>")).astype(int)
        return example
    padded_batch = cell_batch.map(pad_label_example)
    return padded_batch

# Function to find the largest number smaller
# than or equal to N that is divisible by k
def find_largest_div(N, K):
    rem = N % K
    if(rem == 0):
        return N
    else:
        return N - rem
    
# padded to be the same length.
set_len=len(new_dataset)
max_set_len = max(new_dataset.select([i for i in range(set_len)])["length"])
padded_dataset = preprocess_classifier_batch(new_dataset, max_set_len)
Loading cached processed dataset at D:\jupyterNote\Geneformer\examples\token_data\tk_pbmc3k.dataset\cache-4214517ba3d677b2.arrow
pd.DataFrame(padded_dataset)

接下来,读入微调模型进行预测

# 3. load the fine-tuned model
# reload fine-tuned model
ft_model = BertForSequenceClassification.from_pretrained("cell_class_test/230719_geneformer_CellClassifier_immune_L2048_B4_LR5e-05_LSlinear_WU500_E10_Oadamw_F0/")
# since immune organ only include 10 celltypes in the training set, the out_features=10 in the final output layer
print(ft_model)

ft_trainer = Trainer(model=ft_model)

# 4. perform prediction 
ct_predictions = ft_trainer.predict(padded_dataset)
ct_pred = ct_predictions.predictions
BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(25426, 256, padding_idx=0)
      (position_embeddings): Embedding(2048, 256)
      (token_type_embeddings): Embedding(2, 256)
      (LayerNorm): LayerNorm((256,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.02, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-5): 6 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=256, out_features=256, bias=True)
              (key): Linear(in_features=256, out_features=256, bias=True)
              (value): Linear(in_features=256, out_features=256, bias=True)
              (dropout): Dropout(p=0.02, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=256, out_features=256, bias=True)
              (LayerNorm): LayerNorm((256,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0.02, inplace=False)
            )
          )
          (intermediate): BertIntermediate(
            (dense): Linear(in_features=256, out_features=512, bias=True)
            (intermediate_act_fn): ReLU()
          )
          (output): BertOutput(
            (dense): Linear(in_features=512, out_features=256, bias=True)
            (LayerNorm): LayerNorm((256,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.02, inplace=False)
          )
        )
      )
    )
    (pooler): BertPooler(
      (dense): Linear(in_features=256, out_features=256, bias=True)
      (activation): Tanh()
    )
  )
  (dropout): Dropout(p=0.02, inplace=False)
  (classifier): Linear(in_features=256, out_features=10, bias=True)
)

使用fine-tuned model进行分类预测,这里根据最大预测值判断cell type。

# celltype : index 
immune_label_idx_dict = target_dict_list[5]

# get the predicted cell type by max value of prediction 
ct_pred_id = ct_pred.argmax(-1) # list 
ct_pred_label = [k for idx in ct_pred_id for k, v in immune_label_idx_dict.items() if v == idx]
print(len(ct_pred_label))
2638
print(ct_pred_id[0:10])
print(ct_pred_label[0:10])
print(immune_label_idx_dict)
[2 9 2 1 2 9 2 2 2 1]
['T cell', 'B cell', 'T cell', 'Monocyte', 'T cell', 'B cell', 'T cell', 'T cell', 'T cell', 'Monocyte']
{'Antigen presenting cell (RPS high)': 0, 'Monocyte': 1, 'T cell': 2, 'Neutrophil': 3, 'Erythroid cell': 4, 'Erythroid progenitor cell (RP high)': 5, 'Dendritic cell': 6, 'B cell (Plasmocyte)': 7, 'Neutrophil (RPS high)': 8, 'B cell': 9}

用UMAP可视化细胞分类的结果

import numpy as np
import pandas as pd
import scanpy as sc
import anndata

adata = anndata.read_h5ad("D:/jupyterNote/pySC/output/pbmc3k.h5ad")
adata
AnnData object with n_obs × n_vars = 2638 × 1838
    obs: 'n_genes', 'n_genes_by_counts', 'n_counts', 'total_counts_mt', 'pct_counts_mt', 'leiden', 'cell_type', 'organ_major'
    var: 'ensembl_id', 'n_cells', 'mt', 'n_cells_by_counts', 'mean_counts', 'pct_dropout_by_counts', 'total_counts', 'highly_variable', 'means', 'dispersions', 'dispersions_norm', 'mean', 'std', 'gene_name'
    uns: 'hvg', 'leiden', 'leiden_colors', 'log1p', 'neighbors', 'pca', 'rank_genes_groups', 'umap'
    obsm: 'X_pca', 'X_umap'
    varm: 'PCs'
    layers: 'counts', 'data', 'scaled'
    obsp: 'connectivities', 'distances'

尽管Geneformer对细胞的名称和原数据不太一样,我们可以看到Geneformer注释的结果大体上和原本注释是一致的。总的来说,Geneformer可以作为一种细胞类型预测的工具使用,但最好先对预训练模型微调,这要求我们有相关的单细胞数据集进行微调训练。

adata.obs['geneformer_pred'] = ct_pred_label
sc.pl.umap(adata, color='geneformer_pred')
sc.pl.umap(adata, color='cell_type')
E:\miniconda3\envs\geneformer\lib\site-packages\scanpy\plotting\_tools\scatterplots.py:392: UserWarning: No data for colormapping provided via 'c'. Parameters 'cmap' will be ignored
  cax = scatter(
E:\miniconda3\envs\geneformer\lib\site-packages\scanpy\plotting\_tools\scatterplots.py:392: UserWarning: No data for colormapping provided via 'c'. Parameters 'cmap' will be ignored
  cax = scatter(
adata.obs

总结

对于细胞分类的微调,我们需要:

  1. 获取组织对应的微调数据集,并且有细胞的label信息,例如各个细胞类型;

    关于数据集大小,从作者提供的例子来看,最少的情况是884个细胞,但其余下游任务都超过10k细胞

  2. BertForSequenceClassification的方式读入预训练模型,并设置num_labels为分类数目;

  3. 根据微调的数据集训练,加上最后的输出层(task-specific transformer layer),并对微调模型预测性能进行评估;

  4. 在新的数据集上应用微调模型进行预测。

Ref:

Transfer learning enables predictions in network biology: https://doi.org/10.1038/s41586-023-06139-9

https://huggingface.co/ctheodoris/Geneformer

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 213,752评论 6 493
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 91,100评论 3 387
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 159,244评论 0 349
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 57,099评论 1 286
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 66,210评论 6 385
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 50,307评论 1 292
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,346评论 3 412
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 38,133评论 0 269
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,546评论 1 306
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 36,849评论 2 328
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 39,019评论 1 341
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 34,702评论 4 337
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,331评论 3 319
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 31,030评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,260评论 1 267
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 46,871评论 2 365
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 43,898评论 2 351