手搓版BPE算法

从CSDN上弄的,对于简单的测试数据OK,但对于很多字的文本文件,好像有问题。反而我让DEPPSEEK输出的代码,又规范,又能解析更多的文件。用AI学习AI,一个自闭环了。

0 BPE算法简介

BPE算法包含两部分: “词频统计”与“词表合并”。BPE算法从一组基本符号(例如字母和边界字符)开始,迭代地寻找语料库中的两个相邻词元,并将它们替换为新的词元,这一过程被称为合并。合并的选择标准是计算两个连续词元的共现频率,也就是每次迭代中,最频繁出现的一对词元会被选择与合并。合并过程将一直持续达到预定义的词表大小。

词频统计依赖于一个预分词器(pre-tokenization)将训练数据分成单词。预分词器可以非常简单,按照空格进行分词。例如GPT2,RoBERTa等就是这样实现的,更高级的预分词器引入了基于规则的分词,例如XLM,FlauBERT 使用Moses, GPT 使用spaCy和ftfy来统计语料中每个单词的词频。

在预分词之后,创建一个包含不同单词和对应词频的集合,接下来根据这个集合创建包含所有字符的词表,再根据词表切分每个单词,按照合并规则两两合并形成一个新词,如将词内相邻token出现频率最高的新词加入词表,直到达到预先设置的数量,停止合并。

假设在预分词(一般采用Word-based Tokenization)之后,得到如下的包含词频的集合:

("hug", 10), ("pug", 5), ("pun", 12), ("bun", 4),("hugs", 5)

因此,基本词表是这样的:["b", "g", "h", "n", "p", "s", "u"] ,

将所有单词按照词汇表里的字符切割得到如下形式:

("h" "u" "g", 10), ("p" "u" "g", 5), ("p" "u" "n", 12), ("b" "u" "n", 4), ("h" "u" "g" "s", 5)

接下来统计相邻的两个字符组成的字符对的出现的频率,频率最高的合并加入词表,设置BPE合并次数(超参数),达到则停止合并。

image.png

经过三轮迭代,得到新的词表:["b", "g", "h", "n", "p", "s", "u","ug","un", "hug"]

进行tokenization时,对于不在词表中的设置为特殊词<unk>。如:bug => ["b", "ug"],mug => ["<unk>", "ug"]。

例如:GPT的词汇表大小就是40478,包含基本字符478个,并经过40000次合并,添加了高频pair,然后就利用这个词表进行分词。

1,CSDN上的代码

import re
from collections import defaultdict, Counter


def extract_frequencies(sequence):
    """
    给定一个字符串,计算字符串中的单词出现的频率,并返回词表(一个词到频率的映射-字典)
    """
    token_counter = Counter()
    for item in sequence:
        if item.strip() == '':
            continue
        else:
            tokens = ' '.join(list(item)) + '</w>'
            token_counter[tokens] += 1
    print(token_counter)
    print('number of token: ', len(token_counter))
    return token_counter

def frequency_of_pairs(frequencies):
    """
    给定一个词频字典,返回一个从字符对到频率的映射字典
    """
    pairs_count = Counter()
    for token, count in frequencies.items():
        chars = token.split()
        print('token & chars: ', token, chars)
        for i in range(len(chars)-1):
            print(f'token & chars & count: {token} & {chars} & {count}')
            pair = (chars[i], chars[i+1])
            pairs_count[pair] += count
    return pairs_count

def merge_vocab(merge_pair, vocab):
    """
    给定一对相邻词元和一个词频字典,将相邻词元合并为新的词元,并返回新的词表
    """
    re_pattern = re.escape(' '.join(merge_pair))
    pattern = re.compile(r'(?<!\S)', re_pattern + r'(?!\S)')
    updated_tokens = {pattern.sub(''.join(merge_pair), token): freq for token, freq in vocab.item()}
    return updated_tokens

def encode_with_bpe(texts, iterations):
    """
    给定待分词的数据以及最大合并次数,返回合并后的词表。
    """
    vocab_map = extract_frequencies(texts)
    for i in range(iterations):
        print(i)
        pair_freqs = frequency_of_pairs(vocab_map)
        if not pair_freqs:
            break
        most_common_pair = pair_freqs.most_common(1)[0][0]
        vocab_map = merge_vocab(most_common_pair, vocab_map)
    print(len(vocab_map))
    return vocab_map

if __name__ == '__main__':
    file_path = "the-verdict.txt"
    with open(file_path, "r", encoding="utf-8") as file:
        text_data = file.read()
    text_data = "low low low low low lower newer newest widest"
    num_merges = 1000
    bpe_pairs = encode_with_bpe(text_data, num_merges)

