Skip to content

Commit

Permalink
feat: show read-only prompt (sqlchat#121)
Browse files Browse the repository at this point in the history
* stash

* feat: show prompt and move setting pos

* feat:wrap long line

* refactor generate prompt

* fix typo

* eslint
  • Loading branch information
CorrectRoadH authored Jun 4, 2023
1 parent e376e7b commit f7eabe9
Show file tree
Hide file tree
Showing 6 changed files with 113 additions and 41 deletions.
10 changes: 8 additions & 2 deletions src/components/CodeBlock.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@ interface Props {
language: string;
value: string;
messageId: Id;
wrapLongLines?: boolean;
}

export const CodeBlock = (props: Props) => {
const { language, value, messageId } = props;
const { language, value, messageId, wrapLongLines } = props;
const { t } = useTranslation();
const connectionStore = useConnectionStore();
const queryStore = useQueryStore();
Expand Down Expand Up @@ -70,7 +71,12 @@ export const CodeBlock = (props: Props) => {
)}
</div>
</div>
<SyntaxHighlighter language={language.toLowerCase()} style={oneDark} customStyle={{ margin: 0 }}>
<SyntaxHighlighter
language={language.toLowerCase()}
wrapLongLines={wrapLongLines || false}
style={oneDark}
customStyle={{ margin: 0 }}
>
{value}
</SyntaxHighlighter>
</div>
Expand Down
10 changes: 0 additions & 10 deletions src/components/ConversationView/Header.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -51,19 +51,9 @@ const Header = (props: Props) => {
/>
</a>
</span>
<div className="mr-2 relative flex flex-row justify-end items-center">
{hasFeature("debug") && (
<button className="p-2 rounded cursor-pointer hover:bg-gray-100 dark:hover:bg-zinc-700">
<Icon.FiSettings className="w-4 h-auto" onClick={() => setShowSchemaDrawer(true)} />
</button>
)}
</div>
</div>

<ConversationTabsView />
</div>

{hasFeature("debug") && showSchemaDrawer && <SchemaDrawer close={() => setShowSchemaDrawer(false)} />}
</>
);
};
Expand Down
2 changes: 1 addition & 1 deletion src/components/ConversationView/MessageView.tsx
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import dayjs from "dayjs";
import { useSession } from "next-auth/react";
import { ReactElement } from "react";
import { ReactElement, useState } from "react";
import { useTranslation } from "react-i18next";
import { toast } from "react-hot-toast";
import ReactMarkdown from "react-markdown";
Expand Down
48 changes: 22 additions & 26 deletions src/components/ConversationView/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,16 @@ import {
useUserStore,
} from "@/store";
import { Conversation, CreatorRole, Message } from "@/types";
import { countTextTokens, generateUUID, getModel, hasFeature } from "@/utils";
import { countTextTokens, generateUUID, getModel, hasFeature, generateDbPromptFromContext } from "@/utils";
import getEventEmitter from "@/utils/event-emitter";
import Header from "./Header";
import EmptyView from "../EmptyView";
import MessageView from "./MessageView";
import ClearConversationButton from "../ClearConversationButton";
import MessageTextarea from "./MessageTextarea";
import DataStorageBanner from "../DataStorageBanner";
import SchemaDrawer from "../SchemaDrawer";
import Icon from "../Icon";

