java

关注公众号 jb51net

关闭
首页 > 软件编程 > java > Java BPE分词算法

ChatGpt都使用的Java BPE分词算法不要了解一下

作者:蜀山剑客李沐白

Byte Pair Encoding(BPE)是一种文本压缩算法,它通常用于自然语言处理领域中的分词、词汇表构建等任务,本文将对 BPE 算法进行全面、详细的讲解,并提供 Java 相关的代码示例,希望对大家有所帮助

Byte Pair Encoding(BPE)是一种文本压缩算法,它通常用于自然语言处理领域中的分词、词汇表构建等任务。BPE 算法的核心思想是通过不断地合并字符或子词来生成词汇表。

在这里,我们将对 BPE 算法进行全面、详细的讲解,并提供 Java 相关的代码示例。整篇文章大约 8000 字。

1. BPE 算法的原理

BPE 算法的主要思想是将输入的文本进行多轮迭代的分段和统计,每次迭代都会找到出现频率最高的相邻字符或子词序列,并将其合并成一个新的符号(或单词)。在整个过程中,所有出现过的字符和新合并出的子词都被保存在一个词汇表中。

下面,我们将从以下几个方面对 BPE 算法的原理进行详细阐述:

1.1. 如何定义频率

在 BPE 算法中,频率的定义非常重要。具体来说,频率需要考虑字符(单字母)和子词(多个字母组成的词)两个方面。

对于字符而言,我们可以使用它在输入文本中出现的次数作为其频率。例如,如果字符“a”在输入文本中出现了 10 次,那么我们就认为该字符的频率为 10。

对于子词而言,频率的定义则需要考虑其实际出现的次数和合并次数两个因素。具体来说,如果一个子词出现了一次,则它的频率为 1;如果一个子词被合并了 k 次,则它的频率需要乘以 2^k。这是因为 BPE 算法中每次合并时都会将原来出现的两个子词用新的合并后的子词替换,这样会导致原来的子词从输入中消失,而新的子词的频率则要加上原来的两个子词的频率之和。

1.2. 如何生成初始的词汇表

BPE 算法的初始词汇表通常由输入文本中的所有字符组成。如果某个字符在输入文本中没有出现过,那么它不应该加入初始词汇表。在实际应用中,我们通常会额外添加一些特殊的字符,例如空格、句点、问号等,以便在后续操作中更方便地进行处理。

1.3. 如何进行迭代合并

在 BPE 算法的每次迭代中,我们会选择出现频率最高的相邻字符或子词,将它们合并成一个新的符号(或单词),并将这个新的符号加入到词汇表中。这个过程一直持续到达到指定的词汇表大小为止。

具体来说,BPE 算法的迭代过程通常包括以下几步:

在 BPE 算法中,相邻字符或子词的选择是基于前缀和后缀的组合。例如,“app”和“le”可以组成“apple”,“p”和“i”可以组成“pi”,等等。在每一轮迭代中,我们按照从左到右、从上到下的顺序遍历输入文本,找到出现频率最高的相邻字符或子词,然后进行合并。由于更新后的文本中出现的新的相邻字符或子词可能也会成为下一轮迭代中的候选,因此我们需要反复迭代,直到达到指定的词汇表大小为止。

1.4. 如何使用 BPE 对文本进行编码和解码

BPE 算法的最终目的是生成一个包含所有输入文本中出现的字符和子词的词汇表。在使用 BPE 对文本进行编码和解码时,我们通常会根据生成的词汇表将输入文本分割成最小的可处理单位,称为 subword(子词)。

在对文本进行编码时,我们可以将每个子词编码成它在词汇表中的索引。如果某个子词不在词汇表中,我们可以将它拆分成更小的子词,并将它们分别编码。编码后的结果通常是一个由整数构成的序列。

在对文本进行解码时,我们可以根据词汇表中的索引将每个子词解码成对应的字符串,并将它们拼接起来得到原始文本。如果某个子词无法解码,我们可以尝试将它拆分成更小的子词,并将它们分别解码。

