Geneformer 是一个基于30M scRNA-seq data训练的Transformer模型,训练数据包括人类的多种组织器官。Geneformer可以用于细胞水平的分类预测和基因水平的分类预测(例如预测是否为耐药基因),这里我们先根据教程演示其在细胞类型预测上的步骤。
Geneformer首先将细胞的基因表达量转换为rank value encoding作为输入,再传递到transformer架构中进行预训练,后续可以根据下游任务在特定数据集上微调加上最后的输出层。
Geneformer 和其训练数据集 Genecorpus-30M 都可以在hugging face上访问到。
Geneformer环境配置
我们首先配置Geneformer分析所需要的环境。
创建conda环境
conda create --envs geneformer
conda activate geneformer
为了在jupyter notebook使用该环境,我们需要安装ipykernel
.
https://blog.csdn.net/mighty13/article/details/119859242
conda install -c anaconda ipykernel python -m ipykernel install --user --name geneformer
Installation
接着,我们下载Geneformer和相关的示例数据集。
Clone project
git clone https://huggingface.co/ctheodoris/Geneformer
cd Geneformer
pip install .
git lfs install
git lfs pull
Download associated dataset
git clone https://huggingface.co/datasets/ctheodoris/Genecorpus-30M
由于clone的.dataset
相关文件只有1kb,我们需要手动下载相应训练数据
cd Genecorpus-30M/example_input_files/cell_classification/cell_type_annotation/cell_type_train_data.dataset
wget https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/resolve/main/example_input_files/cell_classification/cell_type_annotation/cell_type_train_data.dataset/dataset.arrow
wget https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/resolve/main/example_input_files/cell_classification/cell_type_annotation/cell_type_train_data.dataset/dataset_info.json
wget https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/resolve/main/example_input_files/cell_classification/cell_type_annotation/cell_type_train_data.dataset/state.json
Install related modules
pip3 install seaborn
pip3 install datasets
pip3 install -U scikit-learn
pip3 install transformers
# gpu version of torch
pip3 install torch -f https://download.pytorch.org/whl/torch_stable.html
pip3 install statsmodels
pip3 install -U accelerate
Import modules
import os
GPU_NUMBER = [0]
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(s) for s in GPU_NUMBER])
os.environ["NCCL_DEBUG"] = "INFO"
# imports
from collections import Counter
import datetime
import pickle
import subprocess
import seaborn as sns; sns.set()
from datasets import load_from_disk
from sklearn.metrics import accuracy_score, f1_score
from transformers import BertForSequenceClassification
from transformers import Trainer
from transformers.training_args import TrainingArguments
from geneformer import DataCollatorForCellClassification
Prepare training and evaluation datasets
# load cell type dataset (includes all tissues)
train_dataset=load_from_disk("D:/jupyterNote/Geneformer/Genecorpus-30M/example_input_files/cell_classification/cell_type_annotation/cell_type_train_data.dataset")
我们读入文章提供的数据集Genecorpus-30M
,该数据集以Apache Arrow format提供。
Data Fields
- cell_type
- organ_major
- input_id: rank value encoding for an example cell
- length: length of rank value encoding for that example cell
For rank value
- 计算各个检测到的基因在所有细胞中的非零中位值(nonzero median);
- 对每个细胞中的基因read counts除以该细胞的总read counts以校正测序深度;
- 对每个细胞的每个基因除以其相应的非零中位值以求得normalized expression;
- 基于每个细胞的normalized expression进行ranking,获得rank values。
The rank value encodings for each single cell transcriptome were then tokenized based on a total vocabulary of 25,424 protein-coding or miRNA genes detected within Geneformer-30M. The token dictionary mapping each token ID to special tokens (pad and mask) or Ensembl IDs for each gene is included within the repository as a pickle file (token_dictionary.pkl).
Why the rank values do not range from 1 to the number of genes in that cell?
# elements of train_dataset
# Includes 249,556 cells in total
print(train_dataset)
# number of cell for each celltypes
print("Celltypes:")
print(Counter(train_dataset['cell_type']))
# number of cells from different organs
print("\nOrgans:")
print(Counter(train_dataset['organ_major']))
# rank value encoding for each cell
print(len(train_dataset['input_ids'][1]))
# number of genes of each cells
print(train_dataset['length'][1])
Dataset({
features: ['cell_type', 'input_ids', 'length', 'organ_major'],
num_rows: 249556
})
Counter({'B cell (Plasmocyte)': 20728, 'T cell': 16695, 'Enterocyte progenitor': 15441, 'Fetal epithelial progenitor': 14580, 'Fetal neuron': 12287, 'Fetal mesenchymal progenitor': 11905, 'Erythroid progenitor cell (RP high)': 10819, 'Hepatocyte/Endodermal cell': 9781, 'Fetal enterocyte ': 9613, 'Erythroid cell': 9089, 'Loop of Henle': 8439, 'Macrophage': 7854, 'Monocyte': 7541, 'Epithelial cell': 7458, 'AT2 cell': 7333, 'Neutrophil': 7276, 'Fibroblast': 6980, 'Dendritic cell': 5960, 'Pancreas exocrine cell': 5538, 'Endothelial cell (APC)': 5431, 'M2 Macrophage': 5373, 'Endothelial cell': 4738, 'Antigen presenting cell (RPS high)': 4658, 'Intercalated cell': 4414, 'Sinusoidal endothelial cell': 3844, 'B cell': 3783, 'Endothelial cell (endothelial to mesenchymal transition)': 3647, 'Fetal acinar cell': 3214, 'Ureteric bud cell': 2472, 'Enterocyte': 1994, 'Proximal tubule progenitor': 1846, 'Smooth muscle cell': 1794, 'Fetal stromal cell': 1061, 'Stromal cell': 1029, 'Mast cell': 968, 'Fetal endocrine cell': 834, 'Neutrophil (RPS high)': 604, 'Intermediated cell': 529, 'Proliferating T cell': 528, 'CB CD34+': 423, 'Basal cell': 236, 'Primordial germ cell': 230, 'Fetal fibroblast': 125, 'Fetal Neuron': 114, 'Stratified epithelial cell': 113, 'Fetal skeletal muscle cell': 57, 'Fetal chondrocyte': 49, 'Mesothelial cell': 30, 'Goblet cell': 19, 'Chondrocyte': 19, 'hESC': 18, 'Fasciculata cell': 13, 'Gastric endocrine cell': 7, 'Myeloid cell': 7, 'Epithelial cell (intermediated)': 7, 'Astrocyte': 4, 'Kidney intercalated cell': 3, 'Ventricle cardiomyocyte': 2, 'Immature sertoli cell (Pre-Sertoli cell)': 2})
Counter({'large_intestine': 50363, 'kidney': 45059, 'lung': 33309, 'liver': 28376, 'pancreas': 28116, 'immune': 17110, 'spleen': 15614, 'brain': 13440, 'placenta': 9509, 'bone_marrow': 8660})
569
569
print(train_dataset['length'][2])
print(min(train_dataset['input_ids'][2]))
print(max(train_dataset['input_ids'][2]))
625
9
25414
接着,我们将每个组织的细胞分为80%的training set和20%的evaluation set以供后续的model fine-tune使用。
这里的细胞都带有celltype labels,并转换为数值标记,例如"B cell" -- 1
。之后提供给下游fine-tune训练使用。
dataset_list = []
evalset_list = []
organ_list = []
target_dict_list = []
for organ in Counter(train_dataset["organ_major"]).keys():
# collect list of tissues for fine-tuning (immune and bone marrow are included together)
# for each organ
if organ in ["bone_marrow"]:
continue
elif organ=="immune":
organ_ids = ["immune", "bone_marrow"]
organ_list += ["immune"]
else:
organ_ids = [organ]
organ_list += [organ]
print(organ)
# filter datasets for given organ
def if_organ(example, organ_ids):
return example["organ_major"] in organ_ids
# if_organ cannot access the outside variable, pass the organ_ids directly
trainset_organ = train_dataset.filter(if_organ, num_proc=4, fn_kwargs={"organ_ids": organ_ids}) # `num_proc` indicates the number of processors
# per scDeepsort published method, drop cell types representing < 0.5% of cells
celltype_counter = Counter(trainset_organ["cell_type"])
total_cells = sum(celltype_counter.values())
# generates a new list that includes only the keys from the celltype_counter dictionary, but only if the corresponding value is greater than 0.5%
cells_to_keep = [k for k,v in celltype_counter.items() if v > (0.005*total_cells)]
def if_not_rare_celltype(example, cells_to_keep):
return example["cell_type"] in cells_to_keep
trainset_organ_subset = trainset_organ.filter(if_not_rare_celltype, num_proc=6, fn_kwargs={"cells_to_keep": cells_to_keep}) # filter rare cell type (cells < 0.5% )
# shuffle datasets and rename columns
trainset_organ_shuffled = trainset_organ_subset.shuffle(seed=42)
trainset_organ_shuffled = trainset_organ_shuffled.rename_column("cell_type","label")
trainset_organ_shuffled = trainset_organ_shuffled.remove_columns("organ_major")
# create dictonary of cell types : label ids
# Counter returns dictionary with element as key and number of occurences as value
target_names = list(Counter(trainset_organ_shuffled["label"]).keys())
target_name_id_dict = dict(zip(target_names, [i for i in range(len(target_names))]))
target_dict_list += [target_name_id_dict]
# change labels to numerical ids
def classes_to_ids(example, target_name_id_dict):
example["label"] = target_name_id_dict[example["label"]]
return example
labeled_trainset = trainset_organ_shuffled.map(classes_to_ids, num_proc=4, fn_kwargs={"target_name_id_dict": target_name_id_dict})
# create 80/20 train/eval splits
labeled_train_split = labeled_trainset.select([i for i in range(0, round(len(labeled_trainset)*0.8))])
labeled_eval_split = labeled_trainset.select([i for i in range(round(len(labeled_trainset)*0.8),len(labeled_trainset))])
# filter dataset for cell types in corresponding training set
trained_labels = list(Counter(labeled_train_split["label"]).keys())
def if_trained_label(example, trained_labels):
return example["label"] in trained_labels
labeled_eval_split_subset = labeled_eval_split.filter(if_trained_label, num_proc=4, fn_kwargs={"trained_labels": trained_labels})
dataset_list += [labeled_train_split]
evalset_list += [labeled_eval_split_subset]
Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-d895a8dc1f433c21_*_of_00004.arrow
Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-5f89f392ef5206d4_*_of_00006.arrow
Loading cached shuffled indices for dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-ed12a33637a220d0.arrow
Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-37f89eeea757d6c5_*_of_00004.arrow
Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-74f78f156da669e5_*_of_00004.arrow
spleen
Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-f6dc5b7fe424bf2d_*_of_00004.arrow
Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-896e8ac5f576851c_*_of_00006.arrow
Loading cached shuffled indices for dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-820cdcb7f7383e21.arrow
kidney
Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-d4006d9701718093_*_of_00004.arrow
Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-3d777eac360cb136_*_of_00004.arrow
Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-0292fb0af10803a4_*_of_00004.arrow
Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-a996d2bf76adce7c_*_of_00006.arrow
Loading cached shuffled indices for dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-65960d687cc54b82.arrow
lung
Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-456daf9851000084_*_of_00004.arrow
Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-60719877e54cdb77_*_of_00004.arrow
Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-5304f89297ce82b0_*_of_00004.arrow
Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-5284c824687e6edd_*_of_00006.arrow
Loading cached shuffled indices for dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-01ed43e584533226.arrow
Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-53c0460a585f1ee6_*_of_00004.arrow
Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-3d719e36692d0047_*_of_00004.arrow
Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-f94573c555aec2c1_*_of_00004.arrow
brain
placenta
Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-c546b4750f90df75_*_of_00006.arrow
Loading cached shuffled indices for dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-f29113e9ecf5a9a6.arrow
Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-3e7c6c00c9efd043_*_of_00004.arrow
Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-f57326d9b39be686_*_of_00004.arrow
Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-d155ab4f91b9b109_*_of_00004.arrow
immune
Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-7699e47999d0e20a_*_of_00006.arrow
Loading cached shuffled indices for dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-c7fa1566fa301d6f.arrow
Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-c30b7a2b73e574a5_*_of_00004.arrow
Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-b2e2f437bfe2e62b_*_of_00004.arrow
Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-f30bce417351df8c_*_of_00004.arrow
Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-f040851bda405e47_*_of_00006.arrow
large_intestine
Loading cached shuffled indices for dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-fc4569333911ef13.arrow
Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-791285ddf886e3a6_*_of_00004.arrow
Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-06ee2b3139e1b0a8_*_of_00004.arrow
Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-2ae398bc8c532f07_*_of_00004.arrow
pancreas
Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-a47c5e32577914f3_*_of_00006.arrow
Loading cached shuffled indices for dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-0bc8012540760dc1.arrow
Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-d4dcf029b0b1437a_*_of_00004.arrow
Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-f2a3bf8b55c9560b_*_of_00004.arrow
Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-8434ffd865a76d79_*_of_00004.arrow
Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-dbd87150b1b95134_*_of_00006.arrow
Loading cached shuffled indices for dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-b56b8ddf9ca9920d.arrow
liver
Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-a31834f681c29f0f_*_of_00004.arrow
Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-c65ae92f453eb94b_*_of_00004.arrow
每个组织的celltype和label id的映射关系存储在target_dict_list
这个list中
# number of cells for each organ
for i in range(0,len(organ_list)):
print(organ_list[i])
print(Counter(target_dict_list[i]))
spleen
Counter({'Endothelial cell (APC)': 5, 'Macrophage': 4, 'Neutrophil': 3, 'T cell': 2, 'B cell': 1, 'B cell (Plasmocyte)': 0})
kidney
Counter({'T cell': 14, 'Dendritic cell': 13, 'Fetal stromal cell': 12, 'Smooth muscle cell': 11, 'Endothelial cell': 10, 'Macrophage': 9, 'Proximal tubule progenitor': 8, 'Intermediated cell': 7, 'Fetal mesenchymal progenitor': 6, 'Intercalated cell': 5, 'Ureteric bud cell': 4, 'Loop of Henle': 3, 'Fetal epithelial progenitor': 2, 'Epithelial cell': 1, 'Endothelial cell (APC)': 0})
lung
Counter({'Basal cell': 15, 'Monocyte': 14, 'Endothelial cell': 13, 'Proliferating T cell': 12, 'Dendritic cell': 11, 'Fetal epithelial progenitor': 10, 'B cell (Plasmocyte)': 9, 'Mast cell': 8, 'Endothelial cell (APC)': 7, 'T cell': 6, 'Endothelial cell (endothelial to mesenchymal transition)': 5, 'M2 Macrophage': 4, 'Macrophage': 3, 'Fetal mesenchymal progenitor': 2, 'AT2 cell': 1, 'Smooth muscle cell': 0})
brain
Counter({'Fetal epithelial progenitor': 5, 'Fetal endocrine cell': 4, 'Erythroid cell': 3, 'Macrophage': 2, 'Fetal mesenchymal progenitor': 1, 'Fetal neuron': 0})
placenta
Counter({'Macrophage': 2, 'Epithelial cell': 1, 'Fibroblast': 0})
immune
Counter({'B cell': 9, 'Neutrophil (RPS high)': 8, 'B cell (Plasmocyte)': 7, 'Dendritic cell': 6, 'Erythroid progenitor cell (RP high)': 5, 'Erythroid cell': 4, 'Neutrophil': 3, 'T cell': 2, 'Monocyte': 1, 'Antigen presenting cell (RPS high)': 0})
large_intestine
Counter({'Endothelial cell': 15, 'Smooth muscle cell': 14, 'Fetal stromal cell': 13, 'B cell': 12, 'Stromal cell': 11, 'Enterocyte': 10, 'T cell': 9, 'Macrophage': 8, 'Dendritic cell': 7, 'Epithelial cell': 6, 'Fetal neuron': 5, 'Fetal mesenchymal progenitor': 4, 'B cell (Plasmocyte)': 3, 'Fetal enterocyte ': 2, 'Hepatocyte/Endodermal cell': 1, 'Enterocyte progenitor': 0})
pancreas
Counter({'Endothelial cell (APC)': 14, 'Smooth muscle cell': 13, 'Dendritic cell': 12, 'Fetal epithelial progenitor': 11, 'Erythroid cell': 10, 'Fetal endocrine cell': 9, 'Endothelial cell': 8, 'Enterocyte progenitor': 7, 'Fetal neuron': 6, 'Macrophage': 5, 'Fetal mesenchymal progenitor': 4, 'Pancreas exocrine cell': 3, 'Fetal acinar cell': 2, 'T cell': 1, 'B cell': 0})
liver
Counter({'CB CD34+': 11, 'B cell': 10, 'Neutrophil (RPS high)': 9, 'B cell (Plasmocyte)': 8, 'T cell': 7, 'Neutrophil': 6, 'Monocyte': 5, 'Dendritic cell': 4, 'Macrophage': 3, 'Sinusoidal endothelial cell': 2, 'Erythroid cell': 1, 'Erythroid progenitor cell (RP high)': 0})
trainset_dict = dict(zip(organ_list, dataset_list))
traintargetdict_dict = dict(zip(organ_list, target_dict_list))
evalset_dict = dict(zip(organ_list, evalset_list))
Fine-Tune With Cell Classification Learning Objective and Quantify Predictive Performance
接下来,使用预设的hyperparameters进行训练,作者建议根据下游任务调整hyperparameters。
另外,我们定义一个评估模型预测性能的函数compute_metrics
.
def compute_metrics(pred):
labels = pred.label_ids
preds = pred.predictions.argmax(-1)
# calculate accuracy and macro f1 using sklearn's function
acc = accuracy_score(labels, preds)
macro_f1 = f1_score(labels, preds, average='macro')
return {
'accuracy': acc,
'macro_f1': macro_f1
}
# set model parameters
# max inpiut size
max_input_size = 2 ** 11 # 2048
# set training hyperparameters
# max learning rate
max_lr = 5e-5
# how many pretrained layers to freeze
freeze_layers = 0
# number gpus
num_gpus = 1
# number cpu cores
num_proc = 6
# batch size for training and eval
# reducing batch size for limited gpu memeory
geneformer_batch_size = 4
# learning schedule
lr_schedule_fn = "linear"
# warmup steps
warmup_steps = 500
# number of epochs
epochs = 10
# optimizer
optimizer = "adamw"
接着,我们对每个组织都微调一个细胞分类的预测器,其中包括, brain, immune, kidney, large intestine, liver, lung, pancreas, placenta, and spleen.
这一步耗时很久
模型通过BertForSequenceClassification.from_pretrained
读入。num_labels
为模型输出的class数目,这里设置成每个组织对应的细胞类型数量即可。
随后,创建trainer
并进行训练,.predict()
进行预测。
for organ in organ_list:
print(organ)
organ_trainset = trainset_dict[organ]
organ_evalset = evalset_dict[organ]
organ_label_dict = traintargetdict_dict[organ]
# set logging steps
logging_steps = round(len(organ_trainset)/geneformer_batch_size/10)
# reload pretrained model
model = BertForSequenceClassification.from_pretrained("D:\\jupyterNote\\Geneformer",
num_labels = len(organ_label_dict.keys()),
output_attentions = False,
output_hidden_states = False).to("cuda")
# define output directory path
current_date = datetime.datetime.now()
datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}"
output_dir = f"D:\\jupyterNote\\Geneformer\\examples\\cell_class_test\\{datestamp}_geneformer_CellClassifier_{organ}_L{max_input_size}_B{geneformer_batch_size}_LR{max_lr}_LS{lr_schedule_fn}_WU{warmup_steps}_E{epochs}_O{optimizer}_F{freeze_layers}\\"
# ensure not overwriting previously saved model
saved_model_test = os.path.join(output_dir, f"pytorch_model.bin")
if os.path.isfile(saved_model_test) == True:
raise Exception("Model already saved to this directory.")
# make output directory
subprocess.call(f'mkdir {output_dir}', shell=True)
# set training arguments
training_args = {
"learning_rate": max_lr,
"do_train": True,
"do_eval": True,
"evaluation_strategy": "epoch",
"save_strategy": "epoch",
"logging_steps": logging_steps,
"group_by_length": True,
"length_column_name": "length",
"disable_tqdm": False,
"lr_scheduler_type": lr_schedule_fn,
"warmup_steps": warmup_steps,
"weight_decay": 0.001,
"per_device_train_batch_size": geneformer_batch_size,
"per_device_eval_batch_size": geneformer_batch_size,
"num_train_epochs": epochs,
"load_best_model_at_end": True,
"output_dir": output_dir,
}
training_args_init = TrainingArguments(**training_args)
# create the trainer
trainer = Trainer(
model=model,
args=training_args_init,
data_collator=DataCollatorForCellClassification(),
train_dataset=organ_trainset,
eval_dataset=organ_evalset,
compute_metrics=compute_metrics
)
# train the cell type classifier
trainer.train()
predictions = trainer.predict(organ_evalset)
with open(f"{output_dir}predictions.pickle", "wb") as fp:
pickle.dump(predictions, fp)
trainer.save_metrics("eval",predictions.metrics)
trainer.save_model(output_dir)
Some weights of the model checkpoint at D:\jupyterNote\Geneformer were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at D:\jupyterNote\Geneformer and are newly initialized: ['bert.pooler.dense.weight', 'classifier.bias', 'classifier.weight', 'bert.pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
<...training message...>
训练结束后,训练的模型,及其预测结果都输出到设置的output_dir
下。
每个文件夹中都包含了fine-tuned model相关文件(training_args.bin
, config.json
, pytorch_model.bin
),以及预测结果相关文件(predictions.pickle
)
$ ls 230719_geneformer_CellClassifier_brain_L2048_B4_LR5e-05_LSlinear_WU500_E10_Oadamw_F0/
all_results.json checkpoint-15984 checkpoint-23976 checkpoint-5328 eval_results.json training_args.bin
checkpoint-10656 checkpoint-18648 checkpoint-2664 checkpoint-7992 predictions.pickle
checkpoint-13320 checkpoint-21312 checkpoint-26640 config.json pytorch_model.bin
# clear GPU memory after pytorch training
import torch
torch.cuda.empty_cache()
# The pretrained model
model
BertForSequenceClassification(
(bert): BertModel(
(embeddings): BertEmbeddings(
(word_embeddings): Embedding(25426, 256, padding_idx=0)
(position_embeddings): Embedding(2048, 256)
(token_type_embeddings): Embedding(2, 256)
(LayerNorm): LayerNorm((256,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.02, inplace=False)
)
(encoder): BertEncoder(
(layer): ModuleList(
(0-5): 6 x BertLayer(
(attention): BertAttention(
(self): BertSelfAttention(
(query): Linear(in_features=256, out_features=256, bias=True)
(key): Linear(in_features=256, out_features=256, bias=True)
(value): Linear(in_features=256, out_features=256, bias=True)
(dropout): Dropout(p=0.02, inplace=False)
)
(output): BertSelfOutput(
(dense): Linear(in_features=256, out_features=256, bias=True)
(LayerNorm): LayerNorm((256,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.02, inplace=False)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=256, out_features=512, bias=True)
(intermediate_act_fn): ReLU()
)
(output): BertOutput(
(dense): Linear(in_features=512, out_features=256, bias=True)
(LayerNorm): LayerNorm((256,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.02, inplace=False)
)
)
)
)
(pooler): BertPooler(
(dense): Linear(in_features=256, out_features=256, bias=True)
(activation): Tanh()
)
)
(dropout): Dropout(p=0.02, inplace=False)
(classifier): Linear(in_features=256, out_features=12, bias=True)
)
接下来,我们将基于immune数据的微调模型应用到3k PBMCs scRNA-seq data上进行celltype prediction.
首先,我们需要将原始的测序counts转换为rank values encoding (tk.tokenize_data
).
# applying fine-tuned model on new datasets
# using the 3k PBMCs dataset
# 1. transform scRNA-seq expression data to rank value .dataset format
from geneformer import TranscriptomeTokenizer
tk = TranscriptomeTokenizer({"cell_type": "cell_type", "organ_major": "organ_major"}, nproc=4)
tk.tokenize_data("D:/jupyterNote/pySC/output/", output_directory="token_data/", output_prefix="tk_pbmc3k")
Tokenizing D:\jupyterNote\pySC\output\pbmc3k.loom
D:\jupyterNote\pySC\output\pbmc3k.loom has no column attribute 'filter_pass'; tokenizing all cells.
读入tokenized dataset
# 2. load new dataset
new_dataset = load_from_disk("D:/jupyterNote/Geneformer/examples/token_data/tk_pbmc3k.dataset/")
new_dataset
Dataset({
features: ['input_ids', 'cell_type', 'organ_major', 'length'],
num_rows: 2638
})
import pandas as pd
# input_ids represent rank encodings
pd.DataFrame(new_dataset)
由于模型要求input tensors(每个细胞的rank encoding)长度一致,这里将其padding到统一长度(所有细胞中最多的基因数)。
from geneformer.pretrainer import token_dictionary
def preprocess_classifier_batch(cell_batch, max_len):
if max_len == None:
max_len = max([len(i) for i in cell_batch["input_ids"]])
def pad_label_example(example):
#example["labels"] = np.pad(example["labels"],
# (0, max_len-len(example["input_ids"])),
# mode='constant', constant_values=-100)
example["input_ids"] = np.pad(example["input_ids"],
(0, max_len-len(example["input_ids"])),
mode='constant', constant_values=token_dictionary.get("<pad>"))
example["attention_mask"] = (example["input_ids"] != token_dictionary.get("<pad>")).astype(int)
return example
padded_batch = cell_batch.map(pad_label_example)
return padded_batch
# Function to find the largest number smaller
# than or equal to N that is divisible by k
def find_largest_div(N, K):
rem = N % K
if(rem == 0):
return N
else:
return N - rem
# padded to be the same length.
set_len=len(new_dataset)
max_set_len = max(new_dataset.select([i for i in range(set_len)])["length"])
padded_dataset = preprocess_classifier_batch(new_dataset, max_set_len)
Loading cached processed dataset at D:\jupyterNote\Geneformer\examples\token_data\tk_pbmc3k.dataset\cache-4214517ba3d677b2.arrow
pd.DataFrame(padded_dataset)
接下来,读入微调模型进行预测
# 3. load the fine-tuned model
# reload fine-tuned model
ft_model = BertForSequenceClassification.from_pretrained("cell_class_test/230719_geneformer_CellClassifier_immune_L2048_B4_LR5e-05_LSlinear_WU500_E10_Oadamw_F0/")
# since immune organ only include 10 celltypes in the training set, the out_features=10 in the final output layer
print(ft_model)
ft_trainer = Trainer(model=ft_model)
# 4. perform prediction
ct_predictions = ft_trainer.predict(padded_dataset)
ct_pred = ct_predictions.predictions
BertForSequenceClassification(
(bert): BertModel(
(embeddings): BertEmbeddings(
(word_embeddings): Embedding(25426, 256, padding_idx=0)
(position_embeddings): Embedding(2048, 256)
(token_type_embeddings): Embedding(2, 256)
(LayerNorm): LayerNorm((256,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.02, inplace=False)
)
(encoder): BertEncoder(
(layer): ModuleList(
(0-5): 6 x BertLayer(
(attention): BertAttention(
(self): BertSelfAttention(
(query): Linear(in_features=256, out_features=256, bias=True)
(key): Linear(in_features=256, out_features=256, bias=True)
(value): Linear(in_features=256, out_features=256, bias=True)
(dropout): Dropout(p=0.02, inplace=False)
)
(output): BertSelfOutput(
(dense): Linear(in_features=256, out_features=256, bias=True)
(LayerNorm): LayerNorm((256,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.02, inplace=False)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=256, out_features=512, bias=True)
(intermediate_act_fn): ReLU()
)
(output): BertOutput(
(dense): Linear(in_features=512, out_features=256, bias=True)
(LayerNorm): LayerNorm((256,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.02, inplace=False)
)
)
)
)
(pooler): BertPooler(
(dense): Linear(in_features=256, out_features=256, bias=True)
(activation): Tanh()
)
)
(dropout): Dropout(p=0.02, inplace=False)
(classifier): Linear(in_features=256, out_features=10, bias=True)
)
使用fine-tuned model进行分类预测,这里根据最大预测值判断cell type。
# celltype : index
immune_label_idx_dict = target_dict_list[5]
# get the predicted cell type by max value of prediction
ct_pred_id = ct_pred.argmax(-1) # list
ct_pred_label = [k for idx in ct_pred_id for k, v in immune_label_idx_dict.items() if v == idx]
print(len(ct_pred_label))
2638
print(ct_pred_id[0:10])
print(ct_pred_label[0:10])
print(immune_label_idx_dict)
[2 9 2 1 2 9 2 2 2 1]
['T cell', 'B cell', 'T cell', 'Monocyte', 'T cell', 'B cell', 'T cell', 'T cell', 'T cell', 'Monocyte']
{'Antigen presenting cell (RPS high)': 0, 'Monocyte': 1, 'T cell': 2, 'Neutrophil': 3, 'Erythroid cell': 4, 'Erythroid progenitor cell (RP high)': 5, 'Dendritic cell': 6, 'B cell (Plasmocyte)': 7, 'Neutrophil (RPS high)': 8, 'B cell': 9}
用UMAP可视化细胞分类的结果
import numpy as np
import pandas as pd
import scanpy as sc
import anndata
adata = anndata.read_h5ad("D:/jupyterNote/pySC/output/pbmc3k.h5ad")
adata
AnnData object with n_obs × n_vars = 2638 × 1838
obs: 'n_genes', 'n_genes_by_counts', 'n_counts', 'total_counts_mt', 'pct_counts_mt', 'leiden', 'cell_type', 'organ_major'
var: 'ensembl_id', 'n_cells', 'mt', 'n_cells_by_counts', 'mean_counts', 'pct_dropout_by_counts', 'total_counts', 'highly_variable', 'means', 'dispersions', 'dispersions_norm', 'mean', 'std', 'gene_name'
uns: 'hvg', 'leiden', 'leiden_colors', 'log1p', 'neighbors', 'pca', 'rank_genes_groups', 'umap'
obsm: 'X_pca', 'X_umap'
varm: 'PCs'
layers: 'counts', 'data', 'scaled'
obsp: 'connectivities', 'distances'
尽管Geneformer对细胞的名称和原数据不太一样,我们可以看到Geneformer注释的结果大体上和原本注释是一致的。总的来说,Geneformer可以作为一种细胞类型预测的工具使用,但最好先对预训练模型微调,这要求我们有相关的单细胞数据集进行微调训练。
adata.obs['geneformer_pred'] = ct_pred_label
sc.pl.umap(adata, color='geneformer_pred')
sc.pl.umap(adata, color='cell_type')
E:\miniconda3\envs\geneformer\lib\site-packages\scanpy\plotting\_tools\scatterplots.py:392: UserWarning: No data for colormapping provided via 'c'. Parameters 'cmap' will be ignored
cax = scatter(
E:\miniconda3\envs\geneformer\lib\site-packages\scanpy\plotting\_tools\scatterplots.py:392: UserWarning: No data for colormapping provided via 'c'. Parameters 'cmap' will be ignored
cax = scatter(
adata.obs
总结
对于细胞分类的微调,我们需要:
-
获取组织对应的微调数据集,并且有细胞的label信息,例如各个细胞类型;
关于数据集大小,从作者提供的例子来看,最少的情况是884个细胞,但其余下游任务都超过10k细胞
以
BertForSequenceClassification
的方式读入预训练模型,并设置num_labels
为分类数目;根据微调的数据集训练,加上最后的输出层(task-specific transformer layer),并对微调模型预测性能进行评估;
在新的数据集上应用微调模型进行预测。
Ref:
Transfer learning enables predictions in network biology: https://doi.org/10.1038/s41586-023-06139-9