Skip to content

Commit

Permalink
convert css shape to string (Skyvern-AI#1092)
Browse files Browse the repository at this point in the history
  • Loading branch information
LawyZheng authored Oct 30, 2024
1 parent 01fbdee commit 8762865
Show file tree
Hide file tree
Showing 7 changed files with 198 additions and 16 deletions.
158 changes: 152 additions & 6 deletions skyvern/forge/agent_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Dict, List

import structlog
from playwright.async_api import Page
from playwright.async_api import Frame, Page

from skyvern.config import settings
from skyvern.constants import SKYVERN_ID_ATTR
Expand All @@ -19,19 +19,54 @@

LOG = structlog.get_logger()

USELESS_SVG_ATTRIBUTE = [SKYVERN_ID_ATTR, "id", "aria-describedby"]
SVG_RETRY_ATTEMPT = 3
USELESS_SHAPE_ATTRIBUTE = [SKYVERN_ID_ATTR, "id", "aria-describedby"]
SHAPE_CONVERTION_RETRY_ATTEMPT = 3


def _remove_rect(element: dict) -> None:
if "rect" in element:
del element["rect"]


def _should_css_shape_convert(element: Dict) -> bool:
if "id" not in element:
return False

tag_name = element.get("tagName")
if tag_name not in ["a", "span", "i"]:
return False

# if <span> and <i> without any text in the element, we try to convert the shape
if tag_name in ["span", "i"] and not element.get("text"):
return True

# if <a>, it should be no text, no children, no href/target attribute
if tag_name == "a":
attributes = element.get("attributes", {})
if element.get("text"):
return False

if len(element.get("children", [])) > 0:
return False

if "href" in attributes:
return False

if "target" in attributes:
return False
return True

return False


def _get_svg_cache_key(hash: str) -> str:
return f"skyvern:svg:{hash}"


def _get_shape_cache_key(hash: str) -> str:
return f"skyvern:shape:{hash}"


def _remove_skyvern_attributes(element: Dict) -> Dict:
"""
To get the original HTML element without skyvern attributes
Expand All @@ -44,7 +79,7 @@ def _remove_skyvern_attributes(element: Dict) -> Dict:
if "attributes" in element_copied:
attributes: dict = copy.deepcopy(element_copied.get("attributes", {}))
for key in attributes.keys():
if key in USELESS_SVG_ATTRIBUTE:
if key in USELESS_SHAPE_ATTRIBUTE:
del element_copied["attributes"][key]

children: List[Dict] | None = element_copied.get("children", None)
Expand Down Expand Up @@ -80,6 +115,8 @@ async def _convert_svg_to_string(task: Task, step: Step, organization: Organizat
except Exception:
LOG.warning(
"Failed to loaded SVG cache",
task_id=task.task_id,
step_id=step.step_id,
exc_info=True,
key=svg_key,
)
Expand All @@ -92,6 +129,8 @@ async def _convert_svg_to_string(task: Task, step: Step, organization: Organizat
LOG.warning(
"SVG element is too large to convert, going to drop the svg element.",
element_id=element_id,
task_id=task.task_id,
step_id=step.step_id,
length=len(svg_html),
)
del element["children"]
Expand All @@ -101,7 +140,7 @@ async def _convert_svg_to_string(task: Task, step: Step, organization: Organizat
LOG.debug("call LLM to convert SVG to string shape", element_id=element_id)
svg_convert_prompt = prompt_engine.load_prompt("svg-convert", svg_element=svg_html)

for retry in range(SVG_RETRY_ATTEMPT):
for retry in range(SHAPE_CONVERTION_RETRY_ATTEMPT):
try:
json_response = await app.SECONDARY_LLM_API_HANDLER(prompt=svg_convert_prompt, step=step)
svg_shape = json_response.get("shape", "")
Expand All @@ -113,6 +152,8 @@ async def _convert_svg_to_string(task: Task, step: Step, organization: Organizat
except Exception:
LOG.exception(
"Failed to convert SVG to string shape by secondary llm. Will retry if haven't met the max try attempt after 3s.",
task_id=task.task_id,
step_id=step.step_id,
element_id=element_id,
retry=retry,
)
Expand All @@ -126,6 +167,101 @@ async def _convert_svg_to_string(task: Task, step: Step, organization: Organizat
return


async def _convert_css_shape_to_string(
task: Task, step: Step, organization: Organization | None, frame: Page | Frame, element: Dict
) -> None:
element_id: str = element.get("id", "")

shape_element = _remove_skyvern_attributes(element)
svg_html = json_to_html(shape_element)
hash_object = hashlib.sha256()
hash_object.update(svg_html.encode("utf-8"))
shape_hash = hash_object.hexdigest()
shape_key = _get_shape_cache_key(shape_hash)

css_shape: str | None = None
try:
css_shape = await app.CACHE.get(shape_key)
except Exception:
LOG.warning(
"Failed to loaded CSS shape cache",
task_id=task.task_id,
step_id=step.step_id,
exc_info=True,
key=shape_key,
)

if css_shape:
LOG.debug("CSS shape loaded from cache", element_id=element_id, shape=css_shape)
else:
# FIXME: support element in iframe
locater = frame.locator(f'[{SKYVERN_ID_ATTR}="{element_id}"]')
if await locater.count() == 0:
LOG.info(
"No locater found to convert css shape",
task_id=task.task_id,
step_id=step.step_id,
element_id=element_id,
)
return None

if await locater.count() > 1:
LOG.info(
"multiple locaters found to convert css shape",
task_id=task.task_id,
step_id=step.step_id,
element_id=element_id,
)
return None

try:
LOG.debug("call LLM to convert css shape to string shape", element_id=element_id)
screenshot = await locater.screenshot(timeout=settings.BROWSER_SCREENSHOT_TIMEOUT_MS)
prompt = prompt_engine.load_prompt("css-shape-convert")

for retry in range(SHAPE_CONVERTION_RETRY_ATTEMPT):
try:
json_response = await app.SECONDARY_LLM_API_HANDLER(
prompt=prompt, screenshots=[screenshot], step=step
)
css_shape = json_response.get("shape", "")
if not css_shape:
raise Exception("Empty css shape replied by secondary llm")
LOG.info("CSS Shape converted by LLM", element_id=element_id, shape=css_shape)
await app.CACHE.set(shape_key, css_shape)
break
except Exception:
LOG.exception(
"Failed to convert css shape to string shape by secondary llm. Will retry if haven't met the max try attempt after 3s.",
task_id=task.task_id,
step_id=step.step_id,
element_id=element_id,
retry=retry,
)
await asyncio.sleep(3)
else:
LOG.info(
"Max css shape convertion retry, going to abort the convertion.",
task_id=task.task_id,
step_id=step.step_id,
element_id=element_id,
)
return None
except Exception:
LOG.exception(
"Failed to convert css shape to string shape by LLM",
task_id=task.task_id,
step_id=step.step_id,
element_id=element_id,
)
return None

if "attributes" not in element:
element["attributes"] = dict()
element["attributes"]["shape-description"] = css_shape
return None


class AgentFunction:
async def validate_step_execution(
self,
Expand Down Expand Up @@ -181,7 +317,7 @@ def cleanup_element_tree_factory(
step: Step,
organization: Organization | None = None,
) -> CleanupElementTreeFunc:
async def cleanup_element_tree_func(url: str, element_tree: list[dict]) -> list[dict]:
async def cleanup_element_tree_func(frame: Page | Frame, url: str, element_tree: list[dict]) -> list[dict]:
"""
Remove rect and attribute.unique_id from the elements.
The reason we're doing it is to
Expand All @@ -197,6 +333,16 @@ async def cleanup_element_tree_func(url: str, element_tree: list[dict]) -> list[
queue_ele = queue.pop(0)
_remove_rect(queue_ele)
await _convert_svg_to_string(task, step, organization, queue_ele)

if _should_css_shape_convert(element=queue_ele):
await _convert_css_shape_to_string(
task=task,
step=step,
organization=organization,
frame=frame,
element=queue_ele,
)

# TODO: we can come back to test removing the unique_id
# from element attributes to make sure this won't increase hallucination
# _remove_unique_id(queue_ele)
Expand Down
8 changes: 8 additions & 0 deletions skyvern/forge/prompts/skyvern/css-shape-convert.j2
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
You are given a screenshot of an HTML element. You need to figure out what its shape means.

MAKE SURE YOU OUTPUT VALID JSON. No text before or after JSON, no trailing commas, no comments (//), no unnecessary quotes, etc.
Reply in JSON format with the following keys:
{
"confidence_float": float, // The confidence of the action. Pick a number between 0.0 and 1.0. 0.0 means no confidence, 1.0 means full confidence
"shape": string, // A short description of the shape of element and its meaning
}
10 changes: 6 additions & 4 deletions skyvern/webeye/actions/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import structlog
from deprecation import deprecated
from playwright.async_api import FileChooser, Locator, Page, TimeoutError
from playwright.async_api import FileChooser, Frame, Locator, Page, TimeoutError
from pydantic import BaseModel

from skyvern.constants import REPO_ROOT_DIR, SKYVERN_ID_ATTR
Expand Down Expand Up @@ -165,8 +165,10 @@ def remove_exist_elements(element_tree: list[dict], check_exist: CheckExistIDFun
def clean_and_remove_element_tree_factory(
task: Task, step: Step, check_exist_funcs: list[CheckExistIDFunc]
) -> CleanupElementTreeFunc:
async def helper_func(url: str, element_tree: list[dict]) -> list[dict]:
element_tree = await app.AGENT_FUNCTION.cleanup_element_tree_factory(task=task, step=step)(url, element_tree)
async def helper_func(frame: Page | Frame, url: str, element_tree: list[dict]) -> list[dict]:
element_tree = await app.AGENT_FUNCTION.cleanup_element_tree_factory(task=task, step=step)(
frame, url, element_tree
)
for check_exist in check_exist_funcs:
element_tree = remove_exist_elements(element_tree=element_tree, check_exist=check_exist)
return element_tree
Expand Down Expand Up @@ -1270,7 +1272,7 @@ async def choose_auto_completion_dropdown(

if len(confirmed_preserved_list) > 0:
confirmed_preserved_list = await app.AGENT_FUNCTION.cleanup_element_tree_factory(task=task, step=step)(
skyvern_frame.get_frame().url, copy.deepcopy(confirmed_preserved_list)
skyvern_frame.get_frame(), skyvern_frame.get_frame().url, copy.deepcopy(confirmed_preserved_list)
)
confirmed_preserved_list = trim_element_tree(copy.deepcopy(confirmed_preserved_list))

Expand Down
16 changes: 16 additions & 0 deletions skyvern/webeye/scraper/domUtils.js
Original file line number Diff line number Diff line change
Expand Up @@ -686,6 +686,14 @@ const checkRequiredFromStyle = (element) => {
return element.className.toLowerCase().includes("require");
};

function checkDisabledFromStyle(element) {
const className = element.className.toString().toLowerCase();
if (className.includes("react-datepicker__day--disabled")) {
return true;
}
return false;
}

function getElementContext(element) {
// dfs to collect the non unique_id context
let fullContext = new Array();
Expand Down Expand Up @@ -872,6 +880,14 @@ function buildElementObject(frame, element, interactable, purgeable = false) {
attrs[attr.name] = attrValue;
}

if (
checkDisabledFromStyle(element) &&
!attrs["disabled"] &&
!attrs["aria-disabled"]
) {
attrs["disabled"] = true;
}

if (
checkRequiredFromStyle(element) &&
!attrs["required"] &&
Expand Down
11 changes: 6 additions & 5 deletions skyvern/webeye/scraper/scraper.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@
from skyvern.webeye.utils.page import SkyvernFrame

LOG = structlog.get_logger()
CleanupElementTreeFunc = Callable[[str, list[dict]], Awaitable[list[dict]]]
CleanupElementTreeFunc = Callable[[Page | Frame, str, list[dict]], Awaitable[list[dict]]]

RESERVED_ATTRIBUTES = {
"accept", # for input file
"alt",
"shape-description", # for css shape
"aria-checked", # for option tag
"aria-current",
"aria-label",
Expand Down Expand Up @@ -122,8 +123,8 @@ def json_to_html(element: dict, need_skyvern_attrs: bool = True) -> str:
if element.get("purgeable", False):
return children_html + option_html

before_pseudo_text = element.get("beforePseudoText", "")
after_pseudo_text = element.get("afterPseudoText", "")
before_pseudo_text = element.get("beforePseudoText") or ""
after_pseudo_text = element.get("afterPseudoText") or ""

# Check if the element is self-closing
if (
Expand Down Expand Up @@ -347,7 +348,7 @@ async def scrape_web_unsafe(
screenshots = await SkyvernFrame.take_split_screenshots(page=page, url=url, draw_boxes=True)

elements, element_tree = await get_interactable_element_tree(page, scrape_exclude)
element_tree = await cleanup_element_tree(url, copy.deepcopy(element_tree))
element_tree = await cleanup_element_tree(page, url, copy.deepcopy(element_tree))

id_to_css_dict, id_to_element_dict, id_to_frame_dict, id_to_element_hash, hash_to_element_ids = build_element_dict(
elements
Expand Down Expand Up @@ -486,7 +487,7 @@ async def get_incremental_element_tree(

self.elements = incremental_elements

incremental_tree = await cleanup_element_tree(frame.url, copy.deepcopy(incremental_tree))
incremental_tree = await cleanup_element_tree(frame, frame.url, copy.deepcopy(incremental_tree))
trimmed_element_tree = trim_element_tree(copy.deepcopy(incremental_tree))

self.element_tree = incremental_tree
Expand Down
7 changes: 6 additions & 1 deletion skyvern/webeye/utils/dom.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
)
from skyvern.forge.sdk.settings_manager import SettingsManager
from skyvern.webeye.scraper.scraper import IncrementalScrapePage, ScrapedPage, json_to_html, trim_element
from skyvern.webeye.utils.page import SkyvernFrame

LOG = structlog.get_logger()

Expand Down Expand Up @@ -224,10 +225,14 @@ async def is_disabled(self, dynamic: bool = False) -> bool:

disabled_attr: bool | str | None = None
aria_disabled_attr: bool | str | None = None
style_disabled: bool = False

try:
disabled_attr = await self.get_attr("disabled", dynamic=dynamic)
aria_disabled_attr = await self.get_attr("aria-disabled", dynamic=dynamic)
skyvern_frame = await SkyvernFrame.create_instance(self.get_frame())
style_disabled = await skyvern_frame.get_disabled_from_style(await self.get_element_handler())

except Exception:
# FIXME: maybe it should be considered as "disabled" element if failed to get the attributes?
LOG.exception(
Expand All @@ -250,7 +255,7 @@ async def is_disabled(self, dynamic: bool = False) -> bool:
if isinstance(aria_disabled_attr, str):
aria_disabled = aria_disabled_attr.lower() != "false"

return disabled or aria_disabled
return disabled or aria_disabled or style_disabled

async def is_selectable(self) -> bool:
return self.get_selectable() or self.get_tag_name() in SELECTABLE_ELEMENT
Expand Down
4 changes: 4 additions & 0 deletions skyvern/webeye/utils/page.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,10 @@ async def get_element_visible(self, element: ElementHandle) -> bool:
js_script = "(element) => isElementVisible(element) && !isHidden(element)"
return await self.frame.evaluate(js_script, element)

async def get_disabled_from_style(self, element: ElementHandle) -> bool:
js_script = "(element) => checkDisabledFromStyle(element)"
return await self.frame.evaluate(js_script, element)

async def scroll_to_top(self, draw_boxes: bool) -> float:
"""
Scroll to the top of the page and take a screenshot.
Expand Down

0 comments on commit 8762865

Please sign in to comment.