放牧代码和思想
专注自然语言处理、机器学习算法
    This thing called love. Know I would've. Thrown it all away. Wouldn't hesitate.

Michael Collins NLP公开课任务2 PCFG

目录

Micheal Collins在Coursera上的自然语言处理公开课,第二次任务。自然语言中的歧义令人忍俊不禁,只要你或者你的模型脑洞足够大。

hankcs.com 2017-01-04 下午9.29.44.png

hankcs.com 2017-01-04 下午9.34.41.png

语料库来自WSJ,但并不是乔姆斯基范式:

hankcs.com 2016-12-30 下午1.06.21.png

因为乔姆斯基范式中一元rule必须是叶子节点,修正方案是折叠过长的路径:

hankcs.com 2016-12-30 下午1.12.31.png

不可能出现多于2元的rule,修正方案是将多余的分支统一移入右子树:

hankcs.com 2016-12-30 下午1.14.13.png

assignment中附带的语料已经经过了上述修正处理,无需担心。语料格式被预处理为多级嵌套的json数组:

二元——

hankcs.com 2016-12-30 下午1.15.53.png

一元——

hankcs.com 2016-12-30 下午1.16.38.png

整棵树——

["S", ["NP", ["DET", "There"]], ["S", ["VP", ["VERB", "is"], ["VP", ["NP", ["DET", "no"], ["NOUN", "asbestos"]], ["VP", ["PP", ["ADP", "in"], ["NP", ["PRON", "our"], ["NOUN", "products"]]], ["ADVP", ["ADV", "now"]]]]], [".", "."]]]

assignment提供了一个pretty_print_tree.py脚本来打印短语结构树,但不支持不规范的结构树,会抛出如下异常:

Traceback (most recent call last):
  File "pretty_print_tree.py", line 50, in <module>
    main(sys.argv[1])
  File "pretty_print_tree.py", line 38, in main
    pretty_print_tree(json.loads(l))
  File "pretty_print_tree.py", line 34, in pretty_print_tree
    print pprint.pformat(tree)
  File "python2.7/pprint.py", line 63, in pformat
    return PrettyPrinter(indent=indent, width=width, depth=depth).pformat(object)
  File "python2.7/pprint.py", line 122, in pformat
    self._format(object, sio, 0, 0, {}, 0)
  File "python2.7/pprint.py", line 140, in _format
    rep = self._repr(object, context, level - 1)
  File "python2.7/pprint.py", line 226, in _repr
    self._depth, level)
  File "python2.7/pprint.py", line 238, in format
    return _safe_repr(object, context, maxlevels, level)
  File "python2.7/pprint.py", line 314, in _safe_repr
    orepr, oreadable, orecur = _safe_repr(o, context, maxlevels, level)
  File "python2.7/pprint.py", line 314, in _safe_repr
    orepr, oreadable, orecur = _safe_repr(o, context, maxlevels, level)
  File "python2.7/pprint.py", line 323, in _safe_repr
    rep = repr(object)
TypeError: __repr__ returned non-string (type list)

需要作如下修改:

class Node:
  """
  Dummy class for python's pretty printer.
  """
  def __init__(self, name):
      if isinstance(name, list):
          name = ','.join(name)
      self.name = name
  def __repr__(self): return self.name

然后就可以打印出来了:

[S,
 [NP, DET,There],
 [S,
  [VP,
   [VERB, is],
   [VP,
    [NP, [DET, no], [NOUN, asbestos]],
    [VP,
     [PP, [ADP, in], [NP, [PRON, our], [NOUN, products]]],
     [ADVP, ADV,now]]]],
  [., .]]]

问题1

如同标注模型一样,我们需要统计训练语料中的词语概率,于是就会遇到低频词问题。请将词频小于5的低频词统一替换为一个特殊符号。

算是热身运动,有助于理解语料格式,通过assignment中附带的Counts对象统计规则和non-terminate的频次:

"""
Count rule frequencies in a binarized CFG.
"""


class Counts(object):
    def __init__(self):
        self.unary = {}
        self.binary = {}
        self.nonterm = {}

    def show(self):
        for symbol, count in self.nonterm.iteritems():
            print count, "NONTERMINAL", symbol

        for (sym, word), count in self.unary.iteritems():
            print count, "UNARYRULE", sym, word

        for (sym, y1, y2), count in self.binary.iteritems():
            print count, "BINARYRULE", sym, y1, y2

    def count(self, tree):
        """
        Count the frequencies of non-terminals and rules in the tree.
        """
        if not isinstance(tree, list): return

        # Count the non-terminal symbol.
        symbol = tree[0]
        self.nonterm.setdefault(symbol, 0)
        self.nonterm[symbol] += 1

        if len(tree) == 3:
            # It is a binary rule.
            y1, y2 = (tree[1][0], tree[2][0])
            key = (symbol, y1, y2)
            self.binary.setdefault(key, 0)
            self.binary[(symbol, y1, y2)] += 1

            # Recursively count the children.
            self.count(tree[1])
            self.count(tree[2])
        elif len(tree) == 2:
            # It is a unary rule.
            y1 = tree[1]
            key = (symbol, y1)
            self.unary.setdefault(key, 0)
            self.unary[key] += 1

