CARD 空间转录组和单细胞转录组之间的反卷积

首先附上参考文献:《Spatially informed cell-type deconvolution for spatial transcriptomics》

原理分析

截至目前,大部分的反卷积求细胞组分的基本逻辑都是回归,CARD的基本逻辑同样采用的是回归策略:

其中:

  1. B: as the G-by-K cell-type-specific expression matrix for the informative genes(单细胞表达矩阵)
  2. X: as the G-by-N gene expression matrix for the same set of informative genes measured on N spatial locations(空转表达矩阵)
  3. V: as the N-by-K cell-type composition matrix(空间每个location的细胞组分矩阵)
  4. E: 服从
  5. 以上矩阵均为非负矩阵

同时作者也考虑到了相邻位置的两个location细胞成分可能会很相近,因此考虑了矫正,这里的矫正作者利用了回归到思想,建立了 VikVjk 之间的线性关系,表明第 i 个位置的细胞组分首临近位置细胞组分的影响(第 i 个位置的细胞组分为非自由项):

其中:

  1. Vik:represents the proportion of cell type k on the ith location,即第 i 个位置 cell type k 的比例
  2. bk: is the kth cell-type-specific intercept that represents the average cell-type composition across locations,类似于回归问题中的截距,表示 cell type k 在所有位置比例的均值,为一列向量
  3. W: an N-by-N non-negative weight matrix with each element Wij specifying the weight used for inferring the cell-type composition on the ith location based on the cell-type composition information on the jth location,类似于回归问题中的回归系数
  4. ϕ: is a spatial autocorrelation parameter that determines the strength of the spatial correlation in cell-type composition,表征为决定系数
  5. Vjk:kth cell-type compositions on all other locations

而我们的目的就是要推断出 Vik 矩阵,即每个location的细胞组分

统计推断部分:
根据引理:

引理

假设下面式子中的εik服从正态分布

εik 代入引理可得到公式1,我们可以得到如下关系,公式1主要描述的是寻找一个合适的 Vik 矩阵,使得误差 εik 尽可能小:
公式1

而整个统计推断将转变为一个最优化问题,即寻找一个合适的 Vik 矩阵,使得 Vik
之间的误差 εik 尽可能小

将公式1化简以后:


这里巧用矩阵乘法法则,将加和(∑)改变成为了两个矩阵乘积形式,其中:



提出公因式后

接下来构造似然函数,利用极大似然法求解参数:
首先,作者先定义协方差矩阵如下:


第二步将所有函数的似然值相乘构造似然函数:( k 为 k 个cell type,i 为 i 个location)
这里的似然函数要估计 4 个参数 V,λk,σe2,bk,并且在这个似然函数中要优化两个回归:


于是联合起来得到的似然函数,这个似然函数的作用是使得误差 Egi 和 εik 以及σe2在 0 处的似然值最大λk 感觉像是一个正则项
公式2


这里求解的是似然函数的最小值,作者把求解最小值的问题转换为求解最大值:
公式3-1

公式3-2

最后利用极大似然法求出最优 V矩阵 的参数解,这里的G代表基因数目,N为 location 数目

代码分析

首先下载相关数据:

  1. sc_count.RData为单细胞表达矩阵
  2. sc_meta.RData为单细胞基本信息
  3. spatial_count.RData为空间转录组表达矩阵
    spatial_count
    行代表基因,列代表位置信息,矩阵元素代表基因在不同 location 的表达量
  4. spatial_location.RData为空间转录组位置信息
    spatial_location
    x,y相当于二维坐标,空间转录组相当于在一个平面上按照矩形形式标记不同的像素点,10×10代表平面第10行第10列的像素点,每一个像素点相当于一个小的 bulk-seq
    每一个小圆点相当于一个像素点,称之为一个 location

示例数据给的流程如下:

# 载入数据
load("C:/Users/lenovo/Downloads/spatial_count.RData")
load("C:/Users/lenovo/Downloads/spatial_location.RData")
load("C:/Users/lenovo/Downloads/sc_count.RData")
load("C:/Users/lenovo/Downloads/sc_meta.RData")

