Skip to content

Commit 0c56a5c

Browse files
committed
Added summary-level rougeL scorer
1 parent a1e5096 commit 0c56a5c

File tree

5 files changed

+151
-41
lines changed

5 files changed

+151
-41
lines changed

compare_mt/formatting.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,15 @@ def escape_latex(self, x):
2222
x = pat.sub(replace_with, x)
2323
return x
2424

25-
def __call__(self, x):
25+
def __call__(self, x, latex=True):
2626
"""Convert object to string with controlled decimals"""
2727
if isinstance(x, str):
28-
return self.escape_latex(x)
28+
return self.escape_latex(x) if latex else x
2929
elif isinstance(x, int):
3030
return f"{x:d}"
3131
elif isinstance(x, float):
3232
return f"{x:.{self.decimals}f}"
3333
else:
3434
str(x)
3535

36-
fmt = Formatter(decimals=4)
36+
fmt = Formatter(decimals=4)

compare_mt/reporters.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def print_header(self, header):
132132

133133
def print_tabbed_table(self, tab):
134134
for x in tab:
135-
print('\t'.join([fmt(y) if y else '' for y in x]))
135+
print('\t'.join([fmt(y, latex=False) if y else '' for y in x]))
136136
print()
137137

138138
def generate_report(self, output_fig_file=None, output_fig_format=None, output_directory=None):

compare_mt/rouge/rouge_scorer.py

Lines changed: 123 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# coding=utf-8
2-
# Copyright 2018 The Google Research Authors.
2+
# Copyright 2019 The Google Research Authors.
33
#
44
# Licensed under the Apache License, Version 2.0 (the "License");
55
# you may not use this file except in compliance with the License.
@@ -13,20 +13,16 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
# Lint as: python2, python3
1617
"""Computes rouge scores between two text blobs.
17-
1818
Implementation replicates the functionality in the original ROUGE package. See:
19-
2019
Lin, Chin-Yew. ROUGE: a Package for Automatic Evaluation of Summaries. In
2120
Proceedings of the Workshop on Text Summarization Branches Out (WAS 2004),
2221
Barcelona, Spain, July 25 - 26, 2004.
23-
2422
Default options are equivalent to running:
2523
ROUGE-1.5.5.pl -e data -n 2 -a settings.xml
26-
2724
Or with use_stemmer=True:
2825
ROUGE-1.5.5.pl -m -e data -n 2 -a settings.xml
29-
3026
In these examples settings.xml lists input files and formats.
3127
"""
3228

@@ -38,16 +34,15 @@
3834
import re
3935

4036
from nltk.stem import porter
41-
import numpy as np
4237
import six
43-
from six.moves import xrange # pylint: disable=redefined-builtin
38+
from six.moves import map
39+
from six.moves import range
4440
from compare_mt.rouge import scoring
45-
from compare_mt.rouge import tokenizer
41+
from compare_mt.rouge import tokenize
4642

4743

4844
class RougeScorer(scoring.BaseScorer):
4945
"""Calculate rouges scores between two blobs of text.
50-
5146
Sample usage:
5247
scorer = RougeScorer(['rouge1', 'rougeL'], use_stemmer=True)
5348
scores = scorer.score('The quick brown fox jumps over the lazy dog',
@@ -56,11 +51,9 @@ class RougeScorer(scoring.BaseScorer):
5651

5752
def __init__(self, rouge_types, use_stemmer=False):
5853
"""Initializes a new RougeScorer.
59-
6054
Valid rouge types that can be computed are:
6155
rougen (e.g. rouge1, rouge2): n-gram based scoring.
6256
rougeL: Longest common subsequence based scoring.
63-
6457
Args:
6558
rouge_types: A list of rouge types to calculate.
6659
use_stemmer: Bool indicating whether Porter stemmer should be used to
@@ -74,7 +67,6 @@ def __init__(self, rouge_types, use_stemmer=False):
7467