输出DEMO

D:\Python\Python310\python.exe D:\test\minimind\BPE_basic.py 
Counter({'w</w>': 9, 'l</w>': 6, 'o</w>': 6, 'e</w>': 6, 'r</w>': 2, 'n</w>': 2, 's</w>': 2, 't</w>': 2, 'i</w>': 1, 'd</w>': 1})
number of token:  10
0
token & chars:  l</w> ['l</w>']
token & chars:  o</w> ['o</w>']
token & chars:  w</w> ['w</w>']
token & chars:  e</w> ['e</w>']
token & chars:  r</w> ['r</w>']
token & chars:  n</w> ['n</w>']
token & chars:  s</w> ['s</w>']
token & chars:  t</w> ['t</w>']
token & chars:  i</w> ['i</w>']
token & chars:  d</w> ['d</w>']
10

Process finished with exit code 0

2,DEEPSEEK的代码

import re
from collections import defaultdict

class BPE:
    def __init__(self, num_merges=100):
        # 合并次数
        self.num_merges = num_merges
        # 词汇频率统计:ml-citation{ref="1,2" data="citationList"}
        self.vocab = defaultdict(int)
        # 存储合并规则
        self.merges = []

    def _pregrocess(self, text):
        """预处理文本:分割单词并添加结束符</w>"""
        words = text.lower().split()
        print('words:', words)
        # 如new→n e w </w>:ml-citation{ref="1,2" data="citationList"}
        return [' '.join(list(word) + ['</w>']) for word in words]

    def _get_stats(self):
        """统计相邻字符对的频率"""
        pairs = defaultdict(int)
        for word, freq in self.vocab.items():
            symbols = word.split()
            for i in range(len(symbols)-1):
                # 统计相邻字符对:ml-citation{ref="2" data="citationList"}
                pairs[symbols[i], symbols[i+1]] += freq
        # print('pairs: ', pairs)
        return pairs

    def _merge_pair(self, pair, vocab):
        """合并指定字符对并更新词汇表"""
        bigram = re.escape(' '.join(pair))
        # 精确匹配字符对:ml-citation{ref="2" data="citationList"}
        pattern = re.compile(r'(?<!\S)' + bigram + r'(?!\S)')
        new_vocab = {}
        for word in vocab:
            # 替换匹配项
            new_word = pattern.sub(''.join(pair), word)
            new_vocab[new_word] = vocab[word]
        # print('new_vocab: ', new_vocab)
        return new_vocab

    def train(self, corpus):
        """训练BPE模型"""
        # 初始化词汇表
        preprocessed = self._pregrocess(corpus)
        print('preprocessed: ',preprocessed)
        for seq in preprocessed:
            self.vocab[seq] += 1
        # 迭代合并字符对
        for _ in range(self.num_merges):
            pairs = self._get_stats()
            if not pairs:
                break
            # 频率最高的对:ml-citation{ref="2" data="citationList"}
            best = max(pairs, key=lambda x: (pairs[x], x))
            self.merges.append(best)
            self.vocab = self._merge_pair(best, self.vocab)

    def tokenize(self, text):
        """应用训练结果进行分词"""
        words = self._pregrocess(text)
        tokens = []
        for word in words:
            seq = word.split()
            # 应用所有合并规则
            for a, b in self.merges:
                i = 0
                while i < len(seq)-1:
                    if seq[i] == a and seq[i+1] == b:
                        seq = seq[:i] + [a+b] + seq[i+2:]
                    else:
                        i += 1
            tokens.extend(seq)
        return [t.replace('</w>', '') for t in tokens if t != '</w>']

if __name__ == '__main__':
    test_corpus = "low low low low low lower newer newest widest"
    file_path = "the-verdict.txt"
    with open(file_path, "r", encoding="utf-8") as file:
        text_data = file.read()
    # 初始化BPE模型
    bpe = BPE(num_merges=100)
    bpe.train(text_data)
    # 输出训练结果
    print('合并规则:', bpe.merges)

    test_data = 'lowest newer'
    # 测试分词
    print('分词结果:', bpe.tokenize(test_data))

输出DEMO