library(CARD)

CARD_obj = createCARDObject(
  sc_count = sc_count,
  sc_meta = sc_meta,
  spatial_count = spatial_count,
  spatial_location = spatial_location,
  ct.varname = "cellType",
  ct.select = unique(sc_meta$cellType),
  sample.varname = "sampleInfo",
  minCountGene = 100,
  minCountSpot = 5) 

CARD_obj = CARD_deconvolution(CARD_object = CARD_obj)
# 细胞组分矩阵
CARD_obj@Proportion_CARD

细胞组分矩阵:CARD_obj@Proportion_CARD

1.解析createCARDObject()函数

# 载入数据
sc_count = sc_count
sc_meta = sc_meta
spatial_count = spatial_count
spatial_location = spatial_location
ct.varname = "cellType"
ct.select = unique(sc_meta$cellType)
sample.varname = "sampleInfo"
minCountGene = 100
minCountSpot = 5

# step 1 对单细胞的数据进行质控
sc_countMat  <- sc_count
ct.select <- as.character(ct.select[!is.na(ct.select)])
sc_eset = sc_QC(sc_countMat,sc_meta,ct.varname,ct.select,sample.varname)
#### Check the spatial count dataset
#### QC on spatial dataset
spatial_countMat <- spatial_count
commonGene = intersect(rownames(spatial_countMat),rownames(assays(sc_eset)$counts))

# step2 对空转的数据进行过滤
#### QC on spatial dataset
spatial_countMat = spatial_countMat[rowSums(spatial_countMat > 0) > minCountSpot,]
spatial_countMat = spatial_countMat[,(colSums(spatial_countMat) >= minCountGene & colSums(spatial_countMat) <= 1e6)]
spatial_location = spatial_location[rownames(spatial_location) %in% colnames(spatial_countMat),]
spatial_location = spatial_location[match(colnames(spatial_countMat),rownames(spatial_location)),]

object <- new(
  Class = "CARD",
 # 质控后的单细胞不同 cell type 的基因表达矩阵
  sc_eset = sc_eset,
  spatial_countMat = spatial_countMat,
  spatial_location = spatial_location,
  project = "Deconvolution",
  info_parameters = list(ct.varname = ct.varname,ct.select = ct.select,sample.varname = sample.varname)
)
return(object)


## sc_QC 函数的作用是过滤一些低表达的基因和低质量的细胞
sc_QC <- function(counts_in,metaData,ct.varname,ct.select,sample.varname = NULL, min.cells = 0,min.genes = 0){
# Filter based on min.features
    coldf = metaData
    counts = counts_in
    if (min.genes >= 0) {
        nfeatures <- colSums(x = counts )
        counts <- counts[, which(x = nfeatures > min.genes)]
        coldf <- coldf[which(x = nfeatures > min.genes),]
    }
    # filter genes on the number of cells expressing
    if (min.cells >= 0) {
        num.cells <- rowSums(x = counts > 0)
        counts <- counts[which(x = num.cells > min.cells), ]
    }
    fdata = as.data.frame(rownames(counts))
    rownames(fdata) = rownames(counts)
    keepCell = as.character(coldf[,ct.varname]) %in% ct.select
    counts = counts[,keepCell]
    coldf = coldf[keepCell,]
    keepGene = rowSums(counts) > 0
    fdata = as.data.frame(fdata[keepGene,])
    counts = counts[keepGene,]
    sce <- SingleCellExperiment(list(counts=counts),
    colData=as.data.frame(coldf),
    rowData=as.data.frame(fdata))
    return(sce)
}

2.解析CARD_deconvolution()函数

# 读取createCARDObject()的结果文件
CARD_object = CARD_obj

# 获取不同的 cellType 名称, ct.select 为不同 cellType 的名称
ct.select = CARD_object@info_parameters$ct.select
# ct.varname 为字符串 "cellType"
ct.varname = CARD_object@info_parameters$ct.varname
sample.varname = CARD_object@info_parameters$sample.varname

