Showing
1 changed file
with
8 additions
and
1 deletions
... | @@ -111,8 +111,10 @@ if __name__ == "__main__": | ... | @@ -111,8 +111,10 @@ if __name__ == "__main__": |
111 | joblib.dump(X_train, os.path.join(args.outputModelPath, args.inputTrainingData + '.jlb')) | 111 | joblib.dump(X_train, os.path.join(args.outputModelPath, args.inputTrainingData + '.jlb')) |
112 | joblib.dump(y_train, os.path.join(args.outputModelPath, args.inputTrainingData + '.class.jlb')) | 112 | joblib.dump(y_train, os.path.join(args.outputModelPath, args.inputTrainingData + '.class.jlb')) |
113 | else: | 113 | else: |
114 | + print(" Saving matrix and classes...") | ||
114 | X_train = joblib.load(os.path.join(args.outputModelPath, args.inputTrainingData + '.jlb')) | 115 | X_train = joblib.load(os.path.join(args.outputModelPath, args.inputTrainingData + '.jlb')) |
115 | y_train = joblib.load(os.path.join(args.outputModelPath, args.inputTrainingData + '.class.jlb')) | 116 | y_train = joblib.load(os.path.join(args.outputModelPath, args.inputTrainingData + '.class.jlb')) |
117 | + print(" Done!") | ||
116 | 118 | ||
117 | print(" Number of training classes: {}".format(len(y_train))) | 119 | print(" Number of training classes: {}".format(len(y_train))) |
118 | print(" Number of training class A: {}".format(y_train.count('A'))) | 120 | print(" Number of training class A: {}".format(y_train.count('A'))) |
... | @@ -139,20 +141,25 @@ if __name__ == "__main__": | ... | @@ -139,20 +141,25 @@ if __name__ == "__main__": |
139 | joblib.dump(X_test, os.path.join(args.outputModelPath, args.inputTestingData + '.jlb')) | 141 | joblib.dump(X_test, os.path.join(args.outputModelPath, args.inputTestingData + '.jlb')) |
140 | joblib.dump(y_test, os.path.join(args.outputModelPath, args.inputTestingClasses + '.class.jlb')) | 142 | joblib.dump(y_test, os.path.join(args.outputModelPath, args.inputTestingClasses + '.class.jlb')) |
141 | else: | 143 | else: |
144 | + print(" Saving matrix and classes...") | ||
142 | X_test = joblib.load(os.path.join(args.outputModelPath, args.inputTestingData + '.jlb')) | 145 | X_test = joblib.load(os.path.join(args.outputModelPath, args.inputTestingData + '.jlb')) |
143 | y_test = joblib.load(os.path.join(args.outputModelPath, args.inputTestingClasses + '.class.jlb')) | 146 | y_test = joblib.load(os.path.join(args.outputModelPath, args.inputTestingClasses + '.class.jlb')) |
147 | + print(" Done!") | ||
144 | 148 | ||
145 | print(" Number of testing classes: {}".format(len(y_test))) | 149 | print(" Number of testing classes: {}".format(len(y_test))) |
146 | print(" Number of testing class A: {}".format(y_test.count('A'))) | 150 | print(" Number of testing class A: {}".format(y_test.count('A'))) |
147 | print(" Number of testing class I: {}".format(y_test.count('I'))) | 151 | print(" Number of testing class I: {}".format(y_test.count('I'))) |
148 | print(" Shape of testing matrix: {}".format(X_test.shape)) | 152 | print(" Shape of testing matrix: {}".format(X_test.shape)) |
149 | 153 | ||
150 | - if args.classifier == "MultinomialNB": | 154 | + if args.classifier == "BernoulliNB": |
151 | classifier = BernoulliNB() | 155 | classifier = BernoulliNB() |
152 | elif args.classifier == "SVM": | 156 | elif args.classifier == "SVM": |
153 | classifier = SVC() | 157 | classifier = SVC() |
154 | elif args.classifier == "NearestCentroid": | 158 | elif args.classifier == "NearestCentroid": |
155 | classifier = NearestCentroid() | 159 | classifier = NearestCentroid() |
160 | + else: | ||
161 | + print("Bad classifier") | ||
162 | + exit() | ||
156 | 163 | ||
157 | print("Training...") | 164 | print("Training...") |
158 | classifier.fit(X_train, y_train) | 165 | classifier.fit(X_train, y_train) | ... | ... |
-
Please register or login to post a comment