43 lines
1.3 KiB
Python
43 lines
1.3 KiB
Python
# SQL Database session
|
|
from database import db_session, init_db
|
|
from database import Base
|
|
metadata=Base.metadata
|
|
from sklearn.preprocessing import LabelEncoder
|
|
from collections import namedtuple
|
|
# Two main objects: Mail & MailThread
|
|
from mail_model import Mail
|
|
from thread_model import MailThread
|
|
def str_bool(s):
|
|
return s in ["1", "true", "True", "t","T"]
|
|
|
|
# Lade Trainingsdaten fuer einen angegebenen key (Label/Eigenschaft)
|
|
def get_training_threads(key="answered", filters=[]):
|
|
#------------------------------------
|
|
|
|
db_fields= {"answered": lambda t: t.is_answered(),
|
|
"lang": lambda t: t.lang,
|
|
"maintopic": lambda t: t.maintopic}
|
|
|
|
if not db_fields.has_key(key):
|
|
raise ValueError("Key "+str(key)+" unknown")
|
|
|
|
q=db_session.query(MailThread)
|
|
q=q.filter(MailThread.istrained.is_(True))
|
|
|
|
if "de" in filters:
|
|
q=q.filter(MailThread.lang=="de")
|
|
elif "en" in filters:
|
|
q=q.filter(MailThread.lang=="en")
|
|
|
|
# load and extract thread fields
|
|
threads=q.all()
|
|
labels = map(db_fields[key], threads)
|
|
|
|
# encode using LabelEncoder
|
|
le=LabelEncoder()
|
|
labels=le.fit_transform(labels)
|
|
|
|
TrainingThreads=namedtuple("TrainingThreads", ["MailThreads","EncodedLabels","LabelEncoder"])
|
|
|
|
return TrainingThreads(threads,labels,le)
|