专题6: 代码实战──基于语言模型的拼写纠错

一、拼写纠错任务概述
在实现QA系统或者检索系统时,需要用户给出输入,用户在输入问题的时候,不能期待他一定会输入正确,有可能输入的单词的拼写是错误的。在一个完备的系统中,需要后台能够及时捕获拼写错误,并进行纠正,然后再通过修正之后的结果再跟库里的问题进行匹配。这里来实现一个简单的拼写纠错模块,自动去修复错误的单词。

纠错模块是基于Noisy Channel Model噪音通道模型:



其中,candidates指的是针对于错误的单词的候选集,可以通过edit_ distance来获取; valid单词可以定义为存在词典里的单词; c代表的是正确的单词,s代表的是用户错误拼写的单词。

关于Noisy Channel Model的更多说明可参考https://blog.csdn.net/kunpen8944/article/details/83066460和https://www.cnblogs.com/loubin/p/13704777.html。

整个实现思路如下:

通过编辑距离(1到3)获取到candidates集合;

通过词表来过滤candidates,筛选出有效的单词,获取到真正的候选集;

目标:在candidates中找到使得上述条件概率最大的正确写法c

其中,

p(s|c):通过历史数据来获得

也就是对于一个正确的单词c,有多大的概率把它写成了某种错误的形式,一般通过历史数据来获取概率。 拼写错误的历史数据spell_errors.txt没有标记概率,所以可以使用uniform probability来表示,这个也叫channel probability。

p( c ):使用语言模型进行计算

也就是假如把错误的s改造成了c,把它加入到当前的语句之后有多通顺。 假如有两个候选c1、c2,希望分别计算出这个语言模型的概率,使用bigram计算当前词前面和后面词的bigram概率。 给定We are go to school tomorrow,对于这句话我们希望把中间的go替换成正确的形式,假如候选集里{going, went},计算概率p(going|are)p(to|going)和p(went|are)p(to|went) ,记为p( c )的概率,并选择概率最大的候选词作为要替换的正确词。

要计算bigram概率,就需要训练一个语言模型。训练时需要用到文本数据,选择使用nItk自带的reuters的文本类数据来训练一个语言模型。也可以尝试其他的数据。

二、拼写纠错实现
(1)读取词典库,构建词典

# 读取词典库
vocab = set([line.rstrip() for line in open('data/vocab.txt')])
len(vocab)
输出:

48227
(2)实现函数构建候选集

# 1.构建候选集

import string

# 获取一个错误单词的正确单词的候选集
def generate_candicates(word):
    # 为了简化,只生成与给定单词编辑距离为1的单词
    letters = string.ascii_lowercase
    splits = [(word[:i], word[i:]) for i in range(len(word)+1)] # 获取单词所有的拆分组合
    # 插入一个字母
    inserts = [L + c + R for L, R in splits for c in letters]
    # 删除一个字母
    deletes = [L + R[1:] for L, R in splits]
    # 替换一个字母
    replaces = [L + c + R[1:] for L, R in splits for c in letters]
    # 求并集
    candidates = set(inserts + deletes + replaces)
    # 使用词表过滤掉不合法的单词
    candidates = list(filter(lambda word: word in vocab, candidates))
    return candidates

generate_candicates('word')
输出:

['wore',
 'cord',
 'word',
 'ward',
 'sword',
 'wod',
 'worm',
 'wordy',
 'words',
 'work',
 'wood',
 'world',
 'worn',
 'wohd',
 'lord',
 'wold']
(3)基于Bigram计算p(c)

# 2.计算p(c),使用Bigram

import nltk
from nltk.corpus import reuters # 导入语料库

# 下载语料库
nltk.download('reuters')
nltk.download('punkt')

# 读取语料库
categories = reuters.categories()
corpus = reuters.sents(categories=categories)
print(len(corpus), len(corpus[0]))
输出:

54711 49
自己实现Bigram:

# 2.1 自己实现Bigram

from tqdm import tqdm

term_count = {}
bigram_count = {}

