diff --git a/classifier/__init__.py b/classifier/__init__.py index 41c689b..78ef9a2 100644 --- a/classifier/__init__.py +++ b/classifier/__init__.py @@ -1,8 +1,43 @@ +""" +This module holds the necessary functions to classify mail threads. +It predicts if a thread has been answered,the language and the topic. +""" from classifier import print_answers from classifier import get_pipe, test_pipe, get_training_threads #from classifier import store_training_data #in_training, from training import train_single_thread +from storage import db_session, MailThread -from prediction import predict_threads + +def predict_threads(): + """ + Predicts the language, topic and if a thread is anwered and writes that to the database. This function doesn't have a return value. + """ + # Loading pipes for the prediction of each thread + pipe1,le=get_pipe("pipe1",key=b"answered",filter=["db"]) + pipe2,le2=get_pipe("pipe2g", b"maintopic",["db"]) + pipe3,le3=get_pipe("pipe2b", b"lang",["db"]) + + # Loading untrained MailThreads: + q=db_session.query(MailThread).filter(MailThread.istrained.op("IS NOT")(True)) + mail_threads=q.all() + + if len(mail_threads) ==0: + raise StandardError("no untrained threads found in database") + + #Run the prediction for each property + answered=le.inverse_transform(pipe1.predict(mail_threads)) + maintopic=le2.inverse_transform(pipe2.predict(mail_threads)) + lang=le3.inverse_transform(pipe3.predict(mail_threads)) + + # Commit the results to the database + for i, t in enumerate(mail_threads): + t.answered, t.opened, t.maintopic, t.lang = ( bool(answered[i]), + bool(answered[i]), + str(maintopic[i]), + str(lang[i]) + ) + db_session.add(t) + db_session.commit() diff --git a/classifier/prediction.py b/classifier/prediction.py index afa6373..6f890e9 100644 --- a/classifier/prediction.py +++ b/classifier/prediction.py @@ -1,31 +1,2 @@ from classifier import get_pipe from storage import db_session, MailThread - -def predict_threads(): - """ - Predicts the language, topic and if a thread is anwered and writes that to the database. This function doesn't have a return value. - """ - # Loading pipes for the prediction of each thread - pipe1,le=get_pipe("pipe1",key=b"answered",filter=["db"]) - pipe2,le2=get_pipe("pipe2g", b"maintopic",["db"]) - pipe3,le3=get_pipe("pipe2b", b"lang",["db"]) - - # Loading untrained MailThreads: - q=db_session.query(MailThread).filter(MailThread.istrained.op("IS NOT")(True)) - mail_threads=q.all() - - if len(mail_threads) ==0: - raise StandardError("no untrained threads found in database") - - answered=le.inverse_transform(pipe1.predict(mail_threads)) - maintopic=le2.inverse_transform(pipe2.predict(mail_threads)) - lang=le3.inverse_transform(pipe3.predict(mail_threads)) - - for i, t in enumerate(mail_threads): - t.answered, t.opened, t.maintopic, t.lang = ( bool(answered[i]), - bool(answered[i]), - str(maintopic[i]), - str(lang[i]) - ) - db_session.add(t) - db_session.commit() diff --git a/classifier/training.py b/classifier/training.py index f39fd0c..47f4c17 100644 --- a/classifier/training.py +++ b/classifier/training.py @@ -5,7 +5,6 @@ from storage import Mail, MailThread, db_session from classifier import print_answers - def train_fit_pipe(): tt= get_training_threads(b"answered") pipe1.fit(tt[0],tt[1]) @@ -33,18 +32,21 @@ def predict_thread(mth,p,le,key): def train_single_thread(tid,p,le,key="answered"): if (not type(tid) is int): raise TypeError("ID must be of type int") + #------------------------------------------------------- + if not type(p) is Pipeline: raise TypeError("Second Argument needs to be type Pipeline") + if not type(le) is LabelEncoder: raise TypeError("Second Argument needs to be type LabelEncoder") + # Load a single Mailthread by firstmail id mth=db_session.query(MailThread).filter(MailThread.firstmail==tid).first() - if mth is None: raise ValueError("Thread with firstmail %d not in Database" %tid) + if mth is None: raise StandardError("Thread with firstmail %d not in Database" %tid) + + # Output the mail thread print mth.firstmail print mth.subject() print mth.text() - if not p is None and not le is None: - answ=predict_thread(mth,p,le,key) - else: answ=None - if not le is None: - print_answers(le) + answ=predict_thread(mth,p,le,key) + print_answers(le) ca=raw_input("Correct answer..") try: diff --git a/run.py b/run.py index 7190863..3f3ab5e 100644 --- a/run.py +++ b/run.py @@ -15,7 +15,7 @@ from storage.fetch_mail import fetch_threads, flatten_threads from storage import Mail, MailThread, db_session #import yaml #import email -from classifier import get_training_threads, print_answers, in_training, store_training_data, get_pipe, test_pipe, train_single_thread # , pipe2, pipe2b +from classifier import get_training_threads, print_answers,get_pipe, test_pipe, train_single_thread # , pipe2, pipe2b from classifier import predict_threads maintopic_values=["studium", "information","ausleihen","jobausschreibung", "umfragen"] @@ -30,31 +30,8 @@ if len(sys.argv)>1: if sys.argv[1] == "fetch_threads": print flatten_threads(fetch_threads()) - if sys.argv[1] == "predict_threads2": - predict_threads() if sys.argv[1] == "predict_threads": - print "predicting threads" - pipe1,le=get_pipe("pipe1",b"answered",["db"]) - pipe2,le2=get_pipe("pipe2g", b"maintopic",["db"]) - pipe3,le3=get_pipe("pipe2b", b"lang",["db"]) - q=db_session.query(MailThread).filter(MailThread.istrained.op("IS NOT")(True)) - - mail_threads=q.all() - - if len(mail_threads) ==0: - raise ValueError("no untrained threads found") - answered=le.inverse_transform(pipe1.predict(mail_threads)) - maintopic=le2.inverse_transform(pipe2.predict(mail_threads)) - lang=le3.inverse_transform(pipe3.predict(mail_threads)) - - for i, t in enumerate(mail_threads): - t.answered=bool(answered[i]) - t.opened=bool(answered[i]) - - t.maintopic=str(maintopic[i]) - t.lang=str(lang[i]) - db_session.add(t) - db_session.commit() + predict_threads() if sys.argv[1]=="stats": for topic in maintopic_values: print topic @@ -70,20 +47,6 @@ if len(sys.argv)>1: from flaskapp import app app.run(port=3000,debug=True) - - if sys.argv[1] == "trained_threads_from_yml": - from classifier.classifier import train - for k in train: - print k - t=db_session.query(MailThread).filter(MailThread.firstmail==k).first() - t.istrained=True - db_session.add(t) - db_session.commit() - if sys.argv[1] == "print_threads2": - mth=db_session.query(MailThread).all() - for t in mth: - print t.to_text() - print "---------------\n" if sys.argv[1] == "train_thrd2": p, le=get_pipe("pipe2", "maintopic",["db"]) @@ -207,37 +170,7 @@ if len(sys.argv)>1: db_session.commit() - if sys.argv[1] == "fetch_mail": - print "fetching mail %d " % int(sys.argv[2]) - m=fetch_mail(int(sys.argv[2])) - hd=decode_header(m['ENVELOPE'].subject) - hd2=[] - # print hd - for h in hd: - if not h[1] is None: - hd2.append(h[0].decode(h[1])) - # print h[0].decode(h[1]) - else: - hd2.append(h[0]) - print "\nBetreff:" - for h in hd2: - print h - print "FROM:" - for t in m['ENVELOPE'].from_: - print t - print "TO:" - for t in m['ENVELOPE'].to: - print t - em=email.message_from_string(m['RFC822']) - for p in em.walk(): - if p.get_content_maintype()=="text": - print p.get_payload() - elif p.get_content_maintype()=="multipart": - print p.get_payload() - else: - print p.get_content_maintype() - - + if sys.argv[1] == "initdb": from storage import init_db