Showing
3 changed files
with
789 additions
and
71 deletions
training_validation_v1-1.py
0 → 100644
1 | +# -*- coding: UTF-8 -*- | ||
2 | + | ||
3 | +import os | ||
4 | +from itertools import chain | ||
5 | +from optparse import OptionParser | ||
6 | +from time import time | ||
7 | +from collections import Counter | ||
8 | +import re | ||
9 | + | ||
10 | +import nltk | ||
11 | +import sklearn | ||
12 | +import scipy.stats | ||
13 | +import sys | ||
14 | + | ||
15 | +from sklearn.externals import joblib | ||
16 | +from sklearn.metrics import make_scorer | ||
17 | +from sklearn.cross_validation import cross_val_score | ||
18 | +from sklearn.grid_search import RandomizedSearchCV | ||
19 | + | ||
20 | +import sklearn_crfsuite | ||
21 | +from sklearn_crfsuite import scorers | ||
22 | +from sklearn_crfsuite import metrics | ||
23 | + | ||
24 | +from nltk.corpus import stopwords | ||
25 | + | ||
26 | + | ||
27 | +# Objective | ||
28 | +# Training and evaluation of CRFs with sklearn-crfsuite. | ||
29 | +# | ||
30 | +# Input parameters | ||
31 | +# --inputPath=PATH Path of training and test data set | ||
32 | +# --trainingFile File with training data set | ||
33 | +# --testFile File with test data set | ||
34 | +# --outputPath=PATH Output path to place output files | ||
35 | +# --filteringStopWords Filtering stop words | ||
36 | +# --excludeSymbols Filtering punctuation marks | ||
37 | + | ||
38 | +# Output | ||
39 | +# 1) Best model | ||
40 | + | ||
41 | +# Examples | ||
42 | +# python3.4 training-validation-v1.py | ||
43 | +# --inputPath /export/space1/users/compu2/bionlp/conditional-random-fields/data-sets | ||
44 | +# --trainingFile training-data-set-70.txt | ||
45 | +# --testFile test-data-set-30.txt | ||
46 | +# --outputPath /export/space1/users/compu2/bionlp/conditional-random-fields | ||
47 | +# python3.4 training-validation-v1.py --inputPath /export/space1/users/compu2/bionlp/conditional-random-fields/data-sets --trainingFile training-data-set-70.txt --testFile test-data-set-30.txt --outputPath /export/space1/users/compu2/bionlp/conditional-random-fields | ||
48 | + | ||
49 | +################################# | ||
50 | +# FUNCTIONS # | ||
51 | +################################# | ||
52 | +def endsConLow(word): | ||
53 | + miregex = re.compile(r'[^aeiouA-Z0-9]$') | ||
54 | + if miregex.search(word): | ||
55 | + return True | ||
56 | + else: | ||
57 | + return False | ||
58 | + | ||
59 | +def word2features(sent, i): | ||
60 | + listElem = sent[i].split('|') | ||
61 | + word = listElem[0] | ||
62 | + lemma = listElem[1] | ||
63 | + postag = listElem[2] | ||
64 | + | ||
65 | + features = { | ||
66 | + # Suffixes | ||
67 | + #'word[-3:]': word[-3:], | ||
68 | + #'word[-2:]': word[-2:], | ||
69 | + #'word[-1:]': word[-1:], | ||
70 | + #'word.isupper()': word.isupper(), | ||
71 | + #'word': word, | ||
72 | + #'lemma': lemma, | ||
73 | + #'postag': postag, | ||
74 | + 'lemma[-3:]': lemma[-3:], | ||
75 | + 'lemma[-2:]': lemma[-2:], | ||
76 | + 'lemma[-1:]': lemma[-1:], | ||
77 | + 'lemma[+3:]': lemma[:3], | ||
78 | + 'lemma[+2:]': lemma[:2], | ||
79 | + 'lemma[+1:]': lemma[:1], | ||
80 | + #'word[:3]': word[:3], | ||
81 | + #'word[:2]': word[:2], | ||
82 | + #'word[:1]': word[:1], | ||
83 | + #'endsConLow()={}'.format(endsConLow(word)): endsConLow(word), | ||
84 | + } | ||
85 | + if i > 0: | ||
86 | + listElem = sent[i - 1].split('|') | ||
87 | + word1 = listElem[0] | ||
88 | + lemma1 = listElem[1] | ||
89 | + postag1 = listElem[2] | ||
90 | + features.update({ | ||
91 | + #'-1:word': word1, | ||
92 | + '-1:lemma': lemma1, | ||
93 | + '-1:postag': postag1, | ||
94 | + }) | ||
95 | + | ||
96 | + if i < len(sent) - 1: | ||
97 | + listElem = sent[i + 1].split('|') | ||
98 | + word1 = listElem[0] | ||
99 | + lemma1 = listElem[1] | ||
100 | + postag1 = listElem[2] | ||
101 | + features.update({ | ||
102 | + #'+1:word': word1, | ||
103 | + '+1:lemma': lemma1, | ||
104 | + '+1:postag': postag1, | ||
105 | + }) | ||
106 | + | ||
107 | + ''' | ||
108 | + if i > 1: | ||
109 | + listElem = sent[i - 2].split('|') | ||
110 | + word2 = listElem[0] | ||
111 | + lemma2 = listElem[1] | ||
112 | + postag2 = listElem[2] | ||
113 | + features.update({ | ||
114 | + '-2:word': word2, | ||
115 | + '-2:lemma': lemma2, | ||
116 | + }) | ||
117 | + | ||
118 | + if i < len(sent) - 2: | ||
119 | + listElem = sent[i + 2].split('|') | ||
120 | + word2 = listElem[0] | ||
121 | + lemma2 = listElem[1] | ||
122 | + postag2 = listElem[2] | ||
123 | + features.update({ | ||
124 | + '+2:word': word2, | ||
125 | + '+2:lemma': lemma2, | ||
126 | + }) | ||
127 | + | ||
128 | + trigrams = False | ||
129 | + if trigrams: | ||
130 | + if i > 2: | ||
131 | + listElem = sent[i - 3].split('|') | ||
132 | + word3 = listElem[0] | ||
133 | + lemma3 = listElem[1] | ||
134 | + postag3 = listElem[2] | ||
135 | + features.update({ | ||
136 | + '-3:word': word3, | ||
137 | + '-3:lemma': lemma3, | ||
138 | + }) | ||
139 | + | ||
140 | + if i < len(sent) - 3: | ||
141 | + listElem = sent[i + 3].split('|') | ||
142 | + word3 = listElem[0] | ||
143 | + lemma3 = listElem[1] | ||
144 | + postag3 = listElem[2] | ||
145 | + features.update({ | ||
146 | + '+3:word': word3, | ||
147 | + '+3:lemma': lemma3, | ||
148 | + }) | ||
149 | + ''' | ||
150 | + return features | ||
151 | + | ||
152 | + | ||
153 | +def sent2features(sent): | ||
154 | + return [word2features(sent, i) for i in range(len(sent))] | ||
155 | + | ||
156 | + | ||
157 | +def sent2labels(sent): | ||
158 | + return [elem.split('|')[3] for elem in sent] | ||
159 | + | ||
160 | + | ||
161 | +def sent2tokens(sent): | ||
162 | + return [token for token, postag, label in sent] | ||
163 | + | ||
164 | + | ||
165 | +def print_transitions(trans_features, f): | ||
166 | + for (label_from, label_to), weight in trans_features: | ||
167 | + f.write("{:6} -> {:7} {:0.6f}\n".format(label_from, label_to, weight)) | ||
168 | + | ||
169 | + | ||
170 | +def print_state_features(state_features, f): | ||
171 | + for (attr, label), weight in state_features: | ||
172 | + f.write("{:0.6f} {:8} {}\n".format(weight, label, attr.encode("utf-8"))) | ||
173 | + | ||
174 | + | ||
175 | +__author__ = 'CMendezC' | ||
176 | + | ||
177 | +########################################## | ||
178 | +# MAIN PROGRAM # | ||
179 | +########################################## | ||
180 | + | ||
181 | +if __name__ == "__main__": | ||
182 | + # Defining parameters | ||
183 | + parser = OptionParser() | ||
184 | + parser.add_option("--inputPath", dest="inputPath", | ||
185 | + help="Path of training data set", metavar="PATH") | ||
186 | + parser.add_option("--outputPath", dest="outputPath", | ||
187 | + help="Output path to place output files", | ||
188 | + metavar="PATH") | ||
189 | + parser.add_option("--trainingFile", dest="trainingFile", | ||
190 | + help="File with training data set", metavar="FILE") | ||
191 | + parser.add_option("--testFile", dest="testFile", | ||
192 | + help="File with test data set", metavar="FILE") | ||
193 | + parser.add_option("--excludeStopWords", default=False, | ||
194 | + action="store_true", dest="excludeStopWords", | ||
195 | + help="Exclude stop words") | ||
196 | + parser.add_option("--excludeSymbols", default=False, | ||
197 | + action="store_true", dest="excludeSymbols", | ||
198 | + help="Exclude punctuation marks") | ||
199 | + | ||
200 | + (options, args) = parser.parse_args() | ||
201 | + if len(args) > 0: | ||
202 | + parser.error("Any parameter given.") | ||
203 | + sys.exit(1) | ||
204 | + | ||
205 | + print('-------------------------------- PARAMETERS --------------------------------') | ||
206 | + print("Path of training data set: " + options.inputPath) | ||
207 | + print("File with training data set: " + str(options.trainingFile)) | ||
208 | + print("Path of test data set: " + options.inputPath) | ||
209 | + print("File with test data set: " + str(options.testFile)) | ||
210 | + print("Exclude stop words: " + str(options.excludeStopWords)) | ||
211 | + symbols = ['.', ',', ':', ';', '?', '!', '\'', '"', '<', '>', '(', ')', '-', '_', '/', '\\', '¿', '¡', '+', '{', | ||
212 | + '}', '[', ']', '*', '%', '$', '#', '&', '°', '`', '...'] | ||
213 | + #print("Exclude symbols " + str(symbols) + ': ' + str(options.excludeSymbols)) | ||
214 | + print("Exclude symbols: " + str(options.excludeSymbols)) | ||
215 | + | ||
216 | + print('-------------------------------- PROCESSING --------------------------------') | ||
217 | + print('Reading corpus...') | ||
218 | + t0 = time() | ||
219 | + | ||
220 | + sentencesTrainingData = [] | ||
221 | + sentencesTestData = [] | ||
222 | + | ||
223 | + stopwords = [word for word in stopwords.words('english')] | ||
224 | + | ||
225 | + with open(os.path.join(options.inputPath, options.trainingFile), "r") as iFile: | ||
226 | + for line in iFile.readlines(): | ||
227 | + listLine = [] | ||
228 | + line = line.strip('\n') | ||
229 | + for token in line.split(): | ||
230 | + if options.excludeStopWords: | ||
231 | + listToken = token.split('|') | ||
232 | + lemma = listToken[1] | ||
233 | + if lemma in stopwords: | ||
234 | + continue | ||
235 | + if options.excludeSymbols: | ||
236 | + listToken = token.split('|') | ||
237 | + lemma = listToken[1] | ||
238 | + if lemma in symbols: | ||
239 | + continue | ||
240 | + listLine.append(token) | ||
241 | + sentencesTrainingData.append(listLine) | ||
242 | + print(" Sentences training data: " + str(len(sentencesTrainingData))) | ||
243 | + | ||
244 | + with open(os.path.join(options.inputPath, options.testFile), "r") as iFile: | ||
245 | + for line in iFile.readlines(): | ||
246 | + listLine = [] | ||
247 | + line = line.strip('\n') | ||
248 | + for token in line.split(): | ||
249 | + if options.excludeStopWords: | ||
250 | + listToken = token.split('|') | ||
251 | + lemma = listToken[1] | ||
252 | + if lemma in stopwords: | ||
253 | + continue | ||
254 | + if options.excludeSymbols: | ||
255 | + listToken = token.split('|') | ||
256 | + lemma = listToken[1] | ||
257 | + if lemma in symbols: | ||
258 | + continue | ||
259 | + listLine.append(token) | ||
260 | + sentencesTestData.append(listLine) | ||
261 | + print(" Sentences test data: " + str(len(sentencesTestData))) | ||
262 | + | ||
263 | + print("Reading corpus done in: %fs" % (time() - t0)) | ||
264 | + | ||
265 | + #print(sent2features(sentencesTrainingData[0])[0]) | ||
266 | + #print(sent2features(sentencesTestData[0])[0]) | ||
267 | + t0 = time() | ||
268 | + | ||
269 | + X_train = [sent2features(s) for s in sentencesTrainingData] | ||
270 | + y_train = [sent2labels(s) for s in sentencesTrainingData] | ||
271 | + | ||
272 | + X_test = [sent2features(s) for s in sentencesTestData] | ||
273 | + # print X_test | ||
274 | + y_test = [sent2labels(s) for s in sentencesTestData] | ||
275 | + | ||
276 | + # Fixed parameters | ||
277 | + # crf = sklearn_crfsuite.CRF( | ||
278 | + # algorithm='lbfgs', | ||
279 | + # c1=0.1, | ||
280 | + # c2=0.1, | ||
281 | + # max_iterations=100, | ||
282 | + # all_possible_transitions=True | ||
283 | + # ) | ||
284 | + | ||
285 | + # Hyperparameter Optimization | ||
286 | + crf = sklearn_crfsuite.CRF( | ||
287 | + algorithm='lbfgs', | ||
288 | + max_iterations=100, | ||
289 | + all_possible_transitions=True | ||
290 | + ) | ||
291 | + params_space = { | ||
292 | + 'c1': scipy.stats.expon(scale=0.5), | ||
293 | + 'c2': scipy.stats.expon(scale=0.05), | ||
294 | + } | ||
295 | + | ||
296 | + # Original: labels = list(crf.classes_) | ||
297 | + # Original: labels.remove('O') | ||
298 | + labels = list(['GENE']) | ||
299 | + | ||
300 | + # use the same metric for evaluation | ||
301 | + f1_scorer = make_scorer(metrics.flat_f1_score, | ||
302 | + average='weighted', labels=labels) | ||
303 | + | ||
304 | + # search | ||
305 | + rs = RandomizedSearchCV(crf, params_space, | ||
306 | + cv=10, | ||
307 | + verbose=3, | ||
308 | + n_jobs=-1, | ||
309 | + n_iter=20, | ||
310 | + # n_iter=50, | ||
311 | + scoring=f1_scorer) | ||
312 | + rs.fit(X_train, y_train) | ||
313 | + | ||
314 | + # Fixed parameters | ||
315 | + # crf.fit(X_train, y_train) | ||
316 | + | ||
317 | + # Best hiperparameters | ||
318 | + # crf = rs.best_estimator_ | ||
319 | + nameReport = options.trainingFile.replace('.txt', '.fStopWords_' + str(options.excludeStopWords) + '.fSymbols_' + str( | ||
320 | + options.excludeSymbols) + '.txt') | ||
321 | + with open(os.path.join(options.outputPath, "reports", "report_" + nameReport), mode="w") as oFile: | ||
322 | + oFile.write("********** TRAINING AND TESTING REPORT **********\n") | ||
323 | + oFile.write("Training file: " + options.trainingFile + '\n') | ||
324 | + oFile.write('\n') | ||
325 | + oFile.write('best params:' + str(rs.best_params_) + '\n') | ||
326 | + oFile.write('best CV score:' + str(rs.best_score_) + '\n') | ||
327 | + oFile.write('model size: {:0.2f}M\n'.format(rs.best_estimator_.size_ / 1000000)) | ||
328 | + | ||
329 | + print("Training done in: %fs" % (time() - t0)) | ||
330 | + t0 = time() | ||
331 | + | ||
332 | + # Update best crf | ||
333 | + crf = rs.best_estimator_ | ||
334 | + | ||
335 | + # Saving model | ||
336 | + print(" Saving training model...") | ||
337 | + t1 = time() | ||
338 | + nameModel = options.trainingFile.replace('.txt', '.fStopWords_' + str(options.excludeStopWords) + '.fSymbols_' + str( | ||
339 | + options.excludeSymbols) + '.mod') | ||
340 | + joblib.dump(crf, os.path.join(options.outputPath, "models", nameModel)) | ||
341 | + print(" Saving training model done in: %fs" % (time() - t1)) | ||
342 | + | ||
343 | + # Evaluation against test data | ||
344 | + y_pred = crf.predict(X_test) | ||
345 | + print("*********************************") | ||
346 | + name = options.trainingFile.replace('.txt', '.fStopWords_' + str(options.excludeStopWords) + '.fSymbols_' + str( | ||
347 | + options.excludeSymbols) + '.txt') | ||
348 | + with open(os.path.join(options.outputPath, "reports", "y_pred_" + name), "w") as oFile: | ||
349 | + for y in y_pred: | ||
350 | + oFile.write(str(y) + '\n') | ||
351 | + | ||
352 | + print("*********************************") | ||
353 | + name = options.trainingFile.replace('.txt', '.fStopWords_' + str(options.excludeStopWords) + '.fSymbols_' + str( | ||
354 | + options.excludeSymbols) + '.txt') | ||
355 | + with open(os.path.join(options.outputPath, "reports", "y_test_" + name), "w") as oFile: | ||
356 | + for y in y_test: | ||
357 | + oFile.write(str(y) + '\n') | ||
358 | + | ||
359 | + print("Prediction done in: %fs" % (time() - t0)) | ||
360 | + | ||
361 | + # labels = list(crf.classes_) | ||
362 | + # labels.remove('O') | ||
363 | + | ||
364 | + with open(os.path.join(options.outputPath, "reports", "report_" + nameReport), mode="a") as oFile: | ||
365 | + oFile.write('\n') | ||
366 | + oFile.write("Flat F1: " + str(metrics.flat_f1_score(y_test, y_pred, average='weighted', labels=labels))) | ||
367 | + oFile.write('\n') | ||
368 | + # labels = list(crf.classes_) | ||
369 | + sorted_labels = sorted( | ||
370 | + labels, | ||
371 | + key=lambda name: (name[1:], name[0]) | ||
372 | + ) | ||
373 | + oFile.write(metrics.flat_classification_report( | ||
374 | + y_test, y_pred, labels=sorted_labels, digits=3 | ||
375 | + )) | ||
376 | + oFile.write('\n') | ||
377 | + | ||
378 | + oFile.write("\nTop likely transitions:\n") | ||
379 | + print_transitions(Counter(crf.transition_features_).most_common(50), oFile) | ||
380 | + oFile.write('\n') | ||
381 | + | ||
382 | + oFile.write("\nTop unlikely transitions:\n") | ||
383 | + print_transitions(Counter(crf.transition_features_).most_common()[-50:], oFile) | ||
384 | + oFile.write('\n') | ||
385 | + | ||
386 | + oFile.write("\nTop positive:\n") | ||
387 | + print_state_features(Counter(crf.state_features_).most_common(200), oFile) | ||
388 | + oFile.write('\n') | ||
389 | + | ||
390 | + oFile.write("\nTop negative:\n") | ||
391 | + print_state_features(Counter(crf.state_features_).most_common()[-200:], oFile) | ||
392 | + oFile.write('\n') |
... | @@ -49,42 +49,21 @@ from nltk.corpus import stopwords | ... | @@ -49,42 +49,21 @@ from nltk.corpus import stopwords |
49 | ################################# | 49 | ################################# |
50 | # FUNCTIONS # | 50 | # FUNCTIONS # |
51 | ################################# | 51 | ################################# |
52 | -def endsConLow(word): | ||
53 | - miregex = re.compile(r'[^aeiouA-Z0-9]$') | ||
54 | - if miregex.search(word): | ||
55 | - return True | ||
56 | - else: | ||
57 | - return False | ||
58 | - | ||
59 | def word2features(sent, i): | 52 | def word2features(sent, i): |
60 | listElem = sent[i].split('|') | 53 | listElem = sent[i].split('|') |
61 | word = listElem[0] | 54 | word = listElem[0] |
55 | + #print("word: {}".format(word)) | ||
62 | lemma = listElem[1] | 56 | lemma = listElem[1] |
63 | postag = listElem[2] | 57 | postag = listElem[2] |
64 | 58 | ||
65 | features = { | 59 | features = { |
66 | - # Suffixes | ||
67 | - #'word[-3:]': word[-3:], | ||
68 | - #'word[-2:]': word[-2:], | ||
69 | - #'word[-1:]': word[-1:], | ||
70 | - #'word.isupper()': word.isupper(), | ||
71 | #'word': word, | 60 | #'word': word, |
72 | - #'lemma': lemma, | 61 | + 'lemma': lemma, |
73 | - #'postag': postag, | 62 | + 'postag': postag, |
74 | - 'lemma[-3:]': lemma[-3:], | ||
75 | - 'lemma[-2:]': lemma[-2:], | ||
76 | - 'lemma[-1:]': lemma[-1:], | ||
77 | - 'lemma[+3:]': lemma[:3], | ||
78 | - 'lemma[+2:]': lemma[:2], | ||
79 | - 'lemma[+1:]': lemma[:1], | ||
80 | - #'word[:3]': word[:3], | ||
81 | - #'word[:2]': word[:2], | ||
82 | - #'word[:1]': word[:1], | ||
83 | - #'endsConLow()={}'.format(endsConLow(word)): endsConLow(word), | ||
84 | } | 63 | } |
85 | if i > 0: | 64 | if i > 0: |
86 | listElem = sent[i - 1].split('|') | 65 | listElem = sent[i - 1].split('|') |
87 | - word1 = listElem[0] | 66 | + #word1 = listElem[0] |
88 | lemma1 = listElem[1] | 67 | lemma1 = listElem[1] |
89 | postag1 = listElem[2] | 68 | postag1 = listElem[2] |
90 | features.update({ | 69 | features.update({ |
... | @@ -95,7 +74,7 @@ def word2features(sent, i): | ... | @@ -95,7 +74,7 @@ def word2features(sent, i): |
95 | 74 | ||
96 | if i < len(sent) - 1: | 75 | if i < len(sent) - 1: |
97 | listElem = sent[i + 1].split('|') | 76 | listElem = sent[i + 1].split('|') |
98 | - word1 = listElem[0] | 77 | + #word1 = listElem[0] |
99 | lemma1 = listElem[1] | 78 | lemma1 = listElem[1] |
100 | postag1 = listElem[2] | 79 | postag1 = listElem[2] |
101 | features.update({ | 80 | features.update({ |
... | @@ -103,53 +82,8 @@ def word2features(sent, i): | ... | @@ -103,53 +82,8 @@ def word2features(sent, i): |
103 | '+1:lemma': lemma1, | 82 | '+1:lemma': lemma1, |
104 | '+1:postag': postag1, | 83 | '+1:postag': postag1, |
105 | }) | 84 | }) |
106 | - | ||
107 | - ''' | ||
108 | - if i > 1: | ||
109 | - listElem = sent[i - 2].split('|') | ||
110 | - word2 = listElem[0] | ||
111 | - lemma2 = listElem[1] | ||
112 | - postag2 = listElem[2] | ||
113 | - features.update({ | ||
114 | - '-2:word': word2, | ||
115 | - '-2:lemma': lemma2, | ||
116 | - }) | ||
117 | - | ||
118 | - if i < len(sent) - 2: | ||
119 | - listElem = sent[i + 2].split('|') | ||
120 | - word2 = listElem[0] | ||
121 | - lemma2 = listElem[1] | ||
122 | - postag2 = listElem[2] | ||
123 | - features.update({ | ||
124 | - '+2:word': word2, | ||
125 | - '+2:lemma': lemma2, | ||
126 | - }) | ||
127 | - | ||
128 | - trigrams = False | ||
129 | - if trigrams: | ||
130 | - if i > 2: | ||
131 | - listElem = sent[i - 3].split('|') | ||
132 | - word3 = listElem[0] | ||
133 | - lemma3 = listElem[1] | ||
134 | - postag3 = listElem[2] | ||
135 | - features.update({ | ||
136 | - '-3:word': word3, | ||
137 | - '-3:lemma': lemma3, | ||
138 | - }) | ||
139 | - | ||
140 | - if i < len(sent) - 3: | ||
141 | - listElem = sent[i + 3].split('|') | ||
142 | - word3 = listElem[0] | ||
143 | - lemma3 = listElem[1] | ||
144 | - postag3 = listElem[2] | ||
145 | - features.update({ | ||
146 | - '+3:word': word3, | ||
147 | - '+3:lemma': lemma3, | ||
148 | - }) | ||
149 | - ''' | ||
150 | return features | 85 | return features |
151 | 86 | ||
152 | - | ||
153 | def sent2features(sent): | 87 | def sent2features(sent): |
154 | return [word2features(sent, i) for i in range(len(sent))] | 88 | return [word2features(sent, i) for i in range(len(sent))] |
155 | 89 | ... | ... |
training_validation_v3.py
0 → 100644
1 | +# -*- coding: UTF-8 -*- | ||
2 | + | ||
3 | +import os | ||
4 | +from itertools import chain | ||
5 | +from optparse import OptionParser | ||
6 | +from time import time | ||
7 | +from collections import Counter | ||
8 | +import re | ||
9 | + | ||
10 | +import nltk | ||
11 | +import sklearn | ||
12 | +import scipy.stats | ||
13 | +import sys | ||
14 | + | ||
15 | +from sklearn.externals import joblib | ||
16 | +from sklearn.metrics import make_scorer | ||
17 | +from sklearn.cross_validation import cross_val_score | ||
18 | +from sklearn.grid_search import RandomizedSearchCV | ||
19 | + | ||
20 | +import sklearn_crfsuite | ||
21 | +from sklearn_crfsuite import scorers | ||
22 | +from sklearn_crfsuite import metrics | ||
23 | + | ||
24 | +from nltk.corpus import stopwords | ||
25 | + | ||
26 | + | ||
27 | +# Objective | ||
28 | +# Training and evaluation of CRFs with sklearn-crfsuite. | ||
29 | +# | ||
30 | +# Input parameters | ||
31 | +# --inputPath=PATH Path of training and test data set | ||
32 | +# --trainingFile File with training data set | ||
33 | +# --testFile File with test data set | ||
34 | +# --outputPath=PATH Output path to place output files | ||
35 | +# --filteringStopWords Filtering stop words | ||
36 | +# --excludeSymbols Filtering punctuation marks | ||
37 | + | ||
38 | +# Output | ||
39 | +# 1) Best model | ||
40 | + | ||
41 | +# Examples | ||
42 | +# python3.4 training-validation-v1.py | ||
43 | +# --inputPath /export/space1/users/compu2/bionlp/conditional-random-fields/data-sets | ||
44 | +# --trainingFile training-data-set-70.txt | ||
45 | +# --testFile test-data-set-30.txt | ||
46 | +# --outputPath /export/space1/users/compu2/bionlp/conditional-random-fields | ||
47 | +# python3.4 training-validation-v1.py --inputPath /export/space1/users/compu2/bionlp/conditional-random-fields/data-sets --trainingFile training-data-set-70.txt --testFile test-data-set-30.txt --outputPath /export/space1/users/compu2/bionlp/conditional-random-fields | ||
48 | + | ||
49 | +################################# | ||
50 | +# FUNCTIONS # | ||
51 | +################################# | ||
52 | +def endsConLow(word): | ||
53 | + miregex = re.compile(r'[^aeiouA-Z0-9]$') | ||
54 | + if miregex.search(word): | ||
55 | + return True | ||
56 | + else: | ||
57 | + return False | ||
58 | + | ||
59 | +def word2features(sent, i): | ||
60 | + listElem = sent[i].split('|') | ||
61 | + word = listElem[0] | ||
62 | + lemma = listElem[1] | ||
63 | + postag = listElem[2] | ||
64 | + | ||
65 | + features = { | ||
66 | + # Suffixes | ||
67 | + #'word[-3:]': word[-3:], | ||
68 | + #'word[-2:]': word[-2:], | ||
69 | + #'word[-1:]': word[-1:], | ||
70 | + #'word.isupper()': word.isupper(), | ||
71 | + #'word': word, | ||
72 | + #'lemma': lemma, | ||
73 | + #'postag': postag, | ||
74 | + 'lemma[-3:]': lemma[-3:], | ||
75 | + 'lemma[-2:]': lemma[-2:], | ||
76 | + 'lemma[-1:]': lemma[-1:], | ||
77 | + 'lemma[+3:]': lemma[:3], | ||
78 | + 'lemma[+2:]': lemma[:2], | ||
79 | + 'lemma[+1:]': lemma[:1], | ||
80 | + #'word[:3]': word[:3], | ||
81 | + #'word[:2]': word[:2], | ||
82 | + #'word[:1]': word[:1], | ||
83 | + #'endsConLow()={}'.format(endsConLow(word)): endsConLow(word), | ||
84 | + } | ||
85 | + if i > 0: | ||
86 | + listElem = sent[i - 1].split('|') | ||
87 | + word1 = listElem[0] | ||
88 | + lemma1 = listElem[1] | ||
89 | + postag1 = listElem[2] | ||
90 | + features.update({ | ||
91 | + #'-1:word': word1, | ||
92 | + '-1:lemma': lemma1, | ||
93 | + '-1:postag': postag1, | ||
94 | + }) | ||
95 | + | ||
96 | + if i < len(sent) - 1: | ||
97 | + listElem = sent[i + 1].split('|') | ||
98 | + word1 = listElem[0] | ||
99 | + lemma1 = listElem[1] | ||
100 | + postag1 = listElem[2] | ||
101 | + features.update({ | ||
102 | + #'+1:word': word1, | ||
103 | + '+1:lemma': lemma1, | ||
104 | + '+1:postag': postag1, | ||
105 | + }) | ||
106 | + | ||
107 | + ''' | ||
108 | + if i > 1: | ||
109 | + listElem = sent[i - 2].split('|') | ||
110 | + word2 = listElem[0] | ||
111 | + lemma2 = listElem[1] | ||
112 | + postag2 = listElem[2] | ||
113 | + features.update({ | ||
114 | + '-2:word': word2, | ||
115 | + '-2:lemma': lemma2, | ||
116 | + }) | ||
117 | + | ||
118 | + if i < len(sent) - 2: | ||
119 | + listElem = sent[i + 2].split('|') | ||
120 | + word2 = listElem[0] | ||
121 | + lemma2 = listElem[1] | ||
122 | + postag2 = listElem[2] | ||
123 | + features.update({ | ||
124 | + '+2:word': word2, | ||
125 | + '+2:lemma': lemma2, | ||
126 | + }) | ||
127 | + | ||
128 | + trigrams = False | ||
129 | + if trigrams: | ||
130 | + if i > 2: | ||
131 | + listElem = sent[i - 3].split('|') | ||
132 | + word3 = listElem[0] | ||
133 | + lemma3 = listElem[1] | ||
134 | + postag3 = listElem[2] | ||
135 | + features.update({ | ||
136 | + '-3:word': word3, | ||
137 | + '-3:lemma': lemma3, | ||
138 | + }) | ||
139 | + | ||
140 | + if i < len(sent) - 3: | ||
141 | + listElem = sent[i + 3].split('|') | ||
142 | + word3 = listElem[0] | ||
143 | + lemma3 = listElem[1] | ||
144 | + postag3 = listElem[2] | ||
145 | + features.update({ | ||
146 | + '+3:word': word3, | ||
147 | + '+3:lemma': lemma3, | ||
148 | + }) | ||
149 | + ''' | ||
150 | + return features | ||
151 | + | ||
152 | + | ||
153 | +def sent2features(sent): | ||
154 | + return [word2features(sent, i) for i in range(len(sent))] | ||
155 | + | ||
156 | + | ||
157 | +def sent2labels(sent): | ||
158 | + return [elem.split('|')[3] for elem in sent] | ||
159 | + | ||
160 | + | ||
161 | +def sent2tokens(sent): | ||
162 | + return [token for token, postag, label in sent] | ||
163 | + | ||
164 | + | ||
165 | +def print_transitions(trans_features, f): | ||
166 | + for (label_from, label_to), weight in trans_features: | ||
167 | + f.write("{:6} -> {:7} {:0.6f}\n".format(label_from, label_to, weight)) | ||
168 | + | ||
169 | + | ||
170 | +def print_state_features(state_features, f): | ||
171 | + for (attr, label), weight in state_features: | ||
172 | + f.write("{:0.6f} {:8} {}\n".format(weight, label, attr.encode("utf-8"))) | ||
173 | + | ||
174 | + | ||
175 | +__author__ = 'CMendezC' | ||
176 | + | ||
177 | +########################################## | ||
178 | +# MAIN PROGRAM # | ||
179 | +########################################## | ||
180 | + | ||
181 | +if __name__ == "__main__": | ||
182 | + # Defining parameters | ||
183 | + parser = OptionParser() | ||
184 | + parser.add_option("--inputPath", dest="inputPath", | ||
185 | + help="Path of training data set", metavar="PATH") | ||
186 | + parser.add_option("--outputPath", dest="outputPath", | ||
187 | + help="Output path to place output files", | ||
188 | + metavar="PATH") | ||
189 | + parser.add_option("--trainingFile", dest="trainingFile", | ||
190 | + help="File with training data set", metavar="FILE") | ||
191 | + parser.add_option("--testFile", dest="testFile", | ||
192 | + help="File with test data set", metavar="FILE") | ||
193 | + parser.add_option("--excludeStopWords", default=False, | ||
194 | + action="store_true", dest="excludeStopWords", | ||
195 | + help="Exclude stop words") | ||
196 | + parser.add_option("--excludeSymbols", default=False, | ||
197 | + action="store_true", dest="excludeSymbols", | ||
198 | + help="Exclude punctuation marks") | ||
199 | + | ||
200 | + (options, args) = parser.parse_args() | ||
201 | + if len(args) > 0: | ||
202 | + parser.error("Any parameter given.") | ||
203 | + sys.exit(1) | ||
204 | + | ||
205 | + print('-------------------------------- PARAMETERS --------------------------------') | ||
206 | + print("Path of training data set: " + options.inputPath) | ||
207 | + print("File with training data set: " + str(options.trainingFile)) | ||
208 | + print("Path of test data set: " + options.inputPath) | ||
209 | + print("File with test data set: " + str(options.testFile)) | ||
210 | + print("Exclude stop words: " + str(options.excludeStopWords)) | ||
211 | + symbols = ['.', ',', ':', ';', '?', '!', '\'', '"', '<', '>', '(', ')', '-', '_', '/', '\\', '¿', '¡', '+', '{', | ||
212 | + '}', '[', ']', '*', '%', '$', '#', '&', '°', '`', '...'] | ||
213 | + #print("Exclude symbols " + str(symbols) + ': ' + str(options.excludeSymbols)) | ||
214 | + print("Exclude symbols: " + str(options.excludeSymbols)) | ||
215 | + | ||
216 | + print('-------------------------------- PROCESSING --------------------------------') | ||
217 | + print('Reading corpus...') | ||
218 | + t0 = time() | ||
219 | + | ||
220 | + sentencesTrainingData = [] | ||
221 | + sentencesTestData = [] | ||
222 | + | ||
223 | + stopwords = [word for word in stopwords.words('english')] | ||
224 | + | ||
225 | + with open(os.path.join(options.inputPath, options.trainingFile), "r") as iFile: | ||
226 | + for line in iFile.readlines(): | ||
227 | + listLine = [] | ||
228 | + line = line.strip('\n') | ||
229 | + for token in line.split(): | ||
230 | + if options.excludeStopWords: | ||
231 | + listToken = token.split('|') | ||
232 | + lemma = listToken[1] | ||
233 | + if lemma in stopwords: | ||
234 | + continue | ||
235 | + if options.excludeSymbols: | ||
236 | + listToken = token.split('|') | ||
237 | + lemma = listToken[1] | ||
238 | + if lemma in symbols: | ||
239 | + continue | ||
240 | + listLine.append(token) | ||
241 | + sentencesTrainingData.append(listLine) | ||
242 | + print(" Sentences training data: " + str(len(sentencesTrainingData))) | ||
243 | + | ||
244 | + with open(os.path.join(options.inputPath, options.testFile), "r") as iFile: | ||
245 | + for line in iFile.readlines(): | ||
246 | + listLine = [] | ||
247 | + line = line.strip('\n') | ||
248 | + for token in line.split(): | ||
249 | + if options.excludeStopWords: | ||
250 | + listToken = token.split('|') | ||
251 | + lemma = listToken[1] | ||
252 | + if lemma in stopwords: | ||
253 | + continue | ||
254 | + if options.excludeSymbols: | ||
255 | + listToken = token.split('|') | ||
256 | + lemma = listToken[1] | ||
257 | + if lemma in symbols: | ||
258 | + continue | ||
259 | + listLine.append(token) | ||
260 | + sentencesTestData.append(listLine) | ||
261 | + print(" Sentences test data: " + str(len(sentencesTestData))) | ||
262 | + | ||
263 | + print("Reading corpus done in: %fs" % (time() - t0)) | ||
264 | + | ||
265 | + #print(sent2features(sentencesTrainingData[0])[0]) | ||
266 | + #print(sent2features(sentencesTestData[0])[0]) | ||
267 | + t0 = time() | ||
268 | + | ||
269 | + X_train = [sent2features(s) for s in sentencesTrainingData] | ||
270 | + y_train = [sent2labels(s) for s in sentencesTrainingData] | ||
271 | + | ||
272 | + X_test = [sent2features(s) for s in sentencesTestData] | ||
273 | + # print X_test | ||
274 | + y_test = [sent2labels(s) for s in sentencesTestData] | ||
275 | + | ||
276 | + # Fixed parameters | ||
277 | + # crf = sklearn_crfsuite.CRF( | ||
278 | + # algorithm='lbfgs', | ||
279 | + # c1=0.1, | ||
280 | + # c2=0.1, | ||
281 | + # max_iterations=100, | ||
282 | + # all_possible_transitions=True | ||
283 | + # ) | ||
284 | + | ||
285 | + # Hyperparameter Optimization | ||
286 | + crf = sklearn_crfsuite.CRF( | ||
287 | + algorithm='lbfgs', | ||
288 | + max_iterations=100, | ||
289 | + all_possible_transitions=True | ||
290 | + ) | ||
291 | + params_space = { | ||
292 | + 'c1': scipy.stats.expon(scale=0.5), | ||
293 | + 'c2': scipy.stats.expon(scale=0.05), | ||
294 | + } | ||
295 | + | ||
296 | + # Original: labels = list(crf.classes_) | ||
297 | + # Original: labels.remove('O') | ||
298 | + labels = list(['GENE']) | ||
299 | + | ||
300 | + # use the same metric for evaluation | ||
301 | + f1_scorer = make_scorer(metrics.flat_f1_score, | ||
302 | + average='weighted', labels=labels) | ||
303 | + | ||
304 | + # search | ||
305 | + rs = RandomizedSearchCV(crf, params_space, | ||
306 | + cv=10, | ||
307 | + verbose=3, | ||
308 | + n_jobs=-1, | ||
309 | + n_iter=20, | ||
310 | + # n_iter=50, | ||
311 | + scoring=f1_scorer) | ||
312 | + rs.fit(X_train, y_train) | ||
313 | + | ||
314 | + # Fixed parameters | ||
315 | + # crf.fit(X_train, y_train) | ||
316 | + | ||
317 | + # Best hiperparameters | ||
318 | + # crf = rs.best_estimator_ | ||
319 | + nameReport = options.trainingFile.replace('.txt', '.fStopWords_' + str(options.excludeStopWords) + '.fSymbols_' + str( | ||
320 | + options.excludeSymbols) + '.txt') | ||
321 | + with open(os.path.join(options.outputPath, "reports", "report_" + nameReport), mode="w") as oFile: | ||
322 | + oFile.write("********** TRAINING AND TESTING REPORT **********\n") | ||
323 | + oFile.write("Training file: " + options.trainingFile + '\n') | ||
324 | + oFile.write('\n') | ||
325 | + oFile.write('best params:' + str(rs.best_params_) + '\n') | ||
326 | + oFile.write('best CV score:' + str(rs.best_score_) + '\n') | ||
327 | + oFile.write('model size: {:0.2f}M\n'.format(rs.best_estimator_.size_ / 1000000)) | ||
328 | + | ||
329 | + print("Training done in: %fs" % (time() - t0)) | ||
330 | + t0 = time() | ||
331 | + | ||
332 | + # Update best crf | ||
333 | + crf = rs.best_estimator_ | ||
334 | + | ||
335 | + # Saving model | ||
336 | + print(" Saving training model...") | ||
337 | + t1 = time() | ||
338 | + nameModel = options.trainingFile.replace('.txt', '.fStopWords_' + str(options.excludeStopWords) + '.fSymbols_' + str( | ||
339 | + options.excludeSymbols) + '.mod') | ||
340 | + joblib.dump(crf, os.path.join(options.outputPath, "models", nameModel)) | ||
341 | + print(" Saving training model done in: %fs" % (time() - t1)) | ||
342 | + | ||
343 | + # Evaluation against test data | ||
344 | + y_pred = crf.predict(X_test) | ||
345 | + print("*********************************") | ||
346 | + name = options.trainingFile.replace('.txt', '.fStopWords_' + str(options.excludeStopWords) + '.fSymbols_' + str( | ||
347 | + options.excludeSymbols) + '.txt') | ||
348 | + with open(os.path.join(options.outputPath, "reports", "y_pred_" + name), "w") as oFile: | ||
349 | + for y in y_pred: | ||
350 | + oFile.write(str(y) + '\n') | ||
351 | + | ||
352 | + print("*********************************") | ||
353 | + name = options.trainingFile.replace('.txt', '.fStopWords_' + str(options.excludeStopWords) + '.fSymbols_' + str( | ||
354 | + options.excludeSymbols) + '.txt') | ||
355 | + with open(os.path.join(options.outputPath, "reports", "y_test_" + name), "w") as oFile: | ||
356 | + for y in y_test: | ||
357 | + oFile.write(str(y) + '\n') | ||
358 | + | ||
359 | + print("Prediction done in: %fs" % (time() - t0)) | ||
360 | + | ||
361 | + # labels = list(crf.classes_) | ||
362 | + # labels.remove('O') | ||
363 | + | ||
364 | + with open(os.path.join(options.outputPath, "reports", "report_" + nameReport), mode="a") as oFile: | ||
365 | + oFile.write('\n') | ||
366 | + oFile.write("Flat F1: " + str(metrics.flat_f1_score(y_test, y_pred, average='weighted', labels=labels))) | ||
367 | + oFile.write('\n') | ||
368 | + # labels = list(crf.classes_) | ||
369 | + sorted_labels = sorted( | ||
370 | + labels, | ||
371 | + key=lambda name: (name[1:], name[0]) | ||
372 | + ) | ||
373 | + oFile.write(metrics.flat_classification_report( | ||
374 | + y_test, y_pred, labels=sorted_labels, digits=3 | ||
375 | + )) | ||
376 | + oFile.write('\n') | ||
377 | + | ||
378 | + oFile.write("\nTop likely transitions:\n") | ||
379 | + print_transitions(Counter(crf.transition_features_).most_common(50), oFile) | ||
380 | + oFile.write('\n') | ||
381 | + | ||
382 | + oFile.write("\nTop unlikely transitions:\n") | ||
383 | + print_transitions(Counter(crf.transition_features_).most_common()[-50:], oFile) | ||
384 | + oFile.write('\n') | ||
385 | + | ||
386 | + oFile.write("\nTop positive:\n") | ||
387 | + print_state_features(Counter(crf.state_features_).most_common(200), oFile) | ||
388 | + oFile.write('\n') | ||
389 | + | ||
390 | + oFile.write("\nTop negative:\n") | ||
391 | + print_state_features(Counter(crf.state_features_).most_common()[-200:], oFile) | ||
392 | + oFile.write('\n') |
-
Please register or login to post a comment