refactor1

This commit is contained in:
Andreas Stephanides
2017-08-28 09:08:47 +02:00
parent 699f4f6546
commit 630b982502
14 changed files with 274 additions and 230 deletions

View File

@@ -2,7 +2,41 @@
from database import db_session, init_db
from database import Base
metadata=Base.metadata
from sklearn.preprocessing import LabelEncoder
from collections import namedtuple
# Two main objects: Mail & MailThread
from mail_model import Mail
from thread_model import MailThread
def str_bool(s):
return s in ["1", "true", "True", "t","T"]
# Lade Trainingsdaten fuer einen angegebenen key (Label/Eigenschaft)
def get_training_threads(key="answered", filters=[]):
#------------------------------------
db_fields= {"answered": lambda t: t.is_answered(),
"lang": lambda t: t.lang,
"maintopic": lambda t: t.maintopic}
if not db_fields.has_key(key):
raise ValueError("Key "+str(key)+" unknown")
q=db_session.query(MailThread)
q=q.filter(MailThread.istrained.is_(True))
if "de" in filters:
q=q.filter(MailThread.lang=="de")
elif "en" in filters:
q=q.filter(MailThread.lang=="en")
# load and extract thread fields
threads=q.all()
labels = map(db_fields[key], threads)
# encode using LabelEncoder
le=LabelEncoder()
labels=le.fit_transform(labels)
TrainingThreads=namedtuple("TrainingThreads", ["MailThreads","EncodedLabels","LabelEncoder"])
return TrainingThreads(threads,labels,le)

View File

@@ -128,10 +128,14 @@ class Mail(Base):
self.text= yaml.dump(b4.get_text())
else:
self.text =yaml.dump( pl)
def print_head(self):
fr=yaml.load(self.from_)
return "Gesendet von: "+str(fr[0]["mail"])+"@"+str(fr[0]["host"])+" am "+ str(self.date)
def print_text(self):
txt=""
fr=yaml.load(self.from_)
txt= txt+ "Gesendet von: "+str(fr[0]["mail"])+"@"+str(fr[0]["host"])+" am "+ str(self.date) + "\n"
# txt= txt+ "Gesendet von: "+str(fr[0]["mail"])+"@"+str(fr[0]["host"])+" am "+ str(self.date) + "\n"
t=yaml.load(self.text)
if type(t) is unicode:
#txt=txt

View File

@@ -105,6 +105,14 @@ class MailThread(Base):
db_session.commit()
self.date=self.mails()[0].date
def print_mail(self, filter="all"):
a=[]
if filter=="all":
mm=self.mails()
for m in mm:
a.append((m.print_head(), m.print_text()))
return a
def print_text(self,filter="all"):
a=[]
if filter=="all":
@@ -114,6 +122,16 @@ class MailThread(Base):
elif filter=="first":
a.append(m[0].print_text())
return a
def print_head(self,filter="all"):
a=[]
if filter=="all":
mm=self.mails()
for m in mm:
a.append(m.print_head())
elif filter=="first":
a.append(m[0].print_head())
return a
def text(self,filter="all"):
a=u""
def mail_txt(m):