Skip to content

Commit

Permalink
Use new API worker types
Browse files Browse the repository at this point in the history
  • Loading branch information
wilsonzlin committed May 4, 2024
1 parent 331910b commit 91daadc
Show file tree
Hide file tree
Showing 6 changed files with 42 additions and 39 deletions.
6 changes: 5 additions & 1 deletion Dockerfile.runpod-base
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,12 @@ RUN curl -fLSs https://deb.nodesource.com/setup_21.x | bash - && apt install -yq

WORKDIR /app

# https://docs.cupy.dev/en/stable/install.html#installing-cupy-from-pypi
RUN pip install cupy-cuda11x
RUN python -m cupyx.tools.install_library --cuda 11.x --library cutensor
# https://docs.rapids.ai/install
RUN pip install --extra-index-url=https://pypi.nvidia.com cudf-cu11==24.4.*
COPY requirements.txt .
RUN pip install cupy
RUN pip install -r requirements.txt
# TODO HACK to disable TQDM from trashing our logs. (TQDM_DISABLE=1 doesn't seem to work.)
RUN sed -i 's%tqdm(%(lambda x, **o: x)(%' /usr/local/lib/python3.10/dist-packages/FlagEmbedding/bge_m3.py
Expand Down
41 changes: 12 additions & 29 deletions api-worker-broker/main.ts
Original file line number Diff line number Diff line change
@@ -1,19 +1,11 @@
import { decode, encode } from "@msgpack/msgpack";
import {
VBytes,
VFiniteNumber,
VInteger,
VObjectMap,
VString,
VStruct,
Valid,
} from "@wzlin/valid";
import { VInteger, VString, VStruct, VUnknown, Valid } from "@wzlin/valid";
import Dict from "@xtjs/lib/Dict";
import assertExists from "@xtjs/lib/assertExists";
import assertInstanceOf from "@xtjs/lib/assertInstanceOf";
import assertState from "@xtjs/lib/assertState";
import randomPick from "@xtjs/lib/randomPick";
import readStringStream from "@xtjs/lib/readStringStream";
import readBufferStream from "@xtjs/lib/readBufferStream";
import http from "http";
import https from "https";
import { WebSocket, WebSocketServer } from "ws";
Expand All @@ -25,31 +17,25 @@ let nextReqId = 0;
const reqs = new Dict<
number,
{
resolve: (res: ApiResponseBody) => void;
resolve: (res: any) => void;
reject: (err: any) => void;
}
>();
const connToReq = new WeakMap<WebSocket, Set<number>>();

type ApiResponseBody = {
embeddingDense: Uint8Array;
embeddingSparse: Record<string, number>;
};

const vNodeInitMessage = new VStruct({
ip: new VString(),
token: new VString(),
});

const vMessageToNode = new VStruct({
id: new VInteger(0),
text: new VString(),
input: new VUnknown(),
});

const vMessageToBroker = new VStruct({
id: new VInteger(0),
emb_dense: new VBytes(1024 * 4, 1024 * 4),
emb_sparse: new VObjectMap(new VFiniteNumber()),
output: new VUnknown(),
});

const wsServer = https.createServer({
Expand Down Expand Up @@ -91,10 +77,7 @@ ws.on("connection", (conn) => {
decode(assertInstanceOf(raw, Buffer)),
);
connToReq.get(conn)!.delete(msg.id);
reqs.remove(msg.id)?.resolve({
embeddingDense: msg.emb_dense,
embeddingSparse: msg.emb_sparse,
});
reqs.remove(msg.id)?.resolve(msg.output);
});
});
conn.on("close", () => {
Expand All @@ -105,16 +88,16 @@ ws.on("connection", (conn) => {
});
});