for doc in tqdm(corpus):
    doc = ['<s>'] + doc
    for i in range(len(doc) - 1):
        term = doc[i]
        bigram = ''.join(doc[i:i+2]) # bigram: [i, i+1]
        if term in term_count:
            term_count[term] += 1
        else:
            term_count[term] = 1
        if bigram in bigram_count:
            bigram_count[bigram] += 1
        else:
            bigram_count[bigram] = 1

bigram_count
输出:

{'<s>ASIAN': 4,
 'ASIANEXPORTERS': 1,
 'EXPORTERSFEAR': 1,
 'FEARDAMAGE': 1,
 'DAMAGEFROM': 2,
 'FROMU': 4,
 ...
 'estimated27': 1,
 '27pct': 33,
 'pcton': 121,
 'ona': 521,
 'akilowatt': 1,
 ...}
使用nltk库自带的模块实现Bigram:

# 2.2 使用nltk提供的模块实现bigram

from nltk.util import ngrams
from nltk.lm import NgramCounter

unigrams = [list(ngrams(['<s>'] + doc, 1)) for doc in tqdm(corpus)]
bigrams = [list(ngrams(['<s>'] + doc, 2)) for doc in tqdm(corpus)]
ngram_counts = NgramCounter(unigrams + bigrams)
ngram_counts[['we']]['are']
输出:

115
实现计算两个给定单词的Bigram得分:

# 计算给定两个单词的bigram得分

import numpy as np

vocab_count = len(vocab)

def get_bigram(word1, word2):
    join_count = ngram_counts[[word1]][word2]
    word1_count = ngram_counts[word1]
    bigram_probability = (join_count + 1) / (word1_count + vocab_count)
    return np.log(bigram_probability)

get_bigram('we', 'are')
输出:






-6.04867564155262
nltk库关于Ngram的更多用法可参考https://www.nltk.org/api/nltk.util.html#nltk.util.ngrams和https://www.nltk.org/api/nltk.lm.html#nltk.lm.NgramCounter。

(4)基于历史数据计算p(s|c):

# 3.计算p(s|c)
# 用户打错的概率统计,即 计算channel probability
channel_prob = {} # channel_prob[c][s]表示正确的单词c被错写成s的概率

for line in tqdm(open('data/spell-errors.txt')):
    items = line.split(':')
    correct = items[0].strip()
    mistakes = [item.strip() for item in items[1].strip().split(',')]
    channel_prob[correct] = {}
    prob = 1 / len(mistakes)
    for mistake in mistakes:
        channel_prob[correct][mistake] = prob

channel_prob['apple']
输出:

{'alipple': 0.047619047619047616,
 'apoll': 0.047619047619047616,
 'alpper': 0.047619047619047616,
 'appy': 0.047619047619047616,
 'alpple': 0.047619047619047616,
 'ait': 0.047619047619047616,
 'appel': 0.047619047619047616,
 'appre': 0.047619047619047616,
 'abuol': 0.047619047619047616,
 'apelle': 0.047619047619047616,
 'appple': 0.047619047619047616,
 'alploo': 0.047619047619047616,
 'alppe': 0.047619047619047616,
 'apl': 0.047619047619047616,
 'apll': 0.047619047619047616,
 'apply': 0.047619047619047616,
 'alppel': 0.047619047619047616,
 'aplep': 0.047619047619047616,
 'apoler': 0.047619047619047616,
 'appe': 0.047619047619047616,
 'aple': 0.047619047619047616}
(5)实现拼写纠错主函数:

# 4.拼写纠错主函数

