Carlos-Francisco Méndez-Cruz

Conditional Random Fields

1 +# -*- coding: UTF-8 -*-
2 +
3 +import os
4 +from itertools import chain
5 +from optparse import OptionParser
6 +from time import time
7 +from collections import Counter
8 +import re
9 +
10 +import nltk
11 +import sklearn
12 +import scipy.stats
13 +import sys
14 +
15 +from sklearn.externals import joblib
16 +from sklearn.metrics import make_scorer
17 +from sklearn.cross_validation import cross_val_score
18 +from sklearn.grid_search import RandomizedSearchCV
19 +
20 +import sklearn_crfsuite
21 +from sklearn_crfsuite import scorers
22 +from sklearn_crfsuite import metrics
23 +
24 +from nltk.corpus import stopwords
25 +
26 +
27 +# Objective
28 +# Training and evaluation of CRFs with sklearn-crfsuite.
29 +#
30 +# Input parameters
31 +# --inputPath=PATH Path of training and test data set
32 +# --trainingFile File with training data set
33 +# --testFile File with test data set
34 +# --outputPath=PATH Output path to place output files
35 +# --filteringStopWords Filtering stop words
36 +# --excludeSymbols Filtering punctuation marks
37 +
38 +# Output
39 +# 1) Best model
40 +
41 +# Examples
42 +# python3.4 training-validation-v1.py
43 +# --inputPath /export/space1/users/compu2/bionlp/conditional-random-fields/data-sets
44 +# --trainingFile training-data-set-70.txt
45 +# --testFile test-data-set-30.txt
46 +# --outputPath /export/space1/users/compu2/bionlp/conditional-random-fields
47 +# python3.4 training-validation-v1.py --inputPath /export/space1/users/compu2/bionlp/conditional-random-fields/data-sets --trainingFile training-data-set-70.txt --testFile test-data-set-30.txt --outputPath /export/space1/users/compu2/bionlp/conditional-random-fields
48 +
49 +#################################
50 +# FUNCTIONS #
51 +#################################
52 +def endsConLow(word):
53 + miregex = re.compile(r'[^aeiouA-Z0-9]$')
54 + if miregex.search(word):
55 + return True
56 + else:
57 + return False
58 +
59 +def word2features(sent, i):
60 + listElem = sent[i].split('|')
61 + word = listElem[0]
62 + lemma = listElem[1]
63 + postag = listElem[2]
64 +
65 + features = {
66 + # Suffixes
67 + 'word[-3:]': word[-3:],
68 + 'word[-2:]': word[-2:],
69 + 'word[-1:]': word[-1:],
70 + #'word.isupper()': word.isupper(),
71 + 'word': word,
72 + 'lemma': lemma,
73 + 'postag': postag,
74 + 'lemma[-3:]': lemma[-3:],
75 + 'lemma[-2:]': lemma[-2:],
76 + 'lemma[-1:]': lemma[-1:],
77 + 'word[:3]': word[:3],
78 + 'word[:2]': word[:2],
79 + 'word[:1]': word[:1],
80 + 'endsConLow()={}'.format(endsConLow(word)): endsConLow(word),
81 + }
82 + if i > 0:
83 + listElem = sent[i - 1].split('|')
84 + word1 = listElem[0]
85 + lemma1 = listElem[1]
86 + postag1 = listElem[2]
87 + features.update({
88 + '-1:word': word1,
89 + '-1:lemma': lemma1,
90 + '-1:postag': postag1,
91 + })
92 +
93 + if i < len(sent) - 1:
94 + listElem = sent[i + 1].split('|')
95 + word1 = listElem[0]
96 + lemma1 = listElem[1]
97 + postag1 = listElem[2]
98 + features.update({
99 + '+1:word': word1,
100 + '+1:lemma': lemma1,
101 + '+1:postag': postag1,
102 + })
103 +
104 + '''
105 + if i > 1:
106 + listElem = sent[i - 2].split('|')
107 + word2 = listElem[0]
108 + lemma2 = listElem[1]
109 + postag2 = listElem[2]
110 + features.update({
111 + '-2:word': word2,
112 + '-2:lemma': lemma2,
113 + })
114 +
115 + if i < len(sent) - 2:
116 + listElem = sent[i + 2].split('|')
117 + word2 = listElem[0]
118 + lemma2 = listElem[1]
119 + postag2 = listElem[2]
120 + features.update({
121 + '+2:word': word2,
122 + '+2:lemma': lemma2,
123 + })
124 +
125 + trigrams = False
126 + if trigrams:
127 + if i > 2:
128 + listElem = sent[i - 3].split('|')
129 + word3 = listElem[0]
130 + lemma3 = listElem[1]
131 + postag3 = listElem[2]
132 + features.update({
133 + '-3:word': word3,
134 + '-3:lemma': lemma3,
135 + })
136 +
137 + if i < len(sent) - 3:
138 + listElem = sent[i + 3].split('|')
139 + word3 = listElem[0]
140 + lemma3 = listElem[1]
141 + postag3 = listElem[2]
142 + features.update({
143 + '+3:word': word3,
144 + '+3:lemma': lemma3,
145 + })
146 + '''
147 + return features
148 +
149 +
150 +def sent2features(sent):
151 + return [word2features(sent, i) for i in range(len(sent))]
152 +
153 +
154 +def sent2labels(sent):
155 + return [elem.split('|')[3] for elem in sent]
156 +
157 +
158 +def sent2tokens(sent):
159 + return [token for token, postag, label in sent]
160 +
161 +
162 +def print_transitions(trans_features, f):
163 + for (label_from, label_to), weight in trans_features:
164 + f.write("{:6} -> {:7} {:0.6f}\n".format(label_from, label_to, weight))
165 +
166 +
167 +def print_state_features(state_features, f):
168 + for (attr, label), weight in state_features:
169 + f.write("{:0.6f} {:8} {}\n".format(weight, label, attr.encode("utf-8")))
170 +
171 +
172 +__author__ = 'CMendezC'
173 +
174 +##########################################
175 +# MAIN PROGRAM #
176 +##########################################
177 +
178 +if __name__ == "__main__":
179 + # Defining parameters
180 + parser = OptionParser()
181 + parser.add_option("--inputPath", dest="inputPath",
182 + help="Path of training data set", metavar="PATH")
183 + parser.add_option("--outputPath", dest="outputPath",
184 + help="Output path to place output files",
185 + metavar="PATH")
186 + parser.add_option("--trainingFile", dest="trainingFile",
187 + help="File with training data set", metavar="FILE")
188 + parser.add_option("--testFile", dest="testFile",
189 + help="File with test data set", metavar="FILE")
190 + parser.add_option("--excludeStopWords", default=False,
191 + action="store_true", dest="excludeStopWords",
192 + help="Exclude stop words")
193 + parser.add_option("--excludeSymbols", default=False,
194 + action="store_true", dest="excludeSymbols",
195 + help="Exclude punctuation marks")
196 +
197 + (options, args) = parser.parse_args()
198 + if len(args) > 0:
199 + parser.error("Any parameter given.")
200 + sys.exit(1)
201 +
202 + print('-------------------------------- PARAMETERS --------------------------------')
203 + print("Path of training data set: " + options.inputPath)
204 + print("File with training data set: " + str(options.trainingFile))
205 + print("Path of test data set: " + options.inputPath)
206 + print("File with test data set: " + str(options.testFile))
207 + print("Exclude stop words: " + str(options.excludeStopWords))
208 + symbols = ['.', ',', ':', ';', '?', '!', '\'', '"', '<', '>', '(', ')', '-', '_', '/', '\\', '¿', '¡', '+', '{',
209 + '}', '[', ']', '*', '%', '$', '#', '&', '°', '`', '...']
210 + #print("Exclude symbols " + str(symbols) + ': ' + str(options.excludeSymbols))
211 + print("Exclude symbols: " + str(options.excludeSymbols))
212 +
213 + print('-------------------------------- PROCESSING --------------------------------')
214 + print('Reading corpus...')
215 + t0 = time()
216 +
217 + sentencesTrainingData = []
218 + sentencesTestData = []
219 +
220 + stopwords = [word for word in stopwords.words('english')]
221 +
222 + with open(os.path.join(options.inputPath, options.trainingFile), "r") as iFile:
223 + for line in iFile.readlines():
224 + listLine = []
225 + line = line.strip('\n')
226 + for token in line.split():
227 + if options.excludeStopWords:
228 + listToken = token.split('|')
229 + lemma = listToken[1]
230 + if lemma in stopwords:
231 + continue
232 + if options.excludeSymbols:
233 + listToken = token.split('|')
234 + lemma = listToken[1]
235 + if lemma in symbols:
236 + continue
237 + listLine.append(token)
238 + sentencesTrainingData.append(listLine)
239 + print(" Sentences training data: " + str(len(sentencesTrainingData)))
240 +
241 + with open(os.path.join(options.inputPath, options.testFile), "r") as iFile:
242 + for line in iFile.readlines():
243 + listLine = []
244 + line = line.strip('\n')
245 + for token in line.split():
246 + if options.excludeStopWords:
247 + listToken = token.split('|')
248 + lemma = listToken[1]
249 + if lemma in stopwords:
250 + continue
251 + if options.excludeSymbols:
252 + listToken = token.split('|')
253 + lemma = listToken[1]
254 + if lemma in symbols:
255 + continue
256 + listLine.append(token)
257 + sentencesTestData.append(listLine)
258 + print(" Sentences test data: " + str(len(sentencesTestData)))
259 +
260 + print("Reading corpus done in: %fs" % (time() - t0))
261 +
262 + print(sent2features(sentencesTrainingData[0])[0])
263 + print(sent2features(sentencesTestData[0])[0])
264 + t0 = time()
265 +
266 + X_train = [sent2features(s) for s in sentencesTrainingData]
267 + y_train = [sent2labels(s) for s in sentencesTrainingData]
268 +
269 + X_test = [sent2features(s) for s in sentencesTestData]
270 + # print X_test
271 + y_test = [sent2labels(s) for s in sentencesTestData]
272 +
273 + # Fixed parameters
274 + # crf = sklearn_crfsuite.CRF(
275 + # algorithm='lbfgs',
276 + # c1=0.1,
277 + # c2=0.1,
278 + # max_iterations=100,
279 + # all_possible_transitions=True
280 + # )
281 +
282 + # Hyperparameter Optimization
283 + crf = sklearn_crfsuite.CRF(
284 + algorithm='lbfgs',
285 + max_iterations=100,
286 + all_possible_transitions=True
287 + )
288 + params_space = {
289 + 'c1': scipy.stats.expon(scale=0.5),
290 + 'c2': scipy.stats.expon(scale=0.05),
291 + }
292 +
293 + # Original: labels = list(crf.classes_)
294 + # Original: labels.remove('O')
295 + labels = list(['GENE'])
296 +
297 + # use the same metric for evaluation
298 + f1_scorer = make_scorer(metrics.flat_f1_score,
299 + average='weighted', labels=labels)
300 +
301 + # search
302 + rs = RandomizedSearchCV(crf, params_space,
303 + cv=10,
304 + verbose=3,
305 + n_jobs=-1,
306 + n_iter=20,
307 + # n_iter=50,
308 + scoring=f1_scorer)
309 + rs.fit(X_train, y_train)
310 +
311 + # Fixed parameters
312 + # crf.fit(X_train, y_train)
313 +
314 + # Best hiperparameters
315 + # crf = rs.best_estimator_
316 + nameReport = options.trainingFile.replace('.txt', '.fStopWords_' + str(options.excludeStopWords) + '.fSymbols_' + str(
317 + options.excludeSymbols) + '.txt')
318 + with open(os.path.join(options.outputPath, "reports", "report_" + nameReport), mode="w") as oFile:
319 + oFile.write("********** TRAINING AND TESTING REPORT **********\n")
320 + oFile.write("Training file: " + options.trainingFile + '\n')
321 + oFile.write('\n')
322 + oFile.write('best params:' + str(rs.best_params_) + '\n')
323 + oFile.write('best CV score:' + str(rs.best_score_) + '\n')
324 + oFile.write('model size: {:0.2f}M\n'.format(rs.best_estimator_.size_ / 1000000))
325 +
326 + print("Training done in: %fs" % (time() - t0))
327 + t0 = time()
328 +
329 + # Update best crf
330 + crf = rs.best_estimator_
331 +
332 + # Saving model
333 + print(" Saving training model...")
334 + t1 = time()
335 + nameModel = options.trainingFile.replace('.txt', '.fStopWords_' + str(options.excludeStopWords) + '.fSymbols_' + str(
336 + options.excludeSymbols) + '.mod')
337 + joblib.dump(crf, os.path.join(options.outputPath, "models", nameModel))
338 + print(" Saving training model done in: %fs" % (time() - t1))
339 +
340 + # Evaluation against test data
341 + y_pred = crf.predict(X_test)
342 + print("*********************************")
343 + name = options.trainingFile.replace('.txt', '.fStopWords_' + str(options.excludeStopWords) + '.fSymbols_' + str(
344 + options.excludeSymbols) + '.txt')
345 + with open(os.path.join(options.outputPath, "reports", "y_pred_" + name), "w") as oFile:
346 + for y in y_pred:
347 + oFile.write(str(y) + '\n')
348 +
349 + print("*********************************")
350 + name = options.trainingFile.replace('.txt', '.fStopWords_' + str(options.excludeStopWords) + '.fSymbols_' + str(
351 + options.excludeSymbols) + '.txt')
352 + with open(os.path.join(options.outputPath, "reports", "y_test_" + name), "w") as oFile:
353 + for y in y_test:
354 + oFile.write(str(y) + '\n')
355 +
356 + print("Prediction done in: %fs" % (time() - t0))
357 +
358 + # labels = list(crf.classes_)
359 + # labels.remove('O')
360 +
361 + with open(os.path.join(options.outputPath, "reports", "report_" + nameReport), mode="a") as oFile:
362 + oFile.write('\n')
363 + oFile.write("Flat F1: " + str(metrics.flat_f1_score(y_test, y_pred, average='weighted', labels=labels)))
364 + oFile.write('\n')
365 + # labels = list(crf.classes_)
366 + sorted_labels = sorted(
367 + labels,
368 + key=lambda name: (name[1:], name[0])
369 + )
370 + oFile.write(metrics.flat_classification_report(
371 + y_test, y_pred, labels=sorted_labels, digits=3
372 + ))
373 + oFile.write('\n')
374 +
375 + oFile.write("\nTop likely transitions:\n")
376 + print_transitions(Counter(crf.transition_features_).most_common(50), oFile)
377 + oFile.write('\n')
378 +
379 + oFile.write("\nTop unlikely transitions:\n")
380 + print_transitions(Counter(crf.transition_features_).most_common()[-50:], oFile)
381 + oFile.write('\n')
382 +
383 + oFile.write("\nTop positive:\n")
384 + print_state_features(Counter(crf.state_features_).most_common(200), oFile)
385 + oFile.write('\n')
386 +
387 + oFile.write("\nTop negative:\n")
388 + print_state_features(Counter(crf.state_features_).most_common()[-200:], oFile)
389 + oFile.write('\n')