div tzg
This commit is contained in:
28
run.py
28
run.py
@@ -34,18 +34,18 @@ if len(sys.argv)>1:
|
||||
pipe1,le=get_pipe("pipe1",b"answered")
|
||||
pipe2,le2=get_pipe("pipe2b", b"maintopic")
|
||||
pipe3,le3=get_pipe("pipe2b", b"lang")
|
||||
mail_threads=db_session.query(MailThread).all()
|
||||
mail_threads=db_session.query(MailThread).filter(MailThread.istrained==False).all()
|
||||
|
||||
answered=le.inverse_transform(pipe1.predict(mail_threads))
|
||||
maintopic=le2.inverse_transform(pipe2.predict(mail_threads))
|
||||
lang=le3.inverse_transform(pipe3.predict(mail_threads))
|
||||
|
||||
for i, t in enumerate(mail_threads):
|
||||
t.answered=answered[i]
|
||||
t.opened=answered[i]
|
||||
t.answered=bool(answered[i])
|
||||
t.opened=bool(answered[i])
|
||||
|
||||
t.maintopic=maintopic[i]
|
||||
t.lang=lang[i]
|
||||
t.maintopic=str(maintopic[i])
|
||||
t.lang=str(lang[i])
|
||||
db_session.add(t)
|
||||
db_session.commit()
|
||||
|
||||
@@ -74,8 +74,15 @@ if len(sys.argv)>1:
|
||||
mth=db_session.query(MailThread).all()
|
||||
for t in mth:
|
||||
t.compile()
|
||||
|
||||
|
||||
|
||||
if sys.argv[1] == "trained_threads_from_yml":
|
||||
from classifier.classifier import train
|
||||
for k in train:
|
||||
print k
|
||||
t=db_session.query(MailThread).filter(MailThread.firstmail==k).first()
|
||||
t.istrained=True
|
||||
db_session.add(t)
|
||||
db_session.commit()
|
||||
if sys.argv[1] == "print_threads2":
|
||||
mth=db_session.query(MailThread).all()
|
||||
for t in mth:
|
||||
@@ -83,8 +90,8 @@ if len(sys.argv)>1:
|
||||
print "---------------\n"
|
||||
|
||||
if sys.argv[1] == "train_thrd2":
|
||||
p, le=get_pipe("pipe2", "maintopic")
|
||||
pb, lb =get_pipe("pipe2b", "maintopic")
|
||||
p, le=get_pipe("pipe2", "maintopic",["db"])
|
||||
pb, lb =get_pipe("pipe2b", "maintopic",["db"])
|
||||
|
||||
train_single_thread(int(sys.argv[2]),p,le,b"maintopic")
|
||||
|
||||
@@ -120,7 +127,8 @@ if len(sys.argv)>1:
|
||||
t=db_session.query(MailThread).filter(MailThread.firstmail==sys.argv[2]).first()
|
||||
print t.to_text()
|
||||
print le.inverse_transform(pipe2.predict([t]))
|
||||
|
||||
|
||||
|
||||
|
||||
if sys.argv[1] == "train_thrd":
|
||||
pipe1, labelencoder=train_fit_pipe()
|
||||
|
||||
Reference in New Issue
Block a user