Skip to content

Commit

Permalink
Render heatmap
Browse files Browse the repository at this point in the history
  • Loading branch information
wilsonzlin committed Apr 28, 2024
1 parent c9cda0d commit f50e866
Show file tree
Hide file tree
Showing 4 changed files with 163 additions and 82 deletions.
145 changes: 81 additions & 64 deletions api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from fastapi import Response
from fastapi.middleware.cors import CORSMiddleware
from FlagEmbedding import BGEM3FlagModel
from io import BytesIO
from PIL import Image
from pydantic import BaseModel
from scipy.ndimage import gaussian_filter
Expand Down Expand Up @@ -43,6 +44,10 @@ class Dataset:
table: pd.DataFrame
# We store this separately from the DataFrame because we need it to be a continguous matrix, and .to_numpy() just creates a NumPy array of NumPy array objects.
emb_mat: np.ndarray
x_min: float
x_max: float
y_min: float
y_max: float
index: hnswlib.Index


Expand Down Expand Up @@ -83,69 +88,74 @@ def load_umap(ids: npt.NDArray[np.uint32], name: str):
def load_jinav2small_umap():
df_posts = load_post_embs_table()
df_comments = load_comment_embs_table()
mat_ids = merge_posts_and_comments(posts=df_posts, comments=df_comments)[
"id"
].to_numpy()
df = merge_posts_and_comments(posts=df_posts, comments=df_comments)
mat_ids = df["id"].to_numpy()
return load_umap(mat_ids, "hnsw_n50_d0.25")


def load_posts_data():
df_posts = load_table("posts", columns=["id", "score", "ts"])
df_posts = df_posts.merge(load_post_embs_table(), on="id", how="inner")
df_posts = df_posts.merge(load_jinav2small_umap(), on="id", how="inner")
df_posts = normalize_dataset(df_posts)
print("Posts:", len(df_posts))
df = load_table("posts", columns=["id", "score", "ts"])
df = df.merge(load_post_embs_table(), on="id", how="inner")
df = df.merge(load_jinav2small_umap(), on="id", how="inner")
df = normalize_dataset(df)
print("Posts:", len(df))
return Dataset(
model=model_jinav2small,
table=df_posts,
emb_mat=np.vstack(df_posts.pop("emb")),
emb_mat=np.vstack(df.pop("emb")),
index=load_hnsw_index("posts", 512),
model=model_jinav2small,
table=df,
x_max=df["x"].max(),
x_min=df["x"].min(),
y_max=df["y"].max(),
y_min=df["y"].min(),
)


def load_posts_bgem3_data():
df_posts_bgem3 = load_table("posts", columns=["id", "score", "ts"])
df_posts_bgem3 = df_posts_bgem3.merge(
load_post_embs_bgem3_table(), on="id", how="inner"
)
df_posts_bgem3 = df_posts_bgem3.merge(
load_umap(df_posts_bgem3["id"].to_numpy(), "hnsw-bgem3_n300_d0.25"),
on="id",
how="inner",
)
df_posts_bgem3 = normalize_dataset(df_posts_bgem3)
print("Posts bgem3:", len(df_posts_bgem3))
df = load_table("posts", columns=["id", "score", "ts"])
df_embs = load_post_embs_bgem3_table()
df = df.merge(df_embs, on="id", how="inner")
df_umap = load_umap(df_embs["id"].to_numpy(), "hnsw-bgem3_n300_d0.25")
df = df.merge(df_umap, on="id", how="inner")
df = normalize_dataset(df)
print("Posts bgem3:", len(df))
return Dataset(
model=model_bgem3,
table=df_posts_bgem3,
emb_mat=np.vstack(df_posts_bgem3.pop("emb")),
emb_mat=np.vstack(df.pop("emb")),
index=load_hnsw_index("posts_bgem3", 1024),
model=model_bgem3,
table=df,
x_max=df["x"].max(),
x_min=df["x"].min(),
y_max=df["y"].max(),
y_min=df["y"].min(),
)


