div tzg
This commit is contained in:
@@ -41,7 +41,7 @@ def store_training_data(i, d,key=b"answered"):
|
||||
|
||||
|
||||
# Lade Trainingsdaten fuer einen angegebenen key (Label/Eigenschaft)
|
||||
def get_training_threads(key="answered", filter=[]):
|
||||
def get_training_threads(key="answered", filters=[]):
|
||||
if not data_types.has_key(key):
|
||||
raise ValueError("Key "+str(key)+" unknown")
|
||||
#------------------------------------
|
||||
@@ -49,12 +49,23 @@ def get_training_threads(key="answered", filter=[]):
|
||||
d_a=[]
|
||||
d_a2=[]
|
||||
#------------------------------------
|
||||
for i in train:
|
||||
if train[i].has_key(key): # In den Trainingsdaten muss der relevante Key sein
|
||||
t=db_session.query(MailThread).filter(MailThread.firstmail==i).first()
|
||||
if not t is None: # Thread muss in der Datenbank sein
|
||||
t_a.append(t)
|
||||
d_a.append(train[i][key])
|
||||
if "db" in filters:
|
||||
tt=db_session.query(MailThread).filter(MailThread.istrained==True).all()
|
||||
for t in tt:
|
||||
t_a.append(t)
|
||||
if key =="answered":
|
||||
d_a.append(t.answered)
|
||||
elif key=="maintopic":
|
||||
d_a.append(t.maintopic)
|
||||
|
||||
|
||||
else:
|
||||
for i in train:
|
||||
if train[i].has_key(key): # In den Trainingsdaten muss der relevante Key sein
|
||||
t=db_session.query(MailThread).filter(MailThread.firstmail==i).first()
|
||||
if not t is None: # Thread muss in der Datenbank sein
|
||||
t_a.append(t)
|
||||
d_a.append(train[i][key])
|
||||
le=LabelEncoder()
|
||||
d_a2=le.fit_transform(d_a)
|
||||
return (t_a,d_a2,le)
|
||||
@@ -91,9 +102,9 @@ class ThreadTextExtractor(BaseEstimator, TransformerMixin):
|
||||
def transform(self, X,y=None):
|
||||
return [t.text() for t in X]
|
||||
|
||||
def get_pipe(p=b"pipe1",k=b"answered"):
|
||||
def get_pipe(p=b"pipe1",k=b"answered",filters=[]):
|
||||
p=build_pipe(p)
|
||||
tt= get_training_threads(k)
|
||||
tt= get_training_threads(k,filters)
|
||||
if len(tt[0]) > 0:
|
||||
p.fit(tt[0],tt[1])
|
||||
return p,tt[2]
|
||||
|
||||
@@ -60,7 +60,7 @@ def train_single_thread(tid,p,le,key="answered"):
|
||||
l=le.inverse_transform([ca])[0]
|
||||
if type(l) is numpy.bool_:
|
||||
l=bool(l)
|
||||
if type(l) is numpy.string_:
|
||||
if type(l) is numpy.string_ or type(l) is numpy.unicode_:
|
||||
l=str(l)
|
||||
store_training_data(tid,l, key)
|
||||
elif not ca.strip() == "":
|
||||
|
||||
Reference in New Issue
Block a user