diff --git a/lib/pleroma/web/mastodon_api/mastodon_api_controller.ex b/lib/pleroma/web/mastodon_api/mastodon_api_controller.ex index 89fd7629a8..bcc79b08a0 100644 --- a/lib/pleroma/web/mastodon_api/mastodon_api_controller.ex +++ b/lib/pleroma/web/mastodon_api/mastodon_api_controller.ex @@ -1091,9 +1091,7 @@ def list_timeline(%{assigns: %{user: user}} = conn, %{"list_id" => id} = params) end def index(%{assigns: %{user: user}} = conn, _params) do - token = - conn - |> get_session(:oauth_token) + token = get_session(conn, :oauth_token) if user && token do mastodon_emoji = mastodonized_emoji() @@ -1194,6 +1192,7 @@ def index(%{assigns: %{user: user}} = conn, _params) do |> render("index.html", %{initial_state: initial_state, flavour: flavour}) else conn + |> put_session(:return_to, conn.request_path) |> redirect(to: "/web/login") end end @@ -1278,12 +1277,20 @@ def login(conn, _) do scope: Enum.join(app.scopes, " ") ) - conn - |> redirect(to: path) + redirect(conn, to: path) end end - defp local_mastodon_root_path(conn), do: mastodon_api_path(conn, :index, ["getting-started"]) + defp local_mastodon_root_path(conn) do + case get_session(conn, :return_to) do + nil -> + mastodon_api_path(conn, :index, ["getting-started"]) + + return_to -> + delete_session(conn, :return_to) + return_to + end + end defp get_or_make_app do find_attrs = %{client_name: @local_mastodon_name, redirect_uris: "."} diff --git a/test/support/factory.ex b/test/support/factory.ex index e1a08315a2..b37bc2c075 100644 --- a/test/support/factory.ex +++ b/test/support/factory.ex @@ -240,6 +240,16 @@ def oauth_token_factory do } end + def oauth_authorization_factory do + %Pleroma.Web.OAuth.Authorization{ + token: :crypto.strong_rand_bytes(32) |> Base.url_encode64(padding: false), + scopes: ["read", "write", "follow", "push"], + valid_until: NaiveDateTime.add(NaiveDateTime.utc_now(), 60 * 10), + user: build(:user), + app: build(:oauth_app) + } + end + def push_subscription_factory do %Pleroma.Web.Push.Subscription{ user: build(:user), diff --git a/test/web/mastodon_api/mastodon_api_controller_test.exs b/test/web/mastodon_api/mastodon_api_controller_test.exs index 6060cc97fe..438e9507d4 100644 --- a/test/web/mastodon_api/mastodon_api_controller_test.exs +++ b/test/web/mastodon_api/mastodon_api_controller_test.exs @@ -2340,4 +2340,71 @@ test "accounts fetches correct account for nicknames beginning with numbers", %{ refute acc_one == acc_two assert acc_two == acc_three end + + describe "index/2 redirections" do + setup %{conn: conn} do + session_opts = [ + store: :cookie, + key: "_test", + signing_salt: "cooldude" + ] + + conn = + conn + |> Plug.Session.call(Plug.Session.init(session_opts)) + |> fetch_session() + + test_path = "/web/statuses/test" + %{conn: conn, path: test_path} + end + + test "redirects not logged-in users to the login page", %{conn: conn, path: path} do + conn = get(conn, path) + + assert conn.status == 302 + assert redirected_to(conn) == "/web/login" + end + + test "does not redirect logged in users to the login page", %{conn: conn, path: path} do + token = insert(:oauth_token) + + conn = + conn + |> assign(:user, token.user) + |> put_session(:oauth_token, token.token) + |> get(path) + + assert conn.status == 200 + end + + test "saves referer path to session", %{conn: conn, path: path} do + conn = get(conn, path) + return_to = Plug.Conn.get_session(conn, :return_to) + + assert return_to == path + end + + test "redirects to the saved path after log in", %{conn: conn, path: path} do + app = insert(:oauth_app, client_name: "Mastodon-Local", redirect_uris: ".") + auth = insert(:oauth_authorization, app: app) + + conn = + conn + |> put_session(:return_to, path) + |> get("/web/login", %{code: auth.token}) + + assert conn.status == 302 + assert redirected_to(conn) == path + end + + test "redirects to the getting-started page when referer is not present", %{conn: conn} do + app = insert(:oauth_app, client_name: "Mastodon-Local", redirect_uris: ".") + auth = insert(:oauth_authorization, app: app) + + conn = get(conn, "/web/login", %{code: auth.token}) + + assert conn.status == 302 + assert redirected_to(conn) == "/web/getting-started" + end + end end