const ConversationView = () => {
const { data: session } = useSession();
Expand All @@ -39,7 +41,8 @@ const ConversationView = () => {
? messageStore.messageList.filter((message: Message) => message.conversationId === currentConversation.id)
: [];
const lastMessage = last(messageList);

const [showSchemaDrawer, setShowSchemaDrawer] = useState<boolean>(false);
console.log(showHeaderShadow);
useEffect(() => {
messageStore.messageList.map((message: Message) => {
if (message.status === "LOADING") {
Expand Down Expand Up @@ -145,34 +148,19 @@ const ConversationView = () => {

// Augument with database schema if available
if (connectionStore.currentConnectionCtx?.database) {
let schema = "";
const schemaList = await connectionStore.getOrFetchDatabaseSchema(connectionStore.currentConnectionCtx?.database);
try {
const schemaList = await connectionStore.getOrFetchDatabaseSchema(connectionStore.currentConnectionCtx?.database);
// Empty table name(such as []) denote all table. [] and `undefined` both are false in `if`
const tableList: string[] = [];
const selectedSchema = schemaList.find((schema) => schema.name == (currentConversation.selectedSchemaName || ""));
if (currentConversation.selectedTablesName) {
currentConversation.selectedTablesName.forEach((tableName: string) => {
const table = selectedSchema?.tables.find((table) => table.name == tableName);
tableList.push(table!.structure);
});
} else {
for (const table of selectedSchema?.tables || []) {
tableList.push(table!.structure);
}
}
if (tableList) {
for (const table of tableList) {
if (tokens < maxToken / 2) {
tokens += countTextTokens(table);
schema += table;
}
}
}
dbPrompt = generateDbPromptFromContext(
promptGenerator,
schemaList,
currentConversation.selectedSchemaName || "",
currentConversation.selectedTablesName || [],
maxToken,
userPrompt
);
} catch (error: any) {
toast.error(error.message);
}
dbPrompt = promptGenerator(schema);
}

// Sliding window to add messages with DONE status all the way back up until we reach the token
Expand Down Expand Up @@ -342,6 +330,14 @@ const ConversationView = () => {
<div className="sticky bottom-0 flex flex-row justify-center items-center w-full max-w-4xl py-2 pb-4 px-4 sm:px-8 mx-auto bg-white dark:bg-zinc-800 bg-opacity-80 backdrop-blur">
<ClearConversationButton />
<MessageTextarea disabled={lastMessage?.status === "LOADING"} sendMessage={sendMessageToCurrentConversation} />
<div className="mr-2 relative flex flex-row justify-end items-center" onClick={() => setShowSchemaDrawer(true)}>
{hasFeature("debug") && (
<button className="p-2 rounded cursor-pointer hover:bg-gray-100 dark:hover:bg-zinc-700">
<Icon.FiSettings className="w-4 h-auto" />
</button>
)}
</div>
{hasFeature("debug") && showSchemaDrawer && <SchemaDrawer close={() => setShowSchemaDrawer(false)} />}
</div>
</div>
);
Expand Down
45 changes: 43 additions & 2 deletions src/components/SchemaDrawer.tsx
Original file line number Diff line number Diff line change
@@ -1,25 +1,66 @@
import { Drawer } from "@mui/material";
import { useEffect } from "react";
import { useEffect, useState } from "react";
import Icon from "./Icon";
import { getAssistantById, getPromptGeneratorOfAssistant, useConnectionStore, useConversationStore, useSettingStore } from "@/store";
import { getModel, generateDbPromptFromContext } from "@/utils";
import toast from "react-hot-toast";
import { CodeBlock } from "./CodeBlock";

interface Props {
close: () => void;
}

const SchemaDrawer = (props: Props) => {
const conversationStore = useConversationStore();
const connectionStore = useConnectionStore();
const settingStore = useSettingStore();

const currentConversation = conversationStore.getConversationById(conversationStore.currentConversationId);
const [prompt, setPrompt] = useState<string>("");

const getPrompt = async () => {
if (!currentConversation) return;
if (!connectionStore.currentConnectionCtx?.database) return;
const promptGenerator = getPromptGeneratorOfAssistant(getAssistantById(currentConversation.assistantId)!);
let dbPrompt = promptGenerator();
const maxToken = getModel(settingStore.setting.openAIApiConfig?.model || "").max_token;
const schemaList = await connectionStore.getOrFetchDatabaseSchema(connectionStore.currentConnectionCtx?.database);

if (connectionStore.currentConnectionCtx?.database) {
try {
dbPrompt = generateDbPromptFromContext(
promptGenerator,
schemaList,
currentConversation.selectedSchemaName || "",
currentConversation.selectedTablesName || [],
maxToken
);
setPrompt(dbPrompt);
} catch (error: any) {
toast.error(error.message);
}
}
};

useEffect(() => {
// TODO: initial state with current conversation.
}, []);

const close = () => props.close();
useEffect(() => {
getPrompt();
}, []);

const close = () => props.close();
return (
<Drawer open={true} anchor="right" className="w-full" onClose={close}>
<div className="dark:text-gray-300 w-screen sm:w-[calc(40vw)] max-w-full flex flex-col justify-start items-start p-4">
<button className="w-8 h-8 p-1 bg-zinc-600 text-gray-100 rounded-full hover:opacity-80" onClick={close}>
<Icon.IoMdClose className="w-full h-auto" />
</button>
<h3 className="font-bold text-2xl mt-4">Current conversation related schema</h3>
<div>
<CodeBlock language="Prompt" value={prompt} messageId={currentConversation?.id || ""} wrapLongLines={true} />
</div>
</div>
</Drawer>
);
Expand Down
39 changes: 39 additions & 0 deletions src/utils/openai.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { encode } from "@nem035/gpt-3-encoder";
import { Schema, Table } from "@/types";

// openAIApiKey is the API key for OpenAI API.
export const openAIApiKey = process.env.OPENAI_API_KEY;
Expand All @@ -9,3 +10,41 @@ export const openAIApiEndpoint = process.env.OPENAI_API_ENDPOINT || "https://api
export const countTextTokens = (text: string) => {
return encode(text).length;
};

export function generateDbPromptFromContext(
promptGenerator: (input: string | undefined) => string,
schemaList: any,
selectedSchemaName: string,
selectedTablesName: string[],
maxToken: number,
userPrompt?: string
): string {
let schema = "";
// userPrompt is the message that user want to send to bot. When to look prompt in drawer, userPrompt is undefined.
let tokens = countTextTokens(userPrompt || "");

// Empty table name(such as []) denote all table. [] and `undefined` both are false in `if`
// The above comment is out of date. [] is true in `if` now. And no selected table should not denote all table now.
// Because in have Token custom number in connectionSidebar. If [] denote all table. the Token will be inconsistent.
const tableList: string[] = [];
const selectedSchema = schemaList.find((schema: Schema) => schema.name == (selectedSchemaName || ""));
if (selectedTablesName) {
selectedTablesName.forEach((tableName: string) => {
const table = selectedSchema?.tables.find((table: Table) => table.name == tableName);
tableList.push(table!.structure);
});
} else {
for (const table of selectedSchema?.tables || []) {
tableList.push(table!.structure);
}
}
if (tableList) {
for (const table of tableList) {
if (tokens < maxToken / 2) {
tokens += countTextTokens(table);
schema += table;
}
}
}
return promptGenerator(schema);
}

0 comments on commit f7eabe9

Please sign in to comment.