71 lines
2.4 KiB
Python
71 lines
2.4 KiB
Python
from sklearn.pipeline import Pipeline
|
|
from sklearn.preprocessing import LabelEncoder
|
|
import numpy
|
|
from storage import Mail, MailThread, db_session
|
|
from classifier import store_training_data, print_answers
|
|
|
|
|
|
|
|
def train_fit_pipe():
|
|
tt= get_training_threads(b"answered")
|
|
pipe1.fit(tt[0],tt[1])
|
|
return pipe1,tt[2]
|
|
def train_fit_pipe2():
|
|
tt= get_training_threads(b"maintopic")
|
|
pipe2.fit(tt[0],tt[1])
|
|
return pipe2,tt[2]
|
|
|
|
def train_fit_pipe2b():
|
|
tt= get_training_threads(b"maintopic")
|
|
pipe2b.fit(tt[0],tt[1])
|
|
return pipe2b,tt[2]
|
|
|
|
|
|
def predict_thread(mth,p,le,key):
|
|
#-------------------------------------------------------
|
|
if not type(p) is Pipeline: raise TypeError("Second Argument needs to be type Pipeline")
|
|
if not type(le) is LabelEncoder: raise TypeError("Second Argument needs to be type LabelEncoder")
|
|
#-------------------------------------------------------
|
|
pre=p.predict([mth])
|
|
answ=pre[0]
|
|
print "Status is answered is estimated to be: " + str(le.inverse_transform(pre)[0])
|
|
return answ
|
|
|
|
def train_single_thread(tid,p,le,key="answered"):
|
|
if (not type(tid) is int): raise TypeError("ID must be of type int")
|
|
|
|
mth=db_session.query(MailThread).filter(MailThread.firstmail==tid).first()
|
|
if mth is None: raise ValueError("Thread with firstmail %d not in Database" %tid)
|
|
print mth.firstmail
|
|
print mth.subject()
|
|
print mth.text()
|
|
|
|
if not p is None and not le is None:
|
|
answ=predict_thread(mth,p,le,key)
|
|
else: answ=None
|
|
if not le is None:
|
|
print_answers(le)
|
|
|
|
ca=raw_input("Correct answer..")
|
|
try:
|
|
ca=int(ca)
|
|
except ValueError:
|
|
print "String Data"
|
|
|
|
if type(ca)==int:
|
|
if ca == answ:
|
|
print ("Yes I got it right")
|
|
else:
|
|
print("Oh no...!")
|
|
l=le.inverse_transform([ca])[0]
|
|
if type(l) is numpy.bool_:
|
|
l=bool(l)
|
|
if type(l) is numpy.string_:
|
|
l=str(l)
|
|
store_training_data(tid,l, key)
|
|
elif not ca.strip() == "":
|
|
store_training_data(tid, ca, key)
|
|
else:
|
|
print "couldn't handle %s" % ca
|
|
|