def load_comments_data():
df_comments = load_table("comments", columns=["id", "score", "ts"])
df_comments = df_comments.merge(load_comment_embs_table(), on="id", how="inner")
df_comments = df_comments.merge(load_jinav2small_umap(), on="id", how="inner")
df_comments = df_comments.merge(
load_table("comment_sentiments"), on="id", how="inner"
)
df_comments["sentiment_weight"] = np.where(
df_comments["negative"] > df_comments[["neutral", "positive"]].max(axis=1),
-df_comments["negative"],
df = load_table("comments", columns=["id", "score", "ts"])
df = df.merge(load_comment_embs_table(), on="id", how="inner")
df = df.merge(load_jinav2small_umap(), on="id", how="inner")
df = df.merge(load_table("comment_sentiments"), on="id", how="inner")
df["sentiment_weight"] = np.where(
df["negative"] > df[["neutral", "positive"]].max(axis=1),
-df["negative"],
np.where(
df_comments["neutral"] > df_comments[["positive"]].max(axis=1),
df["neutral"] > df[["positive"]].max(axis=1),
0,
df_comments["positive"],
df["positive"],
),
)
df_comments = normalize_dataset(df_comments)
print("Comments:", len(df_comments))
df = normalize_dataset(df)
print("Comments:", len(df))
return Dataset(
model=model_jinav2small,
table=df_comments,
emb_mat=np.vstack(df_comments.pop("emb")),
emb_mat=np.vstack(df.pop("emb")),
index=load_hnsw_index("comments", 512),
model=model_jinav2small,
table=df,
x_max=df["x"].max(),
x_min=df["x"].min(),
y_max=df["y"].max(),
y_min=df["y"].min(),
)


Expand Down Expand Up @@ -192,29 +202,32 @@ class Clip(BaseModel):


class HeatmapOutput(BaseModel):
width: int # Max 1024.
height: int # Max 1024.
density: float
color: Tuple[int, int, int]
alpha_min: float = 0.0
alpha_max: float = 1.0
sigma: int = 1 # Max 4.
sigma: int = 1
upscale: int = 1 # Max 4.

def calculate(self, df: pd.DataFrame):
xmin, xmax = df["x"].min(), df["x"].max()
x_range = xmax - xmin
ymin, ymax = df["y"].min(), df["y"].max()
y_range = ymax - ymin
def calculate(self, d: Dataset, df: pd.DataFrame):
# Make sure to use the range of the whole dataset, not just this subset.
x_range = d.x_max - d.x_min
y_range = d.y_max - d.y_min

grid_width = int(x_range * self.density)
grid_height = int(y_range * self.density)

scale_x = self.width / x_range
scale_y = self.height / y_range
df = df.assign(
grid_x=((df["x"] - xmin) * scale_x).clip(upper=self.width - 1).astype(int),
grid_x=((df["x"] - d.x_min) * self.density)
.clip(upper=grid_width - 1)
.astype(int),
# Images are stored top-to-bottom, so we need to flip the y-axis
grid_y=((ymax - df["y"]) * scale_y).clip(upper=self.height - 1).astype(int),
grid_y=((d.y_max - df["y"]) * self.density)
.clip(upper=grid_height - 1)
.astype(int),
)

alpha_grid = np.zeros((self.height, self.width), dtype=np.float32)
alpha_grid = np.zeros((grid_height, grid_width), dtype=np.float32)
alpha_grid[df["grid_y"], df["grid_x"]] = df["final_score"]
alpha_grid = alpha_grid.repeat(self.upscale, axis=0).repeat(
self.upscale, axis=1
Expand All @@ -223,25 +236,28 @@ def calculate(self, df: pd.DataFrame):
blur = blur * (self.alpha_max - self.alpha_min) + self.alpha_min

img = np.full(
(self.height * self.upscale, self.width * self.upscale, 4),
(grid_height * self.upscale, grid_width * self.upscale, 4),
(*self.color, 0),
dtype=np.uint8,
)
img[:, :, 3] = (blur * 255).astype(np.uint8)

webp = Image.fromarray(img, "RGBA").tobytes("webp")
webp_out = BytesIO()
Image.fromarray(img, "RGBA").save(webp_out, format="webp")
webp = webp_out.getvalue()
return struct.pack("<I", len(webp)) + webp


class ItemsOutput(BaseModel):
order_by: str = "id"
order_by: str = "final_score"
order_asc: bool = False
limit: Optional[int] = None

def calculate(self, df: pd.DataFrame):
def calculate(self, d: Dataset, df: pd.DataFrame):
df = df.sort_values(self.order_by, ascending=self.order_asc)
if self.limit is not None:
df = df[: self.limit]
df["final_score"] = df["final_score"].astype("float32")
return pack_rows(df, ["id", "final_score"])


Expand All @@ -254,14 +270,15 @@ class GroupByOutput(BaseModel):
# mean, min, max, sum, count
group_final_score_agg: str = "sum"

def calculate(self, df: pd.DataFrame):
def calculate(self, d: Dataset, df: pd.DataFrame):
df = df.assign(
group=(df[self.group_by] // (self.group_bucket or 1.0)).astype("int32")
)
df = df.groupby("group", as_index=False).agg(
{"final_score": self.group_final_score_agg}
)
df = df[df["final_score"] > 0.0]
df["final_score"] = df["final_score"].astype("float32")
df = df.sort_values("group", ascending=True)
return pack_rows(df, ["group", "final_score"])

Expand All @@ -272,13 +289,13 @@ class Output(BaseModel):
heatmap: Optional[HeatmapOutput] = None
items: Optional[ItemsOutput] = None

def calculate(self, df: pd.DataFrame):
def calculate(self, d: Dataset, df: pd.DataFrame):
if self.group_by is not None:
return self.group_by.calculate(df)
return self.group_by.calculate(d, df)
if self.heatmap is not None:
return self.heatmap.calculate(df)
return self.heatmap.calculate(d, df)
if self.items is not None:
return self.items.calculate(df)
return self.items.calculate(d, df)
assert False


Expand Down Expand Up @@ -383,5 +400,5 @@ def query(input: QueryInput):

out = b""
for o in input.outputs:
out += o.calculate(df)
out += o.calculate(d, df)
return Response(out)
54 changes: 40 additions & 14 deletions app/component/App.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,12 @@ const apiCall = async (
}
| {
heatmap: {
width: number;
height: number;
density: number;
color: [number, number, number];
alpha_min: number;
alpha_max: number;
sigma: number;
upscale: number;
alpha_min?: number;
alpha_max?: number;
sigma?: number;
upscale?: number;
};
}
| {
Expand All @@ -112,7 +111,7 @@ const apiCall = async (
const payload = await res.arrayBuffer();
const dv = new DataView(payload);
let i = 0;
return req.outputs.map((out) => {
const out = req.outputs.map((out) => {
if ("group_by" in out || "items" in out) {
const count = dv.getUint32(i, true);
i += 4;
Expand All @@ -130,6 +129,8 @@ const apiCall = async (
}
throw new UnreachableError();
});
assertState(i === payload.byteLength);
return out;
};
type ApiResponse = Awaited<ReturnType<typeof apiCall>>;

Expand All @@ -156,16 +157,21 @@ export const App = () => {
apiCall(signal, {
dataset: "posts_bgem3",
queries: [query.query],
sim_scale: { min: 0, max: 1 },
sim_scale: { min: 0.55, max: 0.75 },
ts_weight_decay: query.decayTimestamp,
outputs: [
{
items: {
order_by: "final_score",
order_asc: false,
limit: 10,
},
},
{
heatmap: {
density: 100, // TODO Dynamic based on viewport size.
color: [66, 207, 115],
upscale: 1,
},
},
],
weights: {
sim: query.weightSimilarity,
Expand All @@ -185,6 +191,22 @@ export const App = () => {
]),
[queryReq.data],
);
const [heatmap, setHeatmap] = useState<ImageBitmap>();

useEffect(() => {
if (!queryReq.data) {
setHeatmap(undefined);
return;
}
const blob = assertInstanceOf(queryReq.data[1], ApiHeatmapOutput).blob();
const ac = new AbortController();
(async () => {
const heatmap = await createImageBitmap(blob);
ac.signal.throwIfAborted();
setHeatmap(heatmap);
})();
return () => ac.abort();
}, [queryReq.data]);

const [items, setItems] = useState<Record<number, Item>>({});
useEffect(() => {
Expand All @@ -203,7 +225,11 @@ export const App = () => {

return (
<div ref={setRootElem} className="App">
<PointMap height={rootDim?.height ?? 0} width={rootDim?.width ?? 0} />
<PointMap
heatmap={heatmap}
height={rootDim?.height ?? 0}
width={rootDim?.width ?? 0}
/>

<form
ref={$form}
Expand Down Expand Up @@ -254,7 +280,7 @@ export const App = () => {
<input
name="w_sim"
type="number"
defaultValue={0.4}
defaultValue={0.8}
step={0.00001}
/>
</label>
Expand All @@ -265,7 +291,7 @@ export const App = () => {
<input
name="w_score"
type="number"
defaultValue={0.4}
defaultValue={0.1}
step={0.00001}
/>
</label>
Expand All @@ -276,7 +302,7 @@ export const App = () => {
<input
name="w_ts"
type="number"
defaultValue={0.2}
defaultValue={0.1}
step={0.00001}
/>
</label>
Expand Down
Loading

0 comments on commit f50e866

Please sign in to comment.