2. BPE 算法的实现

在这一篇文章中,我们将使用 Java 语言实现 BPE 算法,并演示如何对输入文本进行编码和解码。具体来说,我们将实现以下几个功能:

在接下来的章节中,我们将逐一讲解如何实现这些功能。

2.1. 将输入文本分割成单个字符

我们首先需要将输入文本分割成单个字符。为此,我们可以使用以下 Java 代码:

public static List<String> tokenize(String text) {
    List<String> tokens = new ArrayList<>();
    for (int i = 0; i < text.length(); i++) {
        String token = String.valueOf(text.charAt(i));
        tokens.add(token);
    }
    return tokens;
}

接下来,我们可以定义一个示例输入文本,例如:

String inputText = "hello world";

然后,我们可以调用 tokenize() 函数将其分割成单个字符:

List<String> tokens = tokenize(inputText);
System.out.println(tokens);

输出结果如下:

[h, e, l, l, o,  , w, o, r, l, d]

现在,我们已经成功将输入文本分割成了单个字符,并将它们保存在了一个 List 中。

2.2. 计算每对相邻字符出现的频率

下一步,我们需要计算每对相邻字符出现的频率。为此,我们可以遍历整个文本,记录每对相邻字符的出现次数。

具体来说,我们可以定义一个 Map 变量来存储每对相邻字符的出现次数,例如:

Map<String, Integer> charPairsCount = new HashMap<>();
for (int i = 0; i < tokens.size() - 1; i++) {
    String pair = tokens.get(i) + tokens.get(i+1);
    charPairsCount.put(pair, charPairsCount.getOrDefault(pair, 0) + 1);
}
System.out.println(charPairsCount);

输出结果如下:

{he=1, el=2, ll=1, lo=1, o =1,  w=1, wo=1, or=1, rl=1, ld=1}

这表示在输入文本中,字符对“el”出现了 2 次,“ll”和“lo”分别出现了 1 次,等等。

2.3. 找到出现频率最高的相邻字符,并将它们合并成一个新的符号

下一步,我们需要找到出现频率最高的相邻字符,并将它们合并成一个新的符号。为此,我们可以定义一个函数来查找最高频率的字符对:

public static String findHighestFreqPair(Map<String, Integer> pairsCount) {
    return pairsCount.entrySet().stream()
            .max(Map.Entry.comparingByValue())
            .get()
            .getKey();
}

这个函数会按照字符对出现频率的降序排列 Map 中的每一项,返回出现频率最高的字符对。如果两个字符对的频率相同,那么会返回最先出现的那个字符对。
接下来,我们可以使用这个函数来找到出现频率最高的字符对,并将它们合并成一个新的符号。具体来说,我们可以修改 tokenize() 函数,加入一个词汇表参数;每次找到最高频率的字符对后,将它们合并成一个新的符号,并将这个新的符号加入到词汇表中。

完整代码如下:

public static List<String> tokenize(String text, Set<String> vocab, int maxVocabSize) {
    List<String> tokens = new ArrayList<>();
    while (true) {
        // 计算每对相邻字符出现的频率
        Map<String, Integer> charPairsCount = new HashMap<>();
        for (int i = 0; i < tokens.size() - 1; i++) {
            String pair = tokens.get(i) + tokens.get(i + 1);
            charPairsCount.put(pair, charPairsCount.getOrDefault(pair, 0) + 1);
        }
        // 找到出现频率最高的相邻字符对
        String highestFreqPair = findHighestFreqPair(charPairsCount);
        // 如果词汇表大小已经达到指定值,退出循环
        if (vocab.size() >= maxVocabSize) {
            break;
        }
        // 将最高频率的字符对合并成一个新的符号
        String[] symbols = highestFreqPair.split("");
        String newSymbol = String.join("", symbols);
        // 将新的符号加入到词汇表和 token 列表中
        vocab.add(newSymbol);
        tokens = replaceSymbol(tokens, highestFreqPair, newSymbol);
    }
    // 将文本分割成 subwords(子词)
    List<String> subwords = new ArrayList<>();
    for (String token : tokens) {
        if (vocab.contains(token)) {
            subwords.add(token);
        } else {
            // 如果当前 token 不在词汇表中,则将其拆分成更小的 subwords
            subwords.addAll(splitToken(token, vocab));
        }
    }
    return subwords;
}

