Skip to content

Commit

Permalink
Merge pull request a16z-infra#184 from a16z-infra/add-ollama-back
Browse files Browse the repository at this point in the history
Adding an option to run inference through Ollama
  • Loading branch information
ianmacartney authored Nov 10, 2023
2 parents 6786a90 + b6b67a5 commit 2997569
Show file tree
Hide file tree
Showing 11 changed files with 1,361 additions and 145 deletions.
30 changes: 30 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ a simple project to play around with to a scalable, multi-player game. A seconda
- 💻 [Stack](#stack)
- 🧠 [Installation](#installation)
- 👤 [Customize - run YOUR OWN simulated world](#customize-your-own-simulation)
- 👩‍💻 [Setting up local inference](#setting-up-local-inference)
- 🏆 [Credits](#credits)

## Stack
Expand All @@ -28,6 +29,7 @@ a simple project to play around with to a scalable, multi-player game. A seconda
- Deployment: [Vercel](https://vercel.com/)
- Pixel Art Generation: [Replicate](https://replicate.com/), [Fal.ai](https://serverless.fal.ai/lora)
- Background Music Generation: [Replicate](https://replicate.com/) using [MusicGen](https://huggingface.co/spaces/facebook/MusicGen)
- Local inference: [Ollama](https://github.com/jmorganca/ollama)

## Installation

Expand Down Expand Up @@ -215,6 +217,34 @@ You should find a sprite sheet for your character, and define sprite motion / as
4. Change the background music by modifying the prompt in `convex/music.ts`
5. Change how often to generate new music at `convex/crons.ts` by modifying the `generate new background music` job

## Setting up local inference

We support using [Ollama](https://github.com/jmorganca/ollama) for conversation generations, but don't yet support using a local model for generating embeddings.

Steps to switch to using Ollama:

1. [Install Ollama](https://github.com/jmorganca/ollama#macos)
When Ollama runs on your laptop, it by default uses http://localhost:11434 as an endpoint for generation. Next we need to set up a ngrok tunnel so Convex can access it:
2. [Install Ngrok](https://ngrok.com/docs/getting-started/)
Once ngrok is installed and authenticated, run the following command:

```
ngrok http http://localhost:11434
```

Ngrok should output a unique url once you run this command.

3. Add Ollama endpoint to Convex

- run `npx convex dashboard` to bring up the convex dashboard
- Go to Settings > Environment Variables
- Add `OLLAMA_HOST = [your ngrok unique url from the previous step]`
- You might also want to set:
`ACTION_TIMEOUT` to `100000` or more, to give your model more time to run.
`NUM_MEMORIES_TO_SEARCH` to `1`, to reduce the size of conversation prompts.

By default, we use `llama2-7b` model on Ollama. If you want to customize which model to use, you can set `OLLAMA_MODEL` variable under Environment Variables. Ollama model options can be found [here](https://ollama.ai/library)

## Credits

- All interactions, background music and rendering on the <Game/> component in the project are powered by [PixiJS](https://pixijs.com/).
Expand Down
2 changes: 2 additions & 0 deletions convex/_generated/api.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ import type * as util_geometry from "../util/geometry";
import type * as util_isSimpleObject from "../util/isSimpleObject";
import type * as util_minheap from "../util/minheap";
import type * as util_object from "../util/object";
import type * as util_ollama from "../util/ollama";
import type * as util_openai from "../util/openai";
import type * as util_sleep from "../util/sleep";
import type * as util_types from "../util/types";
Expand Down Expand Up @@ -105,6 +106,7 @@ declare const fullApi: ApiFromModules<{
"util/isSimpleObject": typeof util_isSimpleObject;
"util/minheap": typeof util_minheap;
"util/object": typeof util_object;
"util/ollama": typeof util_ollama;
"util/openai": typeof util_openai;
"util/sleep": typeof util_sleep;
"util/types": typeof util_types;
Expand Down
22 changes: 17 additions & 5 deletions convex/agent/conversation.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import { v } from 'convex/values';
import { Id } from '../_generated/dataModel';
import { ActionCtx, internalQuery } from '../_generated/server';
import { LLMMessage, chatCompletion } from '../util/openai';
import { LLMMessage, chatCompletion, ChatCompletionContent } from '../util/openai';
import { UseOllama, ollamaChatCompletion } from '../util/ollama';
import * as memory from './memory';
import { api, internal } from '../_generated/api';
import * as embeddingsCache from './embeddingsCache';
import { GameId, conversationId, playerId } from '../aiTown/ids';
import { NUM_MEMORIES_TO_SEARCH } from '../constants';

const selfInternal = internal.agent.conversation;
const completionFn = UseOllama ? ollamaChatCompletion : chatCompletion;

export async function startConversationMessage(
ctx: ActionCtx,
Expand All @@ -29,7 +32,14 @@ export async function startConversationMessage(
ctx,
`What do you think about ${otherPlayer.name}?`,
);
const memories = await memory.searchMemories(ctx, player.id as GameId<'players'>, embedding, 3);

const memories = await memory.searchMemories(
ctx,
player.id as GameId<'players'>,
embedding,
NUM_MEMORIES_TO_SEARCH(),
);

const memoryWithOtherPlayer = memories.find(
(m) => m.data.type === 'conversation' && m.data.playerIds.includes(otherPlayerId),
);
Expand All @@ -46,7 +56,7 @@ export async function startConversationMessage(
}
prompt.push(`${player.name}:`);

const { content } = await chatCompletion({
const { content } = await completionFn({
messages: [
{
role: 'user',
Expand Down Expand Up @@ -108,7 +118,8 @@ export async function continueConversationMessage(
)),
];
llmMessages.push({ role: 'user', content: `${player.name}:` });
const { content } = await chatCompletion({

const { content } = await completionFn({
messages: llmMessages,
max_tokens: 300,
stream: true,
Expand Down Expand Up @@ -156,7 +167,8 @@ export async function leaveConversationMessage(
)),
];
llmMessages.push({ role: 'user', content: `${player.name}:` });
const { content } = await chatCompletion({

const { content } = await completionFn({
messages: llmMessages,
max_tokens: 300,
stream: true,
Expand Down
55 changes: 8 additions & 47 deletions convex/agent/memory.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import { defineTable } from 'convex/server';
import { v } from 'convex/values';
import { ActionCtx, DatabaseReader, internalMutation, internalQuery } from '../_generated/server';
import { Doc, Id } from '../_generated/dataModel';
Expand All @@ -7,41 +6,18 @@ import { LLMMessage, chatCompletion, fetchEmbedding } from '../util/openai';
import { asyncMap } from '../util/asyncMap';
import { GameId, agentId, conversationId, playerId } from '../aiTown/ids';
import { SerializedPlayer } from '../aiTown/player';
import { UseOllama, ollamaChatCompletion } from '../util/ollama';
import { memoryFields } from './schema';

const completionFn = UseOllama ? ollamaChatCompletion : chatCompletion;

// How long to wait before updating a memory's last access time.
export const MEMORY_ACCESS_THROTTLE = 300_000; // In ms
// We fetch 10x the number of memories by relevance, to have more candidates
// for sorting by relevance + recency + importance.
const MEMORY_OVERFETCH = 10;

const selfInternal = internal.agent.memory;

const memoryFields = {
playerId,
description: v.string(),
embeddingId: v.id('memoryEmbeddings'),
importance: v.number(),
lastAccess: v.number(),
data: v.union(
// Setting up dynamics between players
v.object({
type: v.literal('relationship'),
// The player this memory is about, from the perspective of the player
// whose memory this is.
playerId,
}),
v.object({
type: v.literal('conversation'),
conversationId,
// The other player(s) in the conversation.
playerIds: v.array(playerId),
}),
v.object({
type: v.literal('reflection'),
relatedMemoryIds: v.array(v.id('memories')),
}),
),
};
export type Memory = Doc<'memories'>;
export type MemoryType = Memory['data']['type'];
export type MemoryOfType<T extends MemoryType> = Omit<Memory, 'data'> & {
Expand Down Expand Up @@ -86,7 +62,7 @@ export async function rememberConversation(
});
}
llmMessages.push({ role: 'user', content: 'Summary:' });
const { content } = await chatCompletion({
const { content } = await completionFn({
messages: llmMessages,
max_tokens: 500,
});
Expand Down Expand Up @@ -267,13 +243,13 @@ export const loadMessages = internalQuery({
});

async function calculateImportance(description: string) {
const { content: importanceRaw } = await chatCompletion({
const { content: importanceRaw } = await completionFn({
messages: [
{
role: 'user',
content: `On the scale of 0 to 9, where 0 is purely mundane (e.g., brushing teeth, making bed) and 9 is extremely poignant (e.g., a break up, college acceptance), rate the likely poignancy of the following piece of memory.
Memory: ${description}
Answer on a scale of 0 to 9. Respond with number only, e.g. "5"`,
Memory: ${description}
Answer on a scale of 0 to 9. Respond with number only, e.g. "5"`,
},
],
temperature: 0.0,
Expand Down Expand Up @@ -471,18 +447,3 @@ export async function latestMemoryOfType<T extends MemoryType>(
if (!entry) return null;
return entry as MemoryOfType<T>;
}

export const memoryTables = {
memories: defineTable(memoryFields)
.index('embeddingId', ['embeddingId'])
.index('playerId_type', ['playerId', 'data.type'])
.index('playerId', ['playerId']),
memoryEmbeddings: defineTable({
playerId,
embedding: v.array(v.float64()),
}).vectorIndex('embedding', {
vectorField: 'embedding',
filterFields: ['playerId'],
dimensions: 1536,
}),
};
45 changes: 44 additions & 1 deletion convex/agent/schema.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,48 @@
import { memoryTables } from './memory';
import { embeddingsCacheTables } from './embeddingsCache';
import { v } from 'convex/values';
import { playerId, conversationId } from '../aiTown/ids';
import { defineTable } from 'convex/server';

export const memoryFields = {
playerId,
description: v.string(),
embeddingId: v.id('memoryEmbeddings'),
importance: v.number(),
lastAccess: v.number(),
data: v.union(
// Setting up dynamics between players
v.object({
type: v.literal('relationship'),
// The player this memory is about, from the perspective of the player
// whose memory this is.
playerId,
}),
v.object({
type: v.literal('conversation'),
conversationId,
// The other player(s) in the conversation.
playerIds: v.array(playerId),
}),
v.object({
type: v.literal('reflection'),
relatedMemoryIds: v.array(v.id('memories')),
}),
),
};
export const memoryTables = {
memories: defineTable(memoryFields)
.index('embeddingId', ['embeddingId'])
.index('playerId_type', ['playerId', 'data.type'])
.index('playerId', ['playerId']),
memoryEmbeddings: defineTable({
playerId,
embedding: v.array(v.float64()),
}).vectorIndex('embedding', {
vectorField: 'embedding',
filterFields: ['playerId'],
dimensions: 1536,
}),
};

export const agentTables = {
...memoryTables,
Expand Down
2 changes: 1 addition & 1 deletion convex/aiTown/agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ export class Agent {
throw new Error(`Invalid player ID ${this.playerId}`);
}
if (this.inProgressOperation) {
if (now < this.inProgressOperation.started + ACTION_TIMEOUT) {
if (now < this.inProgressOperation.started + ACTION_TIMEOUT()) {
// Wait on the operation to finish.
return;
}
Expand Down
2 changes: 2 additions & 0 deletions convex/aiTown/agentOperations.ts
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,9 @@ export const agentGenerateMessage = internalAction({
args.playerId as GameId<'players'>,
args.otherPlayerId as GameId<'players'>,
);
// TODO: stream in the text instead of reading it all at once.
const text = await completion.readAll();

await ctx.runMutation(internal.aiTown.agent.agentSendMessage, {
worldId: args.worldId,
conversationId: args.conversationId,
Expand Down
10 changes: 9 additions & 1 deletion convex/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,16 @@ export const MAX_CONVERSATION_MESSAGES = 8;
// once we can await on an input being processed.
export const INPUT_DELAY = 1000;

// How many memories to get from the agent's memory.
// This is over-fetched by 10x so we can prioritize memories by more than relevance.
export function NUM_MEMORIES_TO_SEARCH() {
return Number(process.env.NUM_MEMORIES_TO_SEARCH) || 3;
}

// Timeout a request to the conversation layer after a minute.
export const ACTION_TIMEOUT = 60 * 1000;
export function ACTION_TIMEOUT() {
return Number(process.env.ACTION_TIMEOUT) || 60 * 1000;
}

// Wait for at least two seconds before sending another message.
export const MESSAGE_COOLDOWN = 2000;
Expand Down
Loading

0 comments on commit 2997569

Please sign in to comment.