Skip to content

Commit

Permalink
feat: check files using modified time instead of hash
Browse files Browse the repository at this point in the history
  • Loading branch information
JanPokorny committed Oct 15, 2024
1 parent 4f02a2d commit 1da8498
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 65 deletions.
69 changes: 18 additions & 51 deletions executor/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,12 @@

use actix_web::{middleware::Logger, web, App, Error, HttpResponse, HttpServer};
use futures::StreamExt;
use futures::TryStreamExt;
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use std::collections::{HashMap, HashSet};
use std::collections::HashSet;
use std::env;
use std::path::Path;
use std::time::Duration;
use std::time::{Duration, SystemTime};
use tempfile::TempDir;
use tokio_stream::wrappers::ReadDirStream;
use tokio_util::io::ReaderStream;
use tokio::fs::{self, OpenOptions};
use tokio::io::{AsyncWriteExt, AsyncBufReadExt};
use tokio::process::Command;
Expand All @@ -39,14 +35,7 @@ struct ExecuteResult {
stdout: String,
stderr: String,
exit_code: i32,
files: Vec<File>,
}

#[derive(Serialize)]
struct File {
path: String,
old_hash: Option<String>,
new_hash: Option<String>,
files: Vec<String>,
}

static REQUIREMENTS: std::sync::LazyLock<HashSet<String>> = std::sync::LazyLock::new(|| {
Expand Down Expand Up @@ -103,35 +92,30 @@ async fn download_file(path: web::Path<String>) -> Result<HttpResponse, Error> {
.streaming(tokio_util::io::ReaderStream::new(file)))
}

async fn calculate_sha256(path: &str) -> Result<String, Box<dyn std::error::Error>> {
let file = tokio::fs::File::open(path).await?;
let stream = ReaderStream::new(file);
let mut hasher = Sha256::new();
let mut stream = stream.map_err(|e: std::io::Error| e);
while let Some(chunk) = stream.try_next().await? { hasher.update(&chunk); }
Ok(format!("{:x}", hasher.finalize()))
}

async fn get_file_hashes(dir: &str) -> HashMap<String, String> {
let mut hashes = HashMap::new();
let mut entries = ReadDirStream::new(tokio::fs::read_dir(dir).await.unwrap());
while let Some(Ok(entry)) = entries.next().await {
async fn get_modified_files(dir: &str, since: SystemTime) -> Vec<String> {
let mut modified_files = Vec::new();
let mut read_dir = fs::read_dir(dir).await.unwrap();
while let Some(entry) = read_dir.next_entry().await.unwrap() {
let path = entry.path();
if !path.is_file() {
continue;
}
if let Some(path_str) = path.to_str() {
if let Ok(hash) = calculate_sha256(path_str).await {
hashes.insert(path_str.to_string(), hash);
if let Ok(metadata) = entry.metadata().await {
if let Ok(modified) = metadata.modified() {
if modified > since {
if let Some(path_str) = path.to_str() {
modified_files.push(path_str.to_string());
}
}
}
}
}
hashes
modified_files
}

async fn execute(payload: web::Json<ExecuteRequest>) -> Result<HttpResponse, Error> {
let workspace = env::var("APP_WORKSPACE").unwrap_or_else(|_| "/workspace".to_string());
let before_hashes = get_file_hashes(&workspace).await;
let execution_start_time = SystemTime::now();
let source_dir = TempDir::new()?;

tokio::fs::write(source_dir.path().join("script.py"), &payload.source_code).await?;
Expand Down Expand Up @@ -178,25 +162,8 @@ async fn execute(payload: web::Json<ExecuteRequest>) -> Result<HttpResponse, Err
})
})
.unwrap_or_else(|_| Ok((String::new(), "Execution timed out".to_string(), -1)))?;
let after_hashes = get_file_hashes(&workspace).await;
let files = before_hashes
.iter()
.map(|(path, old_hash)| File {
path: path.clone(),
old_hash: Some(old_hash.clone()),
new_hash: after_hashes.get(path).cloned(),
})
.chain(
after_hashes
.iter()
.filter(|(path, _)| !before_hashes.contains_key(*path))
.map(|(path, new_hash)| File {
path: path.clone(),
old_hash: None,
new_hash: Some(new_hash.clone()),
}),
)
.collect();

let files = get_modified_files(&workspace, execution_start_time).await;

Ok(HttpResponse::Ok().json(ExecuteResult {
stdout,
Expand Down
17 changes: 3 additions & 14 deletions src/code_interpreter/services/kubernetes_code_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,15 +122,7 @@ async def upload_file(file_path, file_hash):
)
).json()

changed_files = {
file["path"]: file["new_hash"]
for file in response["files"]
if file["old_hash"] != file["new_hash"] and file["new_hash"]
}

async def download_file(file_path, file_hash) -> str:
if await self.file_storage.exists(file_hash):
return
async def download_file(file_path) -> str:
async with self.file_storage.writer() as stored_file, client.stream(
"GET",
f"http://{executor_pod_ip}:8000/workspace/{file_path.removeprefix("/workspace/")}",
Expand All @@ -140,14 +132,11 @@ async def download_file(file_path, file_hash) -> str:
await stored_file.write(chunk)
return file_path, stored_file.hash

logger.info("Collecting %s changed files", len(changed_files))
logger.info("Collecting %s changed files", len(response["files"]))
stored_files = {
stored_file_path: stored_file_hash
for stored_file_path, stored_file_hash in await asyncio.gather(
*(
download_file(file_path, file_hash)
for file_path, file_hash in changed_files.items()
)
*(download_file(file_path) for file_path in response["files"])
)
}

Expand Down
2 changes: 2 additions & 0 deletions test/e2e/test_grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def test_create_file_in_interpreter(
)

assert response.exit_code == 0
assert response.files.keys() == {"/workspace/file.txt"}

response: ExecuteResponse = grpc_stub.Execute(
ExecuteRequest(
Expand All @@ -102,6 +103,7 @@ def test_create_file_in_interpreter(
)
assert response.exit_code == 0
assert response.stdout == file_content + "\n"
assert not response.files


def test_parse_custom_tool_success(grpc_stub: CodeInterpreterServiceStub):
Expand Down
2 changes: 2 additions & 0 deletions test/e2e/test_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def test_create_file_in_interpreter(http_client: httpx.Client, config: Config):
assert response.status_code == 200
response_json = response.json()
assert response_json["exit_code"] == 0
assert response_json["files"].keys() == {"/workspace/file.txt"}

# Read the file back
response = http_client.post(
Expand All @@ -81,6 +82,7 @@ def test_create_file_in_interpreter(http_client: httpx.Client, config: Config):
response_json = response.json()
assert response_json["exit_code"] == 0
assert response_json["stdout"] == file_content + "\n"
assert not response_json["files"]


def test_parse_custom_tool_success(http_client: httpx.Client):
Expand Down

0 comments on commit 1da8498

Please sign in to comment.