Skip to content

Commit

Permalink
Restrict requests based on origin
Browse files Browse the repository at this point in the history
  • Loading branch information
ericmj committed Mar 10, 2015
1 parent 8677699 commit 87562a7
Show file tree
Hide file tree
Showing 6 changed files with 107 additions and 11 deletions.
30 changes: 30 additions & 0 deletions lib/phoenix/channel/transport.ex
Original file line number Diff line number Diff line change
Expand Up @@ -206,4 +206,34 @@ defmodule Phoenix.Channel.Transport do
end
:ok
end

def origin_allowed?(nil, _) do
true
end

def origin_allowed?(_, nil) do
true
end

def origin_allowed?(origin, allowed_origins) do
origin = URI.parse(origin)

Enum.any?(allowed_origins, fn allowed ->
allowed = URI.parse(allowed)

success? = compare?(origin.scheme, allowed.scheme) and
compare?(origin.port, allowed.port)

# "example.com" parses into path so compare it instead of host
if allowed.host == nil do
success? and compare?(origin.host, allowed.path)
else
success? and compare?(origin.host, allowed.host)
end
end)
end

defp compare?(nil, _), do: true
defp compare?(_, nil), do: true
defp compare?(x, y), do: x == y
end
30 changes: 23 additions & 7 deletions lib/phoenix/transports/long_poller.ex
Original file line number Diff line number Diff line change
Expand Up @@ -5,27 +5,35 @@ defmodule Phoenix.Transports.LongPoller do

alias Phoenix.Socket.Message
alias Phoenix.Transports.LongPoller
alias Phoenix.Channel.Transport

plug :action

@doc """
Listens for `%Phoenix.Socket.Message{}`'s from `Phoenix.LongPoller.Server`.
As soon as messages are received, they are encoded as JSON and sent down
to the longpolling client, which immediately repolls. If a timeout occurrs,
to the longpolling client, which immediately repolls. If a timeout occurs,
a `:no_content` response is returned, and the client should immediately repoll.
"""
def poll(conn, _params) do
case resume_session(conn) do
{:ok, conn, priv_topic} -> listen(conn, priv_topic)
{:error, conn, :terminated} ->
{conn, priv_topic, sig, _server_pid} = start_session(conn)
{:ok, conn, priv_topic} ->
listen(conn, priv_topic)

conn
|> put_status(:gone)
|> json(%{token: priv_topic, sig: sig})
{:error, conn, :terminated} ->
endpoint = endpoint_module(conn)
allowed_origins = Dict.get(endpoint.config(:transports), :origins)
origin = Plug.Conn.get_req_header(conn, "origin") |> List.first

if Transport.origin_allowed?(origin, allowed_origins) do
new_session(conn)
else
Plug.Conn.send_resp(conn, :forbidden, "")
end
end
end

defp listen(conn, priv_topic) do
ref = :erlang.make_ref()
:ok = broadcast_from(conn, priv_topic, {:flush, ref})
Expand All @@ -43,6 +51,14 @@ defmodule Phoenix.Transports.LongPoller do
end
end

defp new_session(conn) do
{conn, priv_topic, sig, _server_pid} = start_session(conn)

conn
|> put_status(:gone)
|> json(%{token: priv_topic, sig: sig})
end

