Skip to content

Commit

Permalink
Add utils for multimedia data (microsoft#805)
Browse files Browse the repository at this point in the history
# Description

Add utils for multimedia data.

#### Changes:
1. Provide some common functions in `multimedia_utils`.
2. Support image inputs in jinja template.

# All Promptflow Contribution checklist:
- [ ] **The pull request does not introduce [breaking changes].**
- [ ] **CHANGELOG is updated for new features, bug fixes or other
significant changes.**
- [x] **I have read the [contribution guidelines](../CONTRIBUTING.md).**
- [ ] **Create an issue and link to the pull request to get dedicated
review from promptflow team. Learn more: [suggested
workflow](../CONTRIBUTING.md#suggested-workflow).**

## General Guidelines and Best Practices
- [x] Title of the pull request is clear and informative.
- [x] There are a small number of commits, each of which have an
informative message. This means that previously merged commits do not
appear in the history of the PR. For more information on cleaning up the
commits in your PR, [see this
page](https://github.com/Azure/azure-powershell/blob/master/documentation/development-docs/cleaning-up-commits.md).

### Testing Guidelines
- [ ] Pull request includes test coverage for the included changes.

---------

Co-authored-by: Lina Tang <[email protected]>
  • Loading branch information
lumoslnt and Lina Tang authored Oct 19, 2023
1 parent 9581318 commit 983b53a
Show file tree
Hide file tree
Showing 11 changed files with 254 additions and 229 deletions.
4 changes: 2 additions & 2 deletions src/promptflow/promptflow/_core/run_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
from promptflow._utils.dataclass_serializer import serialize
from promptflow._utils.exception_utils import ExceptionPresenter
from promptflow._utils.logger_utils import flow_logger
from promptflow._utils.multimedia_utils import default_json_encoder
from promptflow._utils.openai_metrics_calculator import OpenAIMetricsCalculator
from promptflow.contracts.multimedia import PFBytes
from promptflow.contracts.run_info import FlowRunInfo, RunInfo, Status
from promptflow.contracts.run_mode import RunMode
from promptflow.contracts.tool import ConnectionType
Expand Down Expand Up @@ -241,7 +241,7 @@ def _ensure_serializable_value(self, val, warning_msg: Optional[str] = None):
if self.allow_generator_types and isinstance(val, GeneratorType):
return str(val)
try:
json.dumps(val, default=PFBytes.default_json_encoder)
json.dumps(val, default=default_json_encoder)
return val
except Exception:
if not warning_msg:
Expand Down
4 changes: 2 additions & 2 deletions src/promptflow/promptflow/_core/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from promptflow._core.generator_proxy import GeneratorProxy, generate_from_proxy
from promptflow._utils.dataclass_serializer import serialize
from promptflow.contracts.multimedia import PFBytes
from promptflow._utils.multimedia_utils import default_json_encoder
from promptflow.contracts.tool import ConnectionType
from promptflow.contracts.trace import Trace, TraceType

Expand Down Expand Up @@ -81,7 +81,7 @@ def to_serializable(obj):
return obj
try:
obj = serialize(obj)
json.dumps(obj, default=PFBytes.default_json_encoder)
json.dumps(obj, default=default_json_encoder)
except Exception:
# We don't want to fail the whole function call because of a serialization error,
# so we simply convert it to str if it cannot be serialized.
Expand Down
3 changes: 2 additions & 1 deletion src/promptflow/promptflow/_sdk/operations/_test_submitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from promptflow._utils.context_utils import _change_working_dir
from promptflow._utils.dataclass_serializer import serialize
from promptflow._utils.exception_utils import ErrorResponse
from promptflow._utils.multimedia_utils import persist_multimedia_date
from promptflow.contracts.flow import Flow as ExecutableFlow
from promptflow.contracts.run_info import Status
from promptflow.exceptions import UserErrorException
Expand Down Expand Up @@ -157,7 +158,7 @@ def flow_test(
)
flow_executor.enable_streaming_for_llm_flow(lambda: True)
line_result = flow_executor.exec_line(inputs, index=0, allow_generator_output=allow_generator_output)
line_result.output = flow_executor._persist_images_from_output(
line_result.output = persist_multimedia_date(
line_result.output, base_dir=self.flow.code, sub_dir=Path(".promptflow/output")
)
if line_result.aggregation_inputs:
Expand Down
195 changes: 195 additions & 0 deletions src/promptflow/promptflow/_utils/multimedia_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
import base64
import imghdr
import os
import re
import requests
import uuid

from functools import partial
from pathlib import Path
from typing import Any, Callable
from urllib.parse import urlparse

from promptflow.contracts._errors import InvalidImageInput
from promptflow.contracts.multimedia import Image, PFBytes
from promptflow.exceptions import ErrorTarget

MIME_PATTERN = re.compile(r"^data:image/(.*);(path|base64|url)$")


def get_mime_type_from_path(path: Path):
ext = path.suffix[1:]
return f"image/{ext}" if ext else "image/*"


def get_extension_from_mime_type(mime_type: str):
ext = mime_type.split("/")[-1]
if ext == "*":
return None
return ext


def is_multimedia_dict(multimedia_dict: dict):
if len(multimedia_dict) != 1:
return False
key = list(multimedia_dict.keys())[0]
if re.match(MIME_PATTERN, key):
return True
return False


def get_multimedia_info(key: str):
match = re.match(MIME_PATTERN, key)
if match:
return match.group(1), match.group(2)
return None, None


def is_url(value: str):
try:
result = urlparse(value)
return all([result.scheme, result.netloc])
except ValueError:
return False


def is_base64(value: str):
base64_regex = re.compile(r"^([A-Za-z0-9+/]{4})*(([A-Za-z0-9+/]{2})*(==|[A-Za-z0-9+/]=)?)?$")
if re.match(base64_regex, value):
return True
return False


def create_image_from_file(f: Path, mime_type: str = None):
if not mime_type:
mime_type = get_mime_type_from_path(f)
with open(f, "rb") as fin:
return Image(fin.read(), mime_type=mime_type)


def create_image_from_base64(base64_str: str, mime_type: str = None):
image_bytes = base64.b64decode(base64_str)
if not mime_type:
format = imghdr.what(None, image_bytes)
mime_type = f"image/{format}" if format else "image/*"
return Image(image_bytes, mime_type=mime_type)


def create_image_from_url(url: str, mime_type: str = None):
response = requests.get(url)
if response.status_code == 200:
if not mime_type:
format = imghdr.what(None, response.content)
mime_type = f"image/{format}" if format else "image/*"
return Image(response.content, mime_type=mime_type)
else:
raise InvalidImageInput(
message_format=f"Error while fetching image from URL: {url}. "
"Error code: {response.status_code}. Error message: {response.text}.",
target=ErrorTarget.EXECUTOR,
)


def create_image_from_dict(image_dict: dict):
for k, v in image_dict.items():
format, resource = get_multimedia_info(k)
if resource == "path":
return create_image_from_file(v, mime_type=f"image/{format}")
elif resource == "base64":
return create_image_from_base64(v, mime_type=f"image/{format}")
elif resource == "url":
return create_image_from_url(v, mime_type=f"image/{format}")
else:
raise InvalidImageInput(
message_format=f"Unsupported image resource: {resource}. "
"Supported Resources are [path, base64, url].",
target=ErrorTarget.EXECUTOR,
)


def create_image_from_string(value: str, base_dir: Path = None):
if is_base64(value):
return create_image_from_base64(value)
elif is_url(value):
return create_image_from_url(value)
else:
path = Path(value)
if base_dir and not path.is_absolute():
path = Path.joinpath(base_dir, path)
return create_image_from_file(path)


def create_image(value: any, base_dir: Path = None):
if isinstance(value, PFBytes):
return value
elif isinstance(value, dict):
if is_multimedia_dict(value):
return create_image_from_dict(value)
else:
raise InvalidImageInput(
message_format="Invalid image input format. The image input should be a dictionary like: "
"{data:image/<image_type>;[path|base64|url]: <image_data>}.",
target=ErrorTarget.EXECUTOR,
)
elif isinstance(value, str):
return create_image_from_string(value, base_dir)
else:
raise InvalidImageInput(
message_format=f"Unsupported image input type: {type(value)}. "
"The image inputs should be a string or a dictionary.",
target=ErrorTarget.EXECUTOR,
)


def save_image_to_file(image: Image, file_name: str, folder_path: Path, relative_path: Path = None):
ext = get_extension_from_mime_type(image._mime_type)
file_name = f"{file_name}.{ext}" if ext else file_name
image_reference = {
f"data:{image._mime_type};path": str(relative_path / file_name) if relative_path else file_name
}
path = folder_path / relative_path if relative_path else folder_path
os.makedirs(path, exist_ok=True)
with open(os.path.join(path, file_name), 'wb') as file:
file.write(image)
return image_reference


def get_file_reference_encoder(folder_path: Path, relative_path: Path = None) -> Callable:
def pfbytes_file_reference_encoder(obj):
"""Dumps PFBytes to a file and returns its reference."""
if isinstance(obj, PFBytes):
file_name = str(uuid.uuid4())
return save_image_to_file(obj, file_name, folder_path, relative_path)
raise TypeError(f"Not supported to dump type '{type(obj).__name__}'.")
return pfbytes_file_reference_encoder


def default_json_encoder(obj):
if isinstance(obj, PFBytes):
return str(obj)
else:
raise TypeError(f"Object of type {type(obj).__name__} is not JSON serializable")


def persist_multimedia_date(value: Any, base_dir: Path, sub_dir: Path = None):
pfbytes_file_reference_encoder = get_file_reference_encoder(base_dir, sub_dir)
serialization_funcs = {Image: partial(Image.serialize, **{"encoder": pfbytes_file_reference_encoder})}
return recursive_process(value, process_funcs=serialization_funcs)


def convert_multimedia_date_to_base64(value: Any):
to_base64_funcs = {PFBytes: PFBytes.to_base64}
return recursive_process(value, process_funcs=to_base64_funcs)


# TODO: Move this function to a more general place and integrate serialization to this function.
def recursive_process(value: Any, process_funcs: dict[type, Callable] = None) -> dict:
if process_funcs:
for cls, f in process_funcs.items():
if isinstance(value, cls):
return f(value)
if isinstance(value, list):
return [recursive_process(v, process_funcs) for v in value]
if isinstance(value, dict):
return {k: recursive_process(v, process_funcs) for k, v in value.items()}
return value
Loading

0 comments on commit 983b53a

Please sign in to comment.