局部敏感哈希算法及其SparkScala实现

我们想要在海量embedding数据中,找到一些与目标embedding相似的数据。这个相似可以是欧氏距离。在数据量太大的时候,例如1亿的数据量,把数据库中每一条embedding都取出来算一遍,然后排序,时间代价太大了。局部敏感哈希算法(locality sensitive hashing)就是为了解决此问题的。这个算法的应用场景主要是推荐系统的召回、图像检索等。
此算法构造一个哈希函数,期望使得相似的embedding很大概率进入同一个桶里面,而不相似的embedding很大概率在不同的桶里面。这样就能大大减少搜索时间了。用数学的语言描述是:
M是向量空间,dM上的距离测度,那么局部敏感哈希族(LSH family)是满足以下性质的一族函数h\forall p,q \in M d(p,q)<r_1\Rightarrow prob(h(p)=h(q))\ge p_1 d(p,q)>r_2\Rightarrow prob(h(p)=h(q))\le p_2我们称这个LSH family是(r_1,r_2,p_1,p_2)敏感的。如果我们构造的哈希函数足够“敏感”,那么我们根据分桶策略搜索相似embedding的准确率就越高。不同的距离测度,适合的哈希函数也不一样。
以下介绍集中常用的哈希函数:

Euclidean距离

对于Euclidean距离,通常把embedding映射到某个低维Euclidean空间。我们定义这么一个样子的哈希函数:h(x)=\left \lfloor \frac{\textbf {v}^T\textbf {x}}{r} \right \rfloor
其中\left \lfloor · \right \rfloor是向下取整符号。这个哈希函数是属于LSH family的,我们把这样的函数成为分桶随机映射(BuketedRandomProjection)。其中v是用于降维的随机向量,r是自定义参数,决定了桶的大小。对于任意两个embedding,如果r越大,那么这两个embedding进入同一个桶的概率就越大。
由于哈希函数是随机的,因此算法可以生成好几个哈希函数。当一个检索请求A过来时,假设B是候选集里的某个embedding。可以规定A和B在所有哈希函数下都落在同一个桶里,才算相似,这样提高了检索精度,也缩小了检索得到的相似物品数量。也可以规定A和B存在至少一个哈希函数落在同一个桶里,就算相似。这样做能够得到更多相似物品数量。

Jaccard距离

Jaccard距离是衡量两个集合之间的差异度的指标,Jaccard系数是1减去Jaccary距离,为两个集合的交集元素数量与两个集合元素总数量之差:J(A,B)=\frac{\left | A\cap B \right | }{\left | A\cup B \right | }
比较常见的场景是稀疏矩阵。例如我有一个稀疏向量A=[0, 0, 1, 0, 0, 1],那么同时可以用一个集合{2,5}来表示稀疏向量A,集合中的元素表示稀疏向量A的非零元素的索引。那么Jaccard就可以用来计算两个稀疏向量的相似度了。Jaccard距离测度下,通常用minhash作为其哈希函数:h(A)=\min_{a\in A}\left ( g\left ( a \right ) \right )函数g是随机一对一映射。例如在刚才的例子里,程序提前用随机方法确定这个映射g。2映射到了3,5映射到了1,那么其哈希值是1。
还可以换一个说法。程序对稀疏向量进行随机打乱,每一个向量都按照同样地顺序进行打乱,A变成了[0,1,0,1,0,0],那么其哈希值就是1。
不难证明,prob(h(A)=h(B))=J(A,B)。因为随机事件h(A)=h(B)等价于,在函数g的映射下,最小哈希值的逆映射,也就是A的某个元素h^{-1}(A),同时也在B中。A中的某个元素同时在B中的概率,这个概率就是Jaccard系数。

SparkScala LSH源代码阅读

经过以上的学习,我们可以总结出,一个局部敏感哈希大概要有以下几个接口:

  • 初始化LSH模型,生成随机向量,储存到内存中。定义桶的长度(就是上文的参数r),定义哈希函数数量。
  • hashFunction接口,输入一个或者多个embedding向量,输出哈希值
  • approxNearestNeighbors接口,输入一个embedding向量,根据敏感哈希函数查找最相似的topk个物品。

在spark.ml.feature路径下,有以下几个相关的程序:

  • BuketedRandomProjectionLSH.scala实现了欧式距离下的局部敏感哈希
  • MinHashLSH.scala实现了Jaccard距离下的局部敏感哈希
  • LSH.Scala是上述局部敏感哈希模型的基类

我抽了一些核心代码出来,以下是代码笔记:

  // BuketedRandomProjectionLSHModel的hashFunction
  @Since("2.1.0")
  override protected[ml] def hashFunction(elems: Vector): Array[Vector] = {
    // hashVec: vector with length numHashTables
    // randMatrix: matrix with shape (numHashTables, embeddingDim)
    // elems: vector with length embeddingDim
    val hashVec = new DenseVector(Array.ofDim[Double](randMatrix.numRows))
    BLAS.gemv(1.0 / $(bucketLength), randMatrix, elems, 0.0, hashVec)
    // TODO: Output vectors of dimension numHashFunctions in SPARK-18450
    hashVec.values.map(h => Vectors.dense(h.floor))
  }