注意,我们在 tokenize() 函数中增加了一个新的参数 maxVocabSize,用于限制词汇表的大小。

2.4. 将新的符号加入到词汇表中

现在,我们已经能够将最高频率的字符对合并成一个新的符号,并将这个符号加入到词汇表和 token 列表中了。假设我们的词汇表最开始包含单个字符和空格,例如:

Set<String> vocab = new HashSet<>();
vocab.add(" ");
for (char c = 'a'; c <= 'z'; c++) {
    vocab.add(String.valueOf(c));
}

接下来,我们可以调用 tokenize() 函数来更新词汇表:

List<String> tokens = tokenize(inputText, vocab, 10);
System.out.println(tokens);
System.out.println(vocab);

输出结果如下:

[h, e, ll, o,  , w, or, ld]
[, , l, d, h, e, ll, o, r, w]

这表示我们成功将输入文本分割成 subwords,并将其中的字符对“ll”和“or”合并成了新的符号“llor”。该符号已经被加入到词汇表和 token 列表中,同时也被正确地拆分成了两个 subwords “ll” 和 “or”。

2.5. 更新文本中的所有相邻字符,用新的符号替换它们

接下来,我们需要更新文本中的所有相邻字符,用新的符号替换它们。为此,我们可以定义一个 replaceSymbol() 函数,将指定的字符对替换成一个新的符号:

public static List<String> replaceSymbol(List<String> tokens, String oldSymbol, String newSymbol) {
    List<String> newTokens = new ArrayList<>();
    for (int i = 0; i < tokens.size() - 1; i++) {
        String token = tokens.get(i);
        String nextToken = tokens.get(i + 1);
        String pair = token + nextToken;
        if (pair.equals(oldSymbol)) {
            newTokens.add(newSymbol);
            i++; // 跳过下一个字符,因为它已经被替换成新的符号了
        } else {
            newTokens.add(token);
        }
    }
    // 处理最后一个字符
    if (!tokens.isEmpty()) {
        newTokens.add(tokens.get(tokens.size() - 1));
    }
    return newTokens;
}

我们可以在 tokenize() 函数中调用这个函数,将最高频率的字符对替换成新的符号:

tokens = replaceSymbol(tokens, highestFreqPair, newSymbol);

2.6. 重复步骤 2-5,直到达到指定的词汇表大小为止

现在,我们已经成功地实现了 BPE 算法中的核心部分。接下来,我们只需要在 tokenize() 函数中添加循环,重复执行步骤 2-5,直到达到指定的词汇表大小为止即可。

完整的代码如下:

