filter_abstracts.py.save 5.97 KB
#from pdb import set_trace as st
from sklearn.cross_validation import train_test_split as splitt
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.decomposition import TruncatedSVD
from sklearn.model_selection import RandomizedSearchCV
from sklearn.model_selection import GridSearchCV
from sklearn import metrics
from sklearn.svm import SVC
import numpy as np
import argparse
import csv
import os
from sklearn.externals import joblib
from time import time
from scipy.stats import randint as sp_randint
from scipy.stats import expon
from sklearn.preprocessing import label_binarize


def get_abstracts(file_name, label):
    f = open(file_name)
    extract = {}
    docs = []
    empties = []
    lines = f.readlines()
    copyright = False

    for i, ln in enumerate(lines):
        if not ln.strip():
            empties.append(i)
            continue
        elif ' doi: ' in ln:
            for j in range(i, i + 10):
                if not lines[j].strip():
                    title_idx = j + 1
                    break
            continue

        elif 'Copyright ' in ln or 'Publish' in ln or u'\N{COPYRIGHT SIGN}' in ln:
            copyright = True

        elif 'DOI: ' in ln:
            if 'PMCID: ' in lines[i + 1]:
                extract['pmid'] = int(lines[i + 2].strip().split()[1])
            elif not 'PMCID: ' in lines[i + 1] and 'PMID: ' in lines[i + 1]:
                extract['pmid'] = int(lines[i + 1].strip().split()[1])

            if copyright:
                get = slice(empties[-3], empties[-2])
                copyright = False
            else:
                get = slice(empties[-2], empties[-1])

            extract['body'] = " ".join(lines[get]).replace("\n", ' '
                                                        ).replace("  ", ' ')
            title = []
            for j in range(title_idx, title_idx + 5):
                if lines[j].strip():
                    title.append(lines[j])
                else:
                    break
            extract['title'] = " ".join(title).replace("\n", ' '
                                                        ).replace("  ", ' ')
            extract['topic'] = label
            docs.append(extract)
            empties = []
            extract = {}

    return docs


parser = argparse.ArgumentParser(
    description="This script separates abstracts of biomedical papers that"
            "report data from biomedical experiments from those that do not.")
parser.add_argument("--input", help="Input file containing the abstracts to"
                                "be predited.")
parser.add_argument("--classA", help="Input file containing the abstracts of"
                                "class A to be learned.")
parser.add_argument("--classB", help="Input file containing the abstracts of"
                                "class B to be learned.")
parser.add_argument("--out", help="Path to the output directory "
                     "(default='./filter_output')", default="filter_output")
parser.add_argument("--svcmodel", help="Path to custom pretrained svc model"
        "(default='./model/svm_model.pkl')", default="model/svm_model.pkl")

args = parser.parse_args()

labels = {0: 'useless', 1: 'useful'}

if args.classA and args.classB and not args.input:
    vectorizer = TfidfVectorizer(binary=True)
    print(vectorizer)
    f0 = open("model_params.conf")
    n_iter_search = 10
    params = [p for p in csv.DictReader(f0)]
    f0.close()
    names = list(params[0].keys())
    model_params = {n: [] for n in names}

    for n in names:
        for d in params:
            for k in d:
                if k == n:
                    try:
                        model_params[n].append(float(d[k]))
                    except ValueError:
                        model_params[n].append(d[k])

    model_params = {k: list(set(model_params[k])) for k in model_params}
    abstracs = get_abstracts(file_name=args.classA, label=labels[0])
    abstracs += get_abstracts(file_name=args.classB, label=labels[1])

    tfidf_model = vectorizer.fit([x['body'] for x in abstracs])
    X = tfidf_model.transform([x['body'] for x in abstracs])
    svd = TruncatedSVD(n_components=200, random_state=42, n_iter=20)
    svd_model = svd.fit(X)
    X = svd_model.transform(X)
    #y = [x['topic'] for x in abstracs]
    y = [0 if x['topic'] == 'useless' else 1 for x in abstracs]    

    #X_train, X_test, y_train, y_test = splitt(X, y, test_size=0.3, random_state=42)

    clf = SVC()#kernel='linear', C=100.0, gamma=0.0001)# degree=11, coef0=0.9)
    clf = GridSearchCV(clf, cv=3,
        param_grid=model_params,
    # clf = RandomizedSearchCV(clf, param_distributions=model_params, cv=5, n_iter=n_iter_search,
                                 n_jobs=-1, scoring='f1')
    start = time()
    clf.fit(X, y)

    #clf.fit(X_train, y_train)
    print("GridSearch took %.2f seconds for %d candidates"
      " parameter settings." % ((time() - start), n_iter_search))

    print(clf.best_estimator_)
    print()
    print(clf.best_score_)
    #print(metrics.f1_score(clf.predict(X_test), y_test))

    #joblib.dump(clf, 'model/svm_model.pkl')
    joblib.dump(clf.best_estimator_, 'model/svm_model.pkl')
    joblib.dump(tfidf_model, 'model/tfidf_model.pkl')
    joblib.dump(svd_model, 'model/svd_model.pkl')

else:

    clf = joblib.load(args.svcmodel)
    vectorizer = joblib.load('model/tfidf_model.pkl')
    svd = joblib.load('model/svd_model.pkl')
    abstracs = get_abstracts(file_name=args.input, label='unknown')
    X = vectorizer.transform([x['body'] for x in abstracs])
    X = svd.transform(X)
    classes = clf.predict(X)

    if not os.path.exists(args.out):
        os.makedirs(args.out)
    # Writing predictions to output files
    with open(args.out + "/" + labels[0] + ".out", 'w') as f0, \
                    open(args.out + "/" + labels[1] + ".out", 'w') as f1:
        for c, a in zip(classes, abstracs):
            if c == 0:
                f0.write("%d\t%s\n" % (a['pmid'], a['body']))
            elif c == 1:
                f1.write("%d\t%s\n" % (a['pmid'], a['body']))