44 lines
1.7 KiB
Python
44 lines
1.7 KiB
Python
"""
|
|
This module holds the necessary functions to classify mail threads.
|
|
It predicts if a thread has been answered,the language and the topic.
|
|
"""
|
|
from classifier import print_answers
|
|
from classifier import get_pipe, test_pipe, get_training_threads
|
|
#from classifier import store_training_data
|
|
#in_training,
|
|
|
|
from training import train_single_thread
|
|
from storage import db_session, MailThread
|
|
|
|
|
|
def predict_threads():
|
|
"""
|
|
Predicts the language, topic and if a thread is anwered and writes that to the database. This function doesn't have a return value.
|
|
"""
|
|
# Loading pipes for the prediction of each thread
|
|
pipe1,le=get_pipe("pipe1",key=b"answered",filter=["db"])
|
|
pipe2,le2=get_pipe("pipe2g", b"maintopic",["db"])
|
|
pipe3,le3=get_pipe("pipe2b", b"lang",["db"])
|
|
|
|
# Loading untrained MailThreads:
|
|
q=db_session.query(MailThread).filter(MailThread.istrained.op("IS NOT")(True))
|
|
mail_threads=q.all()
|
|
|
|
if len(mail_threads) ==0:
|
|
raise StandardError("no untrained threads found in database")
|
|
|
|
#Run the prediction for each property
|
|
answered=le.inverse_transform(pipe1.predict(mail_threads))
|
|
maintopic=le2.inverse_transform(pipe2.predict(mail_threads))
|
|
lang=le3.inverse_transform(pipe3.predict(mail_threads))
|
|
|
|
# Commit the results to the database
|
|
for i, t in enumerate(mail_threads):
|
|
t.answered, t.opened, t.maintopic, t.lang = ( bool(answered[i]),
|
|
bool(answered[i]),
|
|
str(maintopic[i]),
|
|
str(lang[i])
|
|
)
|
|
db_session.add(t)
|
|
db_session.commit()
|