Skip to content

Commit

Permalink
Revert "Add openpose json to API response (Mikubill#2033)"
Browse files Browse the repository at this point in the history
This reverts commit e2cd3b9.
  • Loading branch information
lllyasviel authored Sep 4, 2023
1 parent e2cd3b9 commit 7bdcb90
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 77 deletions.
8 changes: 4 additions & 4 deletions annotator/openpose/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,8 @@ def create_keypoint(x, y, c):
)


def encode_poses_as_json(poses: List[PoseResult], canvas_height: int, canvas_width: int) -> dict:
""" Encode the pose as a JSON compatible dict following openpose JSON output format:
def encode_poses_as_json(poses: List[PoseResult], canvas_height: int, canvas_width: int) -> str:
""" Encode the pose as a JSON string following openpose JSON output format:
https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/doc/02_output.md
"""
def compress_keypoints(keypoints: Union[List[Keypoint], None]) -> Union[List[float], None]:
Expand All @@ -137,7 +137,7 @@ def compress_keypoints(keypoints: Union[List[Keypoint], None]) -> Union[List[flo
)
]

return {
return json.dumps({
'people': [
{
'pose_keypoints_2d': compress_keypoints(pose.body.keypoints),
Expand All @@ -149,7 +149,7 @@ def compress_keypoints(keypoints: Union[List[Keypoint], None]) -> Union[List[flo
],
'canvas_height': canvas_height,
'canvas_width': canvas_width,
}
}, indent=4)

class OpenposeDetector:
"""
Expand Down
82 changes: 22 additions & 60 deletions scripts/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,10 @@ def encode_to_base64(image):
else:
return ""


def encode_np_to_base64(image):
pil = Image.fromarray(image)
return api.encode_pil_to_base64(pil)


def controlnet_api(_: gr.Blocks, app: FastAPI):
@app.get("/controlnet/version")
async def version():
Expand All @@ -44,12 +42,12 @@ async def model_list(update: bool = True):
async def module_list(alias_names: bool = False):
_module_list = external_code.get_modules(alias_names)
logger.debug(_module_list)

return {
"module_list": _module_list,
"module_detail": external_code.get_modules_detail(alias_names),
"module_detail": external_code.get_modules_detail(alias_names)
}

