This commit is contained in:
andis
2017-08-11 20:34:17 +02:00
parent 51acbfbd38
commit 81fa9cc575
10 changed files with 202 additions and 79 deletions

28
run.py
View File

@@ -34,18 +34,18 @@ if len(sys.argv)>1:
pipe1,le=get_pipe("pipe1",b"answered")
pipe2,le2=get_pipe("pipe2b", b"maintopic")
pipe3,le3=get_pipe("pipe2b", b"lang")
mail_threads=db_session.query(MailThread).all()
mail_threads=db_session.query(MailThread).filter(MailThread.istrained==False).all()
answered=le.inverse_transform(pipe1.predict(mail_threads))
maintopic=le2.inverse_transform(pipe2.predict(mail_threads))
lang=le3.inverse_transform(pipe3.predict(mail_threads))
for i, t in enumerate(mail_threads):
t.answered=answered[i]
t.opened=answered[i]
t.answered=bool(answered[i])
t.opened=bool(answered[i])
t.maintopic=maintopic[i]
t.lang=lang[i]
t.maintopic=str(maintopic[i])
t.lang=str(lang[i])
db_session.add(t)
db_session.commit()
@@ -74,8 +74,15 @@ if len(sys.argv)>1:
mth=db_session.query(MailThread).all()
for t in mth:
t.compile()
if sys.argv[1] == "trained_threads_from_yml":
from classifier.classifier import train
for k in train:
print k
t=db_session.query(MailThread).filter(MailThread.firstmail==k).first()
t.istrained=True
db_session.add(t)
db_session.commit()
if sys.argv[1] == "print_threads2":
mth=db_session.query(MailThread).all()
for t in mth:
@@ -83,8 +90,8 @@ if len(sys.argv)>1:
print "---------------\n"
if sys.argv[1] == "train_thrd2":
p, le=get_pipe("pipe2", "maintopic")
pb, lb =get_pipe("pipe2b", "maintopic")
p, le=get_pipe("pipe2", "maintopic",["db"])
pb, lb =get_pipe("pipe2b", "maintopic",["db"])
train_single_thread(int(sys.argv[2]),p,le,b"maintopic")
@@ -120,7 +127,8 @@ if len(sys.argv)>1:
t=db_session.query(MailThread).filter(MailThread.firstmail==sys.argv[2]).first()
print t.to_text()
print le.inverse_transform(pipe2.predict([t]))
if sys.argv[1] == "train_thrd":
pipe1, labelencoder=train_fit_pipe()