const sendToNode = (text: string) =>
new Promise<ApiResponseBody>((resolve, reject) => {
const sendToNode = (input: any) =>
new Promise((resolve, reject) => {
const id = nextReqId++;
const conn = randomPick([...ws.clients]);
if (!conn) {
return reject(new Error("No node available"));
}
reqs.set(id, { resolve, reject });
connToReq.get(conn)!.add(id);
const msg: Valid<typeof vMessageToNode> = { id, text };
const msg: Valid<typeof vMessageToNode> = { id, input };
conn.send(encode(msg), { binary: true });
});

Expand All @@ -123,15 +106,15 @@ http
if (req.method !== "POST") {
return res.writeHead(405).end();
}
let text;
let input;
try {
text = await readStringStream(req);
input = decode(await readBufferStream(req));
} catch (err) {
return res.writeHead(400).end(err.message);
}
let resBody;
try {
resBody = await sendToNode(text);
resBody = await sendToNode(input);
} catch (err) {
return res.writeHead(500).end(err.message);
}
Expand Down
16 changes: 11 additions & 5 deletions api-worker-node/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,6 @@ def calculate(self, d: ApiDataset, df: pd.DataFrame):
@serde
@dataclass
class QueryInput:
id: int

dataset: str
outputs: List[Output]
queries: List[str] = field(default_factory=list)
Expand Down Expand Up @@ -191,12 +189,20 @@ class QueryInput:
post_filter_clip: Dict[str, Clip] = field(default_factory=dict)


@serde
@dataclass
class BrokerRequest:
id: int
input: QueryInput


def on_error(ws, error):
print("WS error:", error)


def on_message(ws, raw):
input = from_msgpack(QueryInput, raw)
req = from_msgpack(BrokerRequest, raw)
input = req.input

# Perform checks before expensive embedding encoding.
d, model, ann_idx = datasets[input.dataset]
Expand Down Expand Up @@ -270,8 +276,8 @@ def assign_then_post_filter(col: str, new):
ws.send(
msgpack.packb(
{
"id": input.id,
"outputs": out,
"id": req.id,
"output": out,
}
),
opcode=websocket.ABNF.OPCODE_BINARY,
Expand Down
10 changes: 8 additions & 2 deletions common/api_data.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from common.data import dump_mmap_matrix
from common.data import load_mmap_matrix
from common.data import load_mmap_matrix_to_gpu
from dataclasses import dataclass
import json
import numpy as np
Expand Down Expand Up @@ -37,15 +38,20 @@ def dump(self):
)

@staticmethod
def load(name: str):
def load(name: str, *, to_gpu=False):
pfx = f"/hndr-data/api-{name}"
with open(f"{pfx}-meta.json", "r") as f:
meta = json.load(f)
count = meta.pop("count")
emb_dim = meta.pop("emb_dim")
table = pyarrow.feather.read_feather(f"{pfx}-table.feather", memory_map=True)
assert type(table) == pd.DataFrame
emb_mat = load_mmap_matrix(f"api-{name}-emb", (count, emb_dim), np.float32)
if to_gpu:
emb_mat = load_mmap_matrix_to_gpu(
f"api-{name}-emb", (count, emb_dim), np.float32, np.float16
)
else:
emb_mat = load_mmap_matrix(f"api-{name}-emb", (count, emb_dim), np.float32)
return ApiDataset(
name=name,
table=table,
Expand Down
2 changes: 1 addition & 1 deletion common/terrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def render_terrain(
"y": ((ys - y_min) * dpi).clip(0, grid_height - 1).astype("int32"),
}
)
gv = gv.groupby(["x", "y"]).value_counts().reset_index(name="density")
gv = gv.groupby(["x", "y"]).size().reset_index(name="density")
if use_log_scale:
gv["density"] = np.log(gv["density"] + 1)

Expand Down
6 changes: 5 additions & 1 deletion runpod/docker.entrypoint.sh
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,12 @@ service ssh start
# Wait for log collector to start, as it won't export existing log entries before it starts.
sleep 5

if [[ -d /workspace ]]; then
ln -s /workspace/hndr-data /hndr-data
fi

if [[ -f /app/$MAIN/main.py ]]; then
python3 /app/$MAIN/main.py |& tee /app.log
python3 -m cudf.pandas /app/$MAIN/main.py |& tee /app.log
else
# We cannot use ts-node as it doesn't support node:worker.
node /app/$MAIN/main.js |& tee /app.log
Expand Down

0 comments on commit 91daadc

Please sign in to comment.