以上是Euclidean距离下的敏感哈希函数。可以看到,gemv方法实现了一个随机矩阵randMatrix与原始的向量elem相乘,然后除以参数bucketLength,把结果赋值给hashVec。最后对hashVec用floor方法向下取整。

  // MinHashLSHModel的hashFunction
  @Since("2.1.0")
  override protected[ml] def hashFunction(elems: Vector): Array[Vector] = {
    require(elems.nonZeroIterator.nonEmpty, "Must have at least 1 non zero entry.")
    val hashValues = randCoefficients.map { case (a, b) =>
      elems.nonZeroIterator.map { case (i, _) =>
        ((1L + i) * a + b) % MinHashLSH.HASH_PRIME
      }.min.toDouble
    }
    // TODO: Output vectors of dimension numHashFunctions in SPARK-18450
    hashValues.map(Vectors.dense(_))
  }

以上是Jaccard测度下的敏感哈希函数。Spark的定义函数g为((1+i)\times a+b)\space mod \space 某常数其中a,b是随机数,i是稀疏向量非零元素的索引。得到的结果对一个很大的常数求模,从而保证输出的数字不会很大。
原始空间的距离和哈希值的距离的计算也有很多细节,这里就不一一考究了。以下是搜索最近哈希值距离的接口:

  private[feature] def approxNearestNeighbors(
      dataset: Dataset[_],
      key: Vector,
      numNearestNeighbors: Int,
      singleProbe: Boolean,
      distCol: String): Dataset[_] = {
    require(numNearestNeighbors > 0, "The number of nearest neighbors cannot be less than 1")
    // Get Hash Value of the key
    val keyHash = hashFunction(key)
    val modelDataset = if (!dataset.columns.contains($(outputCol))) {
        transform(dataset)
      } else {
        dataset.toDF()
      }

    val modelSubset = if (singleProbe) {
      def sameBucket(x: Array[Vector], y: Array[Vector]): Boolean = {
        x.iterator.zip(y.iterator).exists(tuple => tuple._1 == tuple._2)
      }

      // In the origin dataset, find the hash value that hash the same bucket with the key
      val sameBucketWithKeyUDF = udf((x: Array[Vector]) => sameBucket(x, keyHash))

      modelDataset.filter(sameBucketWithKeyUDF(col($(outputCol))))
    } else {
      // In the origin dataset, find the hash value that is closest to the key
      // Limit the use of hashDist since it's controversial
      val hashDistUDF = udf((x: Array[Vector]) => hashDistance(x, keyHash))
      val hashDistCol = hashDistUDF(col($(outputCol)))
      val modelDatasetWithDist = modelDataset.withColumn(distCol, hashDistCol)

      val relativeError = 0.05
      val summary = modelDatasetWithDist.select(distCol).rdd.mapPartitions { iter =>
        if (iter.hasNext) {
          var s = new QuantileSummaries(
            QuantileSummaries.defaultCompressThreshold, relativeError)
          while (iter.hasNext) {
            val row = iter.next
            if (!row.isNullAt(0)) {
              val v = row.getDouble(0)
              if (!v.isNaN) s = s.insert(v)
            }
          }
          Iterator.single(s.compress)
        } else Iterator.empty
      }.treeReduce((s1, s2) => s1.merge(s2))
      val count = summary.count

      // Compute threshold to get around k elements.
      // To guarantee to have enough neighbors in one pass, we need (p - err) * N >= M
      // so we pick quantile p = M / N + err
      // M: the number of nearest neighbors; N: the number of elements in dataset
      val approxQuantile = numNearestNeighbors.toDouble / count + relativeError

      if (approxQuantile >= 1) {
        modelDatasetWithDist
      } else {
        val hashThreshold = summary.query(approxQuantile).get
        // Filter the dataset where the hash value is less than the threshold.
        modelDatasetWithDist.filter(hashDistCol <= hashThreshold)
      }
    }

    // Get the top k nearest neighbor by their distance to the key
    val keyDistUDF = udf((x: Vector) => keyDistance(x, key))
    val modelSubsetWithDistCol = modelSubset.withColumn(distCol, keyDistUDF(col($(inputCol))))
    modelSubsetWithDistCol.sort(distCol).limit(numNearestNeighbors)
  }

变量dataset是储存了海量embedding数据的对象,keyHash是需要请求进来要查询的embedding向量的hash值,modelDataset是dataset经过了哈希映射。modelSubset是生成的候选集,生成候选集有两种方法:

  • 一种是把在同一个桶里的所有元素作为候选集,不在同一个桶里的元素不管他。
  • 另一种方法是找到好几个相似的桶生成候选集。这里使用了spark.sql.catalyst.util里的QuantileSummaries类,这个类的作用是统计hashDistCol的分位数。这个哈希值距离hashDistCol是什么呢?大概可以理解成一个RDD数组,这个数组的第i个元素表示查询数据库里的第i个元素和query元素的哈希距离。对这个数组统计分位数,然后按照分位数算出来要取得topK相似元素必须有多大的哈希距离阈值(K值是调用approxNearestNeighbors方法传入的参数)。

最后是对候选集进行排序,输出topK相似item。

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 212,294评论 6 493
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 90,493评论 3 385
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 157,790评论 0 348
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 56,595评论 1 284
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 65,718评论 6 386
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 49,906评论 1 290
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,053评论 3 410
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 37,797评论 0 268
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,250评论 1 303
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 36,570评论 2 327
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 38,711评论 1 341
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 34,388评论 4 332
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,018评论 3 316
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 30,796评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,023评论 1 266
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 46,461评论 2 360
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 43,595评论 2 350

推荐阅读更多精彩内容