refactor1
This commit is contained in:
@@ -2,7 +2,41 @@
|
||||
from database import db_session, init_db
|
||||
from database import Base
|
||||
metadata=Base.metadata
|
||||
|
||||
from sklearn.preprocessing import LabelEncoder
|
||||
from collections import namedtuple
|
||||
# Two main objects: Mail & MailThread
|
||||
from mail_model import Mail
|
||||
from thread_model import MailThread
|
||||
def str_bool(s):
|
||||
return s in ["1", "true", "True", "t","T"]
|
||||
|
||||
# Lade Trainingsdaten fuer einen angegebenen key (Label/Eigenschaft)
|
||||
def get_training_threads(key="answered", filters=[]):
|
||||
#------------------------------------
|
||||
|
||||
db_fields= {"answered": lambda t: t.is_answered(),
|
||||
"lang": lambda t: t.lang,
|
||||
"maintopic": lambda t: t.maintopic}
|
||||
|
||||
if not db_fields.has_key(key):
|
||||
raise ValueError("Key "+str(key)+" unknown")
|
||||
|
||||
q=db_session.query(MailThread)
|
||||
q=q.filter(MailThread.istrained.is_(True))
|
||||
|
||||
if "de" in filters:
|
||||
q=q.filter(MailThread.lang=="de")
|
||||
elif "en" in filters:
|
||||
q=q.filter(MailThread.lang=="en")
|
||||
|
||||
# load and extract thread fields
|
||||
threads=q.all()
|
||||
labels = map(db_fields[key], threads)
|
||||
|
||||
# encode using LabelEncoder
|
||||
le=LabelEncoder()
|
||||
labels=le.fit_transform(labels)
|
||||
|
||||
TrainingThreads=namedtuple("TrainingThreads", ["MailThreads","EncodedLabels","LabelEncoder"])
|
||||
|
||||
return TrainingThreads(threads,labels,le)
|
||||
|
||||
@@ -128,10 +128,14 @@ class Mail(Base):
|
||||
self.text= yaml.dump(b4.get_text())
|
||||
else:
|
||||
self.text =yaml.dump( pl)
|
||||
def print_head(self):
|
||||
fr=yaml.load(self.from_)
|
||||
return "Gesendet von: "+str(fr[0]["mail"])+"@"+str(fr[0]["host"])+" am "+ str(self.date)
|
||||
|
||||
def print_text(self):
|
||||
txt=""
|
||||
fr=yaml.load(self.from_)
|
||||
txt= txt+ "Gesendet von: "+str(fr[0]["mail"])+"@"+str(fr[0]["host"])+" am "+ str(self.date) + "\n"
|
||||
# txt= txt+ "Gesendet von: "+str(fr[0]["mail"])+"@"+str(fr[0]["host"])+" am "+ str(self.date) + "\n"
|
||||
t=yaml.load(self.text)
|
||||
if type(t) is unicode:
|
||||
#txt=txt
|
||||
|
||||
@@ -105,6 +105,14 @@ class MailThread(Base):
|
||||
db_session.commit()
|
||||
self.date=self.mails()[0].date
|
||||
|
||||
def print_mail(self, filter="all"):
|
||||
a=[]
|
||||
if filter=="all":
|
||||
mm=self.mails()
|
||||
for m in mm:
|
||||
a.append((m.print_head(), m.print_text()))
|
||||
return a
|
||||
|
||||
def print_text(self,filter="all"):
|
||||
a=[]
|
||||
if filter=="all":
|
||||
@@ -114,6 +122,16 @@ class MailThread(Base):
|
||||
elif filter=="first":
|
||||
a.append(m[0].print_text())
|
||||
return a
|
||||
def print_head(self,filter="all"):
|
||||
a=[]
|
||||
if filter=="all":
|
||||
mm=self.mails()
|
||||
for m in mm:
|
||||
a.append(m.print_head())
|
||||
elif filter=="first":
|
||||
a.append(m[0].print_head())
|
||||
return a
|
||||
|
||||
def text(self,filter="all"):
|
||||
a=u""
|
||||
def mail_txt(m):
|
||||
|
||||
Reference in New Issue
Block a user