Files
service_mail/classifier/training.py
2017-08-28 15:06:54 +02:00

73 lines
2.6 KiB
Python

from sklearn.pipeline import Pipeline
from sklearn.preprocessing import LabelEncoder
import numpy
from storage import Mail, MailThread, db_session
from classifier import print_answers
def train_fit_pipe():
tt= get_training_threads(b"answered")
pipe1.fit(tt[0],tt[1])
return pipe1,tt[2]
def train_fit_pipe2():
tt= get_training_threads(b"maintopic")
pipe2.fit(tt[0],tt[1])
return pipe2,tt[2]
def train_fit_pipe2b():
tt= get_training_threads(b"maintopic")
pipe2b.fit(tt[0],tt[1])
return pipe2b,tt[2]
def predict_thread(mth,p,le,key):
#-------------------------------------------------------
if not type(p) is Pipeline: raise TypeError("Second Argument needs to be type Pipeline")
if not type(le) is LabelEncoder: raise TypeError("Second Argument needs to be type LabelEncoder")
#-------------------------------------------------------
pre=p.predict([mth])
answ=pre[0]
print "Status is answered is estimated to be: " + str(le.inverse_transform(pre)[0])
return answ
def train_single_thread(tid,p,le,key="answered"):
if (not type(tid) is int): raise TypeError("ID must be of type int")
#-------------------------------------------------------
if not type(p) is Pipeline: raise TypeError("Second Argument needs to be type Pipeline")
if not type(le) is LabelEncoder: raise TypeError("Second Argument needs to be type LabelEncoder")
# Load a single Mailthread by firstmail id
mth=db_session.query(MailThread).filter(MailThread.firstmail==tid).first()
if mth is None: raise StandardError("Thread with firstmail %d not in Database" %tid)
# Output the mail thread
print mth.firstmail
print mth.subject()
print mth.text()
answ=predict_thread(mth,p,le,key)
print_answers(le)
ca=raw_input("Correct answer..")
try:
ca=int(ca)
except ValueError:
print "String Data"
if type(ca)==int:
if ca == answ:
print ("Yes I got it right")
else:
print("Oh no...!")
l=le.inverse_transform([ca])[0]
if type(l) is numpy.bool_:
l=bool(l)
if type(l) is numpy.string_ or type(l) is numpy.unicode_:
l=str(l)
store_training_data(tid,l, key)
elif not ca.strip() == "":
store_training_data(tid, ca, key)
else:
print "couldn't handle %s" % ca