# sc_eset 为单细胞表达矩阵, 利用 counts(sc_eset) 查看表达矩阵
sc_eset = CARD_object@sc_eset

# 对单细胞表达矩阵进行标准化
Basis_ref = createscRef(sc_eset, ct.select, ct.varname, sample.varname)

Basis = Basis_ref$basis
Basis = Basis[,colnames(Basis) %in% ct.select]
Basis = Basis[,match(ct.select,colnames(Basis))]
# 获得空间转录组表达矩阵 spatial_count 
spatial_count = CARD_object@spatial_countMat
commonGene = intersect(rownames(spatial_count),rownames(Basis))
#### remove mitochondrial and ribosomal genes
#### 去除 mt DNA
commonGene  = commonGene[!(commonGene %in% commonGene[grep("mt-",commonGene)])]

common = selectInfo(Basis,sc_eset,commonGene,ct.select,ct.varname)
# 空转表达矩阵 Xinput 
Xinput = spatial_count
# 单细胞表达矩阵 B
B = Basis

##### match the common gene names
##### 对空间表达矩阵选择 common 的 gene 进行后续分析
Xinput = Xinput[order(rownames(Xinput)),]
B = B[order(rownames(B)),]
B = B[rownames(B) %in% common,]
Xinput = Xinput[rownames(Xinput) %in% common,]

##### filter out non expressed genes or cells again
##### 对空间表达矩阵过滤掉没有表达的细胞和基因
Xinput = Xinput[rowSums(Xinput) > 0,]
Xinput = Xinput[,colSums(Xinput) > 0]

##### normalize count data
##### 对空转表达矩阵进行标准化
colsumvec = colSums(Xinput)
### 相当于每个基因对相应位置的总测序深度做标准化
Xinput_norm = sweep(Xinput,2,colsumvec,"/")
B = B[rownames(B) %in% rownames(Xinput_norm),]    
B = B[match(rownames(Xinput_norm),rownames(B)),]

#### spatial location
#### 获取空转的位置信息
spatial_location = CARD_object@spatial_location
spatial_location = spatial_location[rownames(spatial_location) %in% colnames(Xinput_norm),]
spatial_location = spatial_location[match(colnames(Xinput_norm),rownames(spatial_location)),]

##### normalize the coordinates without changing the shape and relative position
### 对空转的位置进行标准化,转换为相对位置
norm_cords = spatial_location[ ,c("x","y")]
norm_cords$x = norm_cords$x - min(norm_cords$x)
norm_cords$y = norm_cords$y - min(norm_cords$y)
scaleFactor = max(norm_cords$x,norm_cords$y)
norm_cords$x = norm_cords$x / scaleFactor
norm_cords$y = norm_cords$y / scaleFactor

##### initialize the proportion matrix
### 计算空转位置间的欧式距离
ED <- rdist::rdist(as.matrix(norm_cords))##Euclidean distance matrix

set.seed(20200107)
Vint1 = as.matrix(gtools::rdirichlet(ncol(Xinput_norm), rep(10,ncol(B))))
colnames(Vint1) = colnames(B)
rownames(Vint1) = colnames(Xinput_norm)
b = rep(0,length(ct.select))

###### parameters that need to be set
isigma = 0.1 ####construct Gaussian kernel with the default scale /length parameter to be 0.1
epsilon = 1e-04  #### convergence epsion 
phi = c(0.01,0.1,0.3,0.5,0.7,0.9,0.99) #### grided values for phi

## 随机生成 W 矩阵
kernel_mat <- exp(-ED^2 / (2 * isigma^2))
diag(kernel_mat) <- 0

