small rewrites
This commit is contained in:
@@ -1,8 +1,43 @@
|
||||
"""
|
||||
This module holds the necessary functions to classify mail threads.
|
||||
It predicts if a thread has been answered,the language and the topic.
|
||||
"""
|
||||
from classifier import print_answers
|
||||
from classifier import get_pipe, test_pipe, get_training_threads
|
||||
#from classifier import store_training_data
|
||||
#in_training,
|
||||
|
||||
from training import train_single_thread
|
||||
from storage import db_session, MailThread
|
||||
|
||||
from prediction import predict_threads
|
||||
|
||||
def predict_threads():
|
||||
"""
|
||||
Predicts the language, topic and if a thread is anwered and writes that to the database. This function doesn't have a return value.
|
||||
"""
|
||||
# Loading pipes for the prediction of each thread
|
||||
pipe1,le=get_pipe("pipe1",key=b"answered",filter=["db"])
|
||||
pipe2,le2=get_pipe("pipe2g", b"maintopic",["db"])
|
||||
pipe3,le3=get_pipe("pipe2b", b"lang",["db"])
|
||||
|
||||
# Loading untrained MailThreads:
|
||||
q=db_session.query(MailThread).filter(MailThread.istrained.op("IS NOT")(True))
|
||||
mail_threads=q.all()
|
||||
|
||||
if len(mail_threads) ==0:
|
||||
raise StandardError("no untrained threads found in database")
|
||||
|
||||
#Run the prediction for each property
|
||||
answered=le.inverse_transform(pipe1.predict(mail_threads))
|
||||
maintopic=le2.inverse_transform(pipe2.predict(mail_threads))
|
||||
lang=le3.inverse_transform(pipe3.predict(mail_threads))
|
||||
|
||||
# Commit the results to the database
|
||||
for i, t in enumerate(mail_threads):
|
||||
t.answered, t.opened, t.maintopic, t.lang = ( bool(answered[i]),
|
||||
bool(answered[i]),
|
||||
str(maintopic[i]),
|
||||
str(lang[i])
|
||||
)
|
||||
db_session.add(t)
|
||||
db_session.commit()
|
||||
|
||||
@@ -1,31 +1,2 @@
|
||||
from classifier import get_pipe
|
||||
from storage import db_session, MailThread
|
||||
|
||||
def predict_threads():
|
||||
"""
|
||||
Predicts the language, topic and if a thread is anwered and writes that to the database. This function doesn't have a return value.
|
||||
"""
|
||||
# Loading pipes for the prediction of each thread
|
||||
pipe1,le=get_pipe("pipe1",key=b"answered",filter=["db"])
|
||||
pipe2,le2=get_pipe("pipe2g", b"maintopic",["db"])
|
||||
pipe3,le3=get_pipe("pipe2b", b"lang",["db"])
|
||||
|
||||
# Loading untrained MailThreads:
|
||||
q=db_session.query(MailThread).filter(MailThread.istrained.op("IS NOT")(True))
|
||||
mail_threads=q.all()
|
||||
|
||||
if len(mail_threads) ==0:
|
||||
raise StandardError("no untrained threads found in database")
|
||||
|
||||
answered=le.inverse_transform(pipe1.predict(mail_threads))
|
||||
maintopic=le2.inverse_transform(pipe2.predict(mail_threads))
|
||||
lang=le3.inverse_transform(pipe3.predict(mail_threads))
|
||||
|
||||
for i, t in enumerate(mail_threads):
|
||||
t.answered, t.opened, t.maintopic, t.lang = ( bool(answered[i]),
|
||||
bool(answered[i]),
|
||||
str(maintopic[i]),
|
||||
str(lang[i])
|
||||
)
|
||||
db_session.add(t)
|
||||
db_session.commit()
|
||||
|
||||
@@ -5,7 +5,6 @@ from storage import Mail, MailThread, db_session
|
||||
from classifier import print_answers
|
||||
|
||||
|
||||
|
||||
def train_fit_pipe():
|
||||
tt= get_training_threads(b"answered")
|
||||
pipe1.fit(tt[0],tt[1])
|
||||
@@ -33,18 +32,21 @@ def predict_thread(mth,p,le,key):
|
||||
|
||||
def train_single_thread(tid,p,le,key="answered"):
|
||||
if (not type(tid) is int): raise TypeError("ID must be of type int")
|
||||
#-------------------------------------------------------
|
||||
if not type(p) is Pipeline: raise TypeError("Second Argument needs to be type Pipeline")
|
||||
if not type(le) is LabelEncoder: raise TypeError("Second Argument needs to be type LabelEncoder")
|
||||
|
||||
# Load a single Mailthread by firstmail id
|
||||
mth=db_session.query(MailThread).filter(MailThread.firstmail==tid).first()
|
||||
if mth is None: raise ValueError("Thread with firstmail %d not in Database" %tid)
|
||||
if mth is None: raise StandardError("Thread with firstmail %d not in Database" %tid)
|
||||
|
||||
# Output the mail thread
|
||||
print mth.firstmail
|
||||
print mth.subject()
|
||||
print mth.text()
|
||||
|
||||
if not p is None and not le is None:
|
||||
answ=predict_thread(mth,p,le,key)
|
||||
else: answ=None
|
||||
if not le is None:
|
||||
print_answers(le)
|
||||
answ=predict_thread(mth,p,le,key)
|
||||
print_answers(le)
|
||||
|
||||
ca=raw_input("Correct answer..")
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user