D:\Python\Python310\python.exe D:\test\minimind\AIBPE_basic.py 
words: ['i', 'had', 'always', 'thought', 'jack', 'gisburn', 'rather', 'a', 'cheap', 'genius--though', 'a', 'good', 'fellow', 'enough--so', 'it', 'was', 'no', 'great', 'surprise', 'to', 'me', 'to', 'hear', 'that,', 'in', 'the', 'height', 'of', 'his', 'glory,', 'he', 'had', 'dropped', 'his', 'painting,', 'married', 'a', 'rich', 'widow,', 'and', 'established', 'himself', 'in', 'a', 'villa', 'on', 'the', 'riviera.', '(though', 'i', 'rather', 'thought', 'it', 'would', 'have', 'been', 'rome', 'or', 'florence.)', '"the', 'height', 'of', 'his', 'glory"--that', 'was', 'what', 'the', 'women', 'called',..., 'kind', 'of', 'art."']
preprocessed:  ['i </w>', 'h a d </w>', 'a l w a y s </w>', 't h o u g h t </w>', 'j a c k </w>', 'g i s b u r n </w>', 'r a t h e r </w>', 'a </w>', 'c h e a p </w>', 'g e n i u s - - t h o u g h </w>', 'a </w>', 'g o o d </w>', 'f e l l o w </w>', 'e n o u g h - - s o </w>', 'i t </w>', 'w a s </w>', 'n o </w>', 'g r e a t </w>', 's u r p r i s e </w>', 't o </w>', 'm e </w>', 't o </w>', 'h e a r </w>', 't h a t , </w>', 'i n </w>', 't h e </w>', 'h e i g h t </w>', 'o f </w>', 'h i s </w>', 'g l o r y , </w>', 'h e </w>', 'h a d </w>', 'd r o p p e d </w>', 'h i s </w>', 'p a i n t i n g , </w>', 'm a r r i e d </w>', 'a </w>', 'r i c h </w>', 'w i d o w , </w>', 'a n d </w>', 'e s t a b l i s h e d </w>', 'h i m s e l f </w>', 'i n </w>', 'a </w>', 'v i l l a </w>', 'o n </w>',......'k i n d </w>', 'o f </w>', 'a r t . " </w>']
合并规则: [('e', '</w>'), ('t', 'h'), ('d', '</w>'), ('t', '</w>'), ('s', '</w>'), ('i', 'n'), (',', '</w>'), ('o', 'u'), ('a', 'n'), ('e', 'r'), ('.', '</w>'), ('th', 'e</w>'), ('o', 'n'), ('y', '</w>'), ('e', 'n'), ('e', 'd</w>'), ('o', '</w>'), ('f', '</w>'), ('h', 'a'), ('in', 'g'), ('h', 'i'), ('i', '</w>'), ('s', 't'), ('t', 'o</w>'), ('o', 'f</w>'), ('h', 'e</w>'), ('w', 'a'), ('o', 'r'), ('-', '-'), ('e', 'a'), ('an', 'd</w>'), ('a', '</w>'), ('ing', '</w>'), ('u', 'r'), ('i', 't'), ('e', 'l'), ('a', 't</w>'), ('wa', 's</w>'), ('o', 'w'), ('e', 's'), ('hi', 's</w>'), ('er', '</w>'), ('a', 'c'), ('en', '</w>'), ('i', 't</w>'), ('a', 'l'), ('a', 't'), ('i', 's'), ('i', 'c'), ('"', '</w>'), ('b', 'e'), ('ha', 'd</w>'), ('o', 'm'), ('g', 'h'), ('in', '</w>'), ('th', 'at</w>'), ('a', 'r'), ('u', 's'), ('v', 'e</w>'), ('r', 'a'), ('w', 'i'), ('on', '</w>'), ('m', 'y</w>'), ('hi', 'm'), ('r', 'e'), ('th', '</w>'), ('r', 's'), ('r', 'o'), ('c', 'h'), ('a', 'b'), ("'", 's</w>'), ('s', 'a'), ('or', '</w>'), ('l', 'a'), ('e', ',</w>'), ('w', 'h'), ('l', 'i'), ('a', 'in'), ('r', '</w>'), ('y', 'ou'), ('r', 'ou'), ('l', '</w>'), ('wi', 'th</w>'), ('s', 'e'), ('l', 'e</w>'), ('ou', 'l'), ('n', '</w>'), ('k', '</w>'), ('i', 'd'), ('t', 'i'), ('l', 'y</w>'), ('e', 'v'), ('an', '</w>'), ("'", 't</w>'), ('u', 't</w>'), ('on', 'e</w>'), ('n', 'o'), ('u', 'p'), ('t', 'er'), ('s', 'he</w>')]
words: ['lowest', 'newer']
分词结果: ['l', 'ow', 'es', 't', 'n', 'e', 'w', 'er']

Process finished with exit code 0

©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容