diff --git a/__init__.py b/__init__.py index 58af61e..8371bea 100644 --- a/__init__.py +++ b/__init__.py @@ -1,5 +1,8 @@ import time import execution +import asyncio + +import server exist_recursive_execute = execution.recursive_execute exist_PromptExecutor_execute = execution.PromptExecutor.execute @@ -7,6 +10,10 @@ profiler_data = {} profiler_outputs = [] +async def send_message(data): + s = server.PromptServer.instance + await s.send_json('profiler', data) + def get_input_unique_ids(inputs): ret = [] for key in inputs: @@ -19,7 +26,7 @@ def get_input_unique_ids(inputs): def get_total_inputs_time(current_item, prompt, calculated_inputs): input_unique_ids = get_input_unique_ids(prompt[current_item]['inputs']) - total_time = profiler_data['nodes'][current_item] + total_time = profiler_data['nodes'].get(current_item, 0) calculated_nodes = calculated_inputs + [current_item] for id in input_unique_ids: if id in calculated_inputs: @@ -36,6 +43,7 @@ def new_recursive_execute(server, prompt, outputs, current_item, extra_data, exe if not profiler_data.get('prompt_id') or profiler_data.get('prompt_id') != prompt_id: profiler_data['prompt_id'] = prompt_id profiler_data['nodes'] = {} + profiler_outputs.clear() inputs = prompt[current_item]['inputs'] input_unique_ids = get_input_unique_ids(inputs) @@ -50,6 +58,12 @@ def new_recursive_execute(server, prompt, outputs, current_item, extra_data, exe profiler_data['nodes'][current_item] = end_time - start_time - this_time_nodes_time total_inputs_time, _ = get_total_inputs_time(current_item, prompt, []) + asyncio.run(send_message({ + 'node': current_item, + 'current_time': profiler_data['nodes'][current_item], + 'total_inputs_time': total_inputs_time + })) + inputs_str = '' if len(input_unique_ids) > 0: inputs_str = '(' @@ -64,11 +78,12 @@ def new_recursive_execute(server, prompt, outputs, current_item, extra_data, exe def new_prompt_executor_execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]): - exist_PromptExecutor_execute(self, prompt, prompt_id, extra_data=extra_data, execute_outputs=execute_outputs) + ret = exist_PromptExecutor_execute(self, prompt, prompt_id, extra_data=extra_data, execute_outputs=execute_outputs) print('\n'.join(profiler_outputs)) + return ret execution.recursive_execute = new_recursive_execute execution.PromptExecutor.execute = new_prompt_executor_execute -WEB_DIRECTORY = "" +WEB_DIRECTORY = "." NODE_CLASS_MAPPINGS = {} diff --git a/index.js b/index.js new file mode 100644 index 0000000..15fea27 --- /dev/null +++ b/index.js @@ -0,0 +1,46 @@ +import { app } from "../../scripts/app.js"; +import { api } from "../../scripts/api.js"; + +function drawText(ctx, text) { + if (!text) { + return; + } + + const fgColor = "white"; + const bgColor = "#0F1F0F"; + const px = 6; + const py = 10; + + ctx.save(); + ctx.font = "12px sans-serif"; + const sz = ctx.measureText(text); + ctx.fillStyle = bgColor; + ctx.beginPath(); + ctx.roundRect(0, -LiteGraph.NODE_TITLE_HEIGHT - py * 2, sz.width + px * 2, px * 2, 5); + ctx.fill(); + + ctx.fillStyle = fgColor; + ctx.fillText(text, px, -LiteGraph.NODE_TITLE_HEIGHT - px); + ctx.restore(); +} + +app.registerExtension({ + name: "ComfyUI.Profiler", + async loadedGraphNode(node, app) { + const orig = node.onDrawForeground; + node.onDrawForeground = function (ctx) { + const ret = orig(ctx, arguments); + drawText(ctx, node.profilingTime || ''); + api.addEventListener("profiler", (event) => { + const data = event.detail; + if (data.node != node.id.toString()) { + return; + } + + node.profilingTime = `${data.current_time.toFixed(2)}s`; + }); + + return ret; + }; + }, +});