Skip to content

Commit

Permalink
Merge pull request erictik#68 from erictik:once-image-event-bug
Browse files Browse the repository at this point in the history
fix bug
  • Loading branch information
zcpua authored May 29, 2023
2 parents 7b759a7 + e67b936 commit 1075e3b
Show file tree
Hide file tree
Showing 6 changed files with 154 additions and 106 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/releases.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name: Node.js Package
env:
APPVERSION: v2.2.${{ github.run_number }}
APPVERSION: v2.3.${{ github.run_number }}
on:
workflow_dispatch:
push:
Expand Down
45 changes: 45 additions & 0 deletions example/imagine-ws-m.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import "dotenv/config";
import { Midjourney } from "../src";
/**
*
* a simple example of using the imagine api with ws
* ```
* npx tsx example/imagine-ws-m.ts
* ```
*/
async function main() {
const client = new Midjourney({
ServerId: <string>process.env.SERVER_ID,
ChannelId: <string>process.env.CHANNEL_ID,
SalaiToken: <string>process.env.SALAI_TOKEN,
HuggingFaceToken: <string>process.env.HUGGINGFACE_TOKEN,
Debug: true,
Ws: true,
});
await client.init();
client
.Imagine("A little pink elephant", (uri) => {
console.log("loading123---", uri);
})
.then(function (msg) {
console.log("msg123", msg);
});

client
.Imagine("A little pink dog", (uri) => {
console.log("loading234---", uri);
})
.then(function (msg) {
console.log("msg234", msg);
});
}
main()
.then(() => {
console.log("finished");
// process.exit(0);
})
.catch((err) => {
console.log("finished");
console.error(err);
process.exit(1);
});
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "midjourney",
"version": "2.2.0",
"version": "2.3.0",
"description": "Node.js client for the unofficial MidJourney API.",
"main": "libs/index.js",
"types": "libs/index.d.ts",
Expand Down
2 changes: 0 additions & 2 deletions src/interfaces/message.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ export interface MJMessage {
export type LoadingHandler = (uri: string, progress: string) => void;

export interface WaitMjEvent {
type: "imagine" | "upscale" | "variation" | "info";
nonce: string;
prompt?: string;
id?: string;
Expand All @@ -19,4 +18,3 @@ export interface WsEventMsg {
error?: Error;
message?: MJMessage;
}
export type ImageEventType = "imagine" | "upscale" | "variation";
6 changes: 3 additions & 3 deletions src/midjourney.ts
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ export class Midjourney extends MidjourneyMessage {
throw new Error(`ImagineApi failed with status ${httpStatus}`);
}
if (this.wsClient) {
return await this.wsClient.waitMessage("imagine", nonce, loading);
return await this.wsClient.waitMessage(nonce, loading);
} else {
this.log(`await generate image`);
const msg = await this.WaitMessage(prompt, loading);
Expand Down Expand Up @@ -157,7 +157,7 @@ export class Midjourney extends MidjourneyMessage {
throw new Error(`VariationApi failed with status ${httpStatus}`);
}
if (this.wsClient) {
return await this.wsClient.waitMessage("variation", nonce, loading);
return await this.wsClient.waitMessage(nonce, loading);
} else {
return await this.WaitOptionMessage(content, `Variations`, loading);
}
Expand Down Expand Up @@ -203,7 +203,7 @@ export class Midjourney extends MidjourneyMessage {
}
this.log(`await generate image`);
if (this.wsClient) {
return await this.wsClient.waitMessage("upscale", nonce, loading);
return await this.wsClient.waitMessage(nonce, loading);
}
return await this.WaitUpscaledMessage(content, index, loading);
}
Expand Down
203 changes: 104 additions & 99 deletions src/ws.message.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import {
MJMessage,
LoadingHandler,
WsEventMsg,
ImageEventType,
} from "./interfaces";
import { VerifyHuman } from "./verify.human";

Expand All @@ -22,7 +21,7 @@ export class WsMessage {
private inflate: Inflate;
private event: Array<{ event: string; callback: (message: any) => void }> =
[];
private waitMjEvent: Array<WaitMjEvent> = [];
private waitMjEvents: Map<string, WaitMjEvent> = new Map();
private reconnectTime: boolean[] = [];
private heartbeatInterval = 0;

Expand Down Expand Up @@ -116,33 +115,11 @@ export class WsMessage {
this.zlibChunks = [];
this.parseMessage(data);
}
// parse message from ws
private parseMessage(data: Buffer) {
var jsonString = data.toString();
const msg = JSON.parse(jsonString);
if (msg.t === null || msg.t === "READY_SUPPLEMENTAL") return;
if (msg.t === "READY") {
this.emit("ready", null);
return;
}
if (!(msg.t === "MESSAGE_CREATE" || msg.t === "MESSAGE_UPDATE")) return;
const message = msg.d;
const {
channel_id,
content,
application_id,
embeds,
id,
nonce,
author,
attachments,
} = message;
if (!(author && author.id === this.MJBotId)) return;
if (channel_id !== this.config.ChannelId) return;
this.log("has message", content, nonce, id);

//waiting start image or info or error
if (nonce && msg.t === "MESSAGE_CREATE") {
private async messageCreate(message: any) {
// this.log("messageCreate", message);
const { application_id, embeds, id, nonce } = message;
if (nonce) {
this.log("waiting start image or info or error");
this.updateMjEventIdByNonce(id, nonce);
if (embeds && embeds.length > 0) {
Expand All @@ -158,7 +135,8 @@ export class WsMessage {
if (embeds[0].title.includes("continue")) {
if (embeds[0].description.includes("verify you're human")) {
//verify human
this.verifyHuman(message);
await this.verifyHuman(message);
return;
}
}
if (embeds[0].title.includes("Invalid")) {
Expand All @@ -170,39 +148,72 @@ export class WsMessage {
}
}
//done image
if (msg.t === "MESSAGE_CREATE" && !nonce && !application_id) {
if (!nonce && !application_id) {
this.log("done image");
this.done(message);
return;
}
this.processingImage(message);
}
private messageUpdate(message: any) {
this.processingImage(message);
}
private processingImage(message: any) {
const { content, id, nonce, attachments } = message;
const event = this.getEventById(id);
if (!event) {
return;
}
event.prompt = content;
//not image
if (!attachments || attachments.length === 0) {
// this.log("no image waiting", { id, nonce, content, event });
return;
}
const MJmsg: MJMessage = {
uri: attachments[0].url,
content: content,
progress: this.content2progress(content),
};
const eventMsg: WsEventMsg = {
message: MJmsg,
};
this.emitImage(event.nonce, eventMsg);
}

//processing image
{
this.log("processing image", jsonString);
const index = this.waitMjEvent.findIndex((e) => e.id === id);
if (index < 0 || !this.waitMjEvent[index]) {
return;
}
const event = this.waitMjEvent[index];
this.waitMjEvent[index].prompt = content;
if (!attachments || attachments.length === 0) {
this.log("wait", {
id,
nonce,
content,
event,
});
return;
}
const MJmsg: MJMessage = {
uri: attachments[0].url,
content: content,
progress: this.content2progress(content),
};
const eventMsg: WsEventMsg = {
message: MJmsg,
};
this.emitImage(<ImageEventType>event.type, eventMsg);
// parse message from ws
private parseMessage(data: Buffer) {
var jsonString = data.toString();
const msg = JSON.parse(jsonString);
if (msg.t === null || msg.t === "READY_SUPPLEMENTAL") return;
if (msg.t === "READY") {
this.emit("ready", null);
return;
}
if (!(msg.t === "MESSAGE_CREATE" || msg.t === "MESSAGE_UPDATE")) return;

const message = msg.d;
const {
channel_id,
content,
application_id,
embeds,
id,
nonce,
author,
attachments,
} = message;
if (!(author && author.id === this.MJBotId)) return;
if (channel_id !== this.config.ChannelId) return;
this.log("has message", content, nonce, id);

if (msg.t === "MESSAGE_CREATE") {
this.messageCreate(message);
return;
}
if (msg.t === "MESSAGE_UPDATE") {
this.messageUpdate(message);
return;
}
}
private async verifyHuman(message: any) {
Expand Down Expand Up @@ -273,16 +284,14 @@ export class WsMessage {
}
}
private EventError(id: string, error: Error) {
this.log("EventError", id, error);
const index = this.waitMjEvent.findIndex((e) => e.id === id);
if (index < 0 || !this.waitMjEvent[index]) {
const event = this.getEventById(id);
if (!event) {
return;
}
const event = this.waitMjEvent[index];
const eventMsg: WsEventMsg = {
error,
};
this.emit(event.type, eventMsg);
this.emit(event.nonce, eventMsg);
}

private done(message: any) {
Expand Down Expand Up @@ -321,31 +330,38 @@ export class WsMessage {
}

private filterMessages(MJmsg: MJMessage) {
// this.log("filterMessages", MJmsg, this.waitMjEvent);
const index = this.waitMjEvent.findIndex(
(e) =>
this.content2prompt(e.prompt) === this.content2prompt(MJmsg.content)
);
if (index < 0) {
this.log("FilterMessages not found", MJmsg, this.waitMjEvent);
return;
}
const event = this.waitMjEvent[index];
const event = this.getEventByContent(MJmsg.content);
if (!event) {
this.log("FilterMessages not found", MJmsg, this.waitMjEvent);
this.log("FilterMessages not found", MJmsg, this.waitMjEvents);
return;
}
const eventMsg: WsEventMsg = {
message: MJmsg,
};
this.emitImage(<ImageEventType>event.type, eventMsg);
this.emitImage(event.nonce, eventMsg);
}
private getEventByContent(content: string) {
const prompt = this.content2prompt(content);
for (const [key, value] of this.waitMjEvents.entries()) {
if (prompt === this.content2prompt(value.prompt)) {
return value;
}
}
}

private getEventById(id: string) {
for (const [key, value] of this.waitMjEvents.entries()) {
if (value.id === id) {
return value;
}
}
}
private updateMjEventIdByNonce(id: string, nonce: string) {
const index = this.waitMjEvent.findIndex((e) => e.nonce === nonce);
if (index < 0) return;
this.waitMjEvent[index].id = id;
this.log("updateMjEventIdByNonce success", this.waitMjEvent[index]);
if (nonce === "" || id === "") return;
let event = this.waitMjEvents.get(nonce);
if (!event) return;
event.id = id;
this.log("updateMjEventIdByNonce success", this.waitMjEvents.get(nonce));
}
uriToHash(uri: string) {
return uri.split("_").pop()?.split(".")[0] ?? "";
Expand Down Expand Up @@ -390,43 +406,32 @@ export class WsMessage {
this.remove("info", callback);
}
private removeWaitMjEvent(nonce: string) {
this.waitMjEvent = this.waitMjEvent.filter((e) => e.nonce !== nonce);
this.waitMjEvents.delete(nonce);
}

private emitImage(type: ImageEventType, message: WsEventMsg) {
private emitImage(type: string, message: WsEventMsg) {
this.emit(type, message);
}
onceImage(
type: ImageEventType,
nonce: string,
callback: (data: WsEventMsg) => void
) {
onceImage(nonce: string, callback: (data: WsEventMsg) => void) {
const once = (data: WsEventMsg) => {
const { message, error } = data;
if (message) {
message.content = this.content2prompt(message.content);
// message.content = this.content2prompt(message.content);
}
if (error || (message && message.progress === "done")) {
this.log("onceImage", type, "done", data, error);
this.remove(type, once);
// this.log("onceImage", type, "done", data, error);
this.remove(nonce, once);
this.removeWaitMjEvent(nonce);
}
callback(data);
};
this.waitMjEvent.push({ type, nonce });
this.event.push({ event: type, callback: once });
}
onceImagine(nonce: string, callback: (data: WsEventMsg) => void) {
this.onceImage("imagine", nonce, callback);
this.waitMjEvents.set(nonce, { nonce });
this.event.push({ event: nonce, callback: once });
}

async waitMessage(
type: ImageEventType,
nonce: string,
loading?: LoadingHandler
) {
async waitMessage(nonce: string, loading?: LoadingHandler) {
return new Promise<MJMessage | null>((resolve, reject) => {
this.onceImage(type, nonce, ({ message, error }) => {
this.onceImage(nonce, ({ message, error }) => {
if (error) {
reject(error);
return;
Expand Down

0 comments on commit 1075e3b

Please sign in to comment.