Skip to content

Commit

Permalink
improve reliability of WebSocket
Browse files Browse the repository at this point in the history
- Fix GC not keeping WebSocket alive
- Fix ignoring messages sent immediately after upgrade

Fixes oven-sh#521
  • Loading branch information
Jarred-Sumner committed Aug 11, 2022
1 parent e511b14 commit f09e7ac
Show file tree
Hide file tree
Showing 8 changed files with 77 additions and 29 deletions.
2 changes: 1 addition & 1 deletion src/baby_list.zig
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ pub fn BabyList(comptime Type: type) type {
}

pub fn one(allocator: std.mem.Allocator, value: Type) !ListType {
var items = try allocator.alloc(Type, 1);
var items = try allocator.allocAdvanced(Type, @alignOf(Type), 1, .exact);
items[0] = value;
return ListType{
.ptr = @ptrCast([*]Type, items.ptr),
Expand Down
2 changes: 1 addition & 1 deletion src/bun.js/bindings/headers-cpp.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
//-- AUTOGENERATED FILE -- 1657408675
//-- AUTOGENERATED FILE -- 1660175100
// clang-format off
#pragma once

Expand Down
8 changes: 4 additions & 4 deletions src/bun.js/bindings/headers.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// clang-format off
//-- AUTOGENERATED FILE -- 1657408675
//-- AUTOGENERATED FILE -- 1660175100
#pragma once

#include <stddef.h>
Expand Down Expand Up @@ -393,6 +393,7 @@ CPP_DECL JSC__JSValue JSC__JSGlobalObject__generateHeapSnapshot(JSC__JSGlobalObj
CPP_DECL JSC__GeneratorFunctionPrototype* JSC__JSGlobalObject__generatorFunctionPrototype(JSC__JSGlobalObject* arg0);
CPP_DECL JSC__GeneratorPrototype* JSC__JSGlobalObject__generatorPrototype(JSC__JSGlobalObject* arg0);
CPP_DECL JSC__JSValue JSC__JSGlobalObject__getCachedObject(JSC__JSGlobalObject* arg0, const ZigString* arg1);
CPP_DECL void JSC__JSGlobalObject__handleRejectedPromises(JSC__JSGlobalObject* arg0);
CPP_DECL JSC__IteratorPrototype* JSC__JSGlobalObject__iteratorPrototype(JSC__JSGlobalObject* arg0);
CPP_DECL JSC__JSObject* JSC__JSGlobalObject__jsSetPrototype(JSC__JSGlobalObject* arg0);
CPP_DECL JSC__MapIteratorPrototype* JSC__JSGlobalObject__mapIteratorPrototype(JSC__JSGlobalObject* arg0);
Expand All @@ -407,7 +408,6 @@ CPP_DECL bool JSC__JSGlobalObject__startRemoteInspector(JSC__JSGlobalObject* arg
CPP_DECL JSC__StringPrototype* JSC__JSGlobalObject__stringPrototype(JSC__JSGlobalObject* arg0);
CPP_DECL JSC__JSObject* JSC__JSGlobalObject__symbolPrototype(JSC__JSGlobalObject* arg0);
CPP_DECL JSC__VM* JSC__JSGlobalObject__vm(JSC__JSGlobalObject* arg0);
CPP_DECL void JSC__JSGlobalObject__handleRejectedPromises(JSC__JSGlobalObject* arg0);

#pragma mark - WTF::URL

Expand Down Expand Up @@ -772,7 +772,7 @@ ZIG_DECL void Bun__WebSocketHTTPSClient__register(JSC__JSGlobalObject* arg0, voi

ZIG_DECL void Bun__WebSocketClient__close(WebSocketClient* arg0, uint16_t arg1, const ZigString* arg2);
ZIG_DECL void Bun__WebSocketClient__finalize(WebSocketClient* arg0);
ZIG_DECL void* Bun__WebSocketClient__init(void* arg0, void* arg1, void* arg2, JSC__JSGlobalObject* arg3);
ZIG_DECL void* Bun__WebSocketClient__init(void* arg0, void* arg1, void* arg2, JSC__JSGlobalObject* arg3, unsigned char* arg4, size_t arg5);
ZIG_DECL void Bun__WebSocketClient__register(JSC__JSGlobalObject* arg0, void* arg1, void* arg2);
ZIG_DECL void Bun__WebSocketClient__writeBinaryData(WebSocketClient* arg0, const unsigned char* arg1, size_t arg2);
ZIG_DECL void Bun__WebSocketClient__writeString(WebSocketClient* arg0, const ZigString* arg1);
Expand All @@ -783,7 +783,7 @@ ZIG_DECL void Bun__WebSocketClient__writeString(WebSocketClient* arg0, const Zig

ZIG_DECL void Bun__WebSocketClientTLS__close(WebSocketClientTLS* arg0, uint16_t arg1, const ZigString* arg2);
ZIG_DECL void Bun__WebSocketClientTLS__finalize(WebSocketClientTLS* arg0);
ZIG_DECL void* Bun__WebSocketClientTLS__init(void* arg0, void* arg1, void* arg2, JSC__JSGlobalObject* arg3);
ZIG_DECL void* Bun__WebSocketClientTLS__init(void* arg0, void* arg1, void* arg2, JSC__JSGlobalObject* arg3, unsigned char* arg4, size_t arg5);
ZIG_DECL void Bun__WebSocketClientTLS__register(JSC__JSGlobalObject* arg0, void* arg1, void* arg2);
ZIG_DECL void Bun__WebSocketClientTLS__writeBinaryData(WebSocketClientTLS* arg0, const unsigned char* arg1, size_t arg2);
ZIG_DECL void Bun__WebSocketClientTLS__writeString(WebSocketClientTLS* arg0, const ZigString* arg1);
Expand Down
2 changes: 1 addition & 1 deletion src/bun.js/bindings/headers.zig
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ pub extern fn JSC__JSGlobalObject__generateHeapSnapshot(arg0: ?*JSC__JSGlobalObj
pub extern fn JSC__JSGlobalObject__generatorFunctionPrototype(arg0: ?*JSC__JSGlobalObject) ?*bindings.GeneratorFunctionPrototype;
pub extern fn JSC__JSGlobalObject__generatorPrototype(arg0: ?*JSC__JSGlobalObject) ?*bindings.GeneratorPrototype;
pub extern fn JSC__JSGlobalObject__getCachedObject(arg0: ?*JSC__JSGlobalObject, arg1: [*c]const ZigString) JSC__JSValue;
pub extern fn JSC__JSGlobalObject__handleRejectedPromises(arg0: ?*JSC__JSGlobalObject) void;
pub extern fn JSC__JSGlobalObject__iteratorPrototype(arg0: ?*JSC__JSGlobalObject) ?*bindings.IteratorPrototype;
pub extern fn JSC__JSGlobalObject__jsSetPrototype(arg0: ?*JSC__JSGlobalObject) [*c]JSC__JSObject;
pub extern fn JSC__JSGlobalObject__mapIteratorPrototype(arg0: ?*JSC__JSGlobalObject) ?*bindings.MapIteratorPrototype;
Expand All @@ -212,7 +213,6 @@ pub extern fn JSC__JSGlobalObject__startRemoteInspector(arg0: ?*JSC__JSGlobalObj
pub extern fn JSC__JSGlobalObject__stringPrototype(arg0: ?*JSC__JSGlobalObject) ?*bindings.StringPrototype;
pub extern fn JSC__JSGlobalObject__symbolPrototype(arg0: ?*JSC__JSGlobalObject) [*c]JSC__JSObject;
pub extern fn JSC__JSGlobalObject__vm(arg0: ?*JSC__JSGlobalObject) [*c]JSC__VM;
pub extern fn JSC__JSGlobalObject__handleRejectedPromises(arg0: ?*JSC__JSGlobalObject) void;
pub extern fn WTF__URL__encodedPassword(arg0: [*c]WTF__URL) bWTF__StringView;
pub extern fn WTF__URL__encodedUser(arg0: [*c]WTF__URL) bWTF__StringView;
pub extern fn WTF__URL__fileSystemPath(arg0: [*c]WTF__URL) bWTF__String;
Expand Down
10 changes: 5 additions & 5 deletions src/bun.js/bindings/webcore/JSWebSocket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -633,11 +633,11 @@ bool JSWebSocketOwner::isReachableFromOpaqueRoots(JSC::Handle<JSC::Unknown> hand
{
auto* jsWebSocket = jsCast<JSWebSocket*>(handle.slot()->asCell());
auto& wrapped = jsWebSocket->wrapped();
// if (!wrapped.isContextStopped() && wrapped.hasPendingActivity()) {
// if (UNLIKELY(reason))
// *reason = "ActiveDOMObject with pending activity";
// return true;
// }
if (wrapped.hasPendingActivity()) {
if (UNLIKELY(reason))
*reason = "ActiveDOMObject with pending activity";
return true;
}
if (jsWebSocket->wrapped().isFiringEventListeners()) {
if (UNLIKELY(reason))
*reason = "EventTarget firing event listeners";
Expand Down
34 changes: 28 additions & 6 deletions src/bun.js/bindings/webcore/WebSocket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -371,16 +371,19 @@ ExceptionOr<void> WebSocket::connect(const String& url, const Vector<String>& pr
if (is_secure) {
us_socket_context_t* ctx = scriptExecutionContext()->webSocketContext<true>();
RELEASE_ASSERT(ctx);
this->m_pendingActivityCount++;
this->m_upgradeClient = Bun__WebSocketHTTPSClient__connect(scriptExecutionContext()->jsGlobalObject(), ctx, this, &host, port, &path, &clientProtocolString);
} else {
us_socket_context_t* ctx = scriptExecutionContext()->webSocketContext<false>();
RELEASE_ASSERT(ctx);
this->m_pendingActivityCount++;
this->m_upgradeClient = Bun__WebSocketHTTPClient__connect(scriptExecutionContext()->jsGlobalObject(), ctx, this, &host, port, &path, &clientProtocolString);
}

if (this->m_upgradeClient == nullptr) {
// context.addConsoleMessage(MessageSource::JS, MessageLevel::Error, );
m_state = CLOSED;
this->m_pendingActivityCount--;
return Exception { SyntaxError, "WebSocket connection failed"_s };
}

Expand Down Expand Up @@ -714,6 +717,9 @@ ScriptExecutionContext* WebSocket::scriptExecutionContext() const

void WebSocket::didConnect()
{
// from new WebSocket() -> connect()
this->m_pendingActivityCount--;

LOG(Network, "WebSocket %p didConnect()", this);
// queueTaskKeepingObjectAlive(*this, TaskSource::WebSocket, [this] {
if (m_state == CLOSED)
Expand All @@ -725,17 +731,20 @@ void WebSocket::didConnect()
m_state = OPEN;

if (auto* context = scriptExecutionContext()) {

if (this->hasEventListeners("open"_s)) {
// the main reason for dispatching on a separate tick is to handle when you haven't yet attached an event listener
dispatchEvent(Event::create(eventNames().openEvent, Event::CanBubble::No, Event::IsCancelable::No));
} else {
this->m_pendingActivityCount++;
context->postTask([this, protectedThis = Ref { *this }](ScriptExecutionContext& context) {
ASSERT(scriptExecutionContext());

// m_subprotocol = m_channel->subprotocol();
// m_extensions = m_channel->extensions();
protectedThis->dispatchEvent(Event::create(eventNames().openEvent, Event::CanBubble::No, Event::IsCancelable::No));
// });
protectedThis->m_pendingActivityCount--;
});
}
}
Expand All @@ -762,10 +771,11 @@ void WebSocket::didReceiveMessage(String&& message)
}

if (auto* context = scriptExecutionContext()) {

context->postTask([this, message_ = message, protectedThis = Ref { *this }](ScriptExecutionContext& context) {
this->m_pendingActivityCount++;
context->postTask([this, message_ = WTFMove(message), protectedThis = Ref { *this }](ScriptExecutionContext& context) {
ASSERT(scriptExecutionContext());
protectedThis->dispatchEvent(MessageEvent::create(message_, protectedThis->m_url.string()));
protectedThis->m_pendingActivityCount--;
});
}

Expand Down Expand Up @@ -797,9 +807,12 @@ void WebSocket::didReceiveBinaryData(Vector<uint8_t>&& binaryData)
}

if (auto* context = scriptExecutionContext()) {
context->postTask([this, binaryData = binaryData, protectedThis = Ref { *this }](ScriptExecutionContext& context) {
auto arrayBuffer = JSC::ArrayBuffer::create(binaryData.data(), binaryData.size());
this->m_pendingActivityCount++;
context->postTask([this, buffer = WTFMove(arrayBuffer), protectedThis = Ref { *this }](ScriptExecutionContext& context) {
ASSERT(scriptExecutionContext());
protectedThis->dispatchEvent(MessageEvent::create(ArrayBuffer::create(binaryData.data(), binaryData.size()), m_url.string()));
protectedThis->dispatchEvent(MessageEvent::create(buffer, m_url.string()));
protectedThis->m_pendingActivityCount--;
});
}

Expand All @@ -817,6 +830,8 @@ void WebSocket::didReceiveMessageError(WTF::StringImpl::StaticStringImpl* reason
return;
m_state = CLOSED;
if (auto* context = scriptExecutionContext()) {
this->m_pendingActivityCount++;

context->postTask([this, reason, protectedThis = Ref { *this }](ScriptExecutionContext& context) {
ASSERT(scriptExecutionContext());
// if (UNLIKELY(InspectorInstrumentation::hasFrontends())) {
Expand All @@ -826,6 +841,7 @@ void WebSocket::didReceiveMessageError(WTF::StringImpl::StaticStringImpl* reason

// FIXME: As per https://html.spec.whatwg.org/multipage/web-sockets.html#feedback-from-the-protocol:concept-websocket-closed, we should synchronously fire a close event.
dispatchEvent(CloseEvent::create(false, 0, WTF::String(reason)));
protectedThis->m_pendingActivityCount--;
});
}
}
Expand Down Expand Up @@ -874,9 +890,11 @@ void WebSocket::didClose(unsigned unhandledBufferedAmount, unsigned short code,
this->m_upgradeClient = nullptr;

if (auto* context = scriptExecutionContext()) {
this->m_pendingActivityCount++;
context->postTask([this, code, wasClean, reason, protectedThis = Ref { *this }](ScriptExecutionContext& context) {
ASSERT(scriptExecutionContext());
protectedThis->dispatchEvent(CloseEvent::create(wasClean, code, reason));
protectedThis->m_pendingActivityCount++;
});
}

Expand All @@ -898,9 +916,11 @@ void WebSocket::dispatchErrorEventIfNeeded()
m_dispatchedErrorEvent = true;

if (auto* context = scriptExecutionContext()) {
this->m_pendingActivityCount++;
context->postTask([this, protectedThis = Ref { *this }](ScriptExecutionContext& context) {
ASSERT(scriptExecutionContext());
protectedThis->dispatchEvent(Event::create(eventNames().errorEvent, Event::CanBubble::No, Event::IsCancelable::No));
protectedThis->m_pendingActivityCount--;
});
}
}
Expand All @@ -910,21 +930,23 @@ void WebSocket::didConnect(us_socket_t* socket, char* bufferedData, size_t buffe
this->m_upgradeClient = nullptr;
if (m_isSecure) {
us_socket_context_t* ctx = (us_socket_context_t*)this->scriptExecutionContext()->connectedWebSocketContext<true, false>();
this->m_connectedWebSocket.clientSSL = Bun__WebSocketClientTLS__init(this, socket, ctx, this->scriptExecutionContext()->jsGlobalObject());
this->m_connectedWebSocket.clientSSL = Bun__WebSocketClientTLS__init(this, socket, ctx, this->scriptExecutionContext()->jsGlobalObject(), reinterpret_cast<unsigned char*>(bufferedData), bufferedDataSize);
this->m_connectedWebSocketKind = ConnectedWebSocketKind::ClientSSL;
} else {
us_socket_context_t* ctx = (us_socket_context_t*)this->scriptExecutionContext()->connectedWebSocketContext<false, false>();
this->m_connectedWebSocket.client = Bun__WebSocketClient__init(this, socket, ctx, this->scriptExecutionContext()->jsGlobalObject());
this->m_connectedWebSocket.client = Bun__WebSocketClient__init(this, socket, ctx, this->scriptExecutionContext()->jsGlobalObject(), reinterpret_cast<unsigned char*>(bufferedData), bufferedDataSize);
this->m_connectedWebSocketKind = ConnectedWebSocketKind::Client;
}

this->didConnect();
}
void WebSocket::didFailWithErrorCode(int32_t code)
{
// from new WebSocket() -> connect()
if (m_state == CLOSED)
return;

this->m_pendingActivityCount = this->m_pendingActivityCount > 0 ? this->m_pendingActivityCount - 1 : 0;
this->m_upgradeClient = nullptr;
this->m_connectedWebSocketKind = ConnectedWebSocketKind::None;
this->m_connectedWebSocket.client = nullptr;
Expand Down
6 changes: 6 additions & 0 deletions src/bun.js/bindings/webcore/WebSocket.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,11 @@ class WebSocket final : public RefCounted<WebSocket>, public EventTargetWithInli
void didReceiveData(const char* data, size_t length);
void didReceiveBinaryData(Vector<uint8_t>&&);

bool hasPendingActivity() const
{
return m_state == State::OPEN || m_state == State::CLOSING || m_pendingActivityCount > 0;
}

private:
typedef union AnyWebSocket {
WebSocketClient* client;
Expand Down Expand Up @@ -157,6 +162,7 @@ class WebSocket final : public RefCounted<WebSocket>, public EventTargetWithInli
bool m_isSecure { false };
AnyWebSocket m_connectedWebSocket { nullptr };
ConnectedWebSocketKind m_connectedWebSocketKind { ConnectedWebSocketKind::None };
size_t m_pendingActivityCount { 0 };

bool m_dispatchedErrorEvent { false };
// RefPtr<PendingActivity<WebSocket>> m_pendingActivity;
Expand Down
42 changes: 31 additions & 11 deletions src/http/websocket_http_client.zig
Original file line number Diff line number Diff line change
Expand Up @@ -295,18 +295,15 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type {
}
};

var buffered_body_data = body[@minimum(@intCast(usize, response.bytes_read), body.len)..];
buffered_body_data = buffered_body_data[0..@minimum(buffered_body_data.len, this.body_written)];

this.processResponse(response, buffered_body_data, overflow);
this.processResponse(response, available_to_read[@intCast(usize, response.bytes_read)..]);
}

pub fn handleEnd(this: *HTTPClient, socket: Socket) void {
std.debug.assert(socket.socket == this.tcp.socket);
this.terminate(ErrorCode.ended);
}

pub fn processResponse(this: *HTTPClient, response: PicoHTTP.Response, remain_buf: []const u8, overflow_buf: []const u8) void {
pub fn processResponse(this: *HTTPClient, response: PicoHTTP.Response, remain_buf: []const u8) void {
std.debug.assert(this.body_written > 0);

var upgrade_header = PicoHTTP.Header{ .name = "", .value = "" };
Expand All @@ -316,10 +313,6 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type {
// var visited_version = false;
std.debug.assert(response.status_code == 101);

if (remain_buf.len > 0) {
std.debug.assert(overflow_buf.len == 0);
}

for (response.headers) |header| {
switch (header.name.len) {
"Connection".len => {
Expand Down Expand Up @@ -408,15 +401,14 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type {

// TODO: check websocket_accept_header.value

const overflow_len = overflow_buf.len + remain_buf.len;
const overflow_len = remain_buf.len;
var overflow: []u8 = &.{};
if (overflow_len > 0) {
overflow = bun.default_allocator.alloc(u8, overflow_len) catch {
this.terminate(ErrorCode.invalid_response);
return;
};
if (remain_buf.len > 0) @memcpy(overflow.ptr, remain_buf.ptr, remain_buf.len);
if (overflow_buf.len > 0) @memcpy(overflow.ptr + remain_buf.len, overflow_buf.ptr, overflow_buf.len);
}

this.clearData();
Expand Down Expand Up @@ -1432,6 +1424,8 @@ pub fn NewWebSocketClient(comptime ssl: bool) type {
input_socket: *anyopaque,
socket_ctx: *anyopaque,
globalThis: *JSC.JSGlobalObject,
buffered_data: [*]u8,
buffered_data_len: usize,
) callconv(.C) ?*anyopaque {
var tcp = @ptrCast(*uws.Socket, input_socket);
var ctx = @ptrCast(*uws.us_socket_context_t, socket_ctx);
Expand All @@ -1453,6 +1447,32 @@ pub fn NewWebSocketClient(comptime ssl: bool) type {
adopted.event_loop_ref = true;
adopted.globalThis.bunVM().us_loop_reference_count +|= 1;
_ = globalThis.bunVM().eventLoop().ready_tasks_count.fetchAdd(1, .Monotonic);
var buffered_slice: []u8 = buffered_data[0..buffered_data_len];
if (buffered_slice.len > 0) {
const InitialDataHandler = struct {
adopted: *WebSocket,
slice: []u8,

pub fn handle(this: *@This()) void {
defer {
bun.default_allocator.free(this.slice);
bun.default_allocator.destroy(this);
}

this.adopted.receive_buffer.ensureUnusedCapacity(this.slice.len) catch return;
var writable = this.adopted.receive_buffer.writableSlice(0);
@memcpy(writable.ptr, this.slice.ptr, this.slice.len);

this.adopted.handleData(this.adopted.tcp, writable);
}
};
var initial_data = bun.default_allocator.create(InitialDataHandler) catch unreachable;
initial_data.* = .{
.adopted = adopted,
.slice = buffered_slice,
};
globalThis.bunVM().uws_event_loop.?.nextTick(*InitialDataHandler, initial_data, InitialDataHandler.handle);
}
return @ptrCast(
*anyopaque,
adopted,
Expand Down

0 comments on commit f09e7ac

Please sign in to comment.