Skip to content

Commit

Permalink
cloudflare[minor]: Adds Cloudflare D1 checkpointer (langchain-ai#6212)
Browse files Browse the repository at this point in the history
* Adds Cloudflare D1 checkpointer

* Fix lint + format
  • Loading branch information
jacoblee93 authored Jul 25, 2024
1 parent 58da38f commit 019423e
Show file tree
Hide file tree
Showing 5 changed files with 264 additions and 3 deletions.
4 changes: 4 additions & 0 deletions libs/langchain-cloudflare/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@ index.cjs
index.js
index.d.ts
index.d.cts
langgraph/checkpointers.cjs
langgraph/checkpointers.js
langgraph/checkpointers.d.ts
langgraph/checkpointers.d.cts
node_modules
dist
.yarn
3 changes: 2 additions & 1 deletion libs/langchain-cloudflare/langchain.config.js
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@ function abs(relativePath) {


export const config = {
internals: [/node\:/, /@langchain\/core\//],
internals: [/node\:/, /@langchain\/core\//, /@langchain\/langgraph\/web/],
entrypoints: {
index: "index",
"langgraph/checkpointers": "langgraph/checkpointers",
},
tsConfigPath: resolve("./tsconfig.json"),
cjsSource: "./dist-cjs",
Expand Down
26 changes: 24 additions & 2 deletions libs/langchain-cloudflare/package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "@langchain/cloudflare",
"version": "0.0.6",
"version": "0.0.7-rc.0",
"description": "Cloudflare integration for LangChain.js",
"type": "module",
"engines": {
Expand Down Expand Up @@ -42,6 +42,7 @@
"devDependencies": {
"@cloudflare/workers-types": "^4.20231218.0",
"@jest/globals": "^29.5.0",
"@langchain/langgraph": "~0.0.31",
"@langchain/scripts": "~0.0.20",
"@langchain/standard-tests": "0.0.0",
"@swc/core": "^1.3.90",
Expand All @@ -66,6 +67,14 @@
"ts-jest": "^29.1.0",
"typescript": "<5.2.0"
},
"peerDependencies": {
"@langchain/langgraph": "<0.1.0"
},
"peerDependenciesMeta": {
"@langchain/langgraph": {
"optional": true
}
},
"publishConfig": {
"access": "public"
},
Expand All @@ -79,13 +88,26 @@
"import": "./index.js",
"require": "./index.cjs"
},
"./langgraph/checkpointers": {
"types": {
"import": "./langgraph/checkpointers.d.ts",
"require": "./langgraph/checkpointers.d.cts",
"default": "./langgraph/checkpointers.d.ts"
},
"import": "./langgraph/checkpointers.js",
"require": "./langgraph/checkpointers.cjs"
},
"./package.json": "./package.json"
},
"files": [
"dist/",
"index.cjs",
"index.js",
"index.d.ts",
"index.d.cts"
"index.d.cts",
"langgraph/checkpointers.cjs",
"langgraph/checkpointers.js",
"langgraph/checkpointers.d.ts",
"langgraph/checkpointers.d.cts"
]
}
212 changes: 212 additions & 0 deletions libs/langchain-cloudflare/src/langgraph/checkpointers.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
import { D1Database } from "@cloudflare/workers-types";

import { RunnableConfig } from "@langchain/core/runnables";
import {
BaseCheckpointSaver,
Checkpoint,
CheckpointMetadata,
CheckpointTuple,
SerializerProtocol,
} from "@langchain/langgraph/web";

// snake_case is used to match Python implementation
interface Row {
checkpoint: string;
metadata: string;
parent_id?: string;
thread_id: string;
checkpoint_id: string;
}

export type CloudflareD1SaverFields = {
db: D1Database;
};

export class CloudflareD1Saver extends BaseCheckpointSaver {
db: D1Database;

protected isSetup: boolean;

constructor(
fields: CloudflareD1SaverFields,
serde?: SerializerProtocol<Checkpoint>
) {
super(serde);
this.db = fields.db;
this.isSetup = false;
}

private async setup() {
if (this.isSetup) {
return;
}

try {
await this.db.exec(`
CREATE TABLE IF NOT EXISTS checkpoints (thread_id TEXT NOT NULL, checkpoint_id TEXT NOT NULL, parent_id TEXT, checkpoint BLOB, metadata BLOB, PRIMARY KEY (thread_id, checkpoint_id));`);
} catch (error) {
console.log("Error creating checkpoints table", error);
throw error;
}

this.isSetup = true;
}

async getTuple(config: RunnableConfig): Promise<CheckpointTuple | undefined> {
await this.setup();
const thread_id = config.configurable?.thread_id;
const checkpoint_id = config.configurable?.checkpoint_id;

if (checkpoint_id) {
try {
const row: Row | null = await this.db
.prepare(
`SELECT checkpoint, parent_id, metadata FROM checkpoints WHERE thread_id = ? AND checkpoint_id = ?`
)
.bind(thread_id, checkpoint_id)
.first();

if (row) {
return {
config,
checkpoint: (await this.serde.parse(row.checkpoint)) as Checkpoint,
metadata: (await this.serde.parse(
row.metadata
)) as CheckpointMetadata,
parentConfig: row.parent_id
? {
configurable: {
thread_id,
checkpoint_id: row.parent_id,
},
}
: undefined,
};
}
} catch (error) {
console.log("Error retrieving checkpoint", error);
throw error;
}
} else {
const row: Row | null = await this.db
.prepare(
`SELECT thread_id, checkpoint_id, parent_id, checkpoint, metadata FROM checkpoints WHERE thread_id = ? ORDER BY checkpoint_id DESC LIMIT 1`
)
.bind(thread_id)
.first();

if (row) {
return {
config: {
configurable: {
thread_id: row.thread_id,
checkpoint_id: row.checkpoint_id,
},
},
checkpoint: (await this.serde.parse(row.checkpoint)) as Checkpoint,
metadata: (await this.serde.parse(
row.metadata
)) as CheckpointMetadata,
parentConfig: row.parent_id
? {
configurable: {
thread_id: row.thread_id,
checkpoint_id: row.parent_id,
},
}
: undefined,
};
}
}

return undefined;
}

async *list(
config: RunnableConfig,
limit?: number,
before?: RunnableConfig
): AsyncGenerator<CheckpointTuple> {
await this.setup();
const thread_id = config.configurable?.thread_id;
let sql = `SELECT thread_id, checkpoint_id, parent_id, checkpoint, metadata FROM checkpoints WHERE thread_id = ? ${
before ? "AND checkpoint_id < ?" : ""
} ORDER BY checkpoint_id DESC`;
if (limit) {
sql += ` LIMIT ${limit}`;
}
const args = [thread_id, before?.configurable?.checkpoint_id].filter(
Boolean
);

try {
const { results: rows }: { results: Row[] } = await this.db
.prepare(sql)
.bind(...args)
.all();

if (rows) {
for (const row of rows) {
yield {
config: {
configurable: {
thread_id: row.thread_id,
checkpoint_id: row.checkpoint_id,
},
},
checkpoint: (await this.serde.parse(row.checkpoint)) as Checkpoint,
metadata: (await this.serde.parse(
row.metadata
)) as CheckpointMetadata,
parentConfig: row.parent_id
? {
configurable: {
thread_id: row.thread_id,
checkpoint_id: row.parent_id,
},
}
: undefined,
};
}
}
} catch (error) {
console.log("Error listing checkpoints", error);
throw error;
}
}

async put(
config: RunnableConfig,
checkpoint: Checkpoint,
metadata: CheckpointMetadata
): Promise<RunnableConfig> {
await this.setup();

try {
const row = [
config.configurable?.thread_id ?? null,
checkpoint.id,
config.configurable?.checkpoint_id ?? null,
this.serde.stringify(checkpoint),
this.serde.stringify(metadata),
];

await this.db
.prepare(
`INSERT OR REPLACE INTO checkpoints (thread_id, checkpoint_id, parent_id, checkpoint, metadata) VALUES (?, ?, ?, ?, ?)`
)
.bind(...row)
.run();
} catch (error) {
console.log("Error saving checkpoint", error);
throw error;
}

return {
configurable: {
thread_id: config.configurable?.thread_id,
checkpoint_id: checkpoint.id,
},
};
}
}
22 changes: 22 additions & 0 deletions yarn.lock
Original file line number Diff line number Diff line change
Expand Up @@ -10999,6 +10999,7 @@ __metadata:
"@cloudflare/workers-types": ^4.20231218.0
"@jest/globals": ^29.5.0
"@langchain/core": ">0.1.0 <0.3.0"
"@langchain/langgraph": ~0.0.31
"@langchain/scripts": ~0.0.20
"@langchain/standard-tests": 0.0.0
"@swc/core": ^1.3.90
Expand All @@ -11023,6 +11024,11 @@ __metadata:
ts-jest: ^29.1.0
typescript: <5.2.0
uuid: ^10.0.0
peerDependencies:
"@langchain/langgraph": <0.1.0
peerDependenciesMeta:
"@langchain/langgraph":
optional: true
languageName: unknown
linkType: soft

Expand Down Expand Up @@ -11960,6 +11966,22 @@ __metadata:
languageName: node
linkType: hard

"@langchain/langgraph@npm:~0.0.31":
version: 0.0.31
resolution: "@langchain/langgraph@npm:0.0.31"
dependencies:
"@langchain/core": ">=0.2.18 <0.3.0"
uuid: ^10.0.0
zod: ^3.23.8
peerDependencies:
better-sqlite3: ^9.5.0
peerDependenciesMeta:
better-sqlite3:
optional: true
checksum: 74c0af490dab5c1f38d426cdeb0530fd300606bd28bb099d27b0ace029a02800a75fcc047f6755d853b485e78728b472170a19173803014dcc54bafe85939d9f
languageName: node
linkType: hard

"@langchain/mistralai@^0.0.26, @langchain/mistralai@workspace:*, @langchain/mistralai@workspace:libs/langchain-mistralai":
version: 0.0.0-use.local
resolution: "@langchain/mistralai@workspace:libs/langchain-mistralai"
Expand Down

0 comments on commit 019423e

Please sign in to comment.