diff --git a/lib/pleroma/web/mastodon_api/websocket_handler.ex b/lib/pleroma/web/mastodon_api/websocket_handler.ex index 88444106d3..97a1f14012 100644 --- a/lib/pleroma/web/mastodon_api/websocket_handler.ex +++ b/lib/pleroma/web/mastodon_api/websocket_handler.ex @@ -32,8 +32,15 @@ def init(%{qs: qs} = req, state) do req end + topics = + if topic do + [topic] + else + [] + end + {:cowboy_websocket, req, - %{user: user, topic: topic, oauth_token: oauth_token, count: 0, timer: nil}, + %{user: user, topics: topics, oauth_token: oauth_token, count: 0, timer: nil}, %{idle_timeout: @timeout}} else {:error, :bad_topic} -> @@ -50,10 +57,10 @@ def init(%{qs: qs} = req, state) do def websocket_init(state) do Logger.debug( - "#{__MODULE__} accepted websocket connection for user #{(state.user || %{id: "anonymous"}).id}, topic #{state.topic}" + "#{__MODULE__} accepted websocket connection for user #{(state.user || %{id: "anonymous"}).id}, topics #{state.topics}" ) - Streamer.add_socket(state.topic, state.oauth_token) + Enum.each(state.topics, fn topic -> Streamer.add_socket(topic, state.oauth_token) end) {:ok, %{state | timer: timer()}} end @@ -109,10 +116,10 @@ def terminate(_reason, _req, []), do: :ok def terminate(reason, _req, state) do Logger.debug( - "#{__MODULE__} terminating websocket connection for user #{(state.user || %{id: "anonymous"}).id}, topic #{state.topic || "?"}: #{inspect(reason)}" + "#{__MODULE__} terminating websocket connection for user #{(state.user || %{id: "anonymous"}).id}, topics #{state.topics || "?"}: #{inspect(reason)}" ) - Streamer.remove_socket(state.topic) + Enum.each(state.topics, fn topic -> Streamer.remove_socket(topic) end) :ok end diff --git a/lib/pleroma/web/streamer.ex b/lib/pleroma/web/streamer.ex index b9a04cc767..6b5fdbf0c8 100644 --- a/lib/pleroma/web/streamer.ex +++ b/lib/pleroma/web/streamer.ex @@ -59,10 +59,14 @@ defp can_access_stream(user, oauth_token, kind) do end @doc "Expand and authorizes a stream" - @spec get_topic(stream :: String.t(), User.t() | nil, Token.t() | nil, Map.t()) :: - {:ok, topic :: String.t()} | {:error, :bad_topic} + @spec get_topic(stream :: String.t() | nil, User.t() | nil, Token.t() | nil, Map.t()) :: + {:ok, topic :: String.t() | nil} | {:error, :bad_topic} def get_topic(stream, user, oauth_token, params \\ %{}) + def get_topic(nil = _stream, _user, _oauth_token, _params) do + {:ok, nil} + end + # Allow all public steams if the instance allows unauthenticated access. # Otherwise, only allow users with valid oauth tokens. def get_topic(stream, user, oauth_token, _params) when stream in @public_streams do diff --git a/test/pleroma/integration/mastodon_websocket_test.exs b/test/pleroma/integration/mastodon_websocket_test.exs index 9be0445c0e..3b120c10e5 100644 --- a/test/pleroma/integration/mastodon_websocket_test.exs +++ b/test/pleroma/integration/mastodon_websocket_test.exs @@ -33,7 +33,6 @@ def start_socket(qs \\ nil, headers \\ []) do test "refuses invalid requests" do capture_log(fn -> - assert {:error, %WebSockex.RequestError{code: 404}} = start_socket() assert {:error, %WebSockex.RequestError{code: 404}} = start_socket("?stream=ncjdk") Process.sleep(30) end) @@ -49,6 +48,10 @@ test "requires authentication and a valid token for protected streams" do end) end + test "allows unified stream" do + assert {:ok, _} = start_socket() + end + test "allows public streams without authentication" do assert {:ok, _} = start_socket("?stream=public") assert {:ok, _} = start_socket("?stream=public:local") diff --git a/test/pleroma/web/streamer_test.exs b/test/pleroma/web/streamer_test.exs index 7ab0e379b4..da97b87d8f 100644 --- a/test/pleroma/web/streamer_test.exs +++ b/test/pleroma/web/streamer_test.exs @@ -22,6 +22,10 @@ defmodule Pleroma.Web.StreamerTest do setup do: clear_config([:instance, :skip_thread_containment]) describe "get_topic/_ (unauthenticated)" do + test "allows no stream" do + assert {:ok, nil} = Streamer.get_topic(nil, nil, nil) + end + test "allows public" do assert {:ok, "public"} = Streamer.get_topic("public", nil, nil) assert {:ok, "public:local"} = Streamer.get_topic("public:local", nil, nil)