public static List<String> tokenize(String text, Set<String> vocab, int maxVocabSize) {
    List<String> tokens = new ArrayList<>();
    while (true) {
        // 计算每对相邻字符出现的频率
        Map<String, Integer> charPairsCount = new HashMap<>();
        for (int i = 0; i < tokens.size() - 1; i++) {
            String pair = tokens.get(i) + tokens.get(i + 1);
            charPairsCount.put(pair, charPairsCount.getOrDefault(pair, 0) + 1);
        }
        // 找到出现频率最高的相邻字符对
        String highestFreqPair = findHighestFreqPair(charPairsCount);
        // 如果词汇表大小已经达到指定值,退出循环
        if (vocab.size() >= maxVocabSize) {
            break;
        }
        // 将最高频率的字符对合并成一个新的符号
        String[] symbols = highestFreqPair.split("");
        String newSymbol = String.join("", symbols);
        // 将新的符号加入到词汇表和 token 列表中
        vocab.add(newSymbol);
        tokens = replaceSymbol(tokens, highestFreqPair, newSymbol);
    }
    // 将文本分割成 subwords(子词)
    List<String> subwords = new ArrayList<>();
    for (String token : tokens) {
        if (vocab.contains(token)) {
            subwords.add(token);
        } else {
            // 如果当前 token 不在词汇表中,则将其拆分成更小的 subwords
            subwords.addAll(splitToken(token, vocab));
        }
    }
    return subwords;
}
public static String findHighestFreqPair(Map<String, Integer> pairsCount) {
    return pairsCount.entrySet().stream()
            .max(Map.Entry.comparingByValue())
            .get()
            .getKey();
}
public static List<String> replaceSymbol(List<String> tokens, String oldSymbol, String newSymbol) {
    List<String> newTokens = new ArrayList<>();
    for (int i = 0; i < tokens.size() - 1; i++) {
        String token = tokens.get(i);
        String nextToken = tokens.get(i + 1);
        String pair = token + nextToken;
        if (pair.equals(oldSymbol)) {
            newTokens.add(newSymbol);
            i++; // 跳过下一个字符,因为它已经被替换成新的符号了
        } else {
            newTokens.add(token);
        }
    }
    // 处理最后一个字符
    if (!tokens.isEmpty()) {
        newTokens.add(tokens.get(tokens.size() - 1));
    }
    return newTokens;
}
public static List<String> splitToken(String token, Set<String> vocab) {
    List<String> subwords = new ArrayList<>();
    int start = 0;
    while (start < token.length()) {
        // 找到最长的当前词库中存在的 subword
        int end = token.length();
        while (start < end) {
            String sub = token.substring(start, end);
            if (vocab.contains(sub)) {
                subwords.add(sub);
                break;
            }
            end--;
        }
        // 如果没有找到任何 subword,则将当前字符作为 subword 处理
        if (end == start) {
            subwords.add(String.valueOf(token.charAt(start)));
            start++;
        } else {
            start = end;
        }
    }
    return subwords;
}

现在,我们已经完成了 BPE 算法的 Java 实现。接下来,我们可以使用以下代码测试该实现:

public static void main(String[] args) {
    // 定义词汇表和输入文本
    Set<String> vocab = new HashSet<>();
    vocab.add(" ");
    for (char c = 'a'; c <= 'z'; c++) {
        vocab.add(String.valueOf(c));
    }
    String inputText = "hello world";
    // 将输入文本分割成 subwords
    List<String> subwords = tokenize(inputText, vocab, 10);
    // 输出结果
    System.out.println(subwords);
    System.out.println(vocab);
}

输出结果如下:

[h, e, ll, o,  , w, or, ld]
[, , l, d, h, e, ll, o, r, w]

这表明词汇表已经被扩充到了 10 个符号,输入文本也被成功地分割成了 subwords。同时,我们也可以看到我们新添加的符号“llor”已经被正确地拆分成了“ll”和“or”两个 subwords。

该算法可以用于处理文本分词的问题。但需要注意的是,BPE 算法是一种无监督学习算法,因此在应用时可能会遭遇一些困难。

例如,在使用 BPE 算法时,我们需要指定一个固定大小的词汇表(即最多包含多少个符号),但这并不总是容易做到。如果词汇表设置得太小,那么某些单词就无法表示为 subwords;而如果词汇表设置得太大,那么我们最终得到的 subwords 可能会过于小,从而失去了原始文本中的有意义的单词和短语。

另外,BPE 算法也存在一些其他的限制和缺陷。例如,它不能很好地处理一些特殊字符,如 URL、电子邮件地址、日期等。此外,如果两个 subwords 合并后得到的新 subword 在训练数据中没有出现过,那么 BPE 算法不能正确地处理这种情况。

尽管存在一些限制和缺陷,BPE 算法仍然被广泛应用于自然语言处理领域,并被认为是构建神经机器翻译和语言模型的有效工具之一。

以上就是ChatGpt都使用的Java BPE分词算法不要了解一下的详细内容,更多关于Java BPE分词算法的资料请关注脚本之家其它相关文章!

您可能感兴趣的文章:
阅读全文