small fixes and improvements
This commit is contained in:
22
classifier/prediction.py
Normal file
22
classifier/prediction.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from classifier import get_training_threads, print_answers, in_training, store_training_data, get_pipe
|
||||
from storage import db_session, MailThread
|
||||
|
||||
def predict_threads():
|
||||
pipe1,le=get_pipe("pipe1",b"answered",["db"])
|
||||
pipe2,le2=get_pipe("pipe2g", b"maintopic",["db"])
|
||||
pipe3,le3=get_pipe("pipe2b", b"lang",["db"])
|
||||
q=db_session.query(MailThread).filter(MailThread.istrained.op("IS NOT")(True))
|
||||
mail_threads=q.all()
|
||||
if len(mail_threads) ==0:
|
||||
raise ValueError("no untrained threads found")
|
||||
answered=le.inverse_transform(pipe1.predict(mail_threads))
|
||||
maintopic=le2.inverse_transform(pipe2.predict(mail_threads))
|
||||
lang=le3.inverse_transform(pipe3.predict(mail_threads))
|
||||
|
||||
for i, t in enumerate(mail_threads):
|
||||
t.answered=bool(answered[i])
|
||||
t.opened=bool(answered[i])
|
||||
t.maintopic=str(maintopic[i])
|
||||
t.lang=str(lang[i])
|
||||
db_session.add(t)
|
||||
db_session.commit()
|
||||
Reference in New Issue
Block a user