diff --git a/mix.exs b/mix.exs index 6e84fe4825..81c4cf9aeb 100644 --- a/mix.exs +++ b/mix.exs @@ -211,7 +211,8 @@ defp deps do {:excoveralls, "0.12.3", only: :test}, {:hackney, "~> 1.18.0", override: true}, {:mox, "~> 1.0", only: :test}, - {:websocket_client, git: "https://github.com/jeremyong/websocket_client.git", only: :test} + {:mint, "~> 1.4", only: :test, override: true}, + {:mint_web_socket, "~> 0.3.0", only: :test} ] ++ oauth_deps() end diff --git a/mix.lock b/mix.lock index 14e43c7034..bd550041d8 100644 --- a/mix.lock +++ b/mix.lock @@ -79,6 +79,7 @@ "mime": {:hex, :mime, "1.6.0", "dabde576a497cef4bbdd60aceee8160e02a6c89250d6c0b29e56c0dfb00db3d2", [:mix], [], "hexpm", "31a1a8613f8321143dde1dafc36006a17d28d02bdfecb9e95a880fa7aabd19a7"}, "mimerl": {:hex, :mimerl, "1.2.0", "67e2d3f571088d5cfd3e550c383094b47159f3eee8ffa08e64106cdf5e981be3", [:rebar3], [], "hexpm", "f278585650aa581986264638ebf698f8bb19df297f66ad91b18910dfc6e19323"}, "mint": {:hex, :mint, "1.4.0", "cd7d2451b201fc8e4a8fd86257fb3878d9e3752899eb67b0c5b25b180bde1212", [:mix], [{:castore, "~> 0.1.0", [hex: :castore, repo: "hexpm", optional: true]}], "hexpm", "10a99e144b815cbf8522dccbc8199d15802440fc7a64d67b6853adb6fa170217"}, + "mint_web_socket": {:hex, :mint_web_socket, "0.3.0", "c9e130dcc778d673fd713eb66434e16cf7d89cee0754e75f26f8bd9a9e592b63", [:mix], [{:mint, "~> 1.4 and >= 1.4.1", [hex: :mint, repo: "hexpm", optional: false]}], "hexpm", "0605bc3fa684e1a7719b22a3f74be4de5e6a16dd43ac18ebcea72e2adc33b532"}, "mochiweb": {:hex, :mochiweb, "2.18.0", "eb55f1db3e6e960fac4e6db4e2db9ec3602cc9f30b86cd1481d56545c3145d2e", [:rebar3], [], "hexpm"}, "mock": {:hex, :mock, "0.3.7", "75b3bbf1466d7e486ea2052a73c6e062c6256fb429d6797999ab02fa32f29e03", [:mix], [{:meck, "~> 0.9.2", [hex: :meck, repo: "hexpm", optional: false]}], "hexpm", "4da49a4609e41fd99b7836945c26f373623ea968cfb6282742bcb94440cf7e5c"}, "mogrify": {:hex, :mogrify, "0.9.1", "a26f107c4987477769f272bd0f7e3ac4b7b75b11ba597fd001b877beffa9c068", [:mix], [], "hexpm", "134edf189337d2125c0948bf0c228fdeef975c594317452d536224069a5b7f05"}, diff --git a/test/pleroma/integration/mastodon_websocket_test.exs b/test/pleroma/integration/mastodon_websocket_test.exs index 2d4c7f63b2..16525c7400 100644 --- a/test/pleroma/integration/mastodon_websocket_test.exs +++ b/test/pleroma/integration/mastodon_websocket_test.exs @@ -28,21 +28,28 @@ def start_socket(qs \\ nil, headers \\ []) do qs -> @path <> qs end - WebsocketClient.start_link(self(), path, headers) + WebsocketClient.connect(self(), path, headers) end test "refuses invalid requests" do capture_log(fn -> - assert {:error, {404, _}} = start_socket() - assert {:error, {404, _}} = start_socket("?stream=ncjdk") + assert {:error, %Mint.WebSocket.UpgradeFailureError{status_code: 404}} = start_socket() + + assert {:error, %Mint.WebSocket.UpgradeFailureError{status_code: 404}} = + start_socket("?stream=ncjdk") + Process.sleep(30) end) end test "requires authentication and a valid token for protected streams" do capture_log(fn -> - assert {:error, {401, _}} = start_socket("?stream=user&access_token=aaaaaaaaaaaa") - assert {:error, {401, _}} = start_socket("?stream=user") + assert {:error, %Mint.WebSocket.UpgradeFailureError{status_code: 401}} = + start_socket("?stream=user&access_token=aaaaaaaaaaaa") + + assert {:error, %Mint.WebSocket.UpgradeFailureError{status_code: 401}} = + start_socket("?stream=user") + Process.sleep(30) end) end @@ -102,7 +109,9 @@ test "accepts the 'user' stream", %{token: token} = _state do assert {:ok, _} = start_socket("?stream=user&access_token=#{token.token}") capture_log(fn -> - assert {:error, {401, _}} = start_socket("?stream=user") + assert {:error, %Mint.WebSocket.UpgradeFailureError{status_code: 401}} = + start_socket("?stream=user") + Process.sleep(30) end) end @@ -111,7 +120,9 @@ test "accepts the 'user:notification' stream", %{token: token} = _state do assert {:ok, _} = start_socket("?stream=user:notification&access_token=#{token.token}") capture_log(fn -> - assert {:error, {401, _}} = start_socket("?stream=user:notification") + assert {:error, %Mint.WebSocket.UpgradeFailureError{status_code: 401}} = + start_socket("?stream=user:notification") + Process.sleep(30) end) end @@ -120,7 +131,7 @@ test "accepts valid token on Sec-WebSocket-Protocol header", %{token: token} do assert {:ok, _} = start_socket("?stream=user", [{"Sec-WebSocket-Protocol", token.token}]) capture_log(fn -> - assert {:error, {401, _}} = + assert {:error, %Mint.WebSocket.UpgradeFailureError{status_code: 401}} = start_socket("?stream=user", [{"Sec-WebSocket-Protocol", "I am a friend"}]) Process.sleep(30) diff --git a/test/support/websocket_client.ex b/test/support/websocket_client.ex index d149b324ec..43f2854de4 100644 --- a/test/support/websocket_client.ex +++ b/test/support/websocket_client.ex @@ -3,60 +3,199 @@ # SPDX-License-Identifier: AGPL-3.0-only defmodule Pleroma.Integration.WebsocketClient do - # https://github.com/phoenixframework/phoenix/blob/master/test/support/websocket_client.exs + @moduledoc """ + A WebSocket client used to test Mastodon API streaming + + Based on Phoenix Framework's WebsocketClient + https://github.com/phoenixframework/phoenix/blob/master/test/support/websocket_client.exs + """ + + use GenServer + import Kernel, except: [send: 2] + + defstruct [ + :conn, + :request_ref, + :websocket, + :caller, + :status, + :resp_headers, + :sender, + closing?: false + ] @doc """ - Starts the WebSocket server for given ws URL. Received Socket.Message's - are forwarded to the sender pid + Starts the WebSocket client for given ws URL. `Phoenix.Socket.Message`s + received from the server are forwarded to the sender pid. """ - def start_link(sender, url, headers \\ []) do - :crypto.start() - :ssl.start() - - :websocket_client.start_link( - String.to_charlist(url), - __MODULE__, - [sender], - extra_headers: headers - ) + def connect(sender, url, headers \\ []) do + with {:ok, socket} <- GenServer.start_link(__MODULE__, {sender}), + {:ok, :connected} <- GenServer.call(socket, {:connect, url, headers}) do + {:ok, socket} + end end @doc """ Closes the socket """ def close(socket) do - send(socket, :close) + GenServer.cast(socket, :close) end @doc """ Sends a low-level text message to the client. """ def send_text(server_pid, msg) do - send(server_pid, {:text, msg}) + GenServer.call(server_pid, {:text, msg}) end @doc false - def init([sender], _conn_state) do - {:ok, %{sender: sender}} - end + def init({sender}) do + state = %__MODULE__{sender: sender} - @doc false - def websocket_handle(frame, _conn_state, state) do - send(state.sender, frame) {:ok, state} end @doc false - def websocket_info({:text, msg}, _conn_state, state) do - {:reply, {:text, msg}, state} - end + def handle_call({:connect, url, headers}, from, state) do + uri = URI.parse(url) - def websocket_info(:close, _conn_state, _state) do - {:close, <<>>, "done"} + http_scheme = + case uri.scheme do + "ws" -> :http + "wss" -> :https + end + + ws_scheme = + case uri.scheme do + "ws" -> :ws + "wss" -> :wss + end + + path = + case uri.query do + nil -> uri.path + query -> uri.path <> "?" <> query + end + + with {:ok, conn} <- Mint.HTTP.connect(http_scheme, uri.host, uri.port), + {:ok, conn, ref} <- Mint.WebSocket.upgrade(ws_scheme, conn, path, headers) do + state = %{state | conn: conn, request_ref: ref, caller: from} + {:noreply, state} + else + {:error, reason} -> + {:reply, {:error, reason}, state} + + {:error, conn, reason} -> + {:reply, {:error, reason}, put_in(state.conn, conn)} + end end @doc false - def websocket_terminate(_reason, _conn_state, _state) do - :ok + def handle_info(message, state) do + case Mint.WebSocket.stream(state.conn, message) do + {:ok, conn, responses} -> + state = put_in(state.conn, conn) |> handle_responses(responses) + if state.closing?, do: do_close(state), else: {:noreply, state} + + {:error, conn, reason, _responses} -> + state = put_in(state.conn, conn) |> reply({:error, reason}) + {:noreply, state} + + :unknown -> + {:noreply, state} + end + end + + defp do_close(state) do + # Streaming a close frame may fail if the server has already closed + # for writing. + _ = stream_frame(state, :close) + Mint.HTTP.close(state.conn) + {:stop, :normal, state} + end + + defp handle_responses(state, responses) + + defp handle_responses(%{request_ref: ref} = state, [{:status, ref, status} | rest]) do + put_in(state.status, status) + |> handle_responses(rest) + end + + defp handle_responses(%{request_ref: ref} = state, [{:headers, ref, resp_headers} | rest]) do + put_in(state.resp_headers, resp_headers) + |> handle_responses(rest) + end + + defp handle_responses(%{request_ref: ref} = state, [{:done, ref} | rest]) do + case Mint.WebSocket.new(state.conn, ref, state.status, state.resp_headers) do + {:ok, conn, websocket} -> + %{state | conn: conn, websocket: websocket, status: nil, resp_headers: nil} + |> reply({:ok, :connected}) + |> handle_responses(rest) + + {:error, conn, reason} -> + put_in(state.conn, conn) + |> reply({:error, reason}) + end + end + + defp handle_responses(%{request_ref: ref, websocket: websocket} = state, [ + {:data, ref, data} | rest + ]) + when websocket != nil do + case Mint.WebSocket.decode(websocket, data) do + {:ok, websocket, frames} -> + put_in(state.websocket, websocket) + |> handle_frames(frames) + |> handle_responses(rest) + + {:error, websocket, reason} -> + put_in(state.websocket, websocket) + |> reply({:error, reason}) + end + end + + defp handle_responses(state, [_response | rest]) do + handle_responses(state, rest) + end + + defp handle_responses(state, []), do: state + + defp handle_frames(state, frames) do + {frames, state} = + Enum.flat_map_reduce(frames, state, fn + # prepare to close the connection when a close frame is received + {:close, _code, _data}, state -> + {[], put_in(state.closing?, true)} + + frame, state -> + {[frame], state} + end) + + Enum.each(frames, &Kernel.send(state.sender, &1)) + + state + end + + defp reply(state, response) do + if state.caller, do: GenServer.reply(state.caller, response) + put_in(state.caller, nil) + end + + # Encodes a frame as a binary and sends it along the wire, keeping `conn` + # and `websocket` up to date in `state`. + defp stream_frame(state, frame) do + with {:ok, websocket, data} <- Mint.WebSocket.encode(state.websocket, frame), + state = put_in(state.websocket, websocket), + {:ok, conn} <- Mint.WebSocket.stream_request_body(state.conn, state.request_ref, data) do + {:ok, put_in(state.conn, conn)} + else + {:error, %Mint.WebSocket{} = websocket, reason} -> + {:error, put_in(state.websocket, websocket), reason} + + {:error, conn, reason} -> + {:error, put_in(state.conn, conn), reason} + end end end