7568
def score(self, target, prediction):
7669
"""Calculates rouge scores between the target and prediction.
77-
7870
Args:
7971
target: Text containing the target (ground truth) text.
8072
prediction: Text containing the predicted text.
@@ -84,15 +76,29 @@ def score(self, target, prediction):
8476
ValueError: If an invalid rouge type is encountered.
8577
"""
8678

87-
target_tokens = tokenizer.tokenize(target, self._stemmer)
88-
prediction_tokens = tokenizer.tokenize(prediction, self._stemmer)
79+
target_tokens = tokenize.tokenize(target, self._stemmer)
80+
prediction_tokens = tokenize.tokenize(prediction, self._stemmer)
8981
result = {}
9082

9183
for rouge_type in self.rouge_types:
9284
if rouge_type == "rougeL":
9385
# Rouge from longest common subsequences.
9486
scores = _score_lcs(target_tokens, prediction_tokens)
95-
elif re.match(r"rouge[0-9]$", rouge_type):
87+
elif rouge_type == "rougeLsum":
88+
# Note: Does not support multi-line text.
89+
def get_sents(text):
90+
# Assume sentences are separated by newline.
91+
sents = six.ensure_str(text).split("\n")
92+
sents = [x for x in sents if len(x)]
93+
return sents
94+
95+
target_tokens_list = [
96+
tokenize.tokenize(s, self._stemmer) for s in get_sents(target)]
97+
prediction_tokens_list = [
98+
tokenize.tokenize(s, self._stemmer) for s in get_sents(prediction)]
99+
scores = _summary_level_lcs(target_tokens_list,
100+
prediction_tokens_list)
101+
elif re.match(r"rouge[0-9]$", six.ensure_str(rouge_type)):
96102
# Rouge from n-grams.
97103
n = int(rouge_type[5:])
98104
if n <= 0:
@@ -109,7 +115,6 @@ def score(self, target, prediction):
109115

110116
def _create_ngrams(tokens, n):
111117
"""Creates ngrams from the given list of tokens.
112-
113118
Args:
114119
tokens: A list of tokens from which ngrams are created.
115120
n: Number of tokens to use, e.g. 2 for bigrams.
@@ -118,14 +123,13 @@ def _create_ngrams(tokens, n):
118123
"""
119124

120125
ngrams = collections.Counter()
121-
for ngram in (tuple(tokens[i:i + n]) for i in xrange(len(tokens) - n + 1)):
126+
for ngram in (tuple(tokens[i:i + n]) for i in range(len(tokens) - n + 1)):
122127
ngrams[ngram] += 1
123128
return ngrams
124129

125130