###### scale the Xinput_norm and B to speed up the convergence. 
mean_X = mean(Xinput_norm)
mean_B = mean(B)
Xinput_norm = Xinput_norm * 1e-01 / mean_X
B = B * 1e-01 / mean_B
ResList = list()
Obj = c()
## 利用不同的参数 phi 来估计模型
for(iphi in 1:length(phi)){
  res = CARDref(
    XinputIn = as.matrix(Xinput_norm),
    UIn = as.matrix(B),
    WIn = kernel_mat, 
    phiIn = phi[iphi],
    max_iterIn =1000,
    epsilonIn = epsilon,
    initV = Vint1,
    initb = rep(0,ncol(B)),
    initSigma_e2 = 0.1, 
    initLambda = rep(10,length(ct.select)))
  rownames(res$V) = colnames(Xinput_norm)
  colnames(res$V) = colnames(B)
  ResList[[iphi]] = res
  Obj = c(Obj,res$Obj)
}

## 选择最优的参数下的模型
Optimal = which(Obj == max(Obj))
Optimal = Optimal[length(Optimal)] #### just in case if there are two equal objective function values
OptimalPhi = phi[Optimal]
OptimalRes = ResList[[Optimal]]
cat(paste0("## Deconvolution Finish! ...\n"))
CARD_object@info_parameters$phi = OptimalPhi

### 获得细胞组分矩阵 Proportion_CARD
CARD_object@Proportion_CARD = sweep(OptimalRes$V,1,rowSums(OptimalRes$V),"/")
CARD_object@algorithm_matrix = list(B = B * mean_B / 1e-01, Xinput_norm = Xinput_norm * mean_X / 1e-01, Res = OptimalRes)
CARD_object@spatial_location = spatial_location


################################## 其中 createscRef() 函数
# 读取数据
x = sc_eset # 单细胞表达矩阵
ct.select = ct.select # 获取不同的 cellType 名称, ct.select 为不同 cellType 的名称
ct.varname = ct.varname # ct.varname 为字符串 "cellType"
sample.varname = sample.varname

