small rewrites
This commit is contained in:
@@ -1,8 +1,43 @@
|
|||||||
|
"""
|
||||||
|
This module holds the necessary functions to classify mail threads.
|
||||||
|
It predicts if a thread has been answered,the language and the topic.
|
||||||
|
"""
|
||||||
from classifier import print_answers
|
from classifier import print_answers
|
||||||
from classifier import get_pipe, test_pipe, get_training_threads
|
from classifier import get_pipe, test_pipe, get_training_threads
|
||||||
#from classifier import store_training_data
|
#from classifier import store_training_data
|
||||||
#in_training,
|
#in_training,
|
||||||
|
|
||||||
from training import train_single_thread
|
from training import train_single_thread
|
||||||
|
from storage import db_session, MailThread
|
||||||
|
|
||||||
from prediction import predict_threads
|
|
||||||
|
def predict_threads():
|
||||||
|
"""
|
||||||
|
Predicts the language, topic and if a thread is anwered and writes that to the database. This function doesn't have a return value.
|
||||||
|
"""
|
||||||
|
# Loading pipes for the prediction of each thread
|
||||||
|
pipe1,le=get_pipe("pipe1",key=b"answered",filter=["db"])
|
||||||
|
pipe2,le2=get_pipe("pipe2g", b"maintopic",["db"])
|
||||||
|
pipe3,le3=get_pipe("pipe2b", b"lang",["db"])
|
||||||
|
|
||||||
|
# Loading untrained MailThreads:
|
||||||
|
q=db_session.query(MailThread).filter(MailThread.istrained.op("IS NOT")(True))
|
||||||
|
mail_threads=q.all()
|
||||||
|
|
||||||
|
if len(mail_threads) ==0:
|
||||||
|
raise StandardError("no untrained threads found in database")
|
||||||
|
|
||||||
|
#Run the prediction for each property
|
||||||
|
answered=le.inverse_transform(pipe1.predict(mail_threads))
|
||||||
|
maintopic=le2.inverse_transform(pipe2.predict(mail_threads))
|
||||||
|
lang=le3.inverse_transform(pipe3.predict(mail_threads))
|
||||||
|
|
||||||
|
# Commit the results to the database
|
||||||
|
for i, t in enumerate(mail_threads):
|
||||||
|
t.answered, t.opened, t.maintopic, t.lang = ( bool(answered[i]),
|
||||||
|
bool(answered[i]),
|
||||||
|
str(maintopic[i]),
|
||||||
|
str(lang[i])
|
||||||
|
)
|
||||||
|
db_session.add(t)
|
||||||
|
db_session.commit()
|
||||||
|
|||||||
@@ -1,31 +1,2 @@
|
|||||||
from classifier import get_pipe
|
from classifier import get_pipe
|
||||||
from storage import db_session, MailThread
|
from storage import db_session, MailThread
|
||||||
|
|
||||||
def predict_threads():
|
|
||||||
"""
|
|
||||||
Predicts the language, topic and if a thread is anwered and writes that to the database. This function doesn't have a return value.
|
|
||||||
"""
|
|
||||||
# Loading pipes for the prediction of each thread
|
|
||||||
pipe1,le=get_pipe("pipe1",key=b"answered",filter=["db"])
|
|
||||||
pipe2,le2=get_pipe("pipe2g", b"maintopic",["db"])
|
|
||||||
pipe3,le3=get_pipe("pipe2b", b"lang",["db"])
|
|
||||||
|
|
||||||
# Loading untrained MailThreads:
|
|
||||||
q=db_session.query(MailThread).filter(MailThread.istrained.op("IS NOT")(True))
|
|
||||||
mail_threads=q.all()
|
|
||||||
|
|
||||||
if len(mail_threads) ==0:
|
|
||||||
raise StandardError("no untrained threads found in database")
|
|
||||||
|
|
||||||
answered=le.inverse_transform(pipe1.predict(mail_threads))
|
|
||||||
maintopic=le2.inverse_transform(pipe2.predict(mail_threads))
|
|
||||||
lang=le3.inverse_transform(pipe3.predict(mail_threads))
|
|
||||||
|
|
||||||
for i, t in enumerate(mail_threads):
|
|
||||||
t.answered, t.opened, t.maintopic, t.lang = ( bool(answered[i]),
|
|
||||||
bool(answered[i]),
|
|
||||||
str(maintopic[i]),
|
|
||||||
str(lang[i])
|
|
||||||
)
|
|
||||||
db_session.add(t)
|
|
||||||
db_session.commit()
|
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ from storage import Mail, MailThread, db_session
|
|||||||
from classifier import print_answers
|
from classifier import print_answers
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def train_fit_pipe():
|
def train_fit_pipe():
|
||||||
tt= get_training_threads(b"answered")
|
tt= get_training_threads(b"answered")
|
||||||
pipe1.fit(tt[0],tt[1])
|
pipe1.fit(tt[0],tt[1])
|
||||||
@@ -33,17 +32,20 @@ def predict_thread(mth,p,le,key):
|
|||||||
|
|
||||||
def train_single_thread(tid,p,le,key="answered"):
|
def train_single_thread(tid,p,le,key="answered"):
|
||||||
if (not type(tid) is int): raise TypeError("ID must be of type int")
|
if (not type(tid) is int): raise TypeError("ID must be of type int")
|
||||||
|
#-------------------------------------------------------
|
||||||
|
if not type(p) is Pipeline: raise TypeError("Second Argument needs to be type Pipeline")
|
||||||
|
if not type(le) is LabelEncoder: raise TypeError("Second Argument needs to be type LabelEncoder")
|
||||||
|
|
||||||
|
# Load a single Mailthread by firstmail id
|
||||||
mth=db_session.query(MailThread).filter(MailThread.firstmail==tid).first()
|
mth=db_session.query(MailThread).filter(MailThread.firstmail==tid).first()
|
||||||
if mth is None: raise ValueError("Thread with firstmail %d not in Database" %tid)
|
if mth is None: raise StandardError("Thread with firstmail %d not in Database" %tid)
|
||||||
|
|
||||||
|
# Output the mail thread
|
||||||
print mth.firstmail
|
print mth.firstmail
|
||||||
print mth.subject()
|
print mth.subject()
|
||||||
print mth.text()
|
print mth.text()
|
||||||
|
|
||||||
if not p is None and not le is None:
|
|
||||||
answ=predict_thread(mth,p,le,key)
|
answ=predict_thread(mth,p,le,key)
|
||||||
else: answ=None
|
|
||||||
if not le is None:
|
|
||||||
print_answers(le)
|
print_answers(le)
|
||||||
|
|
||||||
ca=raw_input("Correct answer..")
|
ca=raw_input("Correct answer..")
|
||||||
|
|||||||
71
run.py
71
run.py
@@ -15,7 +15,7 @@ from storage.fetch_mail import fetch_threads, flatten_threads
|
|||||||
from storage import Mail, MailThread, db_session
|
from storage import Mail, MailThread, db_session
|
||||||
#import yaml
|
#import yaml
|
||||||
#import email
|
#import email
|
||||||
from classifier import get_training_threads, print_answers, in_training, store_training_data, get_pipe, test_pipe, train_single_thread # , pipe2, pipe2b
|
from classifier import get_training_threads, print_answers,get_pipe, test_pipe, train_single_thread # , pipe2, pipe2b
|
||||||
from classifier import predict_threads
|
from classifier import predict_threads
|
||||||
maintopic_values=["studium", "information","ausleihen","jobausschreibung", "umfragen"]
|
maintopic_values=["studium", "information","ausleihen","jobausschreibung", "umfragen"]
|
||||||
|
|
||||||
@@ -30,31 +30,8 @@ if len(sys.argv)>1:
|
|||||||
|
|
||||||
if sys.argv[1] == "fetch_threads":
|
if sys.argv[1] == "fetch_threads":
|
||||||
print flatten_threads(fetch_threads())
|
print flatten_threads(fetch_threads())
|
||||||
if sys.argv[1] == "predict_threads2":
|
|
||||||
predict_threads()
|
|
||||||
if sys.argv[1] == "predict_threads":
|
if sys.argv[1] == "predict_threads":
|
||||||
print "predicting threads"
|
predict_threads()
|
||||||
pipe1,le=get_pipe("pipe1",b"answered",["db"])
|
|
||||||
pipe2,le2=get_pipe("pipe2g", b"maintopic",["db"])
|
|
||||||
pipe3,le3=get_pipe("pipe2b", b"lang",["db"])
|
|
||||||
q=db_session.query(MailThread).filter(MailThread.istrained.op("IS NOT")(True))
|
|
||||||
|
|
||||||
mail_threads=q.all()
|
|
||||||
|
|
||||||
if len(mail_threads) ==0:
|
|
||||||
raise ValueError("no untrained threads found")
|
|
||||||
answered=le.inverse_transform(pipe1.predict(mail_threads))
|
|
||||||
maintopic=le2.inverse_transform(pipe2.predict(mail_threads))
|
|
||||||
lang=le3.inverse_transform(pipe3.predict(mail_threads))
|
|
||||||
|
|
||||||
for i, t in enumerate(mail_threads):
|
|
||||||
t.answered=bool(answered[i])
|
|
||||||
t.opened=bool(answered[i])
|
|
||||||
|
|
||||||
t.maintopic=str(maintopic[i])
|
|
||||||
t.lang=str(lang[i])
|
|
||||||
db_session.add(t)
|
|
||||||
db_session.commit()
|
|
||||||
if sys.argv[1]=="stats":
|
if sys.argv[1]=="stats":
|
||||||
for topic in maintopic_values:
|
for topic in maintopic_values:
|
||||||
print topic
|
print topic
|
||||||
@@ -71,20 +48,6 @@ if len(sys.argv)>1:
|
|||||||
app.run(port=3000,debug=True)
|
app.run(port=3000,debug=True)
|
||||||
|
|
||||||
|
|
||||||
if sys.argv[1] == "trained_threads_from_yml":
|
|
||||||
from classifier.classifier import train
|
|
||||||
for k in train:
|
|
||||||
print k
|
|
||||||
t=db_session.query(MailThread).filter(MailThread.firstmail==k).first()
|
|
||||||
t.istrained=True
|
|
||||||
db_session.add(t)
|
|
||||||
db_session.commit()
|
|
||||||
if sys.argv[1] == "print_threads2":
|
|
||||||
mth=db_session.query(MailThread).all()
|
|
||||||
for t in mth:
|
|
||||||
print t.to_text()
|
|
||||||
print "---------------\n"
|
|
||||||
|
|
||||||
if sys.argv[1] == "train_thrd2":
|
if sys.argv[1] == "train_thrd2":
|
||||||
p, le=get_pipe("pipe2", "maintopic",["db"])
|
p, le=get_pipe("pipe2", "maintopic",["db"])
|
||||||
pb, lb =get_pipe("pipe2b", "maintopic",["db"])
|
pb, lb =get_pipe("pipe2b", "maintopic",["db"])
|
||||||
@@ -207,36 +170,6 @@ if len(sys.argv)>1:
|
|||||||
db_session.commit()
|
db_session.commit()
|
||||||
|
|
||||||
|
|
||||||
if sys.argv[1] == "fetch_mail":
|
|
||||||
print "fetching mail %d " % int(sys.argv[2])
|
|
||||||
m=fetch_mail(int(sys.argv[2]))
|
|
||||||
hd=decode_header(m['ENVELOPE'].subject)
|
|
||||||
hd2=[]
|
|
||||||
# print hd
|
|
||||||
for h in hd:
|
|
||||||
if not h[1] is None:
|
|
||||||
hd2.append(h[0].decode(h[1]))
|
|
||||||
# print h[0].decode(h[1])
|
|
||||||
else:
|
|
||||||
hd2.append(h[0])
|
|
||||||
print "\nBetreff:"
|
|
||||||
for h in hd2:
|
|
||||||
print h
|
|
||||||
print "FROM:"
|
|
||||||
for t in m['ENVELOPE'].from_:
|
|
||||||
print t
|
|
||||||
print "TO:"
|
|
||||||
for t in m['ENVELOPE'].to:
|
|
||||||
print t
|
|
||||||
em=email.message_from_string(m['RFC822'])
|
|
||||||
for p in em.walk():
|
|
||||||
if p.get_content_maintype()=="text":
|
|
||||||
print p.get_payload()
|
|
||||||
elif p.get_content_maintype()=="multipart":
|
|
||||||
print p.get_payload()
|
|
||||||
else:
|
|
||||||
print p.get_content_maintype()
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if sys.argv[1] == "initdb":
|
if sys.argv[1] == "initdb":
|
||||||
|
|||||||
Reference in New Issue
Block a user