small fixes and improvements

This commit is contained in:
Andreas Stephanides
2017-08-21 23:55:43 +02:00
parent 81fa9cc575
commit 1645221f93
10 changed files with 380 additions and 111 deletions

44
run.py
View File

@@ -16,7 +16,8 @@ from storage import Mail, MailThread, db_session
#import yaml
#import email
from classifier import get_training_threads, print_answers, in_training, store_training_data, get_pipe, test_pipe, train_single_thread # , pipe2, pipe2b
from classifier import predict_threads
maintopic_values=["studium", "information","ausleihen","jobausschreibung", "umfragen"]
def predict_thread(p,l,t):
pre=p.predict([t])
@@ -29,13 +30,19 @@ if len(sys.argv)>1:
if sys.argv[1] == "fetch_threads":
print flatten_threads(fetch_threads())
if sys.argv[1] == "predict_threads2":
predict_threads()
if sys.argv[1] == "predict_threads":
pipe1,le=get_pipe("pipe1",b"answered")
pipe2,le2=get_pipe("pipe2b", b"maintopic")
pipe3,le3=get_pipe("pipe2b", b"lang")
mail_threads=db_session.query(MailThread).filter(MailThread.istrained==False).all()
print "predicting threads"
pipe1,le=get_pipe("pipe1",b"answered",["db"])
pipe2,le2=get_pipe("pipe2g", b"maintopic",["db"])
pipe3,le3=get_pipe("pipe2b", b"lang",["db"])
q=db_session.query(MailThread).filter(MailThread.istrained.op("IS NOT")(True))
mail_threads=q.all()
if len(mail_threads) ==0:
raise ValueError("no untrained threads found")
answered=le.inverse_transform(pipe1.predict(mail_threads))
maintopic=le2.inverse_transform(pipe2.predict(mail_threads))
lang=le3.inverse_transform(pipe3.predict(mail_threads))
@@ -48,7 +55,17 @@ if len(sys.argv)>1:
t.lang=str(lang[i])
db_session.add(t)
db_session.commit()
if sys.argv[1]=="stats":
for topic in maintopic_values:
print topic
n_answ=db_session.query(MailThread).filter(MailThread.maintopic==topic).filter(MailThread.answered.op("IS")(True)).count()
n_nansw=db_session.query(MailThread).filter(MailThread.maintopic==topic).filter(MailThread.answered.op("IS NOT")(True)).count()
n_ges=db_session.query(MailThread).filter(MailThread.maintopic==topic).count()
print "%d answered and %d not answered of %d(%d) that are %d percent answerd" % (n_answ,n_nansw, n_ges,n_answ+n_nansw, float(n_answ)/float(n_ges)*100.0)
if sys.argv[1] == "run_server":
from flaskapp import app
app.run(port=3000,debug=True)
@@ -72,7 +89,9 @@ if len(sys.argv)>1:
print t.text()
if sys.argv[1] == "compile_threads":
mth=db_session.query(MailThread).all()
for t in mth:
l=len(mth)
for i,t in enumerate(mth):
print "%d/%d" % (i,l)
t.compile()
if sys.argv[1] == "trained_threads_from_yml":
@@ -115,9 +134,16 @@ if len(sys.argv)>1:
print t.text()
predict_thread(pb,lb,t)
train_single_thread(t.firstmail, p, labelencoder, b"maintopic")
if sys.argv[1] == "benchpipe3":
test_pipe(["pipe2d","pipe2e","pipe2e1","pipe2f","pipe2g"],"maintopic",["db","de"])
if sys.argv[1] == "benchpipe2":
test_pipe(["pipe2","pipe2b","pipe2c"],"maintopic")
test_pipe(["pipe2","pipe2b","pipe2c","pipe2d"],"maintopic",["db","de"])
# print "testing with db training data:"
# test_pipe(["pipe2b"],"maintopic",["db"])
# test_pipe(["pipe2b"],"maintopic",["db"])
# print "testing only with german data"
# test_pipe(["pipe2b"],"maintopic",["db","de"])
if sys.argv[1] == "testpipe2":
from classifier import ThreadSubjectExtractor, ThreadTextExtractor