# 其中 createscRef() 函数
createscRef <- function(x, ct.select = NULL, ct.varname, sample.varname = NULL){
  library(MuSiC)
  if (is.null(ct.select)) {
    ct.select <- unique(colData(x)[, ct.varname])
  }
  # 去除 cellType 为 NA 的 cell Type
  ct.select <- ct.select[!is.na(ct.select)]
  # countMat <- as.matrix(assays(x)$counts)
  # 将单细胞表达矩阵取出来赋予 countMat 
  countMat <- as(SummarizedExperiment::assays(x)$counts,"sparseMatrix")
  # ct.id 的作用相当于将每个 cell 赋予对应的 cell type
  ct.id <- droplevels(as.factor(SummarizedExperiment::colData(x)[, ct.varname]))
  #if(length(unique(colData(x)[,sample.varname])) > 1){
  if(is.null(sample.varname)){
    SummarizedExperiment::colData(x)$sampleID = "Sample"
    sample.varname = "sampleID"
  }
  sample.id <- as.character(SummarizedExperiment::colData(x)[, sample.varname])
  ct_sample.id <- paste(ct.id, sample.id, sep = "$*$")
  colSums_countMat <- colSums(countMat)
  colSums_countMat_Ct = aggregate(colSums_countMat ~ ct.id + sample.id, FUN = 'sum')
  colSums_countMat_Ct_wide = reshape(colSums_countMat_Ct, idvar = "sample.id", timevar = "ct.id", direction = "wide")
  colnames(colSums_countMat_Ct_wide) = gsub("colSums_countMat.","",colnames(colSums_countMat_Ct_wide))
  rownames(colSums_countMat_Ct_wide) = colSums_countMat_Ct_wide$sample.id
  colSums_countMat_Ct_wide$sample.id <- NULL
  tbl <- table(sample.id,ct.id)
  colSums_countMat_Ct_wide = colSums_countMat_Ct_wide[,match(colnames(tbl),colnames(colSums_countMat_Ct_wide))]
  colSums_countMat_Ct_wide = colSums_countMat_Ct_wide[match(rownames(tbl),rownames(colSums_countMat_Ct_wide)),]
  S_JK <- colSums_countMat_Ct_wide / tbl
  S_JK <- as.matrix(S_JK)
  S_JK[S_JK == 0] = NA
  S_JK[!is.finite(S_JK)] = NA
  S = colMeans(S_JK, na.rm = TRUE)
  S = S[match(unique(ct.id),names(S))]
  library("wrMisc")
  if(nrow(countMat) > 10000 & ncol(countMat) > 50000){ ### to save memory 
    seqID = seq(1,nrow(countMat),by = 10000)
    Theta_S_rowMean = NULL
    for(igs in seqID){
      if(igs != seqID[length(seqID)]){
        Theta_S_rowMean_Tmp <- rowGrpMeans(as.matrix(countMat[(igs:(igs+9999)),]), grp = ct_sample.id, na.rm = TRUE)
      }else{
        Theta_S_rowMean_Tmp <- rowGrpMeans(as.matrix(countMat[igs:nrow(countMat),]), grp = ct_sample.id, na.rm = TRUE)
        
      }
      Theta_S_rowMean <- rbind(Theta_S_rowMean,Theta_S_rowMean_Tmp)
      
    }
  }else{
    Theta_S_rowMean <- rowGrpMeans(as.matrix(countMat), grp = ct_sample.id, na.rm = TRUE)
  }
  tbl_sample = table(ct_sample.id)
  tbl_sample = tbl_sample[match(colnames(Theta_S_rowMean),names(tbl_sample))]
  Theta_S_rowSums <- sweep(Theta_S_rowMean,2,tbl_sample,"*")
  Theta_S <- sweep(Theta_S_rowSums,2,colSums(Theta_S_rowSums),"/")
  grp <- sapply(strsplit(colnames(Theta_S),split="$*$",fixed = TRUE),"[",1)
  Theta = rowGrpMeans(Theta_S, grp = grp, na.rm = TRUE)
  Theta = Theta[,match(unique(ct.id),colnames(Theta))]
  S = S[match(colnames(Theta),names(S))]
  basis = sweep(Theta,2,S,"*")
  colnames(basis) = colnames(Theta)
  rownames(basis) = rownames(Theta)
  return(list(basis = basis))
}



################################## 其中 selectInfo() 函数
selectInfo <- function(Basis,sc_eset,commonGene,ct.select,ct.varname){
#### log2 mean fold change >0.5
gene1 = lapply(ct.select,function(ict){
rest = rowMeans(Basis[,colnames(Basis) != ict])
FC = log((Basis[,ict] + 1e-06)) - log((rest + 1e-06))
rownames(Basis)[FC > 1.25 & Basis[,ict] > 0]
})
gene1 = unique(unlist(gene1))
gene1 = intersect(gene1,commonGene)
counts = assays(sc_eset)$counts
counts = counts[rownames(counts) %in% gene1,]
##### only check the cell type that contains at least 2 cells
ct.select = names(table(colData(sc_eset)[,ct.varname]))[table(colData(sc_eset)[,ct.varname]) > 1]
sd_within = sapply(ct.select,function(ict){
  temp = counts[,colData(sc_eset)[,ct.varname] == ict]
  apply(temp,1,var) / apply(temp,1,mean)
  })
##### remove the outliers that have high dispersion across cell types
gene2 = rownames(sd_within)[apply(sd_within,1,mean,na.rm = T) < quantile(apply(sd_within,1,mean,na.rm = T),prob = 0.99,na.rm = T)]
return(gene2)
}

关于 CARD_deconvolution()中的变量

  1. 有关 Basis_ref$basis
    Basis_ref$basis
  2. 有关 spatial_count
    spatial_count
  3. 有关标准化后的位置信息 norm_cords:
    norm_cords

关于 createscRef()中的变量

  1. 有关 ct.id:
    ct.id
  2. 有关 ct.select:
    ct.select

3. Cpp 函数 CARDref()

