Skip to content

Commit

Permalink
make everything work
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobobryant committed Aug 25, 2017
1 parent 173d577 commit d0a71ec
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 57 deletions.
102 changes: 51 additions & 51 deletions reco/src/reco/reco.clj
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
[get_last_event_id [] long]
[set_last_event_id [long] void]
[session_size [] int]
[update_freshness [java.util.Collection java.util.Collection] void]
^:static [calc_strength [java.util.Collection] double]
^:static [modelify [java.util.Collection long] java.util.Map]
^:static [modelify [java.util.Collection long java.util.Collection String] java.util.Map]
^:static [parse_top_tracks [String] java.util.List]
Expand Down Expand Up @@ -44,9 +46,8 @@
(def g {"artist" "g" "album" "g" "title" "g"})
(def h {"artist" "h" "album" "h" "title" "h"})

(defrecord Event [day skipped])

(defrecord Candidate [event-vec freshness
(defrecord Candidate [freshness
ratio n score
content-ratio content-n content-score])

Expand All @@ -71,7 +72,7 @@
(map (fn [item] [(item :_id)
(map->Song (dissoc item :_id))]))
(into {}))
candidates (repeat (count library) (Candidate. [] 1 0 0 0 1 1 1))]
candidates (repeat (count library) (Candidate. 1 0 0 0 1 1 1))]
[[] (atom {:library library
:candidates (zipmap (keys library) candidates)
:session {}
Expand Down Expand Up @@ -150,36 +151,11 @@
([candidates library model skipped]
(into {} (map #(update-cand % library model skipped) candidates))))

(defn penalty [delta strength]
(- 1 (/ 1 (Math/exp (/ delta strength)))))

(defn predicted [deltas strength]
(reduce * (map #(penalty % strength) deltas)))

(defn total-error [input-data strength]
(reduce + (map #(Math/pow (- (predicted (:deltas %) strength)
(:observed %)) 2) input-data)))

;(def strength-set [0.2 0.5 1 1.5 2 3 5 8 13 21])
(def strength-set [0.5 3 14])
(defn calc-freshness [strength event-vec]
(let [;get-deltas (fn [i event] {:observed (if (:skipped event) 0 1)
; :deltas (map #(- (:day event) (:day %))
; (take i event-vec))})
;input-data (map-indexed get-deltas event-vec)
;strength (apply min-key #(total-error input-data %) strength-set)
day (/ (now) 86400)
deltas (map (fn [e] (max (- day (:day e)) 0)) event-vec)]
(predicted deltas strength)))

; TODO refactor
(defn reset-candidates [candidates library]
(defn reset-candidates [candidates]
(into {} (map (fn [[song-id data]]
[song-id
(assoc data
:freshness (calc-freshness
(get-in library [song-id :mem_strength] 3)
(sort-by :day (:event-vec data)))
:ratio 0
:n 0
:score 0
Expand All @@ -191,9 +167,6 @@
(defn add-event [state model song-id skipped timestamp do-cand-update]
(let [time-delta (- timestamp (:last-time state))
new-session (> time-delta ses-threshold)
; song id -1 represents the empty session. It gives the model
; something to work with when a session is just getting
; started.
session (if new-session
{song-id skipped}
(assoc (:session state) song-id skipped))]
Expand All @@ -204,11 +177,7 @@

(cond-> state
true (assoc :last-time timestamp :session session)
;true (update-in [:candidates song-id :event-vec]
; #(conj % (Event. (/ timestamp 86400) skipped)))
new-session (update :candidates reset-candidates)
;(and new-session do-cand-update) (update-candidates (walk/keywordize-keys model)
; -1 false)
do-cand-update (update-candidates (walk/keywordize-keys model) skipped)
true (assoc :new-model (if new-session
(walk/stringify-keys
Expand Down Expand Up @@ -313,6 +282,51 @@
(defn -session_size [this]
(count (:session @@this)))

(defn parse-timestamp [timestamp]
(let [parser (new java.text.SimpleDateFormat "yyyy-MM-dd HH:mm:ss")]
(quot (.getTime (.parse parser timestamp)) 1000)))

(defn penalty [delta strength]
(- 1 (/ 1 (Math/exp (/ delta strength)))))

(defn predicted [deltas strength]
(reduce * (map #(penalty % strength) deltas)))

(defn total-error [input-data strength]
(reduce + (map #(Math/pow (- (predicted (:deltas %) strength)
(:observed %)) 2) input-data)))

(defn -update_freshness [this raw-events raw-strengths]
(let [event-deltas (->> raw-events
(map (fn [{timestamp "time" song-id "song_id"}]
{song-id [(/ (- (now) (parse-timestamp timestamp)) 86400)]}))
(apply merge-with concat))
strength (into {} (map (fn [{mem-strength "mem_strength" song-id "song_id"}]
[song-id mem-strength]) raw-strengths))]
(swap! (.state this)
(fn [state]
(assoc state :candidates
(into {} (for [[cand-id data] (:candidates state)]
[cand-id
(assoc data :freshness
(predicted
(event-deltas cand-id)
(get strength cand-id 3)))])))))))

;(def strength-set [0.2 0.5 1 1.5 2 3 5 8 13 21])
(defrecord Event [day skipped])
(def strength-set [0.5 3 14])
(defn -calc_strength [raw-events]
(let [event-vec (map #(Event. (/ (parse-timestamp (get % "time"))
86400)
(= (get % "skipped") 1))
raw-events)
get-deltas (fn [i event] {:observed (if (:skipped event) 0 1)
:deltas (map #(- (:day event) (:day %))
(take i event-vec))})
input-data (map-indexed get-deltas event-vec)]
(apply min-key #(total-error input-data %) strength-set)))

(defn -parse_top_tracks [response]
(let [data (json/read-str response)]
(map (fn [item] {"spotify_id" (get item "uri")
Expand Down Expand Up @@ -348,6 +362,7 @@
(get-in data ["tracks" "items" 0 "uri"])))



;(defn demo-real-data []
; (println "Press Enter to start demo")
; (read-line)
Expand Down Expand Up @@ -382,22 +397,7 @@
; (recur (.pick_next rec false) (rest actions))))
; (.pick_random rec false)))

(defn test-serial []
(let [fout (new java.io.FileOutputStream "/tmp/foobar")
oos (new java.io.ObjectOutputStream fout)
rec (new reco.reco [a])]
(.writeObject oos @@rec)
(.close oos)
(.close fout)
(let [fin (new java.io.FileInputStream "/tmp/foobar")
ois (new java.io.ObjectInputStream fin)
state (.readObject ois)]
(.close ois)
(.close fin)
(assert (= state @@rec)))))

(defn -main [& args]
(println "starting up")
(test-serial)
;(demo-real-data)
(println "all tests pass"))
24 changes: 18 additions & 6 deletions src/com/jacobobryant/moody/Moody.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
import static android.database.Cursor.FIELD_TYPE_STRING;

public class Moody {
public static final String STATE_FILE = "reco_state";
private static Moody instance;
private Context context;
public static final String AUTHORITY = "com.jacobobryant.moody.vanilla";
Expand Down Expand Up @@ -100,9 +99,16 @@ private void init(InitProgressListener listener) {
add_to_library(context, songs);
result.close();

SQLiteDatabase db = new Database(context).getReadableDatabase();
// update all strengths
//listener.update("updating memory strengths");
//for (Map<String, Object> record : cursor_to_maps(db.rawQuery(
// "select distinct song_id from events", null))) {
// update_strength((int)record.get("song_id"));
//}

// get library
listener.update("setting up recommendation engine");
SQLiteDatabase db = new Database(context).getReadableDatabase();
rec = new reco(cursor_to_maps(
db.rawQuery("SELECT _id, artist, album, title, source, " +
"spotify_id, duration, mem_strength FROM songs", null)));
Expand Down Expand Up @@ -181,9 +187,9 @@ public Map get_model(SQLiteDatabase db, long id, String artist) {
List artist_model = cursor_to_maps(db.rawQuery(
"select artist_a, artist_b, score from artist_model where artist_a = ?1 or artist_b = ?1",
new String[] {artist}));
return rec.modelify(song_model, id, artist_model, artist);
return reco.modelify(song_model, id, artist_model, artist);
} else {
return rec.modelify(song_model, id);
return reco.modelify(song_model, id);
}
}

Expand Down Expand Up @@ -241,16 +247,22 @@ private void add_event(long song_id, Long event_id, boolean skipped, Long second
} else {
rec.add_event(get_model(db, -1, null), -1, false, seconds, do_update);
}
Log.d(C.TAG, "updating freshness");
rec.update_freshness(cursor_to_maps(db.rawQuery(
"select song_id, time from events", null)),
cursor_to_maps(db.rawQuery(
"select mem_strength, _id from songs where mem_strength is not null",
null)));
}
db.close();
}

private void update_strength(int song_id) {
SQLiteDatabase db = new Database(context).getWritableDatabase();
double strength = rec.calc_strength(song_id, cursor_to_maps(db.rawQuery(
double strength = rec.calc_strength(cursor_to_maps(db.rawQuery(
"select time, skipped from events where song_id = ?",
new String[] {String.valueOf(song_id)})));
db.execSQL("update songs set mem_strength = ? where song_id = ?",
db.execSQL("update songs set mem_strength = ? where _id = ?",
new String[] {String.valueOf(strength), String.valueOf(song_id)});
db.close();
}
Expand Down

0 comments on commit d0a71ec

Please sign in to comment.