# SQL Database session from database import db_session, init_db from database import Base metadata=Base.metadata from sklearn.preprocessing import LabelEncoder from collections import namedtuple # Two main objects: Mail & MailThread from mail_model import Mail from thread_model import MailThread def str_bool(s): return s in ["1", "true", "True", "t","T"] # Lade Trainingsdaten fuer einen angegebenen key (Label/Eigenschaft) def get_training_threads(key="answered", filters=[]): #------------------------------------ db_fields= {"answered": lambda t: t.is_answered(), "lang": lambda t: t.lang, "maintopic": lambda t: t.maintopic} if not db_fields.has_key(key): raise ValueError("Key "+str(key)+" unknown") q=db_session.query(MailThread) q=q.filter(MailThread.istrained.is_(True)) if "de" in filters: q=q.filter(MailThread.lang=="de") elif "en" in filters: q=q.filter(MailThread.lang=="en") # load and extract thread fields threads=q.all() labels = map(db_fields[key], threads) # encode using LabelEncoder le=LabelEncoder() labels=le.fit_transform(labels) TrainingThreads=namedtuple("TrainingThreads", ["MailThreads","EncodedLabels","LabelEncoder"]) return TrainingThreads(threads,labels,le)