From f50e866526af0948dacfc8f860992dd5c5819930 Mon Sep 17 00:00:00 2001 From: Wilson Lin Date: Mon, 29 Apr 2024 06:01:19 +1000 Subject: [PATCH] Render heatmap --- api/main.py | 145 +++++++++++++++++++++---------------- app/component/App.tsx | 54 ++++++++++---- app/component/PointMap.tsx | 5 +- app/util/map.ts | 41 ++++++++++- 4 files changed, 163 insertions(+), 82 deletions(-) diff --git a/api/main.py b/api/main.py index ef4988b..6bb7903 100644 --- a/api/main.py +++ b/api/main.py @@ -10,6 +10,7 @@ from fastapi import Response from fastapi.middleware.cors import CORSMiddleware from FlagEmbedding import BGEM3FlagModel +from io import BytesIO from PIL import Image from pydantic import BaseModel from scipy.ndimage import gaussian_filter @@ -43,6 +44,10 @@ class Dataset: table: pd.DataFrame # We store this separately from the DataFrame because we need it to be a continguous matrix, and .to_numpy() just creates a NumPy array of NumPy array objects. emb_mat: np.ndarray + x_min: float + x_max: float + y_min: float + y_max: float index: hnswlib.Index @@ -83,69 +88,74 @@ def load_umap(ids: npt.NDArray[np.uint32], name: str): def load_jinav2small_umap(): df_posts = load_post_embs_table() df_comments = load_comment_embs_table() - mat_ids = merge_posts_and_comments(posts=df_posts, comments=df_comments)[ - "id" - ].to_numpy() + df = merge_posts_and_comments(posts=df_posts, comments=df_comments) + mat_ids = df["id"].to_numpy() return load_umap(mat_ids, "hnsw_n50_d0.25") def load_posts_data(): - df_posts = load_table("posts", columns=["id", "score", "ts"]) - df_posts = df_posts.merge(load_post_embs_table(), on="id", how="inner") - df_posts = df_posts.merge(load_jinav2small_umap(), on="id", how="inner") - df_posts = normalize_dataset(df_posts) - print("Posts:", len(df_posts)) + df = load_table("posts", columns=["id", "score", "ts"]) + df = df.merge(load_post_embs_table(), on="id", how="inner") + df = df.merge(load_jinav2small_umap(), on="id", how="inner") + df = normalize_dataset(df) + print("Posts:", len(df)) return Dataset( - model=model_jinav2small, - table=df_posts, - emb_mat=np.vstack(df_posts.pop("emb")), + emb_mat=np.vstack(df.pop("emb")), index=load_hnsw_index("posts", 512), + model=model_jinav2small, + table=df, + x_max=df["x"].max(), + x_min=df["x"].min(), + y_max=df["y"].max(), + y_min=df["y"].min(), ) def load_posts_bgem3_data(): - df_posts_bgem3 = load_table("posts", columns=["id", "score", "ts"]) - df_posts_bgem3 = df_posts_bgem3.merge( - load_post_embs_bgem3_table(), on="id", how="inner" - ) - df_posts_bgem3 = df_posts_bgem3.merge( - load_umap(df_posts_bgem3["id"].to_numpy(), "hnsw-bgem3_n300_d0.25"), - on="id", - how="inner", - ) - df_posts_bgem3 = normalize_dataset(df_posts_bgem3) - print("Posts bgem3:", len(df_posts_bgem3)) + df = load_table("posts", columns=["id", "score", "ts"]) + df_embs = load_post_embs_bgem3_table() + df = df.merge(df_embs, on="id", how="inner") + df_umap = load_umap(df_embs["id"].to_numpy(), "hnsw-bgem3_n300_d0.25") + df = df.merge(df_umap, on="id", how="inner") + df = normalize_dataset(df) + print("Posts bgem3:", len(df)) return Dataset( - model=model_bgem3, - table=df_posts_bgem3, - emb_mat=np.vstack(df_posts_bgem3.pop("emb")), + emb_mat=np.vstack(df.pop("emb")), index=load_hnsw_index("posts_bgem3", 1024), + model=model_bgem3, + table=df, + x_max=df["x"].max(), + x_min=df["x"].min(), + y_max=df["y"].max(), + y_min=df["y"].min(), ) def load_comments_data(): - df_comments = load_table("comments", columns=["id", "score", "ts"]) - df_comments = df_comments.merge(load_comment_embs_table(), on="id", how="inner") - df_comments = df_comments.merge(load_jinav2small_umap(), on="id", how="inner") - df_comments = df_comments.merge( - load_table("comment_sentiments"), on="id", how="inner" - ) - df_comments["sentiment_weight"] = np.where( - df_comments["negative"] > df_comments[["neutral", "positive"]].max(axis=1), - -df_comments["negative"], + df = load_table("comments", columns=["id", "score", "ts"]) + df = df.merge(load_comment_embs_table(), on="id", how="inner") + df = df.merge(load_jinav2small_umap(), on="id", how="inner") + df = df.merge(load_table("comment_sentiments"), on="id", how="inner") + df["sentiment_weight"] = np.where( + df["negative"] > df[["neutral", "positive"]].max(axis=1), + -df["negative"], np.where( - df_comments["neutral"] > df_comments[["positive"]].max(axis=1), + df["neutral"] > df[["positive"]].max(axis=1), 0, - df_comments["positive"], + df["positive"], ), ) - df_comments = normalize_dataset(df_comments) - print("Comments:", len(df_comments)) + df = normalize_dataset(df) + print("Comments:", len(df)) return Dataset( - model=model_jinav2small, - table=df_comments, - emb_mat=np.vstack(df_comments.pop("emb")), + emb_mat=np.vstack(df.pop("emb")), index=load_hnsw_index("comments", 512), + model=model_jinav2small, + table=df, + x_max=df["x"].max(), + x_min=df["x"].min(), + y_max=df["y"].max(), + y_min=df["y"].min(), ) @@ -192,29 +202,32 @@ class Clip(BaseModel): class HeatmapOutput(BaseModel): - width: int # Max 1024. - height: int # Max 1024. + density: float color: Tuple[int, int, int] alpha_min: float = 0.0 alpha_max: float = 1.0 - sigma: int = 1 # Max 4. + sigma: int = 1 upscale: int = 1 # Max 4. - def calculate(self, df: pd.DataFrame): - xmin, xmax = df["x"].min(), df["x"].max() - x_range = xmax - xmin - ymin, ymax = df["y"].min(), df["y"].max() - y_range = ymax - ymin + def calculate(self, d: Dataset, df: pd.DataFrame): + # Make sure to use the range of the whole dataset, not just this subset. + x_range = d.x_max - d.x_min + y_range = d.y_max - d.y_min + + grid_width = int(x_range * self.density) + grid_height = int(y_range * self.density) - scale_x = self.width / x_range - scale_y = self.height / y_range df = df.assign( - grid_x=((df["x"] - xmin) * scale_x).clip(upper=self.width - 1).astype(int), + grid_x=((df["x"] - d.x_min) * self.density) + .clip(upper=grid_width - 1) + .astype(int), # Images are stored top-to-bottom, so we need to flip the y-axis - grid_y=((ymax - df["y"]) * scale_y).clip(upper=self.height - 1).astype(int), + grid_y=((d.y_max - df["y"]) * self.density) + .clip(upper=grid_height - 1) + .astype(int), ) - alpha_grid = np.zeros((self.height, self.width), dtype=np.float32) + alpha_grid = np.zeros((grid_height, grid_width), dtype=np.float32) alpha_grid[df["grid_y"], df["grid_x"]] = df["final_score"] alpha_grid = alpha_grid.repeat(self.upscale, axis=0).repeat( self.upscale, axis=1 @@ -223,25 +236,28 @@ def calculate(self, df: pd.DataFrame): blur = blur * (self.alpha_max - self.alpha_min) + self.alpha_min img = np.full( - (self.height * self.upscale, self.width * self.upscale, 4), + (grid_height * self.upscale, grid_width * self.upscale, 4), (*self.color, 0), dtype=np.uint8, ) img[:, :, 3] = (blur * 255).astype(np.uint8) - webp = Image.fromarray(img, "RGBA").tobytes("webp") + webp_out = BytesIO() + Image.fromarray(img, "RGBA").save(webp_out, format="webp") + webp = webp_out.getvalue() return struct.pack(" 0.0] + df["final_score"] = df["final_score"].astype("float32") df = df.sort_values("group", ascending=True) return pack_rows(df, ["group", "final_score"]) @@ -272,13 +289,13 @@ class Output(BaseModel): heatmap: Optional[HeatmapOutput] = None items: Optional[ItemsOutput] = None - def calculate(self, df: pd.DataFrame): + def calculate(self, d: Dataset, df: pd.DataFrame): if self.group_by is not None: - return self.group_by.calculate(df) + return self.group_by.calculate(d, df) if self.heatmap is not None: - return self.heatmap.calculate(df) + return self.heatmap.calculate(d, df) if self.items is not None: - return self.items.calculate(df) + return self.items.calculate(d, df) assert False @@ -383,5 +400,5 @@ def query(input: QueryInput): out = b"" for o in input.outputs: - out += o.calculate(df) + out += o.calculate(d, df) return Response(out) diff --git a/app/component/App.tsx b/app/component/App.tsx index becb455..a6720e6 100644 --- a/app/component/App.tsx +++ b/app/component/App.tsx @@ -79,13 +79,12 @@ const apiCall = async ( } | { heatmap: { - width: number; - height: number; + density: number; color: [number, number, number]; - alpha_min: number; - alpha_max: number; - sigma: number; - upscale: number; + alpha_min?: number; + alpha_max?: number; + sigma?: number; + upscale?: number; }; } | { @@ -112,7 +111,7 @@ const apiCall = async ( const payload = await res.arrayBuffer(); const dv = new DataView(payload); let i = 0; - return req.outputs.map((out) => { + const out = req.outputs.map((out) => { if ("group_by" in out || "items" in out) { const count = dv.getUint32(i, true); i += 4; @@ -130,6 +129,8 @@ const apiCall = async ( } throw new UnreachableError(); }); + assertState(i === payload.byteLength); + return out; }; type ApiResponse = Awaited>; @@ -156,16 +157,21 @@ export const App = () => { apiCall(signal, { dataset: "posts_bgem3", queries: [query.query], - sim_scale: { min: 0, max: 1 }, + sim_scale: { min: 0.55, max: 0.75 }, ts_weight_decay: query.decayTimestamp, outputs: [ { items: { - order_by: "final_score", - order_asc: false, limit: 10, }, }, + { + heatmap: { + density: 100, // TODO Dynamic based on viewport size. + color: [66, 207, 115], + upscale: 1, + }, + }, ], weights: { sim: query.weightSimilarity, @@ -185,6 +191,22 @@ export const App = () => { ]), [queryReq.data], ); + const [heatmap, setHeatmap] = useState(); + + useEffect(() => { + if (!queryReq.data) { + setHeatmap(undefined); + return; + } + const blob = assertInstanceOf(queryReq.data[1], ApiHeatmapOutput).blob(); + const ac = new AbortController(); + (async () => { + const heatmap = await createImageBitmap(blob); + ac.signal.throwIfAborted(); + setHeatmap(heatmap); + })(); + return () => ac.abort(); + }, [queryReq.data]); const [items, setItems] = useState>({}); useEffect(() => { @@ -203,7 +225,11 @@ export const App = () => { return (
- +
{ @@ -265,7 +291,7 @@ export const App = () => { @@ -276,7 +302,7 @@ export const App = () => { diff --git a/app/component/PointMap.tsx b/app/component/PointMap.tsx index 1678e29..ad39ba9 100644 --- a/app/component/PointMap.tsx +++ b/app/component/PointMap.tsx @@ -31,14 +31,17 @@ const EDGES = [ ] as const; export const PointMap = ({ + heatmap, height: vpHeightPx, width: vpWidthPx, }: { + heatmap: ImageBitmap | undefined; height: number; width: number; }) => { const $canvas = useRef(null); const [map, setMap] = useState>(); + useEffect(() => map?.setHeatmap(heatmap), [map, heatmap]); const [meta, setMeta] = useState(); useEffect(() => { const ac = new AbortController(); @@ -59,7 +62,7 @@ export const PointMap = ({ const fetchMeta = async () => { const res = await fetch( - `https://us-ashburn-1.edge-hndr.wilsonl.in/map/hnsw/meta`, + `https://us-ashburn-1.edge-hndr.wilsonl.in/map/hnsw-bgem3/meta`, { signal: ac.signal }, ); const raw = await res.arrayBuffer(); diff --git a/app/util/map.ts b/app/util/map.ts index c3ecc3f..a8b625f 100644 --- a/app/util/map.ts +++ b/app/util/map.ts @@ -146,12 +146,14 @@ export const parseTileData = (raw: ArrayBuffer) => { i += 4 * count; const scores = new Int16Array(raw, i, count); i += 2 * count; - return Array.from({ length: count }, (_, j) => ({ + const res = Array.from({ length: count }, (_, j) => ({ id: ids[j], x: xs[j], y: ys[j], score: scores[j], })); + assertState(i === raw.byteLength); + return res; }; export const cachedFetchTile = async ( @@ -162,7 +164,7 @@ export const cachedFetchTile = async ( y: number, ) => { const res = await cachedFetch( - `https://${edge}.edge-hndr.wilsonl.in/map/hnsw/tile/${lod}/${x}-${y}`, + `https://${edge}.edge-hndr.wilsonl.in/map/hnsw-bgem3/tile/${lod}/${x}-${y}`, signal, // Not all tiles exist (i.e. no points exist). "except-404", @@ -246,6 +248,7 @@ export const createCanvasPointMap = ({ let latestRenderRequestId = 0; let curPoints = Array(); // This must always be sorted by score descending. let curViewport: ViewportState | undefined; + let heatmap: ImageBitmap | undefined; // Zoom (integer) level => point IDs. const labelledPoints = new Dict>(); @@ -309,11 +312,39 @@ export const createCanvasPointMap = ({ if (!vp) { return; } + const scale = viewportScale(vp); + const scaled = scale.scaled(vp); const ctx = assertExists(canvas.getContext("2d")); ctx.clearRect(0, 0, canvas.width, canvas.height); ctx.fillStyle = "#fcfcfc"; ctx.fillRect(0, 0, canvas.width, canvas.height); - const scale = viewportScale(vp); + if (heatmap) { + let dx = 0; + let dy = 0; + let sx = (heatmap.width * (vp.x0Pt - map.xMinPt)) / map.xRangePt; + let sy = (heatmap.height * (vp.y0Pt - map.yMinPt)) / map.yRangePt; + let sWidth = (heatmap.width * (scaled.x1Pt - vp.x0Pt)) / map.xRangePt; + let sHeight = (heatmap.height * (scaled.y1Pt - vp.y0Pt)) / map.yRangePt; + if (sx < 0) { + dx = -sx; + sx = 0; + } + if (sy < 0) { + dy = -sy; + sy = 0; + } + ctx.drawImage( + heatmap, + sx, + sy, + sWidth, + sHeight, + dx, + dy, + canvas.width, + canvas.height, + ); + } const lp = labelledPoints.get(Math.floor(vp.zoom)); for (const p of curPoints) { const scoreWeight = Math.max( @@ -361,6 +392,10 @@ export const createCanvasPointMap = ({ destroy: () => { abortController.abort(); }, + setHeatmap: (hm: ImageBitmap | undefined) => { + heatmap = hm; + renderPoints(); + }, // Render the points at LOD `lod` from (ptX0, ptY0) to (ptX1, ptY1) (inclusive) on the canvas. render: async (newViewport: ViewportState) => { const requestId = ++latestRenderRequestId;