This commit is contained in:
andis
2017-08-11 20:34:17 +02:00
parent 51acbfbd38
commit 81fa9cc575
10 changed files with 202 additions and 79 deletions

View File

@@ -41,7 +41,7 @@ def store_training_data(i, d,key=b"answered"):
# Lade Trainingsdaten fuer einen angegebenen key (Label/Eigenschaft)
def get_training_threads(key="answered", filter=[]):
def get_training_threads(key="answered", filters=[]):
if not data_types.has_key(key):
raise ValueError("Key "+str(key)+" unknown")
#------------------------------------
@@ -49,12 +49,23 @@ def get_training_threads(key="answered", filter=[]):
d_a=[]
d_a2=[]
#------------------------------------
for i in train:
if train[i].has_key(key): # In den Trainingsdaten muss der relevante Key sein
t=db_session.query(MailThread).filter(MailThread.firstmail==i).first()
if not t is None: # Thread muss in der Datenbank sein
t_a.append(t)
d_a.append(train[i][key])
if "db" in filters:
tt=db_session.query(MailThread).filter(MailThread.istrained==True).all()
for t in tt:
t_a.append(t)
if key =="answered":
d_a.append(t.answered)
elif key=="maintopic":
d_a.append(t.maintopic)
else:
for i in train:
if train[i].has_key(key): # In den Trainingsdaten muss der relevante Key sein
t=db_session.query(MailThread).filter(MailThread.firstmail==i).first()
if not t is None: # Thread muss in der Datenbank sein
t_a.append(t)
d_a.append(train[i][key])
le=LabelEncoder()
d_a2=le.fit_transform(d_a)
return (t_a,d_a2,le)
@@ -91,9 +102,9 @@ class ThreadTextExtractor(BaseEstimator, TransformerMixin):
def transform(self, X,y=None):
return [t.text() for t in X]
def get_pipe(p=b"pipe1",k=b"answered"):
def get_pipe(p=b"pipe1",k=b"answered",filters=[]):
p=build_pipe(p)
tt= get_training_threads(k)
tt= get_training_threads(k,filters)
if len(tt[0]) > 0:
p.fit(tt[0],tt[1])
return p,tt[2]

View File

@@ -60,7 +60,7 @@ def train_single_thread(tid,p,le,key="answered"):
l=le.inverse_transform([ca])[0]
if type(l) is numpy.bool_:
l=bool(l)
if type(l) is numpy.string_:
if type(l) is numpy.string_ or type(l) is numpy.unicode_:
l=str(l)
store_training_data(tid,l, key)
elif not ca.strip() == "":