126131
def _score_lcs(target_tokens, prediction_tokens):
127132
"""Computes LCS (Longest Common Subsequence) rouge scores.
128-
129133
Args:
130134
target_tokens: Tokens from the target text.
131135
prediction_tokens: Tokens from the predicted text.
@@ -137,16 +141,8 @@ def _score_lcs(target_tokens, prediction_tokens):
137141
return scoring.Score(precision=0, recall=0, fmeasure=0)
138142

139143
# Compute length of LCS from the bottom up in a table (DP appproach).
140-
cols = len(prediction_tokens) + 1
141-
rows = len(target_tokens) + 1
142-
lcs_table = np.zeros((rows, cols))
143-
for i in xrange(1, rows):
144-
for j in xrange(1, cols):
145-
if target_tokens[i - 1] == prediction_tokens[j - 1]:
146-
lcs_table[i, j] = lcs_table[i - 1, j - 1] + 1
147-
else:
148-
lcs_table[i, j] = max(lcs_table[i - 1, j], lcs_table[i, j - 1])
149-
lcs_length = lcs_table[-1, -1]
144+
lcs_table = _lcs_table(target_tokens, prediction_tokens)
145+
lcs_length = lcs_table[-1][-1]
150146

151147
precision = lcs_length / len(prediction_tokens)
152148
recall = lcs_length / len(target_tokens)
@@ -155,9 +151,106 @@ def _score_lcs(target_tokens, prediction_tokens):
155151
return scoring.Score(precision=precision, recall=recall, fmeasure=fmeasure)
156152

157153

154+
def _lcs_table(ref, can):
155+
"""Create 2-d LCS score table."""
156+
rows = len(ref)
157+
cols = len(can)
158+
lcs_table = [[0] * (cols + 1) for _ in range(rows + 1)]
159+
for i in range(1, rows + 1):
160+
for j in range(1, cols + 1):
161+
if ref[i - 1] == can[j - 1]:
162+
lcs_table[i][j] = lcs_table[i - 1][j - 1] + 1
163+
else:
164+
lcs_table[i][j] = max(lcs_table[i - 1][j], lcs_table[i][j - 1])
165+
return lcs_table
166+
167+
168+
def _backtrack_norec(t, ref, can):
169+
"""Read out LCS."""
170+
i = len(ref)
171+
j = len(can)
172+
lcs = []
173+
while i > 0 and j > 0:
174+
if ref[i - 1] == can[j - 1]:
175+
lcs.insert(0, i-1)
176+
i -= 1
177+
j -= 1
178+
elif t[i][j - 1] > t[i - 1][j]:
179+
j -= 1
180+
else:
181+
i -= 1
182+
return lcs
183+
184+
185+
def _summary_level_lcs(ref_sent, can_sent):
186+
"""ROUGE: Summary-level LCS, section 3.2 in ROUGE paper.
187+
Args:
188+
ref_sent: list of tokenized reference sentences
189+
can_sent: list of tokenized candidate sentences
190+
Returns:
191+
summary level ROUGE score
192+
"""
193+
if not ref_sent or not can_sent:
194+
return scoring.Score(precision=0, recall=0, fmeasure=0)
195+
196+
m = sum(map(len, ref_sent))
197+
n = sum(map(len, can_sent))
198+
if not n or not m:
199+
return scoring.Score(precision=0, recall=0, fmeasure=0)
200+
201+
# get token counts to prevent double counting
202+
token_cnts_r = collections.Counter()
203+
token_cnts_c = collections.Counter()
204+
for s in ref_sent:
205+
# s is a list of tokens
206+
token_cnts_r.update(s)
207+
for s in can_sent:
208+
token_cnts_c.update(s)
209+
210+
hits = 0
211+
for r in ref_sent:
212+
lcs = _union_lcs(r, can_sent)
213+
# Prevent double-counting:
214+
# The paper describes just computing hits += len(_union_lcs()),
215+
# but the implementation prevents double counting. We also
216+
# implement this as in version 1.5.5.
217+
for t in lcs:
218+
if token_cnts_c[t] > 0 and token_cnts_r[t] > 0:
219+
hits += 1
220+
token_cnts_c[t] -= 1
221+
token_cnts_r[t] -= 1
222+
223+
recall = hits / m
224+
precision = hits / n
225+
fmeasure = scoring.fmeasure(precision, recall)
226+
return scoring.Score(precision=precision, recall=recall, fmeasure=fmeasure)
227+
228+
229+
def _union_lcs(ref, c_list):
230+
"""Find union LCS between a ref sentence and list of candidate sentences.
231+
Args:
232+
ref: list of tokens
233+
c_list: list of list of indices for LCS into reference summary
234+
Returns:
235+
List of tokens in ref representing union LCS.
236+
"""
237+
lcs_list = [lcs_ind(ref, c) for c in c_list]
238+
return [ref[i] for i in _find_union(lcs_list)]
239+
240+
241+
def _find_union(lcs_list):
242+
"""Finds union LCS given a list of LCS."""
243+
return sorted(list(set().union(*lcs_list)))
244+
245+
246+
def lcs_ind(ref, can):
247+
"""Returns one of the longest lcs."""
248+
t = _lcs_table(ref, can)
249+
return _backtrack_norec(t, ref, can)
250+
251+
158252
def _score_ngrams(target_ngrams, prediction_ngrams):
159253
"""Compute n-gram based rouge scores.
160-
161254
Args:
162255
target_ngrams: A Counter object mapping each ngram to number of
163256
occurrences for the target text.
Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# coding=utf-8
2-
# Copyright 2018 The Google Research Authors.
2+
# Copyright 2019 The Google Research Authors.
33
#
44
# Licensed under the Apache License, Version 2.0 (the "License");
55
# you may not use this file except in compliance with the License.
@@ -13,40 +13,39 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
# Lint as: python2, python3
1617
"""A library for tokenizing text."""
1718

