局部敏感哈希算法及其SparkScala实现

我们想要在海量embedding数据中,找到一些与目标embedding相似的数据。这个相似可以是欧氏距离。在数据量太大的时候,例如1亿的数据量,把数据库中每一条embedding都取出来算一遍,然后排序,时间代价太大了。局部敏感哈希算法(locality sensitive hashing)就是为了解决此问题的。这个算法的应用场景主要是推荐系统的召回、图像检索等。
此算法构造一个哈希函数,期望使得相似的embedding很大概率进入同一个桶里面,而不相似的embedding很大概率在不同的桶里面。这样就能大大减少搜索时间了。用数学的语言描述是:
M是向量空间,dM上的距离测度,那么局部敏感哈希族(LSH family)是满足以下性质的一族函数h\forall p,q \in M d(p,q)<r_1\Rightarrow prob(h(p)=h(q))\ge p_1 d(p,q)>r_2\Rightarrow prob(h(p)=h(q))\le p_2我们称这个LSH family是(r_1,r_2,p_1,p_2)敏感的。如果我们构造的哈希函数足够“敏感”,那么我们根据分桶策略搜索相似embedding的准确率就越高。不同的距离测度,适合的哈希函数也不一样。
以下介绍集中常用的哈希函数:

Euclidean距离

对于Euclidean距离,通常把embedding映射到某个低维Euclidean空间。我们定义这么一个样子的哈希函数:h(x)=\left \lfloor \frac{\textbf {v}^T\textbf {x}}{r} \right \rfloor
其中\left \lfloor · \right \rfloor是向下取整符号。这个哈希函数是属于LSH family的,我们把这样的函数成为分桶随机映射(BuketedRandomProjection)。其中v是用于降维的随机向量,r是自定义参数,决定了桶的大小。对于任意两个embedding,如果r越大,那么这两个embedding进入同一个桶的概率就越大。
由于哈希函数是随机的,因此算法可以生成好几个哈希函数。当一个检索请求A过来时,假设B是候选集里的某个embedding。可以规定A和B在所有哈希函数下都落在同一个桶里,才算相似,这样提高了检索精度,也缩小了检索得到的相似物品数量。也可以规定A和B存在至少一个哈希函数落在同一个桶里,就算相似。这样做能够得到更多相似物品数量。

Jaccard距离

Jaccard距离是衡量两个集合之间的差异度的指标,Jaccard系数是1减去Jaccary距离,为两个集合的交集元素数量与两个集合元素总数量之差:J(A,B)=\frac{\left | A\cap B \right | }{\left | A\cup B \right | }
比较常见的场景是稀疏矩阵。例如我有一个稀疏向量A=[0, 0, 1, 0, 0, 1],那么同时可以用一个集合{2,5}来表示稀疏向量A,集合中的元素表示稀疏向量A的非零元素的索引。那么Jaccard就可以用来计算两个稀疏向量的相似度了。Jaccard距离测度下,通常用minhash作为其哈希函数:h(A)=\min_{a\in A}\left ( g\left ( a \right ) \right )函数g是随机一对一映射。例如在刚才的例子里,程序提前用随机方法确定这个映射g。2映射到了3,5映射到了1,那么其哈希值是1。
还可以换一个说法。程序对稀疏向量进行随机打乱,每一个向量都按照同样地顺序进行打乱,A变成了[0,1,0,1,0,0],那么其哈希值就是1。
不难证明,prob(h(A)=h(B))=J(A,B)。因为随机事件h(A)=h(B)等价于,在函数g的映射下,最小哈希值的逆映射,也就是A的某个元素h^{-1}(A),同时也在B中。A中的某个元素同时在B中的概率,这个概率就是Jaccard系数。

SparkScala LSH源代码阅读

经过以上的学习,我们可以总结出,一个局部敏感哈希大概要有以下几个接口:

  • 初始化LSH模型,生成随机向量,储存到内存中。定义桶的长度(就是上文的参数r),定义哈希函数数量。
  • hashFunction接口,输入一个或者多个embedding向量,输出哈希值
  • approxNearestNeighbors接口,输入一个embedding向量,根据敏感哈希函数查找最相似的topk个物品。

在spark.ml.feature路径下,有以下几个相关的程序:

  • BuketedRandomProjectionLSH.scala实现了欧式距离下的局部敏感哈希
  • MinHashLSH.scala实现了Jaccard距离下的局部敏感哈希
  • LSH.Scala是上述局部敏感哈希模型的基类

我抽了一些核心代码出来,以下是代码笔记:

  // BuketedRandomProjectionLSHModel的hashFunction
  @Since("2.1.0")
  override protected[ml] def hashFunction(elems: Vector): Array[Vector] = {
    // hashVec: vector with length numHashTables
    // randMatrix: matrix with shape (numHashTables, embeddingDim)
    // elems: vector with length embeddingDim
    val hashVec = new DenseVector(Array.ofDim[Double](randMatrix.numRows))
    BLAS.gemv(1.0 / $(bucketLength), randMatrix, elems, 0.0, hashVec)
    // TODO: Output vectors of dimension numHashFunctions in SPARK-18450
    hashVec.values.map(h => Vectors.dense(h.floor))
  }

