diff --git a/src/baby_list.zig b/src/baby_list.zig index 08745e2fe67088..0c30e7f258b909 100644 --- a/src/baby_list.zig +++ b/src/baby_list.zig @@ -304,7 +304,7 @@ pub fn BabyList(comptime Type: type) type { } pub fn one(allocator: std.mem.Allocator, value: Type) !ListType { - var items = try allocator.alloc(Type, 1); + var items = try allocator.allocAdvanced(Type, @alignOf(Type), 1, .exact); items[0] = value; return ListType{ .ptr = @ptrCast([*]Type, items.ptr), diff --git a/src/bun.js/bindings/headers-cpp.h b/src/bun.js/bindings/headers-cpp.h index 627b234cd59ece..4cb7a80e0b46b8 100644 --- a/src/bun.js/bindings/headers-cpp.h +++ b/src/bun.js/bindings/headers-cpp.h @@ -1,4 +1,4 @@ -//-- AUTOGENERATED FILE -- 1657408675 +//-- AUTOGENERATED FILE -- 1660175100 // clang-format off #pragma once diff --git a/src/bun.js/bindings/headers.h b/src/bun.js/bindings/headers.h index 02b2bbb2d271da..9fd86fb2ed4ba8 100644 --- a/src/bun.js/bindings/headers.h +++ b/src/bun.js/bindings/headers.h @@ -1,5 +1,5 @@ // clang-format off -//-- AUTOGENERATED FILE -- 1657408675 +//-- AUTOGENERATED FILE -- 1660175100 #pragma once #include @@ -393,6 +393,7 @@ CPP_DECL JSC__JSValue JSC__JSGlobalObject__generateHeapSnapshot(JSC__JSGlobalObj CPP_DECL JSC__GeneratorFunctionPrototype* JSC__JSGlobalObject__generatorFunctionPrototype(JSC__JSGlobalObject* arg0); CPP_DECL JSC__GeneratorPrototype* JSC__JSGlobalObject__generatorPrototype(JSC__JSGlobalObject* arg0); CPP_DECL JSC__JSValue JSC__JSGlobalObject__getCachedObject(JSC__JSGlobalObject* arg0, const ZigString* arg1); +CPP_DECL void JSC__JSGlobalObject__handleRejectedPromises(JSC__JSGlobalObject* arg0); CPP_DECL JSC__IteratorPrototype* JSC__JSGlobalObject__iteratorPrototype(JSC__JSGlobalObject* arg0); CPP_DECL JSC__JSObject* JSC__JSGlobalObject__jsSetPrototype(JSC__JSGlobalObject* arg0); CPP_DECL JSC__MapIteratorPrototype* JSC__JSGlobalObject__mapIteratorPrototype(JSC__JSGlobalObject* arg0); @@ -407,7 +408,6 @@ CPP_DECL bool JSC__JSGlobalObject__startRemoteInspector(JSC__JSGlobalObject* arg CPP_DECL JSC__StringPrototype* JSC__JSGlobalObject__stringPrototype(JSC__JSGlobalObject* arg0); CPP_DECL JSC__JSObject* JSC__JSGlobalObject__symbolPrototype(JSC__JSGlobalObject* arg0); CPP_DECL JSC__VM* JSC__JSGlobalObject__vm(JSC__JSGlobalObject* arg0); -CPP_DECL void JSC__JSGlobalObject__handleRejectedPromises(JSC__JSGlobalObject* arg0); #pragma mark - WTF::URL @@ -772,7 +772,7 @@ ZIG_DECL void Bun__WebSocketHTTPSClient__register(JSC__JSGlobalObject* arg0, voi ZIG_DECL void Bun__WebSocketClient__close(WebSocketClient* arg0, uint16_t arg1, const ZigString* arg2); ZIG_DECL void Bun__WebSocketClient__finalize(WebSocketClient* arg0); -ZIG_DECL void* Bun__WebSocketClient__init(void* arg0, void* arg1, void* arg2, JSC__JSGlobalObject* arg3); +ZIG_DECL void* Bun__WebSocketClient__init(void* arg0, void* arg1, void* arg2, JSC__JSGlobalObject* arg3, unsigned char* arg4, size_t arg5); ZIG_DECL void Bun__WebSocketClient__register(JSC__JSGlobalObject* arg0, void* arg1, void* arg2); ZIG_DECL void Bun__WebSocketClient__writeBinaryData(WebSocketClient* arg0, const unsigned char* arg1, size_t arg2); ZIG_DECL void Bun__WebSocketClient__writeString(WebSocketClient* arg0, const ZigString* arg1); @@ -783,7 +783,7 @@ ZIG_DECL void Bun__WebSocketClient__writeString(WebSocketClient* arg0, const Zig ZIG_DECL void Bun__WebSocketClientTLS__close(WebSocketClientTLS* arg0, uint16_t arg1, const ZigString* arg2); ZIG_DECL void Bun__WebSocketClientTLS__finalize(WebSocketClientTLS* arg0); -ZIG_DECL void* Bun__WebSocketClientTLS__init(void* arg0, void* arg1, void* arg2, JSC__JSGlobalObject* arg3); +ZIG_DECL void* Bun__WebSocketClientTLS__init(void* arg0, void* arg1, void* arg2, JSC__JSGlobalObject* arg3, unsigned char* arg4, size_t arg5); ZIG_DECL void Bun__WebSocketClientTLS__register(JSC__JSGlobalObject* arg0, void* arg1, void* arg2); ZIG_DECL void Bun__WebSocketClientTLS__writeBinaryData(WebSocketClientTLS* arg0, const unsigned char* arg1, size_t arg2); ZIG_DECL void Bun__WebSocketClientTLS__writeString(WebSocketClientTLS* arg0, const ZigString* arg1); diff --git a/src/bun.js/bindings/headers.zig b/src/bun.js/bindings/headers.zig index 80600f457f7402..a8379c290a7709 100644 --- a/src/bun.js/bindings/headers.zig +++ b/src/bun.js/bindings/headers.zig @@ -198,6 +198,7 @@ pub extern fn JSC__JSGlobalObject__generateHeapSnapshot(arg0: ?*JSC__JSGlobalObj pub extern fn JSC__JSGlobalObject__generatorFunctionPrototype(arg0: ?*JSC__JSGlobalObject) ?*bindings.GeneratorFunctionPrototype; pub extern fn JSC__JSGlobalObject__generatorPrototype(arg0: ?*JSC__JSGlobalObject) ?*bindings.GeneratorPrototype; pub extern fn JSC__JSGlobalObject__getCachedObject(arg0: ?*JSC__JSGlobalObject, arg1: [*c]const ZigString) JSC__JSValue; +pub extern fn JSC__JSGlobalObject__handleRejectedPromises(arg0: ?*JSC__JSGlobalObject) void; pub extern fn JSC__JSGlobalObject__iteratorPrototype(arg0: ?*JSC__JSGlobalObject) ?*bindings.IteratorPrototype; pub extern fn JSC__JSGlobalObject__jsSetPrototype(arg0: ?*JSC__JSGlobalObject) [*c]JSC__JSObject; pub extern fn JSC__JSGlobalObject__mapIteratorPrototype(arg0: ?*JSC__JSGlobalObject) ?*bindings.MapIteratorPrototype; @@ -212,7 +213,6 @@ pub extern fn JSC__JSGlobalObject__startRemoteInspector(arg0: ?*JSC__JSGlobalObj pub extern fn JSC__JSGlobalObject__stringPrototype(arg0: ?*JSC__JSGlobalObject) ?*bindings.StringPrototype; pub extern fn JSC__JSGlobalObject__symbolPrototype(arg0: ?*JSC__JSGlobalObject) [*c]JSC__JSObject; pub extern fn JSC__JSGlobalObject__vm(arg0: ?*JSC__JSGlobalObject) [*c]JSC__VM; -pub extern fn JSC__JSGlobalObject__handleRejectedPromises(arg0: ?*JSC__JSGlobalObject) void; pub extern fn WTF__URL__encodedPassword(arg0: [*c]WTF__URL) bWTF__StringView; pub extern fn WTF__URL__encodedUser(arg0: [*c]WTF__URL) bWTF__StringView; pub extern fn WTF__URL__fileSystemPath(arg0: [*c]WTF__URL) bWTF__String; diff --git a/src/bun.js/bindings/webcore/JSWebSocket.cpp b/src/bun.js/bindings/webcore/JSWebSocket.cpp index aa351fba39c19e..f140c365d63b61 100644 --- a/src/bun.js/bindings/webcore/JSWebSocket.cpp +++ b/src/bun.js/bindings/webcore/JSWebSocket.cpp @@ -633,11 +633,11 @@ bool JSWebSocketOwner::isReachableFromOpaqueRoots(JSC::Handle hand { auto* jsWebSocket = jsCast(handle.slot()->asCell()); auto& wrapped = jsWebSocket->wrapped(); - // if (!wrapped.isContextStopped() && wrapped.hasPendingActivity()) { - // if (UNLIKELY(reason)) - // *reason = "ActiveDOMObject with pending activity"; - // return true; - // } + if (wrapped.hasPendingActivity()) { + if (UNLIKELY(reason)) + *reason = "ActiveDOMObject with pending activity"; + return true; + } if (jsWebSocket->wrapped().isFiringEventListeners()) { if (UNLIKELY(reason)) *reason = "EventTarget firing event listeners"; diff --git a/src/bun.js/bindings/webcore/WebSocket.cpp b/src/bun.js/bindings/webcore/WebSocket.cpp index 736801a7529e06..9f8bf3ed6f047c 100644 --- a/src/bun.js/bindings/webcore/WebSocket.cpp +++ b/src/bun.js/bindings/webcore/WebSocket.cpp @@ -371,16 +371,19 @@ ExceptionOr WebSocket::connect(const String& url, const Vector& pr if (is_secure) { us_socket_context_t* ctx = scriptExecutionContext()->webSocketContext(); RELEASE_ASSERT(ctx); + this->m_pendingActivityCount++; this->m_upgradeClient = Bun__WebSocketHTTPSClient__connect(scriptExecutionContext()->jsGlobalObject(), ctx, this, &host, port, &path, &clientProtocolString); } else { us_socket_context_t* ctx = scriptExecutionContext()->webSocketContext(); RELEASE_ASSERT(ctx); + this->m_pendingActivityCount++; this->m_upgradeClient = Bun__WebSocketHTTPClient__connect(scriptExecutionContext()->jsGlobalObject(), ctx, this, &host, port, &path, &clientProtocolString); } if (this->m_upgradeClient == nullptr) { // context.addConsoleMessage(MessageSource::JS, MessageLevel::Error, ); m_state = CLOSED; + this->m_pendingActivityCount--; return Exception { SyntaxError, "WebSocket connection failed"_s }; } @@ -714,6 +717,9 @@ ScriptExecutionContext* WebSocket::scriptExecutionContext() const void WebSocket::didConnect() { + // from new WebSocket() -> connect() + this->m_pendingActivityCount--; + LOG(Network, "WebSocket %p didConnect()", this); // queueTaskKeepingObjectAlive(*this, TaskSource::WebSocket, [this] { if (m_state == CLOSED) @@ -725,10 +731,12 @@ void WebSocket::didConnect() m_state = OPEN; if (auto* context = scriptExecutionContext()) { + if (this->hasEventListeners("open"_s)) { // the main reason for dispatching on a separate tick is to handle when you haven't yet attached an event listener dispatchEvent(Event::create(eventNames().openEvent, Event::CanBubble::No, Event::IsCancelable::No)); } else { + this->m_pendingActivityCount++; context->postTask([this, protectedThis = Ref { *this }](ScriptExecutionContext& context) { ASSERT(scriptExecutionContext()); @@ -736,6 +744,7 @@ void WebSocket::didConnect() // m_extensions = m_channel->extensions(); protectedThis->dispatchEvent(Event::create(eventNames().openEvent, Event::CanBubble::No, Event::IsCancelable::No)); // }); + protectedThis->m_pendingActivityCount--; }); } } @@ -762,10 +771,11 @@ void WebSocket::didReceiveMessage(String&& message) } if (auto* context = scriptExecutionContext()) { - - context->postTask([this, message_ = message, protectedThis = Ref { *this }](ScriptExecutionContext& context) { + this->m_pendingActivityCount++; + context->postTask([this, message_ = WTFMove(message), protectedThis = Ref { *this }](ScriptExecutionContext& context) { ASSERT(scriptExecutionContext()); protectedThis->dispatchEvent(MessageEvent::create(message_, protectedThis->m_url.string())); + protectedThis->m_pendingActivityCount--; }); } @@ -797,9 +807,12 @@ void WebSocket::didReceiveBinaryData(Vector&& binaryData) } if (auto* context = scriptExecutionContext()) { - context->postTask([this, binaryData = binaryData, protectedThis = Ref { *this }](ScriptExecutionContext& context) { + auto arrayBuffer = JSC::ArrayBuffer::create(binaryData.data(), binaryData.size()); + this->m_pendingActivityCount++; + context->postTask([this, buffer = WTFMove(arrayBuffer), protectedThis = Ref { *this }](ScriptExecutionContext& context) { ASSERT(scriptExecutionContext()); - protectedThis->dispatchEvent(MessageEvent::create(ArrayBuffer::create(binaryData.data(), binaryData.size()), m_url.string())); + protectedThis->dispatchEvent(MessageEvent::create(buffer, m_url.string())); + protectedThis->m_pendingActivityCount--; }); } @@ -817,6 +830,8 @@ void WebSocket::didReceiveMessageError(WTF::StringImpl::StaticStringImpl* reason return; m_state = CLOSED; if (auto* context = scriptExecutionContext()) { + this->m_pendingActivityCount++; + context->postTask([this, reason, protectedThis = Ref { *this }](ScriptExecutionContext& context) { ASSERT(scriptExecutionContext()); // if (UNLIKELY(InspectorInstrumentation::hasFrontends())) { @@ -826,6 +841,7 @@ void WebSocket::didReceiveMessageError(WTF::StringImpl::StaticStringImpl* reason // FIXME: As per https://html.spec.whatwg.org/multipage/web-sockets.html#feedback-from-the-protocol:concept-websocket-closed, we should synchronously fire a close event. dispatchEvent(CloseEvent::create(false, 0, WTF::String(reason))); + protectedThis->m_pendingActivityCount--; }); } } @@ -874,9 +890,11 @@ void WebSocket::didClose(unsigned unhandledBufferedAmount, unsigned short code, this->m_upgradeClient = nullptr; if (auto* context = scriptExecutionContext()) { + this->m_pendingActivityCount++; context->postTask([this, code, wasClean, reason, protectedThis = Ref { *this }](ScriptExecutionContext& context) { ASSERT(scriptExecutionContext()); protectedThis->dispatchEvent(CloseEvent::create(wasClean, code, reason)); + protectedThis->m_pendingActivityCount++; }); } @@ -898,9 +916,11 @@ void WebSocket::dispatchErrorEventIfNeeded() m_dispatchedErrorEvent = true; if (auto* context = scriptExecutionContext()) { + this->m_pendingActivityCount++; context->postTask([this, protectedThis = Ref { *this }](ScriptExecutionContext& context) { ASSERT(scriptExecutionContext()); protectedThis->dispatchEvent(Event::create(eventNames().errorEvent, Event::CanBubble::No, Event::IsCancelable::No)); + protectedThis->m_pendingActivityCount--; }); } } @@ -910,11 +930,11 @@ void WebSocket::didConnect(us_socket_t* socket, char* bufferedData, size_t buffe this->m_upgradeClient = nullptr; if (m_isSecure) { us_socket_context_t* ctx = (us_socket_context_t*)this->scriptExecutionContext()->connectedWebSocketContext(); - this->m_connectedWebSocket.clientSSL = Bun__WebSocketClientTLS__init(this, socket, ctx, this->scriptExecutionContext()->jsGlobalObject()); + this->m_connectedWebSocket.clientSSL = Bun__WebSocketClientTLS__init(this, socket, ctx, this->scriptExecutionContext()->jsGlobalObject(), reinterpret_cast(bufferedData), bufferedDataSize); this->m_connectedWebSocketKind = ConnectedWebSocketKind::ClientSSL; } else { us_socket_context_t* ctx = (us_socket_context_t*)this->scriptExecutionContext()->connectedWebSocketContext(); - this->m_connectedWebSocket.client = Bun__WebSocketClient__init(this, socket, ctx, this->scriptExecutionContext()->jsGlobalObject()); + this->m_connectedWebSocket.client = Bun__WebSocketClient__init(this, socket, ctx, this->scriptExecutionContext()->jsGlobalObject(), reinterpret_cast(bufferedData), bufferedDataSize); this->m_connectedWebSocketKind = ConnectedWebSocketKind::Client; } @@ -922,9 +942,11 @@ void WebSocket::didConnect(us_socket_t* socket, char* bufferedData, size_t buffe } void WebSocket::didFailWithErrorCode(int32_t code) { + // from new WebSocket() -> connect() if (m_state == CLOSED) return; + this->m_pendingActivityCount = this->m_pendingActivityCount > 0 ? this->m_pendingActivityCount - 1 : 0; this->m_upgradeClient = nullptr; this->m_connectedWebSocketKind = ConnectedWebSocketKind::None; this->m_connectedWebSocket.client = nullptr; diff --git a/src/bun.js/bindings/webcore/WebSocket.h b/src/bun.js/bindings/webcore/WebSocket.h index 03c0d7709fce12..7fd4e24a7ee9de 100644 --- a/src/bun.js/bindings/webcore/WebSocket.h +++ b/src/bun.js/bindings/webcore/WebSocket.h @@ -103,6 +103,11 @@ class WebSocket final : public RefCounted, public EventTargetWithInli void didReceiveData(const char* data, size_t length); void didReceiveBinaryData(Vector&&); + bool hasPendingActivity() const + { + return m_state == State::OPEN || m_state == State::CLOSING || m_pendingActivityCount > 0; + } + private: typedef union AnyWebSocket { WebSocketClient* client; @@ -157,6 +162,7 @@ class WebSocket final : public RefCounted, public EventTargetWithInli bool m_isSecure { false }; AnyWebSocket m_connectedWebSocket { nullptr }; ConnectedWebSocketKind m_connectedWebSocketKind { ConnectedWebSocketKind::None }; + size_t m_pendingActivityCount { 0 }; bool m_dispatchedErrorEvent { false }; // RefPtr> m_pendingActivity; diff --git a/src/http/websocket_http_client.zig b/src/http/websocket_http_client.zig index 7b34dea45628d8..7e5bb26baf888f 100644 --- a/src/http/websocket_http_client.zig +++ b/src/http/websocket_http_client.zig @@ -295,10 +295,7 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type { } }; - var buffered_body_data = body[@minimum(@intCast(usize, response.bytes_read), body.len)..]; - buffered_body_data = buffered_body_data[0..@minimum(buffered_body_data.len, this.body_written)]; - - this.processResponse(response, buffered_body_data, overflow); + this.processResponse(response, available_to_read[@intCast(usize, response.bytes_read)..]); } pub fn handleEnd(this: *HTTPClient, socket: Socket) void { @@ -306,7 +303,7 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type { this.terminate(ErrorCode.ended); } - pub fn processResponse(this: *HTTPClient, response: PicoHTTP.Response, remain_buf: []const u8, overflow_buf: []const u8) void { + pub fn processResponse(this: *HTTPClient, response: PicoHTTP.Response, remain_buf: []const u8) void { std.debug.assert(this.body_written > 0); var upgrade_header = PicoHTTP.Header{ .name = "", .value = "" }; @@ -316,10 +313,6 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type { // var visited_version = false; std.debug.assert(response.status_code == 101); - if (remain_buf.len > 0) { - std.debug.assert(overflow_buf.len == 0); - } - for (response.headers) |header| { switch (header.name.len) { "Connection".len => { @@ -408,7 +401,7 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type { // TODO: check websocket_accept_header.value - const overflow_len = overflow_buf.len + remain_buf.len; + const overflow_len = remain_buf.len; var overflow: []u8 = &.{}; if (overflow_len > 0) { overflow = bun.default_allocator.alloc(u8, overflow_len) catch { @@ -416,7 +409,6 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type { return; }; if (remain_buf.len > 0) @memcpy(overflow.ptr, remain_buf.ptr, remain_buf.len); - if (overflow_buf.len > 0) @memcpy(overflow.ptr + remain_buf.len, overflow_buf.ptr, overflow_buf.len); } this.clearData(); @@ -1432,6 +1424,8 @@ pub fn NewWebSocketClient(comptime ssl: bool) type { input_socket: *anyopaque, socket_ctx: *anyopaque, globalThis: *JSC.JSGlobalObject, + buffered_data: [*]u8, + buffered_data_len: usize, ) callconv(.C) ?*anyopaque { var tcp = @ptrCast(*uws.Socket, input_socket); var ctx = @ptrCast(*uws.us_socket_context_t, socket_ctx); @@ -1453,6 +1447,32 @@ pub fn NewWebSocketClient(comptime ssl: bool) type { adopted.event_loop_ref = true; adopted.globalThis.bunVM().us_loop_reference_count +|= 1; _ = globalThis.bunVM().eventLoop().ready_tasks_count.fetchAdd(1, .Monotonic); + var buffered_slice: []u8 = buffered_data[0..buffered_data_len]; + if (buffered_slice.len > 0) { + const InitialDataHandler = struct { + adopted: *WebSocket, + slice: []u8, + + pub fn handle(this: *@This()) void { + defer { + bun.default_allocator.free(this.slice); + bun.default_allocator.destroy(this); + } + + this.adopted.receive_buffer.ensureUnusedCapacity(this.slice.len) catch return; + var writable = this.adopted.receive_buffer.writableSlice(0); + @memcpy(writable.ptr, this.slice.ptr, this.slice.len); + + this.adopted.handleData(this.adopted.tcp, writable); + } + }; + var initial_data = bun.default_allocator.create(InitialDataHandler) catch unreachable; + initial_data.* = .{ + .adopted = adopted, + .slice = buffered_slice, + }; + globalThis.bunVM().uws_event_loop.?.nextTick(*InitialDataHandler, initial_data, InitialDataHandler.handle); + } return @ptrCast( *anyopaque, adopted,