这个递归函数执行先序遍历,通过root的长度判断是一元还是二元,并对二元执行递归。

可见nonterm中存放所有symbol,unary存放一元rule,binary存放二元rule。要统计词频的话,遍历unary这个dict即可:

def count_word(self):
    '''
    count emitted words and find rare words
    '''
    # count emitted word
    for (sym, word), count in self.unary.iteritems():
        self.word[word] += count
    # find rare word
    for word, count in self.word.iteritems():
        if count < RARE_WORD_THRESHOLD:
            self.rare_words.append(word)

有了词频,替换低频词就好说了,依然是一个递归调用:

def process_rare_words(input_file, output_file, rare_words, processer):
    """
    替换低频词,并输出到文件
    :param input_file:
    :param output_file:
    :param rare_words:
    :param processer:
    """
    for line in input_file:
        tree = json.loads(line)
        replace(tree, rare_words, processer)
        output = json.dumps(tree)
        output_file.write(output)
        output_file.write('\n')


def replace(tree, rare_words, processer):
    """
    替换一棵树中的低频词
    :param tree:
    :param rare_words:
    :param processer:
    :return:
    """
    if isinstance(tree, basestring):
        return

    if len(tree) == 3:
        # Recursively count the children.
        replace(tree[1], rare_words, processer)
        replace(tree[2], rare_words, processer)
    elif len(tree) == 2:
        if tree[1] in rare_words:
            tree[1] = processer(tree[1])

最终输出parser_train.counts.out文件:

1 NONTERMINAL NP+ADVP
254 NONTERMINAL VP+VERB
65 NONTERMINAL SBAR
81 NONTERMINAL ADJP
30 NONTERMINAL WHADVP

问题2

请通过最大似然估计PCFG的参数:

hankcs.com 2017-01-01 下午9.41.28.png

再简单不过了:

def cal_rule_params(self):
    """
    统计uni和bin rule的频率
    """
    # q(X->Y1Y2) = Count(X->Y1Y2) / Count(X)
    for (x, y1, y2), count in self.binary.iteritems():
        key = (x, y1, y2)
        self.q_x_y1y2[key] = float(count) / float(self.nonterm[x])
    # q(X->w) = Count(X->w) / Count(X)
    for (x, w), count in self.unary.iteritems():
        key = (x, w)
        self.q_x_w[key] = float(count) / float(self.nonterm[x])

有了模型参数,请实现CKY算法以计算hankcs.com 2017-01-01 下午9.47.49.png

先引入记号:

n为句子中的单词数

w_i表示第i个单词

N表示语法中的non-terminal个数

S表示语法中的start symbol

定义动态规划表格:

hankcs.com 2017-01-01 下午10.08.11.png表示从单词i到j的这部分句子构成以X作为根节点的最大概率。那么我们的目标就是要计算hankcs.com 2017-01-01 下午10.09.58.png

对于pi的定义,进一步说明如下:

当i=j时,有hankcs.com 2017-01-01 下午10.11.17.png

否则hankcs.com 2017-01-01 下午10.12.01.png

完整的CKY算法如下

hankcs.com 2017-01-01 下午10.13.05.png

第一层循环选定pi的长度,从1开始,第二层循环选取pi的起点i,于是终点j就确定下来了。

第三层循环遍历X的所有可能情况,也就是选定X。然后这层循环里面干了最重要的事情,继续遍历X->YZ的所有情况,选定X->YZ,遍历i到j之间的每个位置s,切一刀,算出概率取最大值赋给pi;同时记录得到最大值时的X->YZ和s。

Python实现如下:

def CKY(self, sentence):
    pi = defaultdict()
    bp = defaultdict()
    N = self.nonterm.keys()

    words = sentence.strip().split(' ')
    n = len(words)

    # process rare word,测试文件中的未登录词按照相同的规则预处理
    for i in xrange(0, n):
        if words[i] not in self.word.keys():
            words[i] = self.rare_words_rule(words[i])

    log('Sentence to process: {sent}'.format(sent=' '.join(words)))
    log('n = {n}, len(N) = {ln}'.format(n=n, ln=len(N)))

    # reduce X, Y and Z searching space,剪枝策略,过滤掉那些不合法的rule
    SET_X = defaultdict()
    for (X, Y, Z) in self.binary.keys():
        if X in SET_X:
            SET_X[X].append((Y, Z))
        else:
            SET_X[X] = []

    # init, unary rule
    for i in xrange(1, n + 1):
        w = words[i - 1]
        for X in N:
            if (X, w) in self.unary.keys():
                pi[(i, i, X)] = Decimal(self.q_x_w[(X, w)])
            else:
                pi[(i, i, X)] = Decimal(0.0)

    # dp
    for l in xrange(1, n):
        for i in xrange(1, n - l + 1):
            j = i + l

            for X, YZPairs in SET_X.iteritems():
                cur_pi, max_pi = 0.0, -1.0
                for (Y, Z) in YZPairs:
                    for s in xrange(i, j):
                        # 由于我们用SET_X做了剪枝,所以需要检查是否属于被过滤掉的非法rule
                        if (i, s, Y) not in pi or (s + 1, j, Z) not in pi:
                            continue
                        cur_pi = Decimal(self.q_x_y1y2[(X, Y, Z)]) \
                              * pi[(i, s, Y)] \
                              * pi[(s + 1, j, Z)]
                        if cur_pi > max_pi:
                            max_pi = cur_pi
                            max_Y, max_Z, max_s = Y, Z, s
                            pi[(i, j, X)] = max_pi
                            bp[(i, j, X)] = (max_Y, max_Z, max_s)

    if (1, n, ROOT) not in bp:
        max_pi = 0.0
        max_X = ''
        for X, YZPairs in SET_X.iteritems():
            if (1, n, X) in pi and pi[(1, n, X)] > max_pi:
                max_pi = pi[(1, n, X)]
                max_X = X
    else:
        max_X = ROOT
    result = self.traceback(pi, bp, sentence, 1, n, max_X)
    return result