#include <iostream>
#include <fstream>
#define ARMA_64BIT_WORD 1
#include <RcppArmadillo.h>
// [[Rcpp::depends(RcppArmadillo)]]

#include <R.h>
#include <Rmath.h>
#include <cmath>
#include <stdio.h>
#include <stdlib.h>
#include <cstring>
#include <ctime>
#include <Rcpp.h>

// Enable C++11 via this plugin (Rcpp 0.10.3 or later)
// [[Rcpp::plugins(cpp11)]]

using namespace std;
using namespace arma;
using namespace Rcpp;

#define ARMA_DONT_PRINT_ERRORS


//*******************************************************************//
//              spatially informed deconvolution:CARD                        //
//*******************************************************************//
//' SpatialDeconv function based on Conditional Autoregressive model
//' @param XinputIn The input of normalized spatial data
//' @param UIn The input of cell type specific basis matrix B
//' @param WIn The constructed W weight matrix from Gaussian kernel
//' @param phiIn The phi value
//' @param max_iterIn Maximum iterations
//' @param epsilonIn epsilon for convergence 
//' @param initV Initial matrix of cell type compositions V
//' @param initb Initial vector of cell type specific intercept
//' @param initSigma_e2 Initial value of residual variance
//' @param initLambda Initial vector of cell type sepcific scalar. 
//'
//' @return A list
//'
//' @export
// [[Rcpp::export]]
SEXP CARDref(SEXP XinputIn, SEXP UIn, SEXP WIn, SEXP phiIn, SEXP max_iterIn, SEXP epsilonIn, SEXP initV, SEXP initb, SEXP initSigma_e2, SEXP initLambda)
{    
    try {
        // read in the data
        arma::mat Xinput = as<mat>(XinputIn);
        arma::mat U = as<mat>(UIn);
        arma::mat W = as<mat>(WIn);
        double phi = as<double>(phiIn);
        int max_iter = Rcpp::as<int>(max_iterIn);
        double epsilon = as<double>(epsilonIn);
        arma::mat V = as<mat>(initV);
        arma::vec b = as<vec>(initb);
        double sigma_e2 = as<double>(initSigma_e2);
        arma::vec lambda = as<vec>(initLambda);
        // initialize some useful items
        int nSample = (int)Xinput.n_cols; // number of spatial sample points
        int mGene = (int)Xinput.n_rows; // number of genes in spatial deconvolution
        int k = (int)U.n_cols; // number of cell type
        arma::mat L = zeros<mat>(nSample,nSample);
        arma::mat D = zeros<mat>(nSample,nSample);
        arma::mat V_old = zeros<mat>(nSample,k);
        arma::mat UtU = zeros<mat>(k,k);
        arma::mat VtV = zeros<mat>(k,k);
        arma::vec colsum_W = zeros<vec>(nSample);
        arma::mat UtX = zeros<mat>(k,nSample);
        arma::mat XtU = zeros<mat>(nSample,k);
        arma::mat UtXV = zeros<mat>(k,k);
        arma::mat temp = zeros<mat>(k,k);
        arma::mat part1 = zeros<mat>(nSample,k);
        arma::mat part2 = zeros<mat>(nSample,k);
        arma::vec updateV_k = zeros<vec>(k);
        arma::vec updateV_den_k = zeros<vec>(k);
        arma::vec vecOne = ones<vec>( nSample);
        arma::vec diag_UtU = zeros<vec>(k);
        bool logicalLogL = FALSE;
        double obj = 0;
        double obj_old = 0;
        double normNMF = 0;
        double logX = 0;
        double logV = 0;
        double alpha = 1.0;
        double beta = nSample / 2.0;
        double logSigmaL2 = 0.0;
        double accu_L = 0.0;
        double trac_xxt = accu(Xinput % Xinput);
        
        // initialize values
        // constant matrix caculations for increasing speed 
        UtX = U.t() * Xinput;
        XtU = UtX.t();
        colsum_W = sum(W,1);
        D =  diagmat(colsum_W);// diagnol matrix whose entries are column
        L = D -  phi*W; // graph laplacian
        accu_L = accu(L);
        UtXV = UtX * V;
        VtV = V.t() * V;
        UtU = U.t() * U;
        diag_UtU = UtU.diag();
        // calculate initial objective function 
        normNMF = trac_xxt - 2.0 * trace(UtXV) + trace(UtU * VtV);
        logX = -(double)(mGene * nSample) * 0.5 * log(sigma_e2) - 0.5 * (double)(normNMF / sigma_e2);
        temp = (V.t() - b * vecOne.t()) * L * (V - vecOne * b.t());
        logV = - (double)(nSample) * 0.5 * sum(log(lambda )) - 0.5 * (sum(temp.diag() / lambda )); 
        logSigmaL2 = -(alpha + 1.0) * sum(log(lambda)) - sum(beta / lambda);
        obj_old = logX + logV + logSigmaL2;
        V_old = V;
        // iteration starts
        for(int i = 1; i <= max_iter; ++i) {
            logV = 0.0;  
            b = sum(V.t() * L, 1) / accu_L;
            lambda = (temp.diag() / 2.0 + beta ) / (double(nSample) / 2.0 + alpha + 1.0);  
            part1 = sigma_e2 * (D * V + phi * colsum_W * b.t());
            part2 = sigma_e2 * (phi * W * V + colsum_W * b.t());
            for(int nCT = 0; nCT < k; ++nCT){
                updateV_den_k = lambda(nCT) * (V.col(nCT) * diag_UtU(nCT) + (V * UtU.col(nCT) - V.col(nCT) * diag_UtU(nCT))) +  part1.col(nCT);
                updateV_k = (lambda(nCT) * XtU.col(nCT) + part2.col(nCT)) / updateV_den_k;
                V.col(nCT) %= updateV_k;
            }
            UtXV = UtX * V;
            VtV = V.t() * V;
            normNMF = trac_xxt - 2.0 * trace(UtXV) + trace(UtU * VtV);
            sigma_e2 = normNMF / (double)(mGene * nSample);
            temp = (V.t() - b * vecOne.t()) * L * (V - vecOne * b.t());
            logX = -(double)(nSample * mGene) * 0.5 * log(sigma_e2) - 0.5 * (double)(normNMF / sigma_e2);
            logV = -(double)(nSample) * 0.5 * sum(log(lambda))- 0.5 * (sum(temp.diag() / lambda )); 
            logSigmaL2 = -(alpha + 1.0) * sum(log(lambda)) - sum(beta / lambda);
            obj = logX + logV + logSigmaL2;
            logicalLogL = (obj > obj_old) && (abs(obj - obj_old) * 2.0 / abs(obj + obj_old) < epsilon);
            if(isnan(obj) || (sqrt(accu((V - V_old) % (V - V_old)) / double(nSample * k))  < epsilon) || logicalLogL){
               if(i > 5){ // run at least 5 iterations 
                   break;
               }
       }else{
            obj_old = obj;
            V_old = V;
         }
       }
       return List::create(Named("V") = V,
                           Named("sigma_e2") = sigma_e2,
                           Named("lambda") = lambda,
                           Named("b") = b,
                           Named("Obj") = obj);
        }//end try 
        catch (std::exception &ex)
        {
            forward_exception_to_r(ex);
        }
        catch (...)
        {
            ::Rf_error("C++ exception (unknown reason)...");
        }
        return R_NilValue;
} // end funcs

