Skip to content

Commit

Permalink
feat: replace hash by random id (#3)
Browse files Browse the repository at this point in the history
  • Loading branch information
JanPokorny authored Oct 14, 2024
1 parent 495becb commit 026992e
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 49 deletions.
18 changes: 11 additions & 7 deletions src/code_interpreter/services/kubernetes_code_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ async def upload_file(file_path, file_hash):
if file["old_hash"] != file["new_hash"] and file["new_hash"]
}

async def download_file(file_path, file_hash):
async def download_file(file_path, file_hash) -> str:
if await self.file_storage.exists(file_hash):
return
async with self.file_storage.writer() as stored_file, client.stream(
Expand All @@ -138,20 +138,24 @@ async def download_file(file_path, file_hash):
pod_file.raise_for_status()
async for chunk in pod_file.aiter_bytes():
await stored_file.write(chunk)
return file_path, stored_file.hash

logger.info("Collecting %s changed files", len(changed_files))
await asyncio.gather(
*(
download_file(file_path, file_hash)
for file_path, file_hash in changed_files.items()
stored_files = {
stored_file_path: stored_file_hash
for stored_file_path, stored_file_hash in await asyncio.gather(
*(
download_file(file_path, file_hash)
for file_path, file_hash in changed_files.items()
)
)
)
}

return KubernetesCodeExecutor.Result(
stdout=response["stdout"],
stderr=response["stderr"],
exit_code=response["exit_code"],
files=changed_files,
files=stored_files,
)

async def fill_executor_pod_queue(self):
Expand Down
43 changes: 9 additions & 34 deletions src/code_interpreter/services/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,9 @@
# limitations under the License.

from contextlib import asynccontextmanager
import hashlib
import secrets
from typing import AsyncIterator, Protocol
from anyio import AsyncFile, Path
from anyio import Path
from pydantic import validate_call

from code_interpreter.utils.validation import Hash
Expand All @@ -27,8 +26,9 @@ async def read(self, size: int = -1) -> bytes: ...


class ObjectWriter(Protocol):
hash: str

async def write(self, data: bytes) -> None: ...
def hash(self) -> str: ...


class Storage:
Expand All @@ -46,46 +46,21 @@ async def writer(self) -> AsyncIterator[ObjectWriter]:
"""
Async context manager for writing a new object to the storage.
Internally, we write to a temporary file first, then rename it to the final name after writing is complete.
This is because the file hash is computed on-the-fly as the data is written.
This is internally done by wrapping the `write` function of the returned `ObjectWriter` to also call update the hash.
The final hash can be retrieved by using the `.hash()` method after the writing is completed.
The final hash can be retrieved by using the `.hash` attribute
"""

class _AsyncFileWrapper:
def __init__(self, file: AsyncFile[bytes]):
self._file = file
self._hash = hashlib.sha256()

async def write(self, data: bytes) -> None:
self._hash.update(data)
await self._file.write(data)

def hash(self) -> str:
return self._hash.hexdigest()

await self.storage_path.mkdir(parents=True, exist_ok=True)
tmp_name = f"tmp-{secrets.token_hex(32)}"
tmp_file = self.storage_path / tmp_name
try:
async with await tmp_file.open("wb") as f:
wrapped_file = _AsyncFileWrapper(f)
yield wrapped_file
final_name = wrapped_file.hash()
final_file = self.storage_path / final_name
if not await final_file.exists():
await tmp_file.rename(final_file)
finally:
if await tmp_file.exists():
await tmp_file.unlink()
hash = secrets.token_hex(32)
async with await (self.storage_path / hash).open("wb") as file:
file.__setattr__("hash", hash)
yield file

async def write(self, data: bytes) -> str:
"""
Writes the data to the storage and returns the hash of the object.
"""
async with self.writer() as f:
await f.write(data)
return f.hash()
return f.hash

@asynccontextmanager
@validate_call
Expand Down
5 changes: 1 addition & 4 deletions test/e2e/test_grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import json
from pathlib import Path
import grpc
import hashlib
import pytest
from code_interpreter.config import Config
from proto.code_interpreter.v1.code_interpreter_service_pb2 import (
Expand Down Expand Up @@ -80,7 +79,6 @@ def test_create_file_in_interpreter(
grpc_stub: CodeInterpreterServiceStub, config: Config
):
file_content = "Hello, World!"
file_hash = hashlib.sha256(file_content.encode()).hexdigest()

response: ExecuteResponse = grpc_stub.Execute(
ExecuteRequest(
Expand All @@ -92,15 +90,14 @@ def test_create_file_in_interpreter(
)

assert response.exit_code == 0
assert response.files["/workspace/file.txt"] == file_hash

response: ExecuteResponse = grpc_stub.Execute(
ExecuteRequest(
source_code="""
with open('file.txt', 'r') as f:
print(f.read())
""",
files={"/workspace/file.txt": file_hash},
files={"/workspace/file.txt": response.files["/workspace/file.txt"]},
)
)
assert response.exit_code == 0
Expand Down
7 changes: 3 additions & 4 deletions test/e2e/test_http.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import json
from pathlib import Path
import hashlib
import pytest
import httpx
from code_interpreter.config import Config
Expand Down Expand Up @@ -47,7 +46,6 @@ def test_ad_hoc_import(http_client: httpx.Client):

def test_create_file_in_interpreter(http_client: httpx.Client, config: Config):
file_content = "Hello, World!"
file_hash = hashlib.sha256(file_content.encode()).hexdigest()

# Create the file in the workspace
response = http_client.post(
Expand All @@ -64,7 +62,6 @@ def test_create_file_in_interpreter(http_client: httpx.Client, config: Config):
assert response.status_code == 200
response_json = response.json()
assert response_json["exit_code"] == 0
assert response_json["files"]["/workspace/file.txt"] == file_hash

# Read the file back
response = http_client.post(
Expand All @@ -74,7 +71,9 @@ def test_create_file_in_interpreter(http_client: httpx.Client, config: Config):
with open('file.txt', 'r') as f:
print(f.read())
""",
"files": {"/workspace/file.txt": file_hash},
"files": {
"/workspace/file.txt": response_json["files"]["/workspace/file.txt"]
},
},
)

Expand Down

0 comments on commit 026992e

Please sign in to comment.