Showing
1 changed file
with
66 additions
and
3 deletions
... | @@ -21,6 +21,11 @@ import os | ... | @@ -21,6 +21,11 @@ import os |
21 | from sklearn.preprocessing import LabelEncoder, OneHotEncoder | 21 | from sklearn.preprocessing import LabelEncoder, OneHotEncoder |
22 | import numpy as np | 22 | import numpy as np |
23 | from sklearn.model_selection import train_test_split | 23 | from sklearn.model_selection import train_test_split |
24 | +from tensorflow.keras.layers import Conv1D, Dense, MaxPooling1D, Flatten | ||
25 | +from tensorflow.keras.models import Sequential | ||
26 | +import matplotlib.pyplot as plt | ||
27 | +from sklearn.metrics import confusion_matrix | ||
28 | +import itertools | ||
24 | 29 | ||
25 | if __name__ == "__main__": | 30 | if __name__ == "__main__": |
26 | parser = argparse.ArgumentParser(description='Get training and test data sets for Human Genome Annotation.') | 31 | parser = argparse.ArgumentParser(description='Get training and test data sets for Human Genome Annotation.') |
... | @@ -79,10 +84,68 @@ if __name__ == "__main__": | ... | @@ -79,10 +84,68 @@ if __name__ == "__main__": |
79 | train_features, test_features, train_labels, test_labels = train_test_split( | 84 | train_features, test_features, train_labels, test_labels = train_test_split( |
80 | input_features, input_labels, test_size=0.25, random_state=42) | 85 | input_features, input_labels, test_size=0.25, random_state=42) |
81 | 86 | ||
87 | + # Model definition | ||
88 | + model = Sequential() | ||
89 | + model.add(Conv1D(filters=32, kernel_size=12, | ||
90 | + input_shape=(train_features.shape[1], 4))) | ||
91 | + model.add(MaxPooling1D(pool_size=4)) | ||
92 | + model.add(Flatten()) | ||
93 | + model.add(Dense(16, activation='relu')) | ||
94 | + model.add(Dense(2, activation='softmax')) | ||
95 | + | ||
96 | + model.compile(loss='binary_crossentropy', optimizer='adam', | ||
97 | + metrics=['binary_accuracy']) | ||
98 | + model.summary() | ||
99 | + | ||
100 | + # Model training and validation | ||
101 | + history = model.fit(train_features, train_labels, | ||
102 | + epochs=50, verbose=0, validation_split=0.25) | ||
103 | + | ||
104 | + # Plot training-validation loss | ||
105 | + plt.figure() | ||
106 | + plt.plot(history.history['loss']) | ||
107 | + plt.plot(history.history['val_loss']) | ||
108 | + plt.title('model loss') | ||
109 | + plt.ylabel('loss') | ||
110 | + plt.xlabel('epoch') | ||
111 | + plt.legend(['train', 'validation']) | ||
112 | + # plt.show() | ||
113 | + plt.savefig('training-validation-loss.png') | ||
114 | + | ||
115 | + # Plot training-validation accuracy | ||
116 | + plt.figure() | ||
117 | + plt.plot(history.history['binary_accuracy']) | ||
118 | + plt.plot(history.history['val_binary_accuracy']) | ||
119 | + plt.title('model accuracy') | ||
120 | + plt.ylabel('accuracy') | ||
121 | + plt.xlabel('epoch') | ||
122 | + plt.legend(['train', 'validation']) | ||
123 | + # plt.show() | ||
124 | + plt.savefig('training-validation-binary-accuracy.png') | ||
125 | + | ||
126 | + # Predict with rest data set | ||
127 | + predicted_labels = model.predict(np.stack(test_features)) | ||
128 | + # Print confusion matrix | ||
129 | + cm = confusion_matrix(np.argmax(test_labels, axis=1), | ||
130 | + np.argmax(predicted_labels, axis=1)) | ||
131 | + print('Confusion matrix:\n', cm) | ||
132 | + cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] | ||
133 | + | ||
134 | + # Plot confusion matrix | ||
135 | + plt.imshow(cm, cmap=plt.cm.Blues) | ||
136 | + plt.title('Normalized confusion matrix') | ||
137 | + plt.colorbar() | ||
138 | + plt.xlabel('True label') | ||
139 | + plt.ylabel('Predicted label') | ||
140 | + plt.xticks([0, 1]); | ||
141 | + plt.yticks([0, 1]) | ||
142 | + plt.grid('off') | ||
143 | + for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): | ||
144 | + plt.text(j, i, format(cm[i, j], '.2f'), | ||
145 | + horizontalalignment='center', | ||
146 | + color='white' if cm[i, j] > 0.5 else 'black') | ||
147 | + | ||
82 | 148 | ||
83 | - with open(os.path.join(args.outputPath, args.outputFile), mode="w") as oFile: | ||
84 | - for elem in list_rows: | ||
85 | - oFile.write(elem) | ||
86 | 149 | ||
87 | 150 | ||
88 | 151 | ... | ... |
-
Please register or login to post a comment