Cpp 中的变量解释:

  1. trac_xxt = accu(Xinput % Xinput);,与公式3-1的:
  2. 2.0 * trace(UtXV)
# 根据矩阵乘法的性质, 体现出加和的形式
UtX = U.t() * Xinput;
UtXV = UtX * V;
trace(UtXV);

UtX = U.t() * Xinput; UtXV = UtX * V;,代表 BTXV

  1. trace(UtU * VtV)
# 根据矩阵乘法的性质, 体现出加和的形式
VtV = V.t() * V;
UtU = U.t() * U; 
trace(UtU * VtV);

UtU = U.t() * U;,代表 BTBVtV = V.t() * V;,代表 VTV

其中:

  1. 初始化各项矩阵:
UtX = U.t() * Xinput;
XtU = UtX.t();
colsum_W = sum(W,1);
D =  diagmat(colsum_W);// diagnol matrix whose entries are column
L = D -  phi*W; // graph laplacian
accu_L = accu(L);
UtXV = UtX * V;
VtV = V.t() * V;
UtU = U.t() * U;
diag_UtU = UtU.diag();
  1. 构造似然函数:
normNMF = trac_xxt - 2.0 * trace(UtXV) + trace(UtU * VtV);
logX = -(double)(mGene * nSample) * 0.5 * log(sigma_e2) - 0.5 * (double)(normNMF / sigma_e2);
temp = (V.t() - b * vecOne.t()) * L * (V - vecOne * b.t());
logV = -(double)(nSample) * 0.5 * sum(log(lambda )) - 0.5 * (sum(temp.diag() / lambda )); 
logSigmaL2 = -(alpha + 1.0) * sum(log(lambda)) - sum(beta / lambda);
obj_old = logX + logV + logSigmaL2;
V_old = V;
  1. normNMF = trac_xxt - 2.0 * trace(UtXV) + trace(UtU * VtV);代表公式3-1中的:
  2. logX = -(double)(mGene * nSample) * 0.5 * log(sigma_e2) - 0.5 * (double)(normNMF / sigma_e2);代表公式3-1中的(不清楚为什么这里的符号是反的):
  3. temp = (V.t() - b * vecOne.t()) * L * (V - vecOne * b.t());代表公式3-1中的:
  4. logV = - (double)(nSample) * 0.5 * sum(log(lambda )) - 0.5 * (sum(temp.diag() / lambda ));代表公式3-1中的:
  5. logSigmaL2 = -(alpha + 1.0) * sum(log(lambda)) - sum(beta / lambda);代表公式3-1中的:
  6. obj_old = logX + logV + logSigmaL2;将他们加和
  7. 迭代终止条件:
