from classifier import get_pipe from storage import db_session, MailThread def predict_threads(): """ Predicts the language, topic and if a thread is anwered and writes that to the database. This function doesn't have a return value. """ # Loading pipes for the prediction of each thread pipe1,le=get_pipe("pipe1",key=b"answered",filter=["db"]) pipe2,le2=get_pipe("pipe2g", b"maintopic",["db"]) pipe3,le3=get_pipe("pipe2b", b"lang",["db"]) # Loading untrained MailThreads: q=db_session.query(MailThread).filter(MailThread.istrained.op("IS NOT")(True)) mail_threads=q.all() if len(mail_threads) ==0: raise StandardError("no untrained threads found in database") answered=le.inverse_transform(pipe1.predict(mail_threads)) maintopic=le2.inverse_transform(pipe2.predict(mail_threads)) lang=le3.inverse_transform(pipe3.predict(mail_threads)) for i, t in enumerate(mail_threads): t.answered, t.opened, t.maintopic, t.lang = ( bool(answered[i]), bool(answered[i]), str(maintopic[i]), str(lang[i]) ) db_session.add(t) db_session.commit()