@doc """
Publishes a `%Phoenix.Socket.Message{}` to a channel.
Expand Down
11 changes: 11 additions & 0 deletions lib/phoenix/transports/websocket.ex
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,17 @@ defmodule Phoenix.Transports.WebSocket do
plug :upgrade

def upgrade(%Plug.Conn{method: "GET"} = conn, _) do
endpoint = endpoint_module(conn)
allowed_origins = Dict.get(endpoint.config(:transports), :origins)
origin = Plug.Conn.get_req_header(conn, "origin") |> List.first

if Transport.origin_allowed?(origin, allowed_origins) do
do_upgrade(conn)
else
Plug.Conn.send_resp(conn, :forbidden, "")
end
end
defp do_upgrade(conn) do
put_private(conn, :phoenix_upgrade, {:websocket, __MODULE__}) |> halt
end

Expand Down
2 changes: 1 addition & 1 deletion mix.lock
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@
"poolboy": {:hex, :poolboy, "1.4.2"},
"ranch": {:hex, :ranch, "1.0.0"},
"redo": {:git, "git://github.com/heroku/redo.git", "b25e7bd3b197564192069d609d1e8e02183b71a6", []},
"websocket_client": {:git, "git://github.com/jeremyong/websocket_client.git", "2b8d9805306d36f22330f432ae6472f1f2625c30", []}}
"websocket_client": {:git, "git://github.com/jeremyong/websocket_client.git", "48c118682292f2e4d80491161134a665525c4934", []}}
40 changes: 39 additions & 1 deletion test/phoenix/integration/channel_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@ defmodule Phoenix.Integration.ChannelTest do
http: [port: @port],
secret_key_base: String.duplicate("abcdefgh", 8),
debug_errors: false,
transports: [longpoller_window_ms: @window_ms, longpoller_pubsub_timeout_ms: @pubsub_window_ms],
transports: [
longpoller_window_ms: @window_ms,
longpoller_pubsub_timeout_ms: @pubsub_window_ms,
origins: ["//example.com", "http://scheme.com", "//port.com:81"]],
server: true,
pubsub: [adapter: Phoenix.PubSub.PG2, name: :int_pub]
])
Expand Down Expand Up @@ -103,6 +106,20 @@ defmodule Phoenix.Integration.ChannelTest do
refute_receive {:socket_reply, %Message{}}
end

test "websocket refuses unallowed origins" do
assert {:ok, _} = WebsocketClient.start_link(self, "ws://127.0.0.1:#{@port}/ws",
[{"origin", "https://example.com"}])
assert {:ok, _} = WebsocketClient.start_link(self, "ws://127.0.0.1:#{@port}/ws",
[{"origin", "http://port.com:81"}])

refute {:ok, _} = WebsocketClient.start_link(self, "ws://127.0.0.1:#{@port}/ws",
[{"origin", "http://notallowed.com"}])
refute {:ok, _} = WebsocketClient.start_link(self, "ws://127.0.0.1:#{@port}/ws",
[{"origin", "https://scheme.com"}])
refute {:ok, _} = WebsocketClient.start_link(self, "ws://127.0.0.1:#{@port}/ws",
[{"origin", "http://port.com:82"}])
end

## Longpoller Transport

@doc """
Expand Down Expand Up @@ -245,4 +262,25 @@ defmodule Phoenix.Integration.ChannelTest do
assert resp.status == 410
end
end

test "longpoller refuses unallowed origins" do
import Plug.Test

conn = conn(:get, "/ws/poll", [], headers: [{"origin", "https://example.com"}])
|> Endpoint.call([])
assert conn.status == 410
conn = conn(:get, "/ws/poll", [], headers: [{"origin", "http://port.com:81"}])
|> Endpoint.call([])
assert conn.status == 410

conn = conn(:get, "/ws/poll", [], headers: [{"origin", "http://notallowed.com"}])
|> Endpoint.call([])
refute conn.status == 410
conn = conn(:get, "/ws/poll", [], headers: [{"origin", "https://scheme.com"}])
|> Endpoint.call([])
refute conn.status == 410
conn = conn(:get, "/ws/poll", [], headers: [{"origin", "http://port.com:82"}])
|> Endpoint.call([])
refute conn.status == 410
end
end
5 changes: 3 additions & 2 deletions test/phoenix/integration/websocket_client.exs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@ defmodule Phoenix.Integration.WebsocketClient do
Starts the WebSocket server for given ws URL. Received Socket.Message's
are forwarded to the sender pid
"""
def start_link(sender, url) do
def start_link(sender, url, headers \\ []) do
:crypto.start
:ssl.start
:websocket_client.start_link(String.to_char_list(url), __MODULE__, [sender])
:websocket_client.start_link(String.to_char_list(url), __MODULE__, [sender],
extra_headers: headers)
end

def init([sender], _conn_state) do
Expand Down

0 comments on commit 87562a7

Please sign in to comment.