以上是Euclidean距离下的敏感哈希函数。可以看到,gemv方法实现了一个随机矩阵randMatrix与原始的向量elem相乘,然后除以参数bucketLength,把结果赋值给hashVec。最后对hashVec用floor方法向下取整。

  // MinHashLSHModel的hashFunction
  @Since("2.1.0")
  override protected[ml] def hashFunction(elems: Vector): Array[Vector] = {
    require(elems.nonZeroIterator.nonEmpty, "Must have at least 1 non zero entry.")
    val hashValues = randCoefficients.map { case (a, b) =>
      elems.nonZeroIterator.map { case (i, _) =>
        ((1L + i) * a + b) % MinHashLSH.HASH_PRIME
      }.min.toDouble
    }
    // TODO: Output vectors of dimension numHashFunctions in SPARK-18450
    hashValues.map(Vectors.dense(_))
  }

以上是Jaccard测度下的敏感哈希函数。Spark的定义函数g为((1+i)\times a+b)\space mod \space 某常数其中a,b是随机数,i是稀疏向量非零元素的索引。得到的结果对一个很大的常数求模,从而保证输出的数字不会很大。
原始空间的距离和哈希值的距离的计算也有很多细节,这里就不一一考究了。以下是搜索最近哈希值距离的接口:

  private[feature] def approxNearestNeighbors(
      dataset: Dataset[_],
      key: Vector,
      numNearestNeighbors: Int,
      singleProbe: Boolean,
      distCol: String): Dataset[_] = {
    require(numNearestNeighbors > 0, "The number of nearest neighbors cannot be less than 1")
    // Get Hash Value of the key
    val keyHash = hashFunction(key)
    val modelDataset = if (!dataset.columns.contains($(outputCol))) {
        transform(dataset)
      } else {
        dataset.toDF()
      }

    val modelSubset = if (singleProbe) {
      def sameBucket(x: Array[Vector], y: Array[Vector]): Boolean = {
        x.iterator.zip(y.iterator).exists(tuple => tuple._1 == tuple._2)
      }

      // In the origin dataset, find the hash value that hash the same bucket with the key
      val sameBucketWithKeyUDF = udf((x: Array[Vector]) => sameBucket(x, keyHash))

      modelDataset.filter(sameBucketWithKeyUDF(col($(outputCol))))
    } else {
      // In the origin dataset, find the hash value that is closest to the key
      // Limit the use of hashDist since it's controversial
      val hashDistUDF = udf((x: Array[Vector]) => hashDistance(x, keyHash))
      val hashDistCol = hashDistUDF(col($(outputCol)))
      val modelDatasetWithDist = modelDataset.withColumn(distCol, hashDistCol)

      val relativeError = 0.05
      val summary = modelDatasetWithDist.select(distCol).rdd.mapPartitions { iter =>
        if (iter.hasNext) {
          var s = new QuantileSummaries(
            QuantileSummaries.defaultCompressThreshold, relativeError)
          while (iter.hasNext) {
            val row = iter.next
            if (!row.isNullAt(0)) {
              val v = row.getDouble(0)
              if (!v.isNaN) s = s.insert(v)
            }
          }
          Iterator.single(s.compress)
        } else Iterator.empty
      }.treeReduce((s1, s2) => s1.merge(s2))
      val count = summary.count

      // Compute threshold to get around k elements.
      // To guarantee to have enough neighbors in one pass, we need (p - err) * N >= M
      // so we pick quantile p = M / N + err
      // M: the number of nearest neighbors; N: the number of elements in dataset
      val approxQuantile = numNearestNeighbors.toDouble / count + relativeError

      if (approxQuantile >= 1) {
        modelDatasetWithDist
      } else {
        val hashThreshold = summary.query(approxQuantile).get
        // Filter the dataset where the hash value is less than the threshold.
        modelDatasetWithDist.filter(hashDistCol <= hashThreshold)
      }
    }

    // Get the top k nearest neighbor by their distance to the key
    val keyDistUDF = udf((x: Vector) => keyDistance(x, key))
    val modelSubsetWithDistCol = modelSubset.withColumn(distCol, keyDistUDF(col($(inputCol))))
    modelSubsetWithDistCol.sort(distCol).limit(numNearestNeighbors)
  }

变量dataset是储存了海量embedding数据的对象,keyHash是需要请求进来要查询的embedding向量的hash值,modelDataset是dataset经过了哈希映射。modelSubset是生成的候选集,生成候选集有两种方法:

  • 一种是把在同一个桶里的所有元素作为候选集,不在同一个桶里的元素不管他。
  • 另一种方法是找到好几个相似的桶生成候选集。这里使用了spark.sql.catalyst.util里的QuantileSummaries类,这个类的作用是统计hashDistCol的分位数。这个哈希值距离hashDistCol是什么呢?大概可以理解成一个RDD数组,这个数组的第i个元素表示查询数据库里的第i个元素和query元素的哈希距离。对这个数组统计分位数,然后按照分位数算出来要取得topK相似元素必须有多大的哈希距离阈值(K值是调用approxNearestNeighbors方法传入的参数)。

最后是对候选集进行排序,输出topK相似item。

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容