diff --git a/bayesian.go b/bayesian.go index 0466eb6..af6e05f 100644 --- a/bayesian.go +++ b/bayesian.go @@ -1,100 +1,130 @@ -/* -Part of this code is borrowed from github.com/jbrukh/bayesian published under a BSD3CLAUSE License -*/ - package sisyphus import ( - "math" - "strconv" + "errors" "github.com/boltdb/bolt" + "github.com/gonum/stat" + "github.com/retailnext/hllpp" ) -// classificationPriors returns the prior probabilities for good and junk +// classificationPrior returns the prior probabilities for good and junk // classes. -func classificationPriors(db *bolt.DB) (g, j float64) { +func classificationPrior(db *bolt.DB) (g float64, err error) { - db.View(func(tx *bolt.Tx) error { + err = db.View(func(tx *bolt.Tx) error { b := tx.Bucket([]byte("Wordlists")) + good := b.Bucket([]byte("Good")) gN := float64(good.Stats().KeyN) + junk := b.Bucket([]byte("Junk")) jN := float64(junk.Stats().KeyN) + // division by zero means there are no learned mails so far + if (gN + jN) == 0 { + return errors.New("no mails have been classified so far") + } + g = gN / (gN + jN) - j = jN / (gN + jN) return nil }) - return + return g, err } -// classificationWordProb returns P(W|C_j) -- the probability of seeing -// a particular word W in a document of this class. -func classificationWordProb(db *bolt.DB, word string) (g, j float64) { +// classificationLikelihood returns P(W|C_j) -- the probability of seeing a +// particular word W in a document of this class. +func classificationLikelihood(db *bolt.DB, word string) (g, j float64, err error) { + + err = db.View(func(tx *bolt.Tx) error { + var gN, jN uint64 - db.View(func(tx *bolt.Tx) error { b := tx.Bucket([]byte("Wordlists")) + good := b.Bucket([]byte("Good")) - gNString := string(good.Get([]byte(word))) - gN, _ := strconv.ParseFloat(gNString, 64) + gWordRaw := good.Get([]byte(word)) + if len(gWordRaw) != 0 { + gWordHLL, err := hllpp.Unmarshal(gWordRaw) + if err != nil { + return err + } + gN = gWordHLL.Count() + } junk := b.Bucket([]byte("Junk")) - jNString := string(junk.Get([]byte(word))) - jN, _ := strconv.ParseFloat(jNString, 64) - - p := tx.Bucket([]byte("Processed")) - counters := p.Bucket([]byte("Counters")) - jString := string(counters.Get([]byte("Junk"))) - j, _ = strconv.ParseFloat(jString, 64) - mails := p.Bucket([]byte("Mails")) - pN := mails.Stats().KeyN - - g = gN / (float64(pN) - j) - j = jN / j + jWordRaw := junk.Get([]byte(word)) + if len(jWordRaw) != 0 { + jWordHLL, err := hllpp.Unmarshal(jWordRaw) + if err != nil { + return err + } + jN = jWordHLL.Count() + } + + p := tx.Bucket([]byte("Statistics")) + gHLL, err := hllpp.Unmarshal(p.Get([]byte("ProcessedGood"))) + if err != nil { + return err + } + jHLL, err := hllpp.Unmarshal(p.Get([]byte("ProcessedJunk"))) + if err != nil { + return err + } + + gTotal := gHLL.Count() + if gTotal == 0 { + return errors.New("no good mails have been classified so far") + } + jTotal := jHLL.Count() + if jTotal == 0 { + return errors.New("no junk mails have been classified so far") + } + + g = float64(gN) / float64(gTotal) + j = float64(jN) / float64(jTotal) return nil }) - return g, j + return g, j, nil +} + +// classificationWord produces the conditional probability of a word belonging +// to good or junk using the classic Bayes' rule. +func classificationWord(db *bolt.DB, word string) (g float64, err error) { + + priorG, err := classificationPrior(db) + if err != nil { + return g, err + } + + likelihoodG, likelihoodJ, err := classificationLikelihood(db, word) + if err != nil { + return g, err + } + + g = (likelihoodG * priorG) / (likelihoodG*priorG + likelihoodJ*(1-priorG)) + + return g, nil } -// LogScores produces "log-likelihood"-like scores that can -// be used to classify documents into classes. -// -// The value of the score is proportional to the likelihood, -// as determined by the classifier, that the given document -// belongs to the given class. This is true even when scores -// returned are negative, which they will be (since we are -// taking logs of probabilities). -// -// The index j of the score corresponds to the class given -// by c.Classes[j]. -// -// Additionally returned are "inx" and "strict" values. The -// inx corresponds to the maximum score in the array. If more -// than one of the scores holds the maximum values, then -// strict is false. -// -// Unlike c.Probabilities(), this function is not prone to -// floating point underflow and is relatively safe to use. -func LogScores(db *bolt.DB, wordlist []string) (scoreG, scoreJ float64, junk bool) { - - priorG, priorJ := classificationPriors(db) - - // calculate the scores - scoreG = math.Log(priorG) - scoreJ = math.Log(priorJ) - for _, word := range wordlist { - gP, jP := classificationWordProb(db, word) - scoreG += math.Log(gP) - scoreJ += math.Log(jP) +// Junk returns true if the wordlist is classified as a junk mail using Bayes' +// rule. +func Junk(db *bolt.DB, wordlist []string) (bool, error) { + var probabilities []float64 + + for _, val := range wordlist { + p, err := classificationWord(db, val) + if err != nil { + return false, err + } + probabilities = append(probabilities, p) } - if scoreJ == math.Max(scoreG, scoreJ) { - junk = true + if stat.HarmonicMean(probabilities, nil) < 0.5 { + return true, nil } - return scoreG, scoreJ, junk + return false, nil } diff --git a/database.go b/database.go index b6226b4..317f2b4 100644 --- a/database.go +++ b/database.go @@ -6,40 +6,20 @@ import ( "github.com/boltdb/bolt" ) -// OpenDB creates and opens a new database and its respective buckets (if required) -func OpenDB(maildir string) (db *bolt.DB, err error) { +// openDB creates and opens a new database and its respective buckets (if required) +func openDB(m Maildir) (db *bolt.DB, err error) { - log.Println("loading database") + log.Println("loading database for " + string(m)) // Open the sisyphus.db data file in your current directory. // It will be created if it doesn't exist. - db, err = bolt.Open(maildir+"/sisyphus.db", 0600, nil) + db, err = bolt.Open(string(m)+"/sisyphus.db", 0600, nil) if err != nil { return db, err } // Create DB bucket for the map of processed e-mail IDs err = db.Update(func(tx *bolt.Tx) error { - _, err := tx.CreateBucketIfNotExists([]byte("Processed")) - return err - }) - if err != nil { - return db, err - } - - // Create DB bucket for Mails inside bucket Processed - err = db.Update(func(tx *bolt.Tx) error { - b := tx.Bucket([]byte("Processed")) - _, err := b.CreateBucketIfNotExists([]byte("Mails")) - return err - }) - if err != nil { - return db, err - } - - // Create DB bucket for Counters inside bucket Processed - err = db.Update(func(tx *bolt.Tx) error { - b := tx.Bucket([]byte("Processed")) - _, err := b.CreateBucketIfNotExists([]byte("Counters")) + _, err := tx.CreateBucketIfNotExists([]byte("Statistics")) return err }) if err != nil { @@ -75,3 +55,28 @@ func OpenDB(maildir string) (db *bolt.DB, err error) { log.Println("database loaded") return db, err } + +// LoadDatabases loads all databases from a given slice of Maildirs +func LoadDatabases(d []Maildir) (databases map[Maildir]*bolt.DB, err error) { + databases = make(map[Maildir]*bolt.DB) + for _, val := range d { + databases[val], err = openDB(val) + if err != nil { + return databases, err + } + } + + return databases, nil +} + +// CloseDatabases closes all databases from a given slice of Maildirs +func CloseDatabases(databases map[Maildir]*bolt.DB) { + for _, val := range databases { + err := val.Close() + if err != nil { + log.Println(err) + } + } + + return +} diff --git a/glide.lock b/glide.lock index 8a026ae..e13ebea 100644 --- a/glide.lock +++ b/glide.lock @@ -1,10 +1,34 @@ -hash: 3cee04040c8122c41e41716acbf51492c7e42ac3d7dc239556123e00101bfc17 -updated: 2017-05-05T01:56:14.319028504Z +hash: 3c3cc20d232bb926c0eff2f838a575b28c9f8669a04af14079464608f74ea78b +updated: 2017-05-08T03:21:39.033013152Z imports: - name: github.com/boltdb/bolt version: 583e8937c61f1af6513608ccc75c97b6abdf4ff9 - name: github.com/fsnotify/fsnotify version: 4da3e2cfbabc9f751898f250b49f2439785783a1 +- name: github.com/gonum/blas + version: 430a98d0f42f5d7b9a9535580771233386034848 + subpackages: + - blas64 + - native + - native/internal/math32 +- name: github.com/gonum/floats + version: a2cbc5c70616cd18491ef2843231f6ce28b2cb02 +- name: github.com/gonum/internal + version: abbe1115275b8ef5f1c4b73e2360c9e8177edcb2 + subpackages: + - asm/f32 + - asm/f64 +- name: github.com/gonum/lapack + version: b2c45d0a904f3d18ef419fd1c65c22f3749046fe + subpackages: + - lapack64 + - native +- name: github.com/gonum/matrix + version: 496fef53954ba166f5dbc5bbbf90b433c130cfb1 + subpackages: + - mat64 +- name: github.com/gonum/stat + version: cd3537419d189410d273f890c4e8950c64b92df5 - name: github.com/kennygrant/sanitize version: 6a0bfdde8629a3a3a7418a7eae45c54154692514 - name: github.com/luksen/maildir @@ -12,12 +36,13 @@ imports: - name: github.com/retailnext/hllpp version: 9fdfea05b3e55bebe7beb22d16c7db15d46cd518 - name: github.com/urfave/cli - version: ab403a54a148f2d857920810291539e1f817ee7b + version: d70f47eeca3afd795160003bc6e28b001d60c67c - name: golang.org/x/net version: feeb485667d1fdabe727840fe00adc22431bc86e subpackages: - html - html/atom + - html/charset - name: golang.org/x/sys version: 9ccfe848b9db8435a24c424abbc07a921adf1df5 subpackages: @@ -44,7 +69,7 @@ testImports: - reporters/stenographer/support/go-isatty - types - name: github.com/onsi/gomega - version: da367351331c75190afec7ed3a0c61e0016942ed + version: da1d0e434025dec4c843b8f05171992bb6c8dfcc subpackages: - format - internal/assertion @@ -57,5 +82,23 @@ testImports: - matchers/support/goraph/node - matchers/support/goraph/util - types +- name: golang.org/x/text + version: 470f45bf29f4147d6fbd7dfd0a02a848e49f5bf4 + subpackages: + - encoding + - encoding/charmap + - encoding/htmlindex + - encoding/internal + - encoding/internal/identifier + - encoding/japanese + - encoding/korean + - encoding/simplifiedchinese + - encoding/traditionalchinese + - encoding/unicode + - internal/tag + - internal/utf8internal + - language + - runes + - transform - name: gopkg.in/yaml.v2 version: cd8b52f8269e0feb286dfeef29f8fe4d5b397e0b diff --git a/glide.yaml b/glide.yaml index 62e57be..29df0df 100644 --- a/glide.yaml +++ b/glide.yaml @@ -7,6 +7,7 @@ import: - package: github.com/urfave/cli - package: github.com/fsnotify/fsnotify - package: github.com/retailnext/hllpp +- package: github.com/gonum/stat testImport: - package: github.com/onsi/ginkgo - package: github.com/onsi/gomega diff --git a/gometalinter.out b/gometalinter.out new file mode 100644 index 0000000..9834ee0 --- /dev/null +++ b/gometalinter.out @@ -0,0 +1,21 @@ +[ + {"linter":"gas","severity":"warning","path":"bayesian.go","line":18,"col":0,"message":"Errors unhandled.,LOW,HIGH"}, + {"linter":"gas","severity":"warning","path":"bayesian.go","line":38,"col":0,"message":"Errors unhandled.,LOW,HIGH"}, + {"linter":"gas","severity":"warning","path":"bayesian.go","line":42,"col":0,"message":"Errors unhandled.,LOW,HIGH"}, + {"linter":"gas","severity":"warning","path":"bayesian.go","line":45,"col":0,"message":"Errors unhandled.,LOW,HIGH"}, + {"linter":"gas","severity":"warning","path":"bayesian.go","line":50,"col":0,"message":"Errors unhandled.,LOW,HIGH"}, + {"linter":"gas","severity":"warning","path":"daemon.go","line":45,"col":0,"message":"Subprocess launching with variable.,HIGH,HIGH"}, + {"linter":"gas","severity":"warning","path":"daemon.go","line":115,"col":0,"message":"Subprocess launching with variable.,HIGH,HIGH"}, + {"linter":"gas","severity":"warning","path":"daemon.go","line":122,"col":0,"message":"Subprocess launching with variable.,HIGH,HIGH"}, + {"linter":"gocyclo","severity":"warning","path":"mail.go","line":232,"col":0,"message":"cyclomatic complexity 16 of function (*Mail).Classify() is high (\u003e 10)"}, + {"linter":"dupl","severity":"warning","path":"mail_test.go","line":135,"col":0,"message":"duplicate of mail_test.go:160-183"}, + {"linter":"dupl","severity":"warning","path":"mail_test.go","line":160,"col":0,"message":"duplicate of mail_test.go:185-208"}, + {"linter":"dupl","severity":"warning","path":"mail_test.go","line":185,"col":0,"message":"duplicate of mail_test.go:210-233"}, + {"linter":"dupl","severity":"warning","path":"mail_test.go","line":210,"col":0,"message":"duplicate of mail_test.go:235-258"}, + {"linter":"dupl","severity":"warning","path":"mail_test.go","line":235,"col":0,"message":"duplicate of mail_test.go:260-283"}, + {"linter":"dupl","severity":"warning","path":"mail_test.go","line":260,"col":0,"message":"duplicate of mail_test.go:135-158"}, + {"linter":"errcheck","severity":"warning","path":"bayesian.go","line":18,"col":9,"message":"error return value not checked (db.View(func(tx *bolt.Tx) error {)"}, + {"linter":"errcheck","severity":"warning","path":"bayesian.go","line":38,"col":9,"message":"error return value not checked (db.View(func(tx *bolt.Tx) error {)"}, + {"linter":"errcheck","severity":"warning","path":"daemon.go","line":26,"col":18,"message":"error return value not checked (defer file.Close())"}, + {"linter":"errcheck","severity":"warning","path":"mail.go","line":275,"col":11,"message":"error return value not checked (db.Update(func(tx *bolt.Tx) error {)"} +] diff --git a/mail.go b/mail.go index 9238952..8751dac 100644 --- a/mail.go +++ b/mail.go @@ -16,13 +16,6 @@ import ( "github.com/luksen/maildir" ) -const ( - // Good holds a placeholder string for the database - Good = "0" - // Junk holds a placeholder string for the database - Junk = "1" -) - // Maildir represents the address to a Maildir directory type Maildir string @@ -228,7 +221,10 @@ func (m *Mail) Wordlist() (w []string) { return w } -// Classify analyses the mail and decides whether it is Junk or Good +// Classify analyses a new mail (a mail that arrived in the "new" directory), +// decides whether it is junk and -- if so -- moves it to the Junk folder. If +// it is not junk, the mail is untouched so it can be handled by the mail +// client. func (m *Mail) Classify(db *bolt.DB) error { err := m.Clean() @@ -237,15 +233,16 @@ func (m *Mail) Classify(db *bolt.DB) error { } list := m.Wordlist() - scoreG, scoreJ, ju := LogScores(db, list) + junk, err := Junk(db, list) + if err != nil { + return err + } - log.Print("Classified " + m.Key + " as Junk=" + strconv.FormatBool(m.Junk) + - " (good: " + strconv.FormatFloat(scoreG, 'f', 4, 64) + - ", junk: " + strconv.FormatFloat(scoreJ, 'f', 4, 64) + ")") + log.Print("Classified " + m.Key + " as Junk=" + strconv.FormatBool(m.Junk)) - // Move mails around after classification - if m.New && ju { - m.Junk = ju + // Move mail around if junk. + if junk { + m.Junk = junk err := os.Rename("./new/"+m.Key, "./.Junk/cur/"+m.Key) if err != nil { return err @@ -253,42 +250,6 @@ func (m *Mail) Classify(db *bolt.DB) error { log.Print("Moved " + m.Key + " from new to Junk folder") } - if !m.New && m.Junk && !ju { - err := os.Rename("./.Junk/cur/"+m.Key, "./cur/"+m.Key) - if err != nil { - return err - } - m.Junk = ju - log.Print("Moved " + m.Key + " from Junk to Good folder") - } - - if !m.New && ju && !m.Junk { - err := os.Rename("./cur/"+m.Key, "./.Junk/cur/"+m.Key) - if err != nil { - return err - } - m.Junk = ju - log.Print("Moved " + m.Key + " from Good to Junk folder") - } - - // Inform the DB about a processed mail - db.Update(func(tx *bolt.Tx) error { - b := tx.Bucket([]byte("Processed")) - bMails := b.Bucket([]byte("Mails")) - if ju { - err := bMails.Put([]byte(m.Key), []byte(Junk)) - if err != nil { - return err - } - } else { - err := bMails.Put([]byte(m.Key), []byte(Good)) - if err != nil { - return err - } - } - return err - }) - return nil } @@ -297,3 +258,27 @@ func (m *Mail) Classify(db *bolt.DB) error { func (m *Mail) Learn(db *bolt.DB) error { return nil } + +// LoadMails creates missing directories and then loads all mails from a given +// slice of Maildirs +func LoadMails(d []Maildir) (mails map[Maildir][]*Mail, err error) { + mails = make(map[Maildir][]*Mail) + + // create missing directories and write index + for _, val := range d { + err := val.CreateDirs() + if err != nil { + return mails, err + } + + var m []*Mail + m, err = val.Index() + if err != nil { + return mails, err + } + + mails[val] = m + } + + return mails, nil +} diff --git a/sisyphus/sisyphus.go b/sisyphus/sisyphus.go index 996ea97..b96b2db 100644 --- a/sisyphus/sisyphus.go +++ b/sisyphus/sisyphus.go @@ -8,13 +8,14 @@ import ( "strings" "syscall" - "github.com/boltdb/bolt" "github.com/carlostrub/sisyphus" "github.com/fsnotify/fsnotify" "github.com/urfave/cli" ) -var version string +var ( + version string +) func main() { @@ -106,61 +107,62 @@ func main() { }() - // Load the Maildir if len(maildirPaths) < 1 { - log.Fatal("No Maildir set.") - } - if len(maildirPaths) > 1 { - log.Fatal("Sorry... only one Maildir supported as of today.") + log.Fatal("No Maildir set. Please check the manual.") } - sisyphus.Maildir(maildirPaths[0]).CreateDirs() + // Populate maildir with the maildirs given by setting the flag. + var maildirs []sisyphus.Maildir + for _, val := range maildirPaths { + maildirs = append(maildirs, sisyphus.Maildir(val)) + } - mails, err := sisyphus.Maildir(maildirPaths[0]).Index() + // Load all mails + mails, err := sisyphus.LoadMails(maildirs) if err != nil { - log.Fatal("Wrong path to Maildir") + log.Fatal(err) } - // Open the database - db, err := sisyphus.OpenDB(maildirPaths[0]) + // Open all databases + dbs, err := sisyphus.LoadDatabases(maildirs) if err != nil { log.Fatal(err) } - defer db.Close() - - // Handle all mails after startup - for i := range mails { - db.View(func(tx *bolt.Tx) error { - b := tx.Bucket([]byte("Processed")) - bMails := b.Bucket([]byte("Mails")) - v := bMails.Get([]byte(mails[i].Key)) - if len(v) == 0 { - err = mails[i].Classify(db) - if err != nil { - log.Print(err) - } - err = mails[i].Learn(db) - if err != nil { - log.Print(err) - } - } - if string(v) == sisyphus.Good && mails[i].Junk == true { - err = mails[i].Learn(db) - if err != nil { - log.Print(err) - } - } - if string(v) == sisyphus.Junk && mails[i].Junk == false { - err = mails[i].Learn(db) - if err != nil { - log.Print(err) - } - } - return nil - }) - } - - // Handle mails as the arrive + defer sisyphus.CloseDatabases(dbs) + + // Learn at startup + // for i := range mails { + // db.View(func(tx *bolt.Tx) error { + // b := tx.Bucket([]byte("Processed")) + // bMails := b.Bucket([]byte("Mails")) + // v := bMails.Get([]byte(mails[i].Key)) + // if len(v) == 0 { + // err = mails[i].Classify(db) + // if err != nil { + // log.Print(err) + // } + // err = mails[i].Learn(db) + // if err != nil { + // log.Print(err) + // } + // } + // if string(v) == sisyphus.Good && mails[i].Junk == true { + // err = mails[i].Learn(db) + // if err != nil { + // log.Print(err) + // } + // } + // if string(v) == sisyphus.Junk && mails[i].Junk == false { + // err = mails[i].Learn(db) + // if err != nil { + // log.Print(err) + // } + // } + // return nil + // }) + // } + + // Classify on arrival watcher, err := fsnotify.NewWatcher() if err != nil { log.Fatal(err)