if(isnan(obj) || (sqrt(accu((V - V_old) % (V - V_old)) / double(nSample * k))  < epsilon) || logicalLogL){
     if(i > 5){ // run at least 5 iterations 
     break;
}

满足下式小于定义的epsilon即可
  1. 每次迭代更新 V 矩阵:
for(int nCT = 0; nCT < k; ++nCT){
     // nCT 相当于cell type k,V.col(nCT) 代表提取 V 矩阵的第 nCT 列
     updateV_den_k = lambda(nCT) * (V.col(nCT) * diag_UtU(nCT) + (V * UtU.col(nCT) - V.col(nCT) * diag_UtU(nCT))) +  part1.col(nCT);
     // 计算每一列的 updateV_k 
     updateV_k = (lambda(nCT) * XtU.col(nCT) + part2.col(nCT)) / updateV_den_k;
     // V 矩阵的每一列除以 updateV_k,从而更新 V 矩阵
     V.col(nCT) %= updateV_k;
}

然后基于更新后的 V 矩阵更新 lambda :

temp = (V.t() - b * vecOne.t()) * L * (V - vecOne * b.t());
lambda = (temp.diag() / 2.0 + beta ) / (double(nSample) / 2.0 + alpha + 1.0);  

然后基于更新后的 lambda 在更新 V 矩阵:

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 213,417评论 6 492
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 90,921评论 3 387
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 158,850评论 0 349
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 56,945评论 1 285
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 66,069评论 6 385
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 50,188评论 1 291
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,239评论 3 412
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 37,994评论 0 268
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,409评论 1 304
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 36,735评论 2 327
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 38,898评论 1 341
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 34,578评论 4 336
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,205评论 3 317
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 30,916评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,156评论 1 267
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 46,722评论 2 363
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 43,781评论 2 351

推荐阅读更多精彩内容