Carlos-Francisco Méndez-Cruz

Conditional Random Fields

1 -# -*- coding: UTF-8 -*-
2 -
3 -import os
4 -from itertools import chain
5 -from optparse import OptionParser
6 -from time import time
7 -from collections import Counter
8 -import re
9 -
10 -import nltk
11 -import sklearn
12 -import scipy.stats
13 -import sys
14 -
15 -from sklearn.externals import joblib
16 -from sklearn.metrics import make_scorer
17 -from sklearn.cross_validation import cross_val_score
18 -from sklearn.grid_search import RandomizedSearchCV
19 -
20 -import sklearn_crfsuite
21 -from sklearn_crfsuite import scorers
22 -from sklearn_crfsuite import metrics
23 -
24 -from nltk.corpus import stopwords
25 -
26 -
27 -# Objective
28 -# Training and evaluation of CRFs with sklearn-crfsuite.
29 -#
30 -# Input parameters
31 -# --inputPath=PATH Path of training and test data set
32 -# --trainingFile File with training data set
33 -# --testFile File with test data set
34 -# --outputPath=PATH Output path to place output files
35 -# --filteringStopWords Filtering stop words
36 -# --excludeSymbols Filtering punctuation marks
37 -
38 -# Output
39 -# 1) Best model
40 -
41 -# Examples
42 -# python3.4 training-validation-v1.py
43 -# --inputPath /export/space1/users/compu2/bionlp/conditional-random-fields/data-sets
44 -# --trainingFile training-data-set-70.txt
45 -# --testFile test-data-set-30.txt
46 -# --outputPath /export/space1/users/compu2/bionlp/conditional-random-fields
47 -# python3.4 training-validation-v1.py --inputPath /export/space1/users/compu2/bionlp/conditional-random-fields/data-sets --trainingFile training-data-set-70.txt --testFile test-data-set-30.txt --outputPath /export/space1/users/compu2/bionlp/conditional-random-fields
48 -
49 -#################################
50 -# FUNCTIONS #
51 -#################################
52 -def endsConLow(word):
53 - miregex = re.compile(r'[^aeiouA-Z0-9]$')
54 - if miregex.search(word):
55 - return True
56 - else:
57 - return False
58 -
59 -def word2features(sent, i):
60 - listElem = sent[i].split('|')
61 - word = listElem[0]
62 - lemma = listElem[1]
63 - postag = listElem[2]
64 -
65 - features = {
66 - # Suffixes
67 - #'word[-3:]': word[-3:],
68 - #'word[-2:]': word[-2:],
69 - #'word[-1:]': word[-1:],
70 - #'word.isupper()': word.isupper(),
71 - #'word': word,
72 - #'lemma': lemma,
73 - #'postag': postag,
74 - 'lemma[-3:]': lemma[-3:],
75 - 'lemma[-2:]': lemma[-2:],
76 - 'lemma[-1:]': lemma[-1:],
77 - 'lemma[+3:]': lemma[:3],
78 - 'lemma[+2:]': lemma[:2],
79 - 'lemma[+1:]': lemma[:1],
80 - #'word[:3]': word[:3],
81 - #'word[:2]': word[:2],
82 - #'word[:1]': word[:1],
83 - #'endsConLow()={}'.format(endsConLow(word)): endsConLow(word),
84 - }
85 - if i > 0:
86 - listElem = sent[i - 1].split('|')
87 - word1 = listElem[0]
88 - lemma1 = listElem[1]
89 - postag1 = listElem[2]
90 - features.update({
91 - #'-1:word': word1,
92 - '-1:lemma': lemma1,
93 - '-1:postag': postag1,
94 - })
95 -
96 - if i < len(sent) - 1:
97 - listElem = sent[i + 1].split('|')
98 - word1 = listElem[0]
99 - lemma1 = listElem[1]
100 - postag1 = listElem[2]
101 - features.update({
102 - #'+1:word': word1,
103 - '+1:lemma': lemma1,
104 - '+1:postag': postag1,
105 - })
106 -
107 - '''
108 - if i > 1:
109 - listElem = sent[i - 2].split('|')
110 - word2 = listElem[0]
111 - lemma2 = listElem[1]
112 - postag2 = listElem[2]
113 - features.update({
114 - '-2:word': word2,
115 - '-2:lemma': lemma2,
116 - })
117 -
118 - if i < len(sent) - 2:
119 - listElem = sent[i + 2].split('|')
120 - word2 = listElem[0]
121 - lemma2 = listElem[1]
122 - postag2 = listElem[2]
123 - features.update({
124 - '+2:word': word2,
125 - '+2:lemma': lemma2,
126 - })
127 -
128 - trigrams = False
129 - if trigrams:
130 - if i > 2:
131 - listElem = sent[i - 3].split('|')
132 - word3 = listElem[0]
133 - lemma3 = listElem[1]
134 - postag3 = listElem[2]
135 - features.update({
136 - '-3:word': word3,
137 - '-3:lemma': lemma3,
138 - })
139 -
140 - if i < len(sent) - 3:
141 - listElem = sent[i + 3].split('|')
142 - word3 = listElem[0]
143 - lemma3 = listElem[1]
144 - postag3 = listElem[2]
145 - features.update({
146 - '+3:word': word3,
147 - '+3:lemma': lemma3,
148 - })
149 - '''
150 - return features
151 -
152 -
153 -def sent2features(sent):
154 - return [word2features(sent, i) for i in range(len(sent))]
155 -
156 -
157 -def sent2labels(sent):
158 - return [elem.split('|')[3] for elem in sent]
159 -
160 -
161 -def sent2tokens(sent):
162 - return [token for token, postag, label in sent]
163 -
164 -
165 -def print_transitions(trans_features, f):
166 - for (label_from, label_to), weight in trans_features:
167 - f.write("{:6} -> {:7} {:0.6f}\n".format(label_from, label_to, weight))
168 -
169 -
170 -def print_state_features(state_features, f):
171 - for (attr, label), weight in state_features:
172 - f.write("{:0.6f} {:8} {}\n".format(weight, label, attr.encode("utf-8")))
173 -
174 -
175 -__author__ = 'CMendezC'
176 -
177 -##########################################
178 -# MAIN PROGRAM #
179 -##########################################
180 -
181 -if __name__ == "__main__":
182 - # Defining parameters
183 - parser = OptionParser()
184 - parser.add_option("--inputPath", dest="inputPath",
185 - help="Path of training data set", metavar="PATH")
186 - parser.add_option("--outputPath", dest="outputPath",
187 - help="Output path to place output files",
188 - metavar="PATH")
189 - parser.add_option("--trainingFile", dest="trainingFile",
190 - help="File with training data set", metavar="FILE")
191 - parser.add_option("--testFile", dest="testFile",
192 - help="File with test data set", metavar="FILE")
193 - parser.add_option("--excludeStopWords", default=False,
194 - action="store_true", dest="excludeStopWords",
195 - help="Exclude stop words")
196 - parser.add_option("--excludeSymbols", default=False,
197 - action="store_true", dest="excludeSymbols",
198 - help="Exclude punctuation marks")
199 -
200 - (options, args) = parser.parse_args()
201 - if len(args) > 0:
202 - parser.error("Any parameter given.")
203 - sys.exit(1)
204 -
205 - print('-------------------------------- PARAMETERS --------------------------------')
206 - print("Path of training data set: " + options.inputPath)
207 - print("File with training data set: " + str(options.trainingFile))
208 - print("Path of test data set: " + options.inputPath)
209 - print("File with test data set: " + str(options.testFile))
210 - print("Exclude stop words: " + str(options.excludeStopWords))
211 - symbols = ['.', ',', ':', ';', '?', '!', '\'', '"', '<', '>', '(', ')', '-', '_', '/', '\\', '¿', '¡', '+', '{',
212 - '}', '[', ']', '*', '%', '$', '#', '&', '°', '`', '...']
213 - #print("Exclude symbols " + str(symbols) + ': ' + str(options.excludeSymbols))
214 - print("Exclude symbols: " + str(options.excludeSymbols))
215 -
216 - print('-------------------------------- PROCESSING --------------------------------')
217 - print('Reading corpus...')
218 - t0 = time()
219 -
220 - sentencesTrainingData = []
221 - sentencesTestData = []
222 -
223 - stopwords = [word for word in stopwords.words('english')]
224 -
225 - with open(os.path.join(options.inputPath, options.trainingFile), "r") as iFile:
226 - for line in iFile.readlines():
227 - listLine = []
228 - line = line.strip('\n')
229 - for token in line.split():
230 - if options.excludeStopWords:
231 - listToken = token.split('|')
232 - lemma = listToken[1]
233 - if lemma in stopwords:
234 - continue
235 - if options.excludeSymbols:
236 - listToken = token.split('|')
237 - lemma = listToken[1]
238 - if lemma in symbols:
239 - continue
240 - listLine.append(token)
241 - sentencesTrainingData.append(listLine)
242 - print(" Sentences training data: " + str(len(sentencesTrainingData)))
243 -
244 - with open(os.path.join(options.inputPath, options.testFile), "r") as iFile:
245 - for line in iFile.readlines():
246 - listLine = []
247 - line = line.strip('\n')
248 - for token in line.split():
249 - if options.excludeStopWords:
250 - listToken = token.split('|')
251 - lemma = listToken[1]
252 - if lemma in stopwords:
253 - continue
254 - if options.excludeSymbols:
255 - listToken = token.split('|')
256 - lemma = listToken[1]
257 - if lemma in symbols:
258 - continue
259 - listLine.append(token)
260 - sentencesTestData.append(listLine)
261 - print(" Sentences test data: " + str(len(sentencesTestData)))
262 -
263 - print("Reading corpus done in: %fs" % (time() - t0))
264 -
265 - #print(sent2features(sentencesTrainingData[0])[0])
266 - #print(sent2features(sentencesTestData[0])[0])
267 - t0 = time()
268 -
269 - X_train = [sent2features(s) for s in sentencesTrainingData]
270 - y_train = [sent2labels(s) for s in sentencesTrainingData]
271 -
272 - X_test = [sent2features(s) for s in sentencesTestData]
273 - # print X_test
274 - y_test = [sent2labels(s) for s in sentencesTestData]
275 -
276 - # Fixed parameters
277 - # crf = sklearn_crfsuite.CRF(
278 - # algorithm='lbfgs',
279 - # c1=0.1,
280 - # c2=0.1,
281 - # max_iterations=100,
282 - # all_possible_transitions=True
283 - # )
284 -
285 - # Hyperparameter Optimization
286 - crf = sklearn_crfsuite.CRF(
287 - algorithm='lbfgs',
288 - max_iterations=100,
289 - all_possible_transitions=True
290 - )
291 - params_space = {
292 - 'c1': scipy.stats.expon(scale=0.5),
293 - 'c2': scipy.stats.expon(scale=0.05),
294 - }
295 -
296 - # Original: labels = list(crf.classes_)
297 - # Original: labels.remove('O')
298 - labels = list(['GENE'])
299 -
300 - # use the same metric for evaluation
301 - f1_scorer = make_scorer(metrics.flat_f1_score,
302 - average='weighted', labels=labels)
303 -
304 - # search
305 - rs = RandomizedSearchCV(crf, params_space,
306 - cv=10,
307 - verbose=3,
308 - n_jobs=-1,
309 - n_iter=20,
310 - # n_iter=50,
311 - scoring=f1_scorer)
312 - rs.fit(X_train, y_train)
313 -
314 - # Fixed parameters
315 - # crf.fit(X_train, y_train)
316 -
317 - # Best hiperparameters
318 - # crf = rs.best_estimator_
319 - nameReport = options.trainingFile.replace('.txt', '.fStopWords_' + str(options.excludeStopWords) + '.fSymbols_' + str(
320 - options.excludeSymbols) + '.txt')
321 - with open(os.path.join(options.outputPath, "reports", "report_" + nameReport), mode="w") as oFile:
322 - oFile.write("********** TRAINING AND TESTING REPORT **********\n")
323 - oFile.write("Training file: " + options.trainingFile + '\n')
324 - oFile.write('\n')
325 - oFile.write('best params:' + str(rs.best_params_) + '\n')
326 - oFile.write('best CV score:' + str(rs.best_score_) + '\n')
327 - oFile.write('model size: {:0.2f}M\n'.format(rs.best_estimator_.size_ / 1000000))
328 -
329 - print("Training done in: %fs" % (time() - t0))
330 - t0 = time()
331 -
332 - # Update best crf
333 - crf = rs.best_estimator_
334 -
335 - # Saving model
336 - print(" Saving training model...")
337 - t1 = time()
338 - nameModel = options.trainingFile.replace('.txt', '.fStopWords_' + str(options.excludeStopWords) + '.fSymbols_' + str(
339 - options.excludeSymbols) + '.mod')
340 - joblib.dump(crf, os.path.join(options.outputPath, "models", nameModel))
341 - print(" Saving training model done in: %fs" % (time() - t1))
342 -
343 - # Evaluation against test data
344 - y_pred = crf.predict(X_test)
345 - print("*********************************")
346 - name = options.trainingFile.replace('.txt', '.fStopWords_' + str(options.excludeStopWords) + '.fSymbols_' + str(
347 - options.excludeSymbols) + '.txt')
348 - with open(os.path.join(options.outputPath, "reports", "y_pred_" + name), "w") as oFile:
349 - for y in y_pred:
350 - oFile.write(str(y) + '\n')
351 -
352 - print("*********************************")
353 - name = options.trainingFile.replace('.txt', '.fStopWords_' + str(options.excludeStopWords) + '.fSymbols_' + str(
354 - options.excludeSymbols) + '.txt')
355 - with open(os.path.join(options.outputPath, "reports", "y_test_" + name), "w") as oFile:
356 - for y in y_test:
357 - oFile.write(str(y) + '\n')
358 -
359 - print("Prediction done in: %fs" % (time() - t0))
360 -
361 - # labels = list(crf.classes_)
362 - # labels.remove('O')
363 -
364 - with open(os.path.join(options.outputPath, "reports", "report_" + nameReport), mode="a") as oFile:
365 - oFile.write('\n')
366 - oFile.write("Flat F1: " + str(metrics.flat_f1_score(y_test, y_pred, average='weighted', labels=labels)))
367 - oFile.write('\n')
368 - # labels = list(crf.classes_)
369 - sorted_labels = sorted(
370 - labels,
371 - key=lambda name: (name[1:], name[0])
372 - )
373 - oFile.write(metrics.flat_classification_report(
374 - y_test, y_pred, labels=sorted_labels, digits=3
375 - ))
376 - oFile.write('\n')
377 -
378 - oFile.write("\nTop likely transitions:\n")
379 - print_transitions(Counter(crf.transition_features_).most_common(50), oFile)
380 - oFile.write('\n')
381 -
382 - oFile.write("\nTop unlikely transitions:\n")
383 - print_transitions(Counter(crf.transition_features_).most_common()[-50:], oFile)
384 - oFile.write('\n')
385 -
386 - oFile.write("\nTop positive:\n")
387 - print_state_features(Counter(crf.state_features_).most_common(200), oFile)
388 - oFile.write('\n')
389 -
390 - oFile.write("\nTop negative:\n")
391 - print_state_features(Counter(crf.state_features_).most_common()[-200:], oFile)
392 - oFile.write('\n')