refactor1

This commit is contained in:
Andreas Stephanides
2017-08-28 09:08:47 +02:00
parent 699f4f6546
commit 630b982502
14 changed files with 274 additions and 230 deletions

View File

@@ -1,22 +1,31 @@
from classifier import get_training_threads, print_answers, in_training, store_training_data, get_pipe
from classifier import get_pipe
from storage import db_session, MailThread
def predict_threads():
pipe1,le=get_pipe("pipe1",b"answered",["db"])
"""
Predicts the language, topic and if a thread is anwered and writes that to the database. This function doesn't have a return value.
"""
# Loading pipes for the prediction of each thread
pipe1,le=get_pipe("pipe1",key=b"answered",filter=["db"])
pipe2,le2=get_pipe("pipe2g", b"maintopic",["db"])
pipe3,le3=get_pipe("pipe2b", b"lang",["db"])
# Loading untrained MailThreads:
q=db_session.query(MailThread).filter(MailThread.istrained.op("IS NOT")(True))
mail_threads=q.all()
if len(mail_threads) ==0:
raise ValueError("no untrained threads found")
raise StandardError("no untrained threads found in database")
answered=le.inverse_transform(pipe1.predict(mail_threads))
maintopic=le2.inverse_transform(pipe2.predict(mail_threads))
lang=le3.inverse_transform(pipe3.predict(mail_threads))
for i, t in enumerate(mail_threads):
t.answered=bool(answered[i])
t.opened=bool(answered[i])
t.maintopic=str(maintopic[i])
t.lang=str(lang[i])
t.answered, t.opened, t.maintopic, t.lang = ( bool(answered[i]),
bool(answered[i]),
str(maintopic[i]),
str(lang[i])
)
db_session.add(t)
db_session.commit()