""" This module holds the necessary functions to classify mail threads. It predicts if a thread has been answered,the language and the topic. """ from classifier import print_answers from classifier import get_pipe, test_pipe, get_training_threads #from classifier import store_training_data #in_training, from training import train_single_thread from storage import db_session, MailThread def predict_threads(): """ Predicts the language, topic and if a thread is anwered and writes that to the database. This function doesn't have a return value. """ # Loading pipes for the prediction of each thread pipe1,le=get_pipe("pipe1",key=b"answered",filter=["db"]) pipe2,le2=get_pipe("pipe2g", b"maintopic",["db"]) pipe3,le3=get_pipe("pipe2b", b"lang",["db"]) # Loading untrained MailThreads: q=db_session.query(MailThread).filter(MailThread.istrained.op("IS NOT")(True)) mail_threads=q.all() if len(mail_threads) ==0: raise StandardError("no untrained threads found in database") #Run the prediction for each property answered=le.inverse_transform(pipe1.predict(mail_threads)) maintopic=le2.inverse_transform(pipe2.predict(mail_threads)) lang=le3.inverse_transform(pipe3.predict(mail_threads)) # Commit the results to the database for i, t in enumerate(mail_threads): t.answered, t.opened, t.maintopic, t.lang = ( bool(answered[i]), bool(answered[i]), str(maintopic[i]), str(lang[i]) ) db_session.add(t) db_session.commit()