1819
from __future__ import absolute_import
1920
from __future__ import division
2021
from __future__ import print_function
2122

2223
import re
24+
import six
2325

2426

2527
def tokenize(text, stemmer):
2628
"""Tokenize input text into a list of tokens.
27-
2829
This approach aims to replicate the approach taken by Chin-Yew Lin in
2930
the original ROUGE implementation.
30-
3131
Args:
3232
text: A text blob to tokenize.
3333
stemmer: An optional stemmer.
34-
3534
Returns:
3635
A list of string tokens extracted from input text.
3736
"""
3837

3938
# Convert everything to lowercase.
4039
text = text.lower()
4140
# Replace any non-alpha-numeric characters with spaces.
42-
text = re.sub(r"[^a-z0-9]+", " ", text)
41+
text = re.sub(r"[^a-z0-9]+", " ", six.ensure_str(text))
4342

4443
tokens = re.split(r"\s+", text)
4544
if stemmer:
4645
# Only stem words more than 3 characters long.
4746
tokens = [stemmer.stem(x) if len(x) > 3 else x for x in tokens]
4847

4948
# One final check to drop any empty or invalid tokens.
50-
tokens = [x for x in tokens if re.match(r"^[a-z0-9]+$", x)]
49+
tokens = [x for x in tokens if re.match(r"^[a-z0-9]+$", six.ensure_str(x))]
5150

5251
return tokens

compare_mt/scorers.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -544,8 +544,14 @@ def score_sentence(self, ref, out):
544544
out = [self._stemmer.stem(x) if len(x) > 3 else x for x in out]
545545

546546
if self.rouge_type == 'rougeL':
547+
ref, out = self.tokenize(" ".join(ref)), self.tokenize(" ".join(out))
547548
scores = rouge_scorer._score_lcs(ref, out)
549+
elif self.rouge_type == 'rougeLsum':
550+
refs = [self.tokenize(s) for s in self.get_sents(ref)]
551+
outs = [self.tokenize(s) for s in self.get_sents(out)]
552+
scores = rouge_scorer._summary_level_lcs(refs, outs)
548553
elif re.match(r"rouge[0-9]$", self.rouge_type):
554+
ref, out = self.tokenize(" ".join(ref)), self.tokenize(" ".join(out))
549555
n = int(self.rouge_type[5:])
550556
if n <= 0:
551557
raise ValueError(f"rougen requires positive n: {self.rouge_type}")
@@ -567,6 +573,18 @@ def score_sentence(self, ref, out):
567573

568574
return self.scale * score_value, None
569575

576+
def get_sents(self, tokens):
577+
# assume sentences are separated by "."
578+
sents = " ".join(tokens).split(".")
579+
sents = [x for x in sents if len(x)]
580+
return sents
581+
582+
def tokenize(self, tokens):
583+
text = re.sub(r"[^a-zA-Z0-9]+", " ", tokens)
584+
tokens = re.split(r"\s+", text)
585+
tokens = [x for x in tokens if len(x)]
586+
return tokens
587+
570588
def name(self):
571589
return self.rouge_type
572590

@@ -859,7 +877,7 @@ def create_scorer_from_profile(profile, case_insensitive=False, meteor_directory
859877
return RibesScorer(case_insensitive=case_insensitive)
860878
elif profile == 'chrf':
861879
return ChrFScorer(case_insensitive=case_insensitive)
862-
elif re.match(r"rouge[0-9L]$", profile):
880+
elif re.match(r"rouge[0-9L](sum)?$", profile):
863881
return RougeScorer(rouge_type=profile, case_insensitive=case_insensitive)
864882
elif profile == 'wer':
865883
return WERScorer(case_insensitive=case_insensitive)

0 commit comments

Comments
 (0)