如果不使用SET_X剪枝策略的话,代码外观可能跟伪码更加匹配,或者说更有简洁的美感,但很多unary的X会降低效率。

bp是一个回溯的数据结构,通过回溯还原一棵由嵌套数组组成的短语句法树:

def traceback(self, pi, bp, sentence, i, j, X):
    words = sentence.strip().split(' ')
    tree = []
    tree.append(X)
    if i == j:
        tree.append(words[i - 1])
    else:
        Y1, Y2, s = bp[(i, j, X)]
        # print Y1, Y2, s
        tree.append(self.traceback(pi, bp, sentence, i, s, Y1))
        tree.append(self.traceback(pi, bp, sentence, s + 1, j, Y2))
    return tree

通过如下脚本运行评估脚本:

#!/usr/bin/env bash
python eval_parser.py parse_dev.key p2.result

得到

      Type       Total   Precision      Recall     F1-Score
===============================================================
      ADJP          13     0.231        0.231        0.231
      ADVP          20     0.800        0.200        0.320
        NP        1081     0.714        0.759        0.736
        PP         326     0.775        0.794        0.785
       PRT           6     0.500        0.333        0.400
        QP           2     0.000        0.000        0.000
         S          45     0.632        0.267        0.375
      SBAR          15     0.500        0.333        0.400
     SBARQ         488     0.974        0.998        0.986
        SQ         488     0.846        0.867        0.856
        VP         305     0.701        0.407        0.515
    WHADJP          43     0.375        0.070        0.118
    WHADVP         125     0.774        0.960        0.857
      WHNP         372     0.885        0.825        0.854
      WHPP          10     0.000        0.000        0.000

     total        3339     0.798        0.770        0.783

如此简单朴素的CKY算法也能获得78%的F1值,挺不错了。

问题3

考虑如下句法树:

hankcs.com 2017-01-04 下午11.28.22.png

PCFG假设NP → DT NOUN 与父节点VP没有任何关系:

p(VP → VERB NP PP ADVP | VP) × p(NP → DET NOUN | NP)

这个假设实在太强烈了,父节点non-terminal对子节点rule的推断也应当有帮助:

p(VP → VERB NP PP ADVP | VP) × p(NP → DET NOUN | NP, parent=VP)

这种思想类似于trigram语言模型,要实现它只需改动一点点,即将子non-terminal加一个上标,上标为父节点:

hankcs.com 2017-01-04 下午11.35.19.png

虽然这增加了non-terminal的数量,使得每个rule更加稀疏,但实践证明这种做法是有效的。请实现它,其中句法树的改进版本已经提供,直接使用即可。

没什么难度,命令行参数改改跑跑,得到的结果果然好一些:

      Type       Total   Precision      Recall     F1-Score
===============================================================
      ADJP          13     0.375        0.231        0.286
      ADVP          20     0.714        0.250        0.370
        NP        1081     0.741        0.807        0.773
        PP         326     0.780        0.816        0.798
       PRT           6     0.250        0.167        0.200
        QP           2     0.000        0.000        0.000
         S          45     0.542        0.289        0.377
      SBAR          15     0.250        0.200        0.222
     SBARQ         488     0.976        0.998        0.987
        SQ         488     0.948        0.969        0.958
        VP         305     0.612        0.413        0.493
    WHADJP          43     0.931        0.628        0.750
    WHADVP         125     0.873        0.992        0.929
      WHNP         372     0.957        0.901        0.928
      WHPP          10     0.000        0.000        0.000

     total        3339     0.830        0.819        0.824

Reference

https://github.com/hankcs/Coursera_NLP_MC 

知识共享许可协议 知识共享署名-非商业性使用-相同方式共享码农场 » Michael Collins NLP公开课任务2 PCFG

评论 欢迎留言

  • 昵称 (必填)
  • 邮箱 (必填)
  • 网址

我的作品

HanLP自然语言处理包《自然语言处理入门》