def spell_correct(query):
    words = query.split()
    word_count = len(words)
    correct_query = []
    for word in words:
        if word not in vocab: # 不存在于词库,则认为拼写错误,需要替换为正确的单词
            # 4.1 生成当前单词的有效候选集
            candidates = generate_candicates(word)
            # 4.2 对于每一个candidate,计算它的score
            probs = []
            for candidate in candidates:
                prob = 0
                # a.计算channel probability
                if candidate in channel_prob and word in channel_prob[candidate]:
                    prob += np.log(channel_prob[candidate][word])
                else:
                    prob += np.log(0.0001)
                # b.计算语言模型概率
                idx = words.index(word)
                if idx > 0:
                    pre_score = get_bigram(words[idx-1], candidate) # [pre_word, candidate]
                    prob += pre_score
                if idx < word_count - 1:
                    after_score = get_bigram(candidate, words[idx+1]) # [candidate, post_word]
                    prob += after_score
                probs.append(prob)
            if len(probs) > 0:  # 有合适的候选词
                max_idx = probs.index(max(probs))
                best_candidate = candidates[max_idx]
                correct_query.append(best_candidate)
            else: # 没有合适的词,认为是正确的,直接加入
                correct_query.append(word)
        else: # 词存在于词表,则认为是正确的词
            correct_query.append(word)
    return ' '.join(correct_query)

spell_correct('toke this away')
输出:

'take this away'
拼写纠错的思路是对于query(待检测的句子)进行分词,然后把分词后的每一个单词在词库里面进行搜索, 如果搜不到,则可以认为是拼写错误的。如果拼写错误,再通过channel和bigram来计算最适合的候选。

对当前单词来说,需要计算其对应的每一个候选词的得分,计算方式如下:

score = p(correct) * p(mistake/correct)

\= log p(correct) + log p(mistake/correct)

计算好之后,再将最大的candidate作为当前错误单词对应的正确单词。

(6)基于测试数据进行测试

# 进行测试
for line in open('data/testdata.txt'):
    items = line.rstrip().split('\t')
    print(spell_correct(items[2]))
输出:

They told Reuter correspondents in Asian capitals a U.S. Move against Japan might boost protectionist sentiment in the U.S. And lead to curbs on American imports of their products
But some exporters said that while the conflict would hurt them in the long-run in the short-term Tkyo's loss might be their gain
The U.S. Has said it will impose 300 mln dlrs of tariffs on imports of Japanese electronics goods on
...
FSIS inspects an estimates 127 mln head of cattle and 4.5 billion chicnek and turkeys every year
Houston said inspection programs have kept pace with change but conceded that the danger of chemical residues in the meat and uopltry supply has increased
He also said that although he was confident the bacterium salmonella eventually could be eradicated it would take time and much money to contain the growing problem
三、语法纠错的应用
文中用到了语言模型,如下:

Unigram



Bigram



N-gram

当N=3时,同理。但是有时$\left(w_{1}, w_{2}, w_{3}, w_{4}, w_{5}\right) \text { 到 } w_{5}$时,可能语料库中没有出现过这个短语。

可以看到,Bigram可以看作是Unigram和N-gram之间的权衡:

Ungram只看一个单词,可能不够准确,因此将当前单词前后的单词结合考虑会更准确;

当N很大时,会存在语料库中可能不包含当前的长度为N的短语,因此N应该取小一点;

所以一般情况下,N取2或3比较理想,即Bigram和Trigram。

错误检测用到的公式为:



这个公式和错误检测的应用场景比较广泛,包括语音识别、机器翻译、拼写纠错、OCR、密码破解等,基本原理都是信号到文本的转换。如下:

机器翻译

公式:$P(\text {中文} \mid \text {英文}) \propto P(\text {英文} \mid \text {中文}) * P(\text {中文})$

其中,P(英文|中文)定义为翻译模型,P(中文)定义为语言模型

拼写纠错

公式:$P(\text {正确的写法}|\text{错误的写法}) \propto P(\text {错误的写法|正确的写法) * } P(\text {正确的写法) }$

其中,P(错误的写法|正确的写法)定义为编辑距离,P(正确的写法)定义为语言模型。

语音识别

公式:$P(\text {文本} \mid \text {语音信号}) \propto P(\text {语音信号|文本}) * P(\text {文本})$

其中,P(语音信号|文本)定义为翻译模型或Recognition Model,P(文本)定义为语言模型。

密码破解

公式:$P(\text {明文} \mid \text {密文}) \propto P(\text {密文} \mid \text {明文}) * P(\text {明文})$

其中,P(密文|明文)定义为翻译模型,P(明文)定义为语言模型。






扫码进群:
Python极客部落群聊二维码