1
1
# coding=utf-8
2
- # Copyright 2018 The Google Research Authors.
2
+ # Copyright 2019 The Google Research Authors.
3
3
#
4
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
5
# you may not use this file except in compliance with the License.
13
13
# See the License for the specific language governing permissions and
14
14
# limitations under the License.
15
15
16
+ # Lint as: python2, python3
16
17
"""Computes rouge scores between two text blobs.
17
-
18
18
Implementation replicates the functionality in the original ROUGE package. See:
19
-
20
19
Lin, Chin-Yew. ROUGE: a Package for Automatic Evaluation of Summaries. In
21
20
Proceedings of the Workshop on Text Summarization Branches Out (WAS 2004),
22
21
Barcelona, Spain, July 25 - 26, 2004.
23
-
24
22
Default options are equivalent to running:
25
23
ROUGE-1.5.5.pl -e data -n 2 -a settings.xml
26
-
27
24
Or with use_stemmer=True:
28
25
ROUGE-1.5.5.pl -m -e data -n 2 -a settings.xml
29
-
30
26
In these examples settings.xml lists input files and formats.
31
27
"""
32
28
38
34
import re
39
35
40
36
from nltk .stem import porter
41
- import numpy as np
42
37
import six
43
- from six .moves import xrange # pylint: disable=redefined-builtin
38
+ from six .moves import map
39
+ from six .moves import range
44
40
from compare_mt .rouge import scoring
45
- from compare_mt .rouge import tokenizer
41
+ from compare_mt .rouge import tokenize
46
42
47
43
48
44
class RougeScorer (scoring .BaseScorer ):
49
45
"""Calculate rouges scores between two blobs of text.
50
-
51
46
Sample usage:
52
47
scorer = RougeScorer(['rouge1', 'rougeL'], use_stemmer=True)
53
48
scores = scorer.score('The quick brown fox jumps over the lazy dog',
@@ -56,11 +51,9 @@ class RougeScorer(scoring.BaseScorer):
56
51
57
52
def __init__ (self , rouge_types , use_stemmer = False ):
58
53
"""Initializes a new RougeScorer.
59
-
60
54
Valid rouge types that can be computed are:
61
55
rougen (e.g. rouge1, rouge2): n-gram based scoring.
62
56
rougeL: Longest common subsequence based scoring.
63
-
64
57
Args:
65
58
rouge_types: A list of rouge types to calculate.
66
59
use_stemmer: Bool indicating whether Porter stemmer should be used to
@@ -74,7 +67,6 @@ def __init__(self, rouge_types, use_stemmer=False):
74
67
75
68
def score (self , target , prediction ):
76
69
"""Calculates rouge scores between the target and prediction.
77
-
78
70
Args:
79
71
target: Text containing the target (ground truth) text.
80
72
prediction: Text containing the predicted text.
@@ -84,15 +76,29 @@ def score(self, target, prediction):
84
76
ValueError: If an invalid rouge type is encountered.
85
77
"""
86
78
87
- target_tokens = tokenizer .tokenize (target , self ._stemmer )
88
- prediction_tokens = tokenizer .tokenize (prediction , self ._stemmer )
79
+ target_tokens = tokenize .tokenize (target , self ._stemmer )
80
+ prediction_tokens = tokenize .tokenize (prediction , self ._stemmer )
89
81
result = {}
90
82
91
83
for rouge_type in self .rouge_types :
92
84
if rouge_type == "rougeL" :
93
85
# Rouge from longest common subsequences.
94
86
scores = _score_lcs (target_tokens , prediction_tokens )
95
- elif re .match (r"rouge[0-9]$" , rouge_type ):
87
+ elif rouge_type == "rougeLsum" :
88
+ # Note: Does not support multi-line text.
89
+ def get_sents (text ):
90
+ # Assume sentences are separated by newline.
91
+ sents = six .ensure_str (text ).split ("\n " )
92
+ sents = [x for x in sents if len (x )]
93
+ return sents
94
+
95
+ target_tokens_list = [
96
+ tokenize .tokenize (s , self ._stemmer ) for s in get_sents (target )]
97
+ prediction_tokens_list = [
98
+ tokenize .tokenize (s , self ._stemmer ) for s in get_sents (prediction )]
99
+ scores = _summary_level_lcs (target_tokens_list ,
100
+ prediction_tokens_list )
101
+ elif re .match (r"rouge[0-9]$" , six .ensure_str (rouge_type )):
96
102
# Rouge from n-grams.
97
103
n = int (rouge_type [5 :])
98
104
if n <= 0 :
@@ -109,7 +115,6 @@ def score(self, target, prediction):
109
115
110
116
def _create_ngrams (tokens , n ):
111
117
"""Creates ngrams from the given list of tokens.
112
-
113
118
Args:
114
119
tokens: A list of tokens from which ngrams are created.
115
120
n: Number of tokens to use, e.g. 2 for bigrams.
@@ -118,14 +123,13 @@ def _create_ngrams(tokens, n):
118
123
"""
119
124
120
125
ngrams = collections .Counter ()
121
- for ngram in (tuple (tokens [i :i + n ]) for i in xrange (len (tokens ) - n + 1 )):
126
+ for ngram in (tuple (tokens [i :i + n ]) for i in range (len (tokens ) - n + 1 )):
122
127
ngrams [ngram ] += 1
123
128
return ngrams
124
129
125
130
126
131
def _score_lcs (target_tokens , prediction_tokens ):
127
132
"""Computes LCS (Longest Common Subsequence) rouge scores.
128
-
129
133
Args:
130
134
target_tokens: Tokens from the target text.
131
135
prediction_tokens: Tokens from the predicted text.
@@ -137,16 +141,8 @@ def _score_lcs(target_tokens, prediction_tokens):
137
141
return scoring .Score (precision = 0 , recall = 0 , fmeasure = 0 )
138
142
139
143
# Compute length of LCS from the bottom up in a table (DP appproach).
140
- cols = len (prediction_tokens ) + 1
141
- rows = len (target_tokens ) + 1
142
- lcs_table = np .zeros ((rows , cols ))
143
- for i in xrange (1 , rows ):
144
- for j in xrange (1 , cols ):
145
- if target_tokens [i - 1 ] == prediction_tokens [j - 1 ]:
146
- lcs_table [i , j ] = lcs_table [i - 1 , j - 1 ] + 1
147
- else :
148
- lcs_table [i , j ] = max (lcs_table [i - 1 , j ], lcs_table [i , j - 1 ])
149
- lcs_length = lcs_table [- 1 , - 1 ]
144
+ lcs_table = _lcs_table (target_tokens , prediction_tokens )
145
+ lcs_length = lcs_table [- 1 ][- 1 ]
150
146
151
147
precision = lcs_length / len (prediction_tokens )
152
148
recall = lcs_length / len (target_tokens )
@@ -155,9 +151,106 @@ def _score_lcs(target_tokens, prediction_tokens):
155
151
return scoring .Score (precision = precision , recall = recall , fmeasure = fmeasure )
156
152
157
153
154
+ def _lcs_table (ref , can ):
155
+ """Create 2-d LCS score table."""
156
+ rows = len (ref )
157
+ cols = len (can )
158
+ lcs_table = [[0 ] * (cols + 1 ) for _ in range (rows + 1 )]
159
+ for i in range (1 , rows + 1 ):
160
+ for j in range (1 , cols + 1 ):
161
+ if ref [i - 1 ] == can [j - 1 ]:
162
+ lcs_table [i ][j ] = lcs_table [i - 1 ][j - 1 ] + 1
163
+ else :
164
+ lcs_table [i ][j ] = max (lcs_table [i - 1 ][j ], lcs_table [i ][j - 1 ])
165
+ return lcs_table
166
+
167
+
168
+ def _backtrack_norec (t , ref , can ):
169
+ """Read out LCS."""
170
+ i = len (ref )
171
+ j = len (can )
172
+ lcs = []
173
+ while i > 0 and j > 0 :
174
+ if ref [i - 1 ] == can [j - 1 ]:
175
+ lcs .insert (0 , i - 1 )
176
+ i -= 1
177
+ j -= 1
178
+ elif t [i ][j - 1 ] > t [i - 1 ][j ]:
179
+ j -= 1
180
+ else :
181
+ i -= 1
182
+ return lcs
183
+
184
+
185
+ def _summary_level_lcs (ref_sent , can_sent ):
186
+ """ROUGE: Summary-level LCS, section 3.2 in ROUGE paper.
187
+ Args:
188
+ ref_sent: list of tokenized reference sentences
189
+ can_sent: list of tokenized candidate sentences
190
+ Returns:
191
+ summary level ROUGE score
192
+ """
193
+ if not ref_sent or not can_sent :
194
+ return scoring .Score (precision = 0 , recall = 0 , fmeasure = 0 )
195
+
196
+ m = sum (map (len , ref_sent ))
197
+ n = sum (map (len , can_sent ))
198
+ if not n or not m :
199
+ return scoring .Score (precision = 0 , recall = 0 , fmeasure = 0 )
200
+
201
+ # get token counts to prevent double counting
202
+ token_cnts_r = collections .Counter ()
203
+ token_cnts_c = collections .Counter ()
204
+ for s in ref_sent :
205
+ # s is a list of tokens
206
+ token_cnts_r .update (s )
207
+ for s in can_sent :
208
+ token_cnts_c .update (s )
209
+
210
+ hits = 0
211
+ for r in ref_sent :
212
+ lcs = _union_lcs (r , can_sent )
213
+ # Prevent double-counting:
214
+ # The paper describes just computing hits += len(_union_lcs()),
215
+ # but the implementation prevents double counting. We also
216
+ # implement this as in version 1.5.5.
217
+ for t in lcs :
218
+ if token_cnts_c [t ] > 0 and token_cnts_r [t ] > 0 :
219
+ hits += 1
220
+ token_cnts_c [t ] -= 1
221
+ token_cnts_r [t ] -= 1
222
+
223
+ recall = hits / m
224
+ precision = hits / n
225
+ fmeasure = scoring .fmeasure (precision , recall )
226
+ return scoring .Score (precision = precision , recall = recall , fmeasure = fmeasure )
227
+
228
+
229
+ def _union_lcs (ref , c_list ):
230
+ """Find union LCS between a ref sentence and list of candidate sentences.
231
+ Args:
232
+ ref: list of tokens
233
+ c_list: list of list of indices for LCS into reference summary
234
+ Returns:
235
+ List of tokens in ref representing union LCS.
236
+ """
237
+ lcs_list = [lcs_ind (ref , c ) for c in c_list ]
238
+ return [ref [i ] for i in _find_union (lcs_list )]
239
+
240
+
241
+ def _find_union (lcs_list ):
242
+ """Finds union LCS given a list of LCS."""
243
+ return sorted (list (set ().union (* lcs_list )))
244
+
245
+
246
+ def lcs_ind (ref , can ):
247
+ """Returns one of the longest lcs."""
248
+ t = _lcs_table (ref , can )
249
+ return _backtrack_norec (t , ref , can )
250
+
251
+
158
252
def _score_ngrams (target_ngrams , prediction_ngrams ):
159
253
"""Compute n-gram based rouge scores.
160
-
161
254
Args:
162
255
target_ngrams: A Counter object mapping each ngram to number of
163
256
occurrences for the target text.
0 commit comments