Skip to content

Commit 943cb4f

Browse files
authored
add various improvements for tools (#26)
* add utils for validating hashes * ruff formatting * ensure chain_id is consistent * simplify tool calling * add web3 for address validation * ruff format * add Literal for known inputs * allow response_format to be added to Nebula tools * filter event and transaction outputs * allow chain_id to be None * default to ethereum with insight * update versions
1 parent 1fb054c commit 943cb4f

File tree

11 files changed

+272
-296
lines changed

11 files changed

+272
-296
lines changed

python/thirdweb-ai/pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "thirdweb-ai"
3-
version = "0.1.9"
3+
version = "0.1.10"
44
description = "thirdweb AI"
55
authors = [{ name = "thirdweb", email = "[email protected]" }]
66
requires-python = ">=3.10,<4.0"
@@ -20,6 +20,7 @@ dependencies = [
2020
"jsonref>=1.1.0,<2",
2121
"httpx>=0.28.1,<0.29",
2222
"aiohttp>=3.11.14",
23+
"web3>=7.9.0",
2324
]
2425

2526
[project.optional-dependencies]
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import re
2+
3+
from web3 import Web3
4+
5+
6+
def validate_address(address: str) -> str:
7+
if not address.startswith("0x") or len(address) != 42:
8+
raise ValueError(f"Invalid blockchain address format: {address}")
9+
10+
if not Web3.is_checksum_address(address):
11+
try:
12+
return Web3.to_checksum_address(address)
13+
except ValueError as e:
14+
raise ValueError(f"Invalid blockchain address: {address}") from e
15+
16+
return address
17+
18+
19+
def validate_transaction_hash(tx_hash: str) -> str:
20+
pattern = re.compile(r"^0x[a-fA-F0-9]{64}$")
21+
if bool(re.fullmatch(pattern, tx_hash)):
22+
return tx_hash
23+
raise ValueError(f"Invalid transaction hash: {tx_hash}")
24+
25+
26+
def validate_block_identifier(block_id: str) -> str:
27+
if block_id.startswith("0x"):
28+
pattern = re.compile(r"^0x[a-fA-F0-9]{64}$")
29+
if bool(re.fullmatch(pattern, block_id)):
30+
return block_id
31+
elif block_id.isdigit():
32+
return block_id
33+
34+
raise ValueError(f"Invalid block identifier: {block_id}")
35+
36+
37+
def validate_signature(signature: str) -> str:
38+
# Function selector (4 bytes)
39+
if signature.startswith("0x") and len(signature) == 10:
40+
pattern = re.compile(r"^0x[a-fA-F0-9]{8}$")
41+
if bool(re.fullmatch(pattern, signature)):
42+
return signature
43+
# Event topic (32 bytes)
44+
elif signature.startswith("0x") and len(signature) == 66:
45+
pattern = re.compile(r"^0x[a-fA-F0-9]{64}$")
46+
if bool(re.fullmatch(pattern, signature)):
47+
return signature
48+
# Plain text signature (e.g. "transfer(address,uint256)")
49+
elif "(" in signature and ")" in signature:
50+
return signature
51+
52+
raise ValueError(f"Invalid function or event signature: {signature}")
Lines changed: 39 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,33 @@
11
import re
22
from typing import Any
33

4+
TRANSACTION_KEYS_TO_KEEP = [
5+
"hash",
6+
"block_number",
7+
"block_timestamp",
8+
"from_address",
9+
"to_address",
10+
"value",
11+
"decodedData",
12+
]
13+
EVENT_KEYS_TO_KEEP = [
14+
"block_number",
15+
"block_timestamp",
16+
"address",
17+
"transaction_hash",
18+
"transaction_index",
19+
"log_index",
20+
"topics",
21+
"data",
22+
"decodedData",
23+
]
24+
425

526
def extract_digits(value: int | str) -> int:
27+
"""Extract the integer value from a string or return the integer directly."""
28+
if isinstance(value, int):
29+
return value
30+
631
value_str = str(value).strip("\"'")
732
digit_match = re.search(r"\d+", value_str)
833

@@ -16,21 +41,8 @@ def extract_digits(value: int | str) -> int:
1641
return int(extracted_digits)
1742

1843

19-
def normalize_chain_id(
20-
in_value: int | str | list[int | str] | None,
21-
) -> int | list[int] | None:
22-
"""Normalize str values integers."""
23-
24-
if in_value is None:
25-
return None
26-
27-
if isinstance(in_value, list):
28-
return [extract_digits(c) for c in in_value]
29-
30-
return extract_digits(in_value)
31-
32-
3344
def is_encoded(encoded_data: str) -> bool:
45+
"""Check if a string is a valid hexadecimal value."""
3446
encoded_data = encoded_data.removeprefix("0x")
3547

3648
try:
@@ -41,10 +53,23 @@ def is_encoded(encoded_data: str) -> bool:
4153

4254

4355
def clean_resolve(out: dict[str, Any]):
56+
"""Clean the response from the resolve function."""
4457
if "transactions" in out["data"]:
4558
for transaction in out["data"]["transactions"]:
4659
if "data" in transaction and is_encoded(transaction["data"]):
4760
transaction.pop("data")
4861
if "logs_bloom" in transaction:
4962
transaction.pop("logs_bloom")
5063
return out
64+
65+
66+
def filter_response_keys(items: list[dict[str, Any]], keys_to_keep: list[str] | None) -> list[dict[str, Any]]:
67+
"""Filter the response items to only include the specified keys"""
68+
if not keys_to_keep:
69+
return items
70+
71+
for item in items:
72+
keys_to_remove = [key for key in item if key not in keys_to_keep]
73+
for key in keys_to_remove:
74+
item.pop(key, None)
75+
return items

python/thirdweb-ai/src/thirdweb_ai/services/engine.py

Lines changed: 33 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
from typing import Annotated, Any
1+
from typing import Annotated, Any, Literal
22

3-
from thirdweb_ai.common.utils import extract_digits, normalize_chain_id
3+
from thirdweb_ai.common.utils import extract_digits
44
from thirdweb_ai.services.service import Service
55
from thirdweb_ai.tools.tool import tool
66

@@ -10,15 +10,15 @@ def __init__(
1010
self,
1111
engine_url: str,
1212
engine_auth_jwt: str,
13-
chain_id: int | str | None = None,
13+
chain_id: int | None = None,
1414
backend_wallet_address: str | None = None,
1515
secret_key: str = "",
1616
):
1717
super().__init__(base_url=engine_url, secret_key=secret_key)
1818
self.engine_url = engine_url
1919
self.engine_auth_jwt = engine_auth_jwt
2020
self.backend_wallet_address = backend_wallet_address
21-
self.chain_id = normalize_chain_id(chain_id)
21+
self.chain_id = chain_id
2222

2323
def _make_headers(self):
2424
headers = super()._make_headers()
@@ -34,7 +34,7 @@ def _make_headers(self):
3434
def create_backend_wallet(
3535
self,
3636
wallet_type: Annotated[
37-
str,
37+
Literal["local", "smart:local"],
3838
"The type of backend wallet to create. Currently supported options are 'local' (stored locally in Engine's database) or 'smart:local' (for smart account wallets with advanced features). Choose 'local' for standard EOA wallets, and 'smart:local' for smart contract wallets with batching capabilities.",
3939
],
4040
label: Annotated[
@@ -76,7 +76,7 @@ def get_all_backend_wallet(
7676
def get_wallet_balance(
7777
self,
7878
chain_id: Annotated[
79-
str | int,
79+
int | None,
8080
"The numeric blockchain network ID to query (e.g., '1' for Ethereum mainnet, '137' for Polygon). If not provided, uses the default chain ID configured in the Engine instance.",
8181
],
8282
backend_wallet_address: Annotated[
@@ -85,9 +85,13 @@ def get_wallet_balance(
8585
] = None,
8686
) -> dict[str, Any]:
8787
"""Get wallet balance for native or ERC20 tokens."""
88-
normalized_chain = normalize_chain_id(chain_id) or self.chain_id
88+
if self.chain_id is not None and chain_id is None:
89+
chain_id = self.chain_id
90+
elif chain_id is None:
91+
raise ValueError("chain_id is required")
92+
8993
backend_wallet_address = backend_wallet_address or self.backend_wallet_address
90-
return self._get(f"backend-wallet/{normalized_chain}/{backend_wallet_address}/get-balance")
94+
return self._get(f"backend-wallet/{chain_id}/{backend_wallet_address}/get-balance")
9195

9296
@tool(
9397
description="Send an on-chain transaction. This powerful function can transfer native currency (ETH, MATIC), ERC20 tokens, or execute any arbitrary contract interaction. The transaction is signed and broadcast to the blockchain automatically."
@@ -107,7 +111,7 @@ def send_transaction(
107111
"The hexadecimal transaction data payload for contract interactions (e.g., '0x23b872dd...'). For simple native currency transfers, leave this empty. For ERC20 transfers or contract calls, this contains the ABI-encoded function call.",
108112
],
109113
chain_id: Annotated[
110-
str | int,
114+
int | None,
111115
"The numeric blockchain network ID to send the transaction on (e.g., '1' for Ethereum mainnet, '137' for Polygon). If not provided, uses the default chain ID configured in the Engine instance.",
112116
],
113117
backend_wallet_address: Annotated[
@@ -126,10 +130,14 @@ def send_transaction(
126130
"data": data or "0x",
127131
}
128132

129-
normalized_chain = normalize_chain_id(chain_id) or self.chain_id
133+
if self.chain_id is not None and chain_id is None:
134+
chain_id = self.chain_id
135+
elif chain_id is None:
136+
raise ValueError("chain_id is required")
137+
130138
backend_wallet_address = backend_wallet_address or self.backend_wallet_address
131139
return self._post(
132-
f"backend-wallet/{normalized_chain}/send-transaction",
140+
f"backend-wallet/{chain_id}/send-transaction",
133141
payload,
134142
headers={"X-Backend-Wallet-Address": backend_wallet_address},
135143
)
@@ -165,7 +173,7 @@ def read_contract(
165173
"An ordered list of arguments to pass to the function (e.g., [address, tokenId]). Must match the types and order expected by the function. For functions with no parameters, use an empty list or None.",
166174
],
167175
chain_id: Annotated[
168-
str | int,
176+
int | None,
169177
"The numeric blockchain network ID where the contract is deployed (e.g., '1' for Ethereum mainnet, '137' for Polygon). If not provided, uses the default chain ID configured in the Engine instance.",
170178
],
171179
) -> dict[str, Any]:
@@ -174,8 +182,12 @@ def read_contract(
174182
"functionName": function_name,
175183
"args": function_args or [],
176184
}
177-
normalized_chain = normalize_chain_id(chain_id) or self.chain_id
178-
return self._get(f"contract/{normalized_chain}/{contract_address}/read", payload)
185+
if self.chain_id is not None and chain_id is None:
186+
chain_id = self.chain_id
187+
elif chain_id is None:
188+
raise ValueError("chain_id is required")
189+
190+
return self._get(f"contract/{chain_id}/{contract_address}/read", payload)
179191

180192
@tool(
181193
description="Execute a state-changing function on a smart contract by sending a transaction. This allows you to modify on-chain data, such as transferring tokens, minting NFTs, or updating contract configuration. The transaction is automatically signed by your backend wallet and submitted to the blockchain."
@@ -199,7 +211,7 @@ def write_contract(
199211
"The amount of native currency (ETH, MATIC, etc.) to send with the transaction, in wei (e.g., '1000000000000000000' for 1 ETH). Required for payable functions, use '0' for non-payable functions. Default to '0'.",
200212
],
201213
chain_id: Annotated[
202-
str | int,
214+
int | None,
203215
"The numeric blockchain network ID where the contract is deployed (e.g., '1' for Ethereum mainnet, '137' for Polygon). If not provided, uses the default chain ID configured in the Engine instance.",
204216
],
205217
) -> dict[str, Any]:
@@ -212,9 +224,13 @@ def write_contract(
212224
if value and value != "0":
213225
payload["txOverrides"] = {"value": value}
214226

215-
normalized_chain = normalize_chain_id(chain_id) or self.chain_id
227+
if self.chain_id is not None and chain_id is None:
228+
chain_id = self.chain_id
229+
elif chain_id is None:
230+
raise ValueError("chain_id is required")
231+
216232
return self._post(
217-
f"contract/{normalized_chain}/{contract_address}/write",
233+
f"contract/{chain_id}/{contract_address}/write",
218234
payload,
219235
headers={"X-Backend-Wallet-Address": self.backend_wallet_address},
220236
)

0 commit comments

Comments
 (0)