Skip to content

Commit

Permalink
take prompt from user input and clean up some things
Browse files Browse the repository at this point in the history
  • Loading branch information
newhouseb committed Apr 24, 2023
1 parent 461dffc commit 7e6b125
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 109 deletions.
209 changes: 110 additions & 99 deletions main.ts
Original file line number Diff line number Diff line change
@@ -1,105 +1,33 @@
import { Var, addMatrix, causalMask, copy, gelu, getSlice, layerNorm, linear, mapInPlace, merge, multiplyMatrix, softmax, split, tensor, transposeMatrix, unsqueeze } from "./math";
import { Var, addMatrix, causalMask, copy, gelu, getSlice, layerNorm, linear, mapInPlace, merge, multiplyMatrix, softmax, split, tensor, transposeMatrix } from "./math";
import * as fs from 'fs';
import type { Tensor } from "./math";
import { createInterface } from 'readline';
import { stdin as input, stdout as output } from 'process';

function bytesToUnicode(): [{ [key: number]: string }, { [key: string]: number }] {
const bs: number[] = [
...Array.from({ length: '~'.charCodeAt(0) - '!'.charCodeAt(0) + 1 }, (_, i) => '!'.charCodeAt(0) + i),
...Array.from({ length: '¬'.charCodeAt(0) - '¡'.charCodeAt(0) + 1 }, (_, i) => '¡'.charCodeAt(0) + i),
...Array.from({ length: 'ÿ'.charCodeAt(0) - '®'.charCodeAt(0) + 1 }, (_, i) => '®'.charCodeAt(0) + i),
];
const cs: number[] = [...bs];
let n = 0;
for (let b = 0; b < 2 ** 8; b++) {
if (!bs.includes(b)) {
bs.push(b);
cs.push(2 ** 8 + n);
n += 1;
}
}
const csStr: string[] = cs.map((n) => String.fromCharCode(n));
const lookupTable: { [key: number]: string } = {};
const unlookupTable: { [key: string]: number } = {};
bs.forEach((key, index) => {
lookupTable[key] = csStr[index];
unlookupTable[csStr[index]] = key;
async function main() {
// Read prompt from stdin
const rl = createInterface({ input, output });
let prompt = await new Promise<string>((resolve) => {
rl.question('Please enter a prompt: ', (input: string) => {
resolve(input);
rl.close();
});
});
return [lookupTable, unlookupTable];
}

function encodeString(str: string) {
// This is a giant dict of strings that map to token IDs
const encoder = JSON.parse(fs.readFileSync('weights/encoder.json', 'utf8'));

// A weird quirk of GPT2's tokenization is that they map control and whitespace characters up by 255 to make them printable, not entirely
// clear why this is but perhaps so that everything can confidently be printable while debugging without things (for example) clearing your terminal
const [byteMapping, byteUnmapping] = bytesToUnicode();
str = str.split('').map((c) => byteMapping[c.charCodeAt(0)]).join('');

const tokens = Object.keys(encoder);
let out = [] as number[];

while (str.length) {
let bestToken = '';
for (const token of tokens) {
if (str.startsWith(token) && token.length > bestToken.length) {
bestToken = token;
}
}
out.push(encoder[bestToken])
str = str.slice(bestToken.length);
}

return out;
}

function decodeString(str: string) {
const [byteMapping, byteUnmapping] = bytesToUnicode();
return str.split('').map((c) => String.fromCharCode(byteUnmapping[c])).join('');
}

function decodeTokens(tokens: number[]) {
const encoder = JSON.parse(fs.readFileSync('weights/encoder.json', 'utf8'));

const decoder = {} as { [key: number]: string };
for (const key in encoder) {
decoder[encoder[key]] = key;
}

return decodeString(tokens.map((token) => decoder[token]).join(''));
}

const loadSmallGPT = async () => {
const model = GPT({
VocabularySize: 50257,
SequenceLength: 1024,
EmbeddingDimensions: 768,
AttentionHeads: 12,
Layers: 12
}, null)

return model;
}

type Multiply<A extends number, B extends number> = number & { label: `${A} * ${B}` }
const Multiply = <A extends number, B extends number>(a: A, b: B) => a * b as Multiply<A, B>;
type Divide<A extends number, B extends number> = number & { label: `${A} / ${B}` }
const Divide = <A extends number, B extends number>(a: A, b: B) => a / b as Divide<A, B>;

async function main() {
let prompt = 'peanut butter and';
// Map the prompt to tokens
let tokens = encodeString(prompt)

console.log('loading gpt')
console.time('Loading Model')
const gpt = await loadSmallGPT();
console.log('done')
console.timeEnd('Loading Model')

// Generate (up to) 100 tokens
let toGenerate = 100;

while (toGenerate > 0) {
let x = tensor([Var(tokens.length, 'Sequence Length'), gpt.EmbeddingDimensions])
const start = (new Date()).getTime();

// Map each token into an embedding + position vector
let x = tensor([Var(tokens.length, 'Sequence Length'), gpt.EmbeddingDimensions])
tokens.map((token, i) => {
const slice = getSlice(x, i)
const embedding = getSlice(gpt.weights.wte, token);
Expand All @@ -114,6 +42,7 @@ async function main() {
i += 1
process.stdout.write('\rBlock ' + i + ' of ' + gpt.weights.blocks.length);

// The start of a block kicks off with a layer normalization
const nx1 = layerNorm(x, block.ln_1.g, block.ln_1.b)

// We weight the inputs to self attention
Expand All @@ -133,16 +62,13 @@ async function main() {
const sqrtD = Math.sqrt(gpt.EmbeddingDimensions / gpt.AttentionHeads);
const mask = causalMask(kHeads[0].shape[0]);
for (let h = 0; h < gpt.AttentionHeads; h++) {
const inner = addMatrix(
aHeads.push(multiplyMatrix(
softmax(addMatrix(
mapInPlace(
multiplyMatrix(qHeads[h], transposeMatrix(kHeads[h])),
(n) => n / sqrtD),
mask);
const smax = softmax(inner);
const outer = multiplyMatrix(
smax,
vHeads[h]);
aHeads.push(outer);
mask)),
vHeads[h]));
}

// Next merge the heads all back together
Expand All @@ -165,9 +91,8 @@ async function main() {
const transposed = transposeMatrix(gpt.weights.wte)
const final = multiplyMatrix(x, transposed);

// Greedily choose the highest logit and use that as our next token
let logits = getSlice(final, final.shape[0] - 1);

// argmax over the logits
let bestToken = 0;
let bestScore = -Infinity;
for (let j = 0; j < logits.data.length; j++) {
Expand All @@ -177,7 +102,9 @@ async function main() {
}
}

console.log('\nChosen token:', [prompt, decodeTokens([bestToken])]);
// Log the chosen token and prepare to loop again
const duration = ((new Date()).getTime() - start) / 1000;
console.log(`\nChose token in ${duration.toFixed(2)}s:`, [prompt, decodeTokens([bestToken])]);
prompt += '' + decodeTokens([bestToken]);
tokens.push(bestToken);

Expand All @@ -186,6 +113,23 @@ async function main() {
}
main();

const loadSmallGPT = async () => {
const model = GPT({
VocabularySize: 50257,
SequenceLength: 1024,
EmbeddingDimensions: 768,
AttentionHeads: 12,
Layers: 12
}, null)

return model;
}

type Multiply<A extends number, B extends number> = number & { label: `${A} * ${B}` }
const Multiply = <A extends number, B extends number>(a: A, b: B) => a * b as Multiply<A, B>;
type Divide<A extends number, B extends number> = number & { label: `${A} / ${B}` }
const Divide = <A extends number, B extends number>(a: A, b: B) => a / b as Divide<A, B>;

function GPT<
SequenceLength extends number,
VocabSize extends number,
Expand Down Expand Up @@ -271,3 +215,70 @@ function GPT<
)}
}
}

function bytesToUnicode(): [{ [key: number]: string }, { [key: string]: number }] {
const bs: number[] = [
...Array.from({ length: '~'.charCodeAt(0) - '!'.charCodeAt(0) + 1 }, (_, i) => '!'.charCodeAt(0) + i),
...Array.from({ length: '¬'.charCodeAt(0) - '¡'.charCodeAt(0) + 1 }, (_, i) => '¡'.charCodeAt(0) + i),
...Array.from({ length: 'ÿ'.charCodeAt(0) - '®'.charCodeAt(0) + 1 }, (_, i) => '®'.charCodeAt(0) + i),
];
const cs: number[] = [...bs];
let n = 0;
for (let b = 0; b < 2 ** 8; b++) {
if (!bs.includes(b)) {
bs.push(b);
cs.push(2 ** 8 + n);
n += 1;
}
}
const csStr: string[] = cs.map((n) => String.fromCharCode(n));
const lookupTable: { [key: number]: string } = {};
const unlookupTable: { [key: string]: number } = {};
bs.forEach((key, index) => {
lookupTable[key] = csStr[index];
unlookupTable[csStr[index]] = key;
});
return [lookupTable, unlookupTable];
}

function encodeString(str: string) {
// This is a giant dict of strings that map to token IDs
const encoder = JSON.parse(fs.readFileSync('weights/encoder.json', 'utf8'));

// A weird quirk of GPT2's tokenization is that they map control and whitespace characters up by 255 to make them printable, not entirely
// clear why this is but perhaps so that everything can confidently be printable while debugging without things (for example) clearing your terminal
const [byteMapping, _] = bytesToUnicode();
str = str.split('').map((c) => byteMapping[c.charCodeAt(0)]).join('');

const tokens = Object.keys(encoder);
let out = [] as number[];

while (str.length) {
let bestToken = '';
for (const token of tokens) {
if (str.startsWith(token) && token.length > bestToken.length) {
bestToken = token;
}
}
out.push(encoder[bestToken])
str = str.slice(bestToken.length);
}

return out;
}

function decodeString(str: string) {
const [_, byteUnmapping] = bytesToUnicode();
return str.split('').map((c) => String.fromCharCode(byteUnmapping[c])).join('');
}

function decodeTokens(tokens: number[]) {
const encoder = JSON.parse(fs.readFileSync('weights/encoder.json', 'utf8'));

const decoder = {} as { [key: number]: string };
for (const key in encoder) {
decoder[encoder[key]] = key;
}

return decodeString(tokens.map((token) => decoder[token]).join(''));
}
12 changes: 11 additions & 1 deletion math.test.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { causalMask, gelu, getSlice, layerNorm, linear, merge, multiplyMatrix, softmax, split, tensor, transposeMatrix } from "./math"
import { Var, causalMask, gelu, getSlice, layerNorm, linear, merge, multiplyMatrix, softmax, split, tensor, transposeMatrix } from "./math"

//const out = multiplyMatrix(tensor1, tensor2);
test('Multiplication', () => {
Expand Down Expand Up @@ -156,6 +156,16 @@ test("Slicing", () => {

})

const seq = [];

const a = Var(seq.length, 'Sequence Length');

const tensorA = tensor([a, 4]);
const tensorB = tensor([4, 5]);


const tensorC = multiplyMatrix(tensorA, tensorB);

const tensor1 = tensor([3, 2]);
const tensor2 = tensor([4, 5]);

Expand Down
8 changes: 0 additions & 8 deletions math.ts
Original file line number Diff line number Diff line change
Expand Up @@ -78,14 +78,6 @@ export const transposeMatrix = <X extends number, Y extends number>
return output;
};

export const unsqueeze = <D extends readonly Dim[]>
(a: Tensor<D>): Tensor<PushHead<D, 1>> => {
return {
data: a.data,
shape: [1, ...a.shape] as any
} as any;
}

export function mapInPlace<D extends readonly Dim[]>(a: Tensor<D>, fn: (n: number) => number): Tensor<D> {
a.data.set(a.data.map(i => fn(i)))
return a;
Expand Down
4 changes: 3 additions & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
{
"devDependencies": {
"@types/jest": "^29.5.0",
"esbuild": "^0.17.18",
"jest": "^29.5.0",
"ts-jest": "^29.1.0",
"ts-node": "^10.9.1",
Expand All @@ -9,6 +10,7 @@
"scripts": {
"test": "jest",
"test:watch": "jest --watchAll",
"start": "ts-node main.ts"
"start": "ts-node main.ts",
"build": "yarn esbuild main.ts --bundle --platform=node --outfile=main.js"
}
}

0 comments on commit 7e6b125

Please sign in to comment.