@app.get("/controlnet/control_types")
async def control_types():
def format_control_type(
Expand All @@ -64,88 +62,52 @@ def format_control_type(
"default_option": default_option,
"default_model": default_model,
}

return {
"control_types": {
control_type: format_control_type(
*global_state.select_control_type(control_type)
)
'control_types': {
control_type: format_control_type(*global_state.select_control_type(control_type))
for control_type in preprocessor_filters.keys()
}
}


@app.get("/controlnet/settings")
async def settings():
max_models_num = external_code.get_max_models_num()
return {"control_net_max_models_num": max_models_num}

cached_cn_preprocessors = global_state.cache_preprocessors(
global_state.cn_preprocessor_modules
)
return {"control_net_max_models_num":max_models_num}

cached_cn_preprocessors = global_state.cache_preprocessors(global_state.cn_preprocessor_modules)
@app.post("/controlnet/detect")
async def detect(
controlnet_module: str = Body("none", title="Controlnet Module"),
controlnet_input_images: List[str] = Body([], title="Controlnet Input Images"),
controlnet_processor_res: int = Body(
512, title="Controlnet Processor Resolution"
),
controlnet_threshold_a: float = Body(64, title="Controlnet Threshold a"),
controlnet_threshold_b: float = Body(64, title="Controlnet Threshold b"),
controlnet_module: str = Body("none", title='Controlnet Module'),
controlnet_input_images: List[str] = Body([], title='Controlnet Input Images'),
controlnet_processor_res: int = Body(512, title='Controlnet Processor Resolution'),
controlnet_threshold_a: float = Body(64, title='Controlnet Threshold a'),
controlnet_threshold_b: float = Body(64, title='Controlnet Threshold b')
):
controlnet_module = global_state.reverse_preprocessor_aliases.get(
controlnet_module, controlnet_module
)
controlnet_module = global_state.reverse_preprocessor_aliases.get(controlnet_module, controlnet_module)

if controlnet_module not in cached_cn_preprocessors:
raise HTTPException(status_code=422, detail="Module not available")
raise HTTPException(
status_code=422, detail="Module not available")

if len(controlnet_input_images) == 0:
raise HTTPException(status_code=422, detail="No image selected")
raise HTTPException(
status_code=422, detail="No image selected")

logger.info(
f"Detecting {str(len(controlnet_input_images))} images with the {controlnet_module} module."
)
logger.info(f"Detecting {str(len(controlnet_input_images))} images with the {controlnet_module} module.")

results = []
poses = []

processor_module = cached_cn_preprocessors[controlnet_module]

for input_image in controlnet_input_images:
img = external_code.to_base64_nparray(input_image)

class JsonAcceptor:
def __init__(self) -> None:
self.value = None

def accept(self, json_dict: dict) -> None:
self.value = json_dict

json_acceptor = JsonAcceptor()

results.append(
processor_module(
img,
res=controlnet_processor_res,
thr_a=controlnet_threshold_a,
thr_b=controlnet_threshold_b,
json_pose_callback=json_acceptor.accept,
)[0]
)

if "openpose" in controlnet_module:
assert json_acceptor.value is not None
poses.append(json_acceptor.value)
results.append(processor_module(img, res=controlnet_processor_res, thr_a=controlnet_threshold_a, thr_b=controlnet_threshold_b)[0])

global_state.cn_preprocessor_unloadable.get(controlnet_module, lambda: None)()
results64 = list(map(encode_to_base64, results))
res = {"images": results64, "info": "Success"}
if poses:
res["poses"] = poses

return res

return {"images": results64, "info": "Success"}

try:
import modules.script_callbacks as script_callbacks
Expand Down
7 changes: 3 additions & 4 deletions scripts/controlnet_ui/controlnet_ui_group.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import json
import gradio as gr
import functools
from copy import copy
Expand Down Expand Up @@ -653,10 +652,10 @@ def run_annotator(image, module, pres, pthr_a, pthr_b, t2i_w, t2i_h, pp, rm):

class JsonAcceptor:
def __init__(self) -> None:
self.value = None
self.value = ""

def accept(self, json_dict: dict) -> None:
self.value = json.dumps(json_dict)
def accept(self, json_string: str) -> None:
self.value = json_string

json_acceptor = JsonAcceptor()

Expand Down
19 changes: 10 additions & 9 deletions tests/annotator_tests/openpose_tests/json_encode_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import unittest
import numpy as np

Expand All @@ -14,19 +15,19 @@ def test_empty_list(self):
canvas_height = 1080
canvas_width = 1920
result = encode_poses_as_json(poses, canvas_height, canvas_width)
expected = {
expected = json.dumps({
'people': [],
'canvas_height': canvas_height,
'canvas_width': canvas_width,
}
self.assertDictEqual(result, expected)
}, indent=4)
self.assertEqual(result, expected)

def test_single_pose_no_keypoints(self):
poses = [PoseResult(BodyResult(None, 0, 0), None, None, None)]
canvas_height = 1080
canvas_width = 1920
result = encode_poses_as_json(poses, canvas_height, canvas_width)
expected = {
expected = json.dumps({
'people': [
{
'pose_keypoints_2d': None,
Expand All @@ -37,16 +38,16 @@ def test_single_pose_no_keypoints(self):
],
'canvas_height': canvas_height,
'canvas_width': canvas_width,
}
self.assertDictEqual(result, expected)
}, indent=4)
self.assertEqual(result, expected)

def test_single_pose_with_keypoints(self):
keypoints = [Keypoint(np.float32(0.5), np.float32(0.5)), None, Keypoint(0.6, 0.6)]
poses = [PoseResult(BodyResult(keypoints, 0, 0), keypoints, keypoints, keypoints)]
canvas_height = 1080
canvas_width = 1920
result = encode_poses_as_json(poses, canvas_height, canvas_width)
expected = {
expected = json.dumps({
'people': [
{
'pose_keypoints_2d': [
Expand All @@ -73,8 +74,8 @@ def test_single_pose_with_keypoints(self):
],
'canvas_height': canvas_height,
'canvas_width': canvas_width,
}
self.assertDictEqual(result, expected)
}, indent=4)
self.assertEqual(result, expected)

if __name__ == '__main__':
unittest.main()

0 comments on commit 7bdcb90

Please sign in to comment.