Skip to content

Commit

Permalink
add profiling display to frontend
Browse files Browse the repository at this point in the history
  • Loading branch information
tzwm committed Jan 4, 2024
1 parent 6f3721d commit 08d8cc8
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 3 deletions.
21 changes: 18 additions & 3 deletions __init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
import time
import execution
import asyncio

import server

exist_recursive_execute = execution.recursive_execute
exist_PromptExecutor_execute = execution.PromptExecutor.execute

profiler_data = {}
profiler_outputs = []

async def send_message(data):
s = server.PromptServer.instance
await s.send_json('profiler', data)

def get_input_unique_ids(inputs):
ret = []
for key in inputs:
Expand All @@ -19,7 +26,7 @@ def get_input_unique_ids(inputs):

def get_total_inputs_time(current_item, prompt, calculated_inputs):
input_unique_ids = get_input_unique_ids(prompt[current_item]['inputs'])
total_time = profiler_data['nodes'][current_item]
total_time = profiler_data['nodes'].get(current_item, 0)
calculated_nodes = calculated_inputs + [current_item]
for id in input_unique_ids:
if id in calculated_inputs:
Expand All @@ -36,6 +43,7 @@ def new_recursive_execute(server, prompt, outputs, current_item, extra_data, exe
if not profiler_data.get('prompt_id') or profiler_data.get('prompt_id') != prompt_id:
profiler_data['prompt_id'] = prompt_id
profiler_data['nodes'] = {}
profiler_outputs.clear()

inputs = prompt[current_item]['inputs']
input_unique_ids = get_input_unique_ids(inputs)
Expand All @@ -50,6 +58,12 @@ def new_recursive_execute(server, prompt, outputs, current_item, extra_data, exe
profiler_data['nodes'][current_item] = end_time - start_time - this_time_nodes_time
total_inputs_time, _ = get_total_inputs_time(current_item, prompt, [])

asyncio.run(send_message({
'node': current_item,
'current_time': profiler_data['nodes'][current_item],
'total_inputs_time': total_inputs_time
}))

inputs_str = ''
if len(input_unique_ids) > 0:
inputs_str = '('
Expand All @@ -64,11 +78,12 @@ def new_recursive_execute(server, prompt, outputs, current_item, extra_data, exe


def new_prompt_executor_execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
exist_PromptExecutor_execute(self, prompt, prompt_id, extra_data=extra_data, execute_outputs=execute_outputs)
ret = exist_PromptExecutor_execute(self, prompt, prompt_id, extra_data=extra_data, execute_outputs=execute_outputs)
print('\n'.join(profiler_outputs))
return ret

execution.recursive_execute = new_recursive_execute
execution.PromptExecutor.execute = new_prompt_executor_execute

WEB_DIRECTORY = ""
WEB_DIRECTORY = "."
NODE_CLASS_MAPPINGS = {}
46 changes: 46 additions & 0 deletions index.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import { app } from "../../scripts/app.js";
import { api } from "../../scripts/api.js";

function drawText(ctx, text) {
if (!text) {
return;
}

const fgColor = "white";
const bgColor = "#0F1F0F";
const px = 6;
const py = 10;

ctx.save();
ctx.font = "12px sans-serif";
const sz = ctx.measureText(text);
ctx.fillStyle = bgColor;
ctx.beginPath();
ctx.roundRect(0, -LiteGraph.NODE_TITLE_HEIGHT - py * 2, sz.width + px * 2, px * 2, 5);
ctx.fill();

ctx.fillStyle = fgColor;
ctx.fillText(text, px, -LiteGraph.NODE_TITLE_HEIGHT - px);
ctx.restore();
}

app.registerExtension({
name: "ComfyUI.Profiler",
async loadedGraphNode(node, app) {
const orig = node.onDrawForeground;
node.onDrawForeground = function (ctx) {
const ret = orig(ctx, arguments);
drawText(ctx, node.profilingTime || '');
api.addEventListener("profiler", (event) => {
const data = event.detail;
if (data.node != node.id.toString()) {
return;
}

node.profilingTime = `${data.current_time.toFixed(2)}s`;
});

return ret;
};
},
});

0 comments on commit 08d8cc8

Please sign in to comment.