init learning cats
This commit is contained in:
238
run.py
Normal file
238
run.py
Normal file
@@ -0,0 +1,238 @@
|
||||
from __future__ import unicode_literals
|
||||
import imapclient
|
||||
from config import Config
|
||||
import sys
|
||||
from email.header import decode_header
|
||||
import email
|
||||
import codecs
|
||||
import sys
|
||||
import bs4
|
||||
#sys.stdout = codecs.getwriter('utf8')(sys.stdout)
|
||||
from storage.fetch_mail import fetch_mail
|
||||
from storage.fetch_mail import fetch_threads, flatten_threads
|
||||
|
||||
from storage import Mail, MailThread, db_session
|
||||
import yaml
|
||||
import email
|
||||
from classifier import get_training_threads, ThreadDictExtractor, pipe1, print_answers, in_training, store_training_data, pipe2
|
||||
from sklearn.pipeline import Pipeline
|
||||
from sklearn.preprocessing import LabelEncoder
|
||||
import numpy
|
||||
|
||||
|
||||
|
||||
def train_fit_pipe():
|
||||
tt= get_training_threads(b"answered")
|
||||
print tt[1]
|
||||
print tt[0]
|
||||
pipe1.fit(tt[0],tt[1])
|
||||
return pipe1,tt[2]
|
||||
def train_fit_pipe2():
|
||||
tt= get_training_threads(b"maintopic")
|
||||
pipe2.fit(tt[0],tt[1])
|
||||
return pipe2,tt[2]
|
||||
|
||||
|
||||
def train_single_thread(tid,p,le,key="answered"):
|
||||
if (not type(tid) is int): raise TypeError("ID must be of type int")
|
||||
if not type(p) is Pipeline: raise TypeError("Second Argument needs to be type Pipeline")
|
||||
if not type(le) is LabelEncoder: raise TypeError("Second Argument needs to be type LabelEncoder")
|
||||
mth=db_session.query(MailThread).filter(MailThread.firstmail==tid).first()
|
||||
if mth is None: raise ValueError("Thread with firstmail %d not in Database" %tid)
|
||||
# Predict the value
|
||||
pre=p.predict([mth])
|
||||
answ=pre[0]
|
||||
#
|
||||
print mth.to_text()
|
||||
print mth.text()
|
||||
print "Status is answered is estimated to be: " + str(le.inverse_transform(pre)[0])
|
||||
print_answers(le)
|
||||
|
||||
ca=raw_input("Correct answer..")
|
||||
try:
|
||||
ca=int(ca)
|
||||
|
||||
except ValueError:
|
||||
print "String Data"
|
||||
if type(ca)==int:
|
||||
if ca == answ:
|
||||
print ("Yes I got it right")
|
||||
else:
|
||||
print("Oh no...!")
|
||||
l=le.inverse_transform([ca])[0]
|
||||
if type(l) is numpy.bool_:
|
||||
l=bool(l)
|
||||
if type(l) is numpy.string_:
|
||||
l=str(l)
|
||||
store_training_data(tid,l, key)
|
||||
elif not ca.strip() == "":
|
||||
store_training_data(tid, ca, key)
|
||||
else:
|
||||
print "couldn't handle %s" % ca
|
||||
|
||||
|
||||
#print "arg1:"+sys.argv[1]
|
||||
if len(sys.argv)>1:
|
||||
if sys.argv[1] == "fetch_threads":
|
||||
print flatten_threads(fetch_threads())
|
||||
|
||||
|
||||
if sys.argv[1] == "print_threads":
|
||||
mth=db_session.query(MailThread).all()
|
||||
for t in mth:
|
||||
print t.firstmail
|
||||
print t.mail_flat_dict()
|
||||
|
||||
if sys.argv[1] == "print_thrd":
|
||||
if len(sys.argv)<3:
|
||||
mth=db_session.query(MailThread).all()
|
||||
for t in mth:
|
||||
print t.firstmail
|
||||
else:
|
||||
t=db_session.query(MailThread).filter(MailThread.firstmail==sys.argv[2]).first()
|
||||
|
||||
print t.firstmail
|
||||
print t.subject()
|
||||
print t.text()
|
||||
|
||||
|
||||
if sys.argv[1] == "print_threads2":
|
||||
mth=db_session.query(MailThread).all()
|
||||
for t in mth:
|
||||
print t.to_text()
|
||||
print "---------------\n"
|
||||
|
||||
if sys.argv[1] == "train_thrd2":
|
||||
p, le=train_fit_pipe2()
|
||||
train_single_thread(int(sys.argv[2]),p,le,b"maintopic")
|
||||
if sys.argv[1] == "train_all2":
|
||||
p, labelencoder=train_fit_pipe2()
|
||||
mth=db_session.query(MailThread).all()
|
||||
print mth
|
||||
for t in mth:
|
||||
if not in_training(t.firstmail,"maintopic"):
|
||||
print "---------------------------------------------------"
|
||||
print "---------------------------------------------------"
|
||||
print t.firstmail
|
||||
print t.text()
|
||||
train_single_thread(t.firstmail, p, labelencoder, b"maintopic")
|
||||
|
||||
|
||||
if sys.argv[1] == "testpipe2":
|
||||
from classifier import ThreadSubjectExtractor, ThreadTextExtractor
|
||||
pipe2,le=train_fit_pipe2()
|
||||
|
||||
if len(sys.argv)>2:
|
||||
t=db_session.query(MailThread).filter(MailThread.firstmail==sys.argv[2]).first()
|
||||
print t.to_text()
|
||||
print le.inverse_transform(pipe2.predict([t]))
|
||||
|
||||
|
||||
if sys.argv[1] == "train_thrd":
|
||||
pipe1, labelencoder=train_fit_pipe()
|
||||
train_single_thread(int(sys.argv[2]),pipe1,labelencoder)
|
||||
|
||||
if sys.argv[1] == "train_all":
|
||||
pipe1, labelencoder=train_fit_pipe()
|
||||
mth=db_session.query(MailThread).all()
|
||||
print mth
|
||||
for t in mth:
|
||||
if not in_training(t.firstmail):
|
||||
print "---------------------------------------------------"
|
||||
print "---------------------------------------------------"
|
||||
print t.firstmail
|
||||
train_single_thread(t.firstmail,pipe1,labelencoder)
|
||||
|
||||
if sys.argv[1] == "print_thread":
|
||||
mth=db_session.query(MailThread).filter(MailThread.firstmail==int(sys.argv[2])).first()
|
||||
print mth.mail_dicts()
|
||||
print mth.mail_flat_dict()
|
||||
|
||||
if sys.argv[1] == "store_threads":
|
||||
thrds=flatten_threads(fetch_threads())
|
||||
for t in thrds:
|
||||
if type(t[0]) is int:
|
||||
th=db_session.query(MailThread).filter(MailThread.firstmail==t[0]).first()
|
||||
if th == None:
|
||||
th=MailThread()
|
||||
th.firstmail=t[0]
|
||||
if not th.body == yaml.dump(t):
|
||||
th.body=yaml.dump(t)
|
||||
th.islabeled=False
|
||||
th.opened=True
|
||||
else:
|
||||
th.body=yaml.dump(t)
|
||||
db_session.add(th)
|
||||
db_session.commit()
|
||||
print thrds
|
||||
|
||||
|
||||
|
||||
if sys.argv[1] == "print_mail":
|
||||
mm=db_session.query(Mail).filter(Mail.id==int(sys.argv[2])).first()
|
||||
mm.compile_text()
|
||||
mm.compile_envelope()
|
||||
print mm.subject
|
||||
print "----------"
|
||||
print mm.text
|
||||
|
||||
if sys.argv[1] == "mail_dict_test":
|
||||
mm=db_session.query(Mail).filter(Mail.id==int(sys.argv[2])).first()
|
||||
mm.compile_envelope()
|
||||
print mm.dict_envelope()
|
||||
|
||||
|
||||
if sys.argv[1] == "load_mail":
|
||||
mm=db_session.query(Mail).filter(Mail.id==int(sys.argv[2])).first()
|
||||
mm.compile_text()
|
||||
print mm.text
|
||||
env=yaml.load(mm.envelope)
|
||||
print env.subject
|
||||
print env
|
||||
|
||||
|
||||
if sys.argv[1] == "store_mail":
|
||||
m=fetch_mail(int(sys.argv[2]))
|
||||
mm=Mail()
|
||||
mm.envelope=yaml.dump(m['ENVELOPE'])
|
||||
mm.body=yaml.dump(m['RFC822'])
|
||||
mm.id=m['id']
|
||||
db_session.add(mm)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
if sys.argv[1] == "fetch_mail":
|
||||
print "fetching mail %d " % int(sys.argv[2])
|
||||
m=fetch_mail(int(sys.argv[2]))
|
||||
hd=decode_header(m['ENVELOPE'].subject)
|
||||
hd2=[]
|
||||
# print hd
|
||||
for h in hd:
|
||||
if not h[1] is None:
|
||||
hd2.append(h[0].decode(h[1]))
|
||||
# print h[0].decode(h[1])
|
||||
else:
|
||||
hd2.append(h[0])
|
||||
print "\nBetreff:"
|
||||
for h in hd2:
|
||||
print h
|
||||
print "FROM:"
|
||||
for t in m['ENVELOPE'].from_:
|
||||
print t
|
||||
print "TO:"
|
||||
for t in m['ENVELOPE'].to:
|
||||
print t
|
||||
em=email.message_from_string(m['RFC822'])
|
||||
for p in em.walk():
|
||||
if p.get_content_maintype()=="text":
|
||||
print p.get_payload()
|
||||
elif p.get_content_maintype()=="multipart":
|
||||
print p.get_payload()
|
||||
else:
|
||||
print p.get_content_maintype()
|
||||
|
||||
|
||||
|
||||
if sys.argv[1] == "initdb":
|
||||
from storage import init_db
|
||||
init_db()
|
||||
Reference in New Issue
Block a user