diff --git a/CHANGELOG.md b/CHANGELOG.md index 310f80f..0c66dcb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,8 @@ All public functions in `TestServer` have been moved to `TestServer.HTTP`. The HTTP server adapters in `TestServer.HTTPServer.*` have been moved to `TestServer.HTTP.Server.*`. +TestServer now has SSH support with `TestServer.SSH`. + - Fixed bug where `:match` functions that raised errors always matched in `TestServer.HTTP.add/2` and `TestServer.HTTP.websocket_handle/3` - Fixed UTF-8 response body handling for `TestServer.HTTP.Server.Httpd` - Fixed invalid host header port parsing in `TestServer.HTTP.Server.Httpd` diff --git a/README.md b/README.md index c6fc4bf..21d17b2 100644 --- a/README.md +++ b/README.md @@ -12,12 +12,14 @@ Features: - HTTP/1 - HTTP/2 - WebSocket +- SSH - Built-in TLS with self-signed certificates - Plug route matching ## Protocols - [`TestServer.HTTP`](lib/test_server/http/README.md) - HTTP/1, HTTP/2, and WebSocket. +- [`TestServer.SSH`](lib/test_server/ssh/README.md) - SSH exec and shell. diff --git a/lib/test_server.ex b/lib/test_server.ex index 17956db..44c67ef 100644 --- a/lib/test_server.ex +++ b/lib/test_server.ex @@ -103,6 +103,27 @@ defmodule TestServer do end end + @doc false + def open_port(options) do + {port, options} = + case Keyword.get(options, :port, 0) do + {port, options} -> {port, options} + port -> {port, []} + end + + unless is_integer(port) and port >= 0 and port <= 65_535, + do: raise("Invalid port, got: #{inspect(port)}") + + with {:ok, socket} <- :gen_tcp.listen(port, options), + {:ok, port} <- :inet.port(socket), + true <- :erlang.port_close(socket) do + port + else + {:error, error} -> + raise("Could not listen to port #{inspect(port)}, because: #{inspect(error)}") + end + end + @doc false def fetch_instance!(protocol_module) do instance_module = Module.concat(protocol_module, Instance) @@ -112,4 +133,24 @@ defmodule TestServer do {:ok, instance} -> instance end end + + @doc false + def get_host(options) do + host = Keyword.get(options, :host) + + unless is_nil(host) or is_binary(host), + do: raise("Invalid host, got: #{inspect(host)}") + + case host do + nil -> + "localhost" + + host -> + :inet_db.set_lookup([:file, :dns]) + :inet_db.add_host({127, 0, 0, 1}, [String.to_charlist(host)]) + :inet_db.add_host({0, 0, 0, 0, 0, 0, 0, 1}, [String.to_charlist(host)]) + + host + end + end end diff --git a/lib/test_server/http.ex b/lib/test_server/http.ex index e2013a0..0a98028 100644 --- a/lib/test_server/http.ex +++ b/lib/test_server/http.ex @@ -171,25 +171,12 @@ defmodule TestServer.HTTP do def url(instance, uri, options) do TestServer.ensure_instance_alive!(__MODULE__, instance) - unless is_nil(options[:host]) or is_binary(options[:host]), - do: raise("Invalid host, got: #{inspect(options[:host])}") - - domain = maybe_enable_host(options[:host]) + domain = TestServer.get_host(options) options = Instance.get_options(instance) "#{Keyword.fetch!(options, :scheme)}://#{domain}:#{Keyword.fetch!(options, :port)}#{uri}" end - defp maybe_enable_host(nil), do: "localhost" - - defp maybe_enable_host(host) do - :inet_db.set_lookup([:file, :dns]) - :inet_db.add_host({127, 0, 0, 1}, [String.to_charlist(host)]) - :inet_db.add_host({0, 0, 0, 0, 0, 0, 0, 1}, [String.to_charlist(host)]) - - host - end - @spec add(binary()) :: :ok def add(uri), do: add(uri, []) diff --git a/lib/test_server/http/server.ex b/lib/test_server/http/server.ex index 383cd2e..d1aeb4d 100644 --- a/lib/test_server/http/server.ex +++ b/lib/test_server/http/server.ex @@ -41,7 +41,7 @@ defmodule TestServer.HTTP.Server do @doc false @spec start(pid(), keyword()) :: {:ok, keyword()} | {:error, any()} def start(instance, options) do - port = open_port(options) + port = TestServer.open_port(options) scheme = parse_scheme(options) {tls_options, x509_options} = maybe_generate_x509_suite(options, scheme) ip_family = Keyword.get(options, :ipfamily, :inet) @@ -65,26 +65,6 @@ defmodule TestServer.HTTP.Server do end end - defp open_port(options) do - {port, options} = - case Keyword.get(options, :port, 0) do - {port, options} -> {port, options} - port -> {port, []} - end - - unless is_integer(port) and port >= 0 and port <= 65_535, - do: raise("Invalid port, got: #{inspect(port)}") - - with {:ok, socket} <- :gen_tcp.listen(port, options), - {:ok, port} <- :inet.port(socket), - true <- :erlang.port_close(socket) do - port - else - {:error, error} -> - raise("Could not listen to port #{inspect(port)}, because: #{inspect(error)}") - end - end - defp parse_scheme(options) do scheme = Keyword.get(options, :scheme, :http) diff --git a/lib/test_server/ssh.ex b/lib/test_server/ssh.ex new file mode 100644 index 0000000..99270e4 --- /dev/null +++ b/lib/test_server/ssh.ex @@ -0,0 +1,380 @@ +defmodule TestServer.SSH do + @external_resource "lib/test_server/ssh/README.md" + @moduledoc "lib/test_server/ssh/README.md" + |> File.read!() + |> String.split("") + |> Enum.fetch!(1) + + alias TestServer.SSH.{Instance, Server} + + @type channel :: {pid(), channel_ref()} + @type channel_ref :: reference() + @type channel_id :: :ssh.channel_id() + @type state :: term() + @type channel_msg :: :ssh_connection.channel_msg() + @type connection :: :ssh.connection_ref() + + @type handler_fun :: (channel_msg(), state() -> + {:reply, iodata(), state()} + | {:reply, {iodata(), keyword()}, state()} + | {:ok, state()}) + + @type raw_handler_fun :: (channel_msg(), connection(), state() -> + {:ok, state()} + | {:stop, channel_id(), state()}) + + @type match_fun :: (channel_msg(), state() -> boolean()) + + @type host_key :: %{ + key: + :public_key.rsa_private_key() + | :public_key.eddsa_private_key() + | :public_key.ecdsa_private_key(), + algorithms: [:ssh.pubkey_alg()], + fingerprint: binary() + } + + @doc """ + Start a test server SSH instance. + + The instance will be terminated when the test case finishes. + + ## Options + + * `:port` - integer of port number, defaults to random port + that can be opened; + * `:host_keys` - list of host keys, or a + `c::ssh_server_key_api.host_key/2` function. Default will autogenerate + keys for algorithms specified in `t::ssh.pubkey_alg/0` and they can be + fetched from `host_keys/1`; + * `:auth_keys` - list of `{"user", public_key}` tuples, or a + `c::ssh_server_key_api.is_auth_key/3` function. Defaults to an empty + list. + * `:user_passwords` - list of `{"user", "password"}` tuples; + * `:no_auth_needed` - boolean value indicating whether to allow + connections with no authentication. Defaults to `true` if `:auth_keys` + and `:user_passwords` has not been set, otherwise `false`; + * `:ipfamily` - The IP address type to use, either `:inet` or + `:inet6`. Defaults to `:inet`; + * `:suppress_ssh_strict_kex_ordering_log` - boolean that suppresses OTP + SSH debug messages for strict KEX ordering. Defaults to `true`. Note: + the filter is installed when the server starts and removed when it stops. + If a test crashes the filter may persist into the next test; + * `:daemon` - options to pass directly to `:ssh.daemon/2`. + + ## Examples + + host_key = :public_key.generate_key({:rsa, 2_048, 65_537}) + auth_key = :public_key.generate_key({:rsa, 2_048, 65_537}) + {:RSAPrivateKey, _, mod, exp, _, _, _, _, _, _, _} = auth_key + auth_public_key = {:RSAPublicKey, mod, exp} + user_dir = SSHClient.write_user_dir_pem!(auth_key, "id_rsa") + + {:ok, _instance} = TestServer.SSH.start( + port: 2222, + host_keys: [host_key], + auth_keys: [{"user", auth_public_key}] + ) + + {:ok, channel} = TestServer.SSH.channel() + + :ok = + TestServer.SSH.handle( + channel, + to: fn {:data, _channel_id, _want_reply, _data}, state -> + {:reply, "pong", state} + end, + match: fn {:data, _channel_id, _want_reply, data}, _state -> + data == "ping" + end + ) + + :ok = TestServer.SSH.handle(channel) + + [%{fingerprint: host_fingerprint}] = TestServer.SSH.host_keys() + + assert {:ok, conn} = + SSHClient.connect( + TestServer.SSH.address(), + user: "user", + auth_methods: "publickey", + user_dir: user_dir, + silently_accept_hosts: fn _peer, fingerprint -> + fingerprint == host_fingerprint + end + ) + + assert {:ok, channel_id} = SSHClient.session_channel(conn) + assert :ok = SSHClient.send(conn, channel_id, "echo") + assert {:ok, "echo"} = SSHClient.receive_data(conn, channel_id) + assert :ok = SSHClient.send(conn, channel_id, "ping") + assert {:ok, "pong"} = SSHClient.receive_data(conn, channel_id) + assert :ok = SSHClient.close(conn) + """ + @spec start(keyword()) :: {:ok, pid()} + def start(options \\ []) do + TestServer.start_instance(__MODULE__, options, &verify!/1) + end + + defp verify!(instance) do + verify_handlers!(instance) + verify_channels!(instance) + end + + defp verify_channels!(instance) do + instance + |> Instance.channels() + |> Enum.filter(&is_nil(&1.channel_id)) + |> case do + [] -> + :ok + + unused_channels -> + raise """ + #{TestServer.format_instance(__MODULE__, instance)} has channels that were not used: + + #{Instance.format_channels(unused_channels)} + """ + end + end + + defp verify_handlers!(instance) do + instance + |> Instance.handlers() + |> Enum.reject(& &1.suspended) + |> case do + [] -> + :ok + + active_handlers -> + raise """ + #{TestServer.format_instance(__MODULE__, instance)} did not receive a message for these handlers before the test ended: + + #{Instance.format_handlers(active_handlers)} + """ + end + end + + @doc """ + Shuts down the current test server SSH instance. + + ## Examples + + {:ok, _instance} = TestServer.SSH.start() + {address, port} = TestServer.SSH.address() + :ok = TestServer.SSH.stop() + + assert SSHClient.connect(TestServer.SSH.address()) == {:error, :econnrefused} + """ + @spec stop() :: :ok | {:error, term()} + def stop, do: stop(TestServer.fetch_instance!(__MODULE__)) + + @doc """ + Shuts down a test server SSH instance. + """ + @spec stop(pid()) :: :ok | {:error, term()} + def stop(instance) do + TestServer.ensure_instance_alive!(__MODULE__, instance) + + Server.stop(Instance.get_options(instance)) + + TestServer.stop_instance(__MODULE__, instance) + end + + @spec address() :: {binary(), non_neg_integer()} + def address, do: address([]) + + @doc """ + Returns the address for current test server. + + ## Options + + * `:host` - binary host value, it'll be added to inet for IP `127.0.0.1` + and `::1`, defaults to `"localhost"`; + + ## Examples + + {:ok, _instance} = TestServer.SSH.start(port: 2222) + + assert TestServer.SSH.address() == {"localhost", 2222} + assert TestServer.SSH.address(host: "myserver.test") == {"myserver.test", 2222} + """ + @spec address(keyword() | pid()) :: {binary(), non_neg_integer()} + def address(options) when is_list(options), + do: address(TestServer.fetch_instance!(__MODULE__), options) + + def address(instance) when is_pid(instance), do: address(instance, []) + + @doc """ + Returns the address for a test server instance. + + See `address/1` for options. + """ + @spec address(pid(), keyword()) :: {binary(), non_neg_integer()} + def address(instance, options) when is_pid(instance) and is_list(options) do + TestServer.ensure_instance_alive!(__MODULE__, instance) + + host = TestServer.get_host(options) + port = instance |> Instance.get_options() |> Keyword.fetch!(:port) + + {host, port} + end + + @spec channel() :: {:ok, channel()} + def channel, do: channel([]) + + @doc """ + Adds a channel to the current test server. + + ## Options + + * `:listen` - list of message types to dispatch to handlers, or `:all`. + Defaults to `[:exec, :data]`. Available types: `:exec`, `:data`, `:env`, + `:pty`, `:shell`, `:eof`; + + ## Examples + + {:ok, channel} = TestServer.SSH.channel() + :ok = TestServer.SSH.handle(channel) + + assert {:ok, conn} = SSHClient.connect(TestServer.SSH.address()) + assert {:ok, channel_id} = SSHClient.session_channel(conn) + assert :ok = SSHClient.exec(conn, channel_id, "ping") + assert {:ok, "ping"} = SSHClient.receive_data(conn, channel_id) + assert :ok = SSHClient.close(conn) + """ + @spec channel(keyword()) :: {:ok, channel()} + def channel(options) when is_list(options) do + {:ok, instance} = TestServer.autostart_instance(__MODULE__) + + channel(instance, options) + end + + @doc """ + Adds a channel to a test server instance. + + See `channel/1` for options. + """ + @spec channel(pid(), keyword()) :: {:ok, channel()} + def channel(instance, options) do + TestServer.ensure_instance_alive!(__MODULE__, instance) + + [_first_module_entry | stacktrace] = TestServer.get_pruned_stacktrace(__MODULE__) + + options = Keyword.put_new(options, :listen, [:exec, :data]) + + {:ok, channel} = Instance.register(instance, {:channel, {options, stacktrace}}) + + {:ok, {instance, channel.ref}} + end + + @spec handle(channel()) :: :ok + def handle(channel), do: handle(channel, []) + + @doc """ + Adds a handler to a test server SSH channel. + + Handlers are matched FIFO (first in, first out). Any messages not matched by + a handler, or any handlers not consumed by a message, will raise an error in + the test case. + + The `:to` callback can be either a two-arity `t:handler_fun/0` or a + three-arity `t:raw_handler_fun/0`. A two-arity handler uses the default SSH + handling, including request acknowledgements, sending replies, and closing + `:exec` channels. A three-arity handler gives you full control over the SSH + connection response. + + ## Options + + * `:match` - a `t:match_fun/0` function that returns a boolean. Defaults to + matching anything; + * `:to` - a `t:handler_fun/0` or `t:raw_handler_fun/0` function called + when the handler matches. Defaults to send back the received data for + `:exec` and `:data` type channel messages, with no `:data` message + returned for any other message types; + + ## `:to` return options + + The `{:reply, {data, options}, state}` return from a `t:handler_fun/0` allows + you to specify options in the options keyword list: + + * `:data_type_code` - an integer SSH data type code to send with the reply, + defaults to `0` (SSH_MSG_CHANNEL_DATA); + * `:exit_status` - an integer exit status to send when finishing an `:exec` + channel, defaults to `0`. Ignored for other channel types; + + ## Examples + + {:ok, channel} = TestServer.SSH.channel() + + :ok = + TestServer.SSH.handle( + channel, + to: fn {:data, _channel_id, _want_reply, _data}, state -> + {:reply, "pong", state} + end, + match: fn {:data, _channel_id, _want_reply, data}, _state -> + data == "ping" + end + ) + + :ok = TestServer.SSH.handle(channel) + + assert {:ok, conn} = SSHClient.connect(TestServer.SSH.address()) + assert {:ok, channel_id} = SSHClient.session_channel(conn) + assert :ok = SSHClient.send(conn, channel_id, "echo") + assert {:ok, "echo"} = SSHClient.receive_data(conn, channel_id) + assert :ok = SSHClient.send(conn, channel_id, "ping") + assert {:ok, "pong"} = SSHClient.receive_data(conn, channel_id) + assert :ok = SSHClient.close(conn) + """ + @spec handle(channel(), keyword()) :: :ok + def handle({instance, channel_ref} = _channel, options) do + TestServer.ensure_instance_alive!(__MODULE__, instance) + + [_first_module_entry | stacktrace] = + TestServer.get_pruned_stacktrace(__MODULE__) + + {:ok, _handler} = Instance.register(instance, {:handle, {channel_ref, options, stacktrace}}) + + :ok + end + + @doc """ + Fetches the host keys for the current test server. + + ## Examples + + {:ok, _instance} = TestServer.SSH.start() + [host_key | _] = TestServer.SSH.host_keys() + + assert {:ok, _conn} = + SSHClient.connect( + TestServer.SSH.address(), + silently_accept_hosts: fn _peer, fingerprint -> + fingerprint == host_key.fingerprint + end) + """ + @spec host_keys() :: [host_key()] + def host_keys, do: host_keys(TestServer.fetch_instance!(__MODULE__)) + + @doc """ + Fetches the generated host keys for a test server instance. + + See `host_keys/0` for more. + """ + @spec host_keys(pid()) :: [host_key()] + def host_keys(instance) do + TestServer.ensure_instance_alive!(__MODULE__, instance) + + options = Instance.get_options(instance) + + case is_list(options[:host_keys]) do + true -> + options[:host_keys] + + false -> + raise "#{TestServer.format_instance(__MODULE__, instance)} is running with `[host_keys: function]` option" + end + end +end diff --git a/lib/test_server/ssh/README.md b/lib/test_server/ssh/README.md new file mode 100644 index 0000000..f22392b --- /dev/null +++ b/lib/test_server/ssh/README.md @@ -0,0 +1,94 @@ +# SSH + + + +Mock SSH endpoint with channel and message expectations, and password/public key authentication. + +## Usage + +### Channels + +Session channels can be set up by calling `TestServer.SSH.channel/2`. By default `TestServer.SSH.handle/2` will echo the message sent. The handlers that match the message will be called in the order they were specified. + +```elixir +test "SSHClient" do + {:ok, channel} = TestServer.SSH.channel() + + # Only matches `hi` data message + :ok = + TestServer.SSH.handle( + channel, + match: fn {:data, _channel_id, _want_reply, data}, _state -> + data == "hi" + end + ) + + # Reply with `pong` message + :ok = + TestServer.SSH.handle( + channel, + to: fn {:data, _channel_id, _want_reply, _data}, state -> + {:reply, "pong", state} + end + ) + + # Default data echo + :ok = TestServer.SSH.handle(channel) + + assert {:ok, conn} = SSHClient.connect(TestServer.SSH.address()) + assert {:ok, channel_id} = SSHClient.session_channel(conn) + assert :ok = SSHClient.send(conn, channel_id, "ping") + assert {:ok, "pong"} = SSHClient.receive_data(conn, channel_id) + assert :ok = SSHClient.send(conn, channel_id, "hi") + assert {:ok, "hi"} = SSHClient.receive_data(conn, channel_id) + assert :ok = SSHClient.send(conn, channel_id, "hello") + assert {:ok, "hello"} = SSHClient.receive_data(conn, channel_id) + assert :ok = SSHClient.close(conn) +end +``` + +By default, only `:exec` and `:data` messages are dispatched to handlers. Use the `:listen` option on `TestServer.SSH.channel/2` to control which message types are dispatched: + +```elixir +{:ok, channel_1} = TestServer.SSH.channel(listen: :all) +{:ok, channel_2} = TestServer.SSH.channel(listen: [:data, :env, :pty]) +``` + +### Host keys + +Host keys can be configured in `TestServer.SSH.start/1`. By default host keys are auto generated to cover most algorithms and can be fetched with `TestServer.SSH.host_keys/1` + +```elixir +host_key = :public_key.generate_key({:rsa, 2048, 65_537}) +{:ok, _instance} = TestServer.SSH.start(host_keys: [host_key]) + +[%{fingerprint: fingerprint}] = TestServer.SSH.host_keys() +``` + +### Authentication + +By default the test server accepts unauthenticated requests. Pass in `:user_passwords` and/or `:auth_keys` options to `TestServer.SSH.start/1` to require authentication: + +```elixir +{:RSAPrivateKey, _, mod, exp, _, _, _, _, _, _, _} = + :public_key.generate_key({:rsa, 2048, 65_537}) + +TestServer.SSH.start( + user_passwords: [{"user1", "pass"}], + auth_keys: [{"user2", {:RSAPublicKey, mod, exp}}] +) +``` + +### IPv6 + +Use the `:ipfamily` option to test with IPv6 when starting the test server with `TestServer.SSH.start/1`: + +```elixir +{:ok, _instance} = TestServer.SSH.start(ipfamily: :inet6) + +assert {:ok, conn} = SSHClient.connect(TestServer.SSH.address(), inet6: true) +assert {{0, 0, 0, 0, 0, 0, 0, 1}, _port} = SSHClient.sockname(conn) +``` + + + diff --git a/lib/test_server/ssh/channel.ex b/lib/test_server/ssh/channel.ex new file mode 100644 index 0000000..bfc49d2 --- /dev/null +++ b/lib/test_server/ssh/channel.ex @@ -0,0 +1,309 @@ +defmodule TestServer.SSH.Channel do + @moduledoc false + + @behaviour :ssh_server_channel + + alias TestServer.SSH.Instance + + defstruct [:instance, :channel, :state] + + @impl true + def init(options) do + {:ok, %__MODULE__{instance: Keyword.fetch!(options, :instance), state: %{}}} + end + + @impl true + def handle_msg({:EXIT, _pid, _reason}, state) do + {:stop, state.channel.channel_id, state} + end + + def handle_msg({:ssh_channel_up, channel_id, connection}, state) do + case Instance.dispatch(state.instance, {:channel_up, channel_id, connection}) do + {:ok, channel} -> + {:ok, %{state | channel: channel}} + + {:error, :not_found} -> + message = + append_formatted_channels( + "#{TestServer.format_instance(TestServer.SSH, state.instance)} received an unexpected SSH channel up message for channel ID #{channel_id} on connection #{inspect(connection)}.", + state.instance + ) + + send_error(connection, channel_id, {RuntimeError.exception(message), []}, state) + + {:stop, channel_id, state} + end + end + + defp append_formatted_channels(message, instance) do + channels = Enum.split_with(Instance.channels(instance), &is_nil(&1.channel_id)) + + """ + #{message} + + #{format_channels(channels)} + """ + end + + defp format_channels({[], used_channels}) do + message = "No available channels." + + case used_channels do + [] -> + message + + used_channels -> + """ + #{message} The following channels have been used: + + #{Instance.format_channels(used_channels)} + """ + end + end + + defp format_channels({available_channels, _used_channels}) do + """ + Available channels: + + #{Instance.format_channels(available_channels)} + """ + end + + defp send_error(connection, channel_id, {exception, stacktrace}, state) do + Instance.report_error(state.instance, {exception, stacktrace}) + + message = Exception.format(:error, exception, stacktrace) + + :ssh_connection.send(connection, channel_id, message) + end + + @impl true + def handle_ssh_msg({:ssh_cm, connection, frame}, state) do + type = elem(frame, 0) + listen = Keyword.fetch!(state.channel.options, :listen) + + case dispatch(listen, type, connection, frame, state) do + {:raw, {:ok, channel_state}} -> + {:ok, %{state | state: channel_state}} + + {:raw, {:stop, channel_id, channel_state}} -> + {:stop, channel_id, %{state | state: channel_state}} + + response -> + respond(response, connection, frame, state) + end + end + + defp dispatch(listen, type, connection, frame, state) do + case listen == :all or type in listen do + true -> + Instance.dispatch( + state.instance, + {:handle, state.channel.ref, connection, frame, state.state} + ) + + false -> + {:ok, state.state} + end + end + + defp respond(response, connection, frame, state) do + type = elem(frame, 0) + channel_id = elem(frame, 1) + + %{ + type: type, + channel_id: channel_id, + connection: connection, + frame: frame, + state: state + } + |> acknowledge() + |> send_response(response) + |> finish(response) + end + + defp acknowledge( + %{ + type: type, + channel_id: channel_id, + connection: connection, + frame: frame + } = reply + ) + when type in ~w(exec env pty shell subsystem)a do + want_reply = elem(frame, 2) + + :ssh_connection.reply_request(connection, want_reply, :success, channel_id) + + reply + end + + defp acknowledge( + %{ + type: :data, + channel_id: channel_id, + connection: connection, + frame: {:data, _channel_id, _want_reply, data} + } = reply + ) do + :ssh_connection.adjust_window(connection, channel_id, byte_size(data)) + + reply + end + + defp acknowledge( + %{ + type: type, + channel_id: _channel_id, + connection: _connection, + frame: _frame + } = reply + ) + when type in ~w(eof closed signal window_change)a, do: reply + + defp send_response( + %{ + connection: _connection, + channel_id: _channel_id, + state: state + } = reply, + {:ok, channel_state} + ) do + %{reply | state: %{state | state: channel_state}} + end + + defp send_response( + %{ + connection: connection, + channel_id: channel_id, + state: state + } = reply, + {:reply, {data, options}, channel_state} + ) do + data_type_code = Keyword.get(options, :data_type_code, 0) + + :ssh_connection.send(connection, channel_id, data_type_code, data) + + %{reply | state: %{state | state: channel_state}} + end + + defp send_response( + %{ + connection: connection, + channel_id: channel_id, + frame: frame, + state: state + } = reply, + {:error, :not_found} + ) do + message = + "#{TestServer.format_instance(TestServer.SSH, state.instance)} received an unexpected SSH message" + |> append_formatted_frame(frame) + |> append_formatted_handlers(state.instance) + + send_error(connection, channel_id, {RuntimeError.exception(message), []}, state) + + reply + end + + defp send_response( + %{ + connection: connection, + channel_id: channel_id, + frame: _frame, + state: state + } = reply, + {:error, {exception, stacktrace}} + ) do + send_error(connection, channel_id, {exception, stacktrace}, state) + + reply + end + + defp append_formatted_frame(message, frame) do + """ + #{message}: + + #{inspect(frame)} + """ + end + + defp append_formatted_handlers(message, instance) do + handlers = Enum.split_with(Instance.handlers(instance), &(not &1.suspended)) + + """ + #{message} + + #{format_handlers(handlers)} + """ + end + + defp format_handlers({[], suspended_handlers}) do + message = "No active handlers." + + case suspended_handlers do + [] -> + message + + suspended_handlers -> + """ + #{message} The following handlers have been processed: + + #{Instance.format_handlers(suspended_handlers)} + """ + end + end + + defp format_handlers({active_handlers, _suspended_handlers}) do + """ + Active handlers: + + #{Instance.format_handlers(active_handlers)} + """ + end + + defp finish( + %{ + connection: connection, + frame: {:exec, channel_id, _want_reply, _command}, + state: state + }, + response + ) do + exit_status = + case response do + {:reply, {_data, options}, _channel_state} -> Keyword.get(options, :exit_status, 0) + _other -> 0 + end + + :ssh_connection.exit_status(connection, channel_id, exit_status) + :ssh_connection.send_eof(connection, channel_id) + :ssh_connection.close(connection, channel_id) + + {:stop, channel_id, state} + end + + defp finish( + %{ + connection: _, + frame: {:closed, channel_id}, + state: state + }, + _response + ), + do: {:stop, channel_id, state} + + defp finish( + %{ + connection: _, + frame: _, + state: state + }, + _response + ), + do: {:ok, state} + + @impl true + def terminate(_reason, _state), do: :ok +end diff --git a/lib/test_server/ssh/instance.ex b/lib/test_server/ssh/instance.ex new file mode 100644 index 0000000..7f3975b --- /dev/null +++ b/lib/test_server/ssh/instance.ex @@ -0,0 +1,346 @@ +defmodule TestServer.SSH.Instance do + @moduledoc false + + use GenServer + + def start_link(options) do + GenServer.start_link(__MODULE__, options) + end + + def stop(instance) do + GenServer.stop(instance) + end + + @spec register(pid(), {:channel, {keyword(), TestServer.stacktrace()}}) :: + {:ok, %{ref: TestServer.SSH.channel_ref()}} + def register(instance, {:channel, {options, stacktrace}}) do + options[:listen] && ensure_listen!(options[:listen]) + + GenServer.call(instance, {:register, {:channel, {options, stacktrace}}}) + end + + @spec register( + pid(), + {:handle, {TestServer.SSH.channel_ref(), keyword(), TestServer.stacktrace()}} + ) :: + {:ok, map()} + def register(instance, {:handle, {channel_ref, options, stacktrace}}) do + options[:to] && ensure_function!(options[:to]) + options[:match] && ensure_function!(options[:match]) + + GenServer.call(instance, {:register, {:handle, {channel_ref, options, stacktrace}}}) + end + + @listen_events ~w(exec data env pty shell eof)a + + defp ensure_listen!(listen) when is_list(listen) do + case Enum.all?(listen, &(&1 in @listen_events)) do + true -> + :ok + + false -> + raise ArgumentError, + "expected list to only include #{inspect(@listen_events)}, got: #{inspect(listen)}" + end + end + + defp ensure_listen!(listen) do + case listen do + :all -> :ok + _ -> raise ArgumentError, "expected :all, got: #{inspect(listen)}" + end + end + + defp ensure_function!(fun) when is_function(fun), do: :ok + defp ensure_function!(fun), do: raise(BadFunctionError, term: fun) + + @spec dispatch(pid(), {:channel_up, TestServer.SSH.channel_id(), TestServer.SSH.connection()}) :: + {:ok, {TestServer.SSH.channel_ref(), keyword(), TestServer.stacktrace()}} + | {:error, :not_found} + def dispatch(instance, {:channel_up, channel_id, connection}) do + GenServer.call(instance, {:dispatch, {:channel_up, channel_id, connection}}) + end + + @spec dispatch( + pid(), + {:handle, TestServer.SSH.channel_id(), TestServer.SSH.connection(), + TestServer.SSH.channel_msg(), TestServer.SSH.state()} + ) :: + {:raw, {:ok, TestServer.SSH.state()}} + | {:raw, {:stop, TestServer.SSH.channel_id(), TestServer.SSH.state()}} + | {:reply, {binary(), keyword()}, TestServer.SSH.state()} + | {:ok, TestServer.SSH.state()} + | {:error, :not_found} + | {:error, {term(), TestServer.stacktrace()}} + def dispatch(instance, {:handle, channel_ref, connection, message, channel_state}) do + GenServer.call( + instance, + {:dispatch, {:handle, channel_ref, connection, {message, channel_state}}} + ) + end + + @spec handlers(pid()) :: [map()] + def handlers(instance) do + GenServer.call(instance, :handlers) + end + + @spec channels(pid()) :: [map()] + def channels(instance) do + GenServer.call(instance, :channels) + end + + @spec get_options(pid()) :: keyword() + def get_options(instance) do + GenServer.call(instance, :options) + end + + @spec format_handlers([map()]) :: binary() + def format_handlers(handlers) do + handlers + |> Enum.with_index() + |> Enum.map_join("\n\n", fn {handler, index} -> + """ + ##{index + 1}: #{inspect(handler.to)} + #{Enum.map_join(handler.stacktrace, "\n ", &Exception.format_stacktrace_entry/1)} + """ + end) + end + + @spec format_channels([map()]) :: binary() + def format_channels(channels) do + channels + |> Enum.with_index() + |> Enum.map_join("\n\n", fn {channel, index} -> + """ + ##{index + 1}: #{inspect(channel.ref)} + #{Enum.map_join(channel.stacktrace, "\n ", &Exception.format_stacktrace_entry/1)} + """ + end) + end + + @spec report_error(pid(), {struct(), TestServer.stacktrace()}) :: :ok + def report_error(instance, {exception, stacktrace}) do + options = get_options(instance) + caller = Keyword.fetch!(options, :caller) + + unless Keyword.get(options, :suppress_warning, false), + do: IO.warn(Exception.format(:error, exception, stacktrace)) + + ExUnit.OnExitHandler.add(caller, make_ref(), fn -> + reraise exception, stacktrace + end) + + :ok + end + + @impl true + def init(options) do + alias TestServer.SSH.Server + + case Server.start(self(), options) do + {:ok, options} -> + {:ok, + %{ + options: options, + channels: [], + handlers: [] + }} + + {:error, reason} -> + {:stop, reason} + end + end + + @impl true + def handle_call({:register, {:channel, {options, stacktrace}}}, _from, state) do + channel = %{ + ref: make_ref(), + options: options, + stacktrace: stacktrace, + channel_id: nil, + connection_ref: nil + } + + {:reply, {:ok, channel}, %{state | channels: state.channels ++ [channel]}} + end + + def handle_call({:dispatch, {:channel_up, channel_id, connection_ref}}, _from, state) do + case Enum.find_index(state.channels, &(is_nil(&1.channel_id) and is_nil(&1.connection_ref))) do + nil -> + {:reply, {:error, :not_found}, state} + + index -> + updated_channel = %{ + Enum.at(state.channels, index) + | channel_id: channel_id, + connection_ref: connection_ref + } + + channels = List.replace_at(state.channels, index, updated_channel) + + {:reply, {:ok, updated_channel}, %{state | channels: channels}} + end + end + + def handle_call({:register, {:handle, {channel_ref, options, stacktrace}}}, _from, state) do + handler = %{ + ref: make_ref(), + channel_ref: channel_ref, + match: Keyword.get(options, :match), + to: Keyword.get(options, :to, &default_handler/2), + stacktrace: stacktrace, + suspended: false, + received: [] + } + + {:reply, {:ok, handler}, %{state | handlers: state.handlers ++ [handler]}} + end + + def handle_call( + {:dispatch, {:handle, channel_ref, connection, {message, channel_state}}}, + _from, + state + ) do + {res, state} = run_handlers(message, channel_ref, connection, channel_state, state) + + {:reply, res, state} + end + + def handle_call(option, _from, state) when option in [:handlers, :channels, :options] do + {:reply, Map.fetch!(state, option), state} + end + + defp default_handler({:exec, _channel_id, _want_reply, command}, state), + do: {:reply, to_string(command), state} + + defp default_handler({:data, _channel_id, _type, data}, state), + do: {:reply, data, state} + + defp default_handler(_message, state), + do: {:ok, state} + + defp run_handlers(message, channel_ref, connection, channel_state, state) do + state.handlers + |> fetch_match_index([message, channel_state], fn + %{channel_ref: ^channel_ref, suspended: true} -> false + %{channel_ref: ^channel_ref, suspended: false} -> true + %{channel_ref: _other, suspended: _any} -> false + end) + |> case do + {:error, :not_found} -> + {{:error, :not_found}, state} + + {:error, {error, stacktrace}} -> + {{:error, {error, stacktrace}}, state} + + {:ok, index} -> + %{to: handler, stacktrace: stacktrace} = Enum.at(state.handlers, index) + + result = try_run_handler(handler, message, connection, channel_state, stacktrace) + + handlers = + List.update_at(state.handlers, index, fn h -> + %{h | suspended: true, received: h.received ++ [message]} + end) + + {result, %{state | handlers: handlers}} + end + end + + defp fetch_match_index(items, args, callback) do + items + |> Enum.find_index(fn %{match: match} = item -> + callback.(item) && (is_nil(match) || apply(match, args)) + end) + |> case do + nil -> {:error, :not_found} + index -> {:ok, index} + end + rescue + error -> {:error, {error, __STACKTRACE__}} + end + + defp try_run_handler(handler, message, connection, channel_state, stacktrace) do + message + |> run_handler(handler, connection, channel_state) + |> validate_response!(handler, stacktrace) + rescue + error -> {:error, {error, __STACKTRACE__}} + end + + defp run_handler(message, handler, connection, channel_state) when is_function(handler, 3) do + handler.(message, connection, channel_state) + end + + defp run_handler(message, handler, _connection, channel_state) when is_function(handler, 2) do + handler.(message, channel_state) + end + + defp validate_response!(response, handler, stacktrace) when is_function(handler, 3) do + case response do + {:ok, state} -> + {:raw, {:ok, state}} + + {:stop, channel_id, state} -> + {:raw, {:stop, channel_id, state}} + + _other -> + raise """ + Invalid callback response, got: #{inspect(response)}. + + Expected one of the following: + + - {:ok, state} + - {:stop, channel_id, state} + + #{Enum.map_join(stacktrace, "\n ", &Exception.format_stacktrace_entry/1)} + """ + end + end + + defp validate_response!(response, handler, stacktrace) when is_function(handler, 2) do + case response do + {:reply, {data, options}, state} -> + validate_options!({:reply, {data, options}, state}, stacktrace) + + {:reply, {data, options}, state} + + {:reply, data, state} when is_binary(data) -> + {:reply, {data, []}, state} + + {:ok, state} -> + {:ok, state} + + _other -> + raise """ + Invalid callback response, got: #{inspect(response)}. + + Expected one of the following: + + - {:reply, data, state} + - {:reply, {data, options}, state} + - {:ok, state} + + #{Enum.map_join(stacktrace, "\n ", &Exception.format_stacktrace_entry/1)} + """ + end + end + + defp validate_options!({:reply, {_data, options}, _state}, stacktrace) do + valid_options = ~w(exit_status data_type_code)a + + case Keyword.validate(options, valid_options) do + {:ok, _options} -> + :ok + + {:error, _keys} -> + raise """ + Invalid options in callback response, got: #{inspect(options)}. + + Valid options are: #{inspect(valid_options)}. + + #{Enum.map_join(stacktrace, "\n ", &Exception.format_stacktrace_entry/1)} + """ + end + end +end diff --git a/lib/test_server/ssh/key_api.ex b/lib/test_server/ssh/key_api.ex new file mode 100644 index 0000000..120f475 --- /dev/null +++ b/lib/test_server/ssh/key_api.ex @@ -0,0 +1,19 @@ +defmodule TestServer.SSH.KeyAPI do + @moduledoc false + + @behaviour :ssh_server_key_api + + @impl true + def host_key(algorithm, daemon_options) do + host_key = Keyword.fetch!(daemon_options[:key_cb_private], :host_key) + + host_key.(algorithm, daemon_options) + end + + @impl true + def is_auth_key(public_key, user, daemon_options) do + is_auth_key = Keyword.fetch!(daemon_options[:key_cb_private], :is_auth_key) + + is_auth_key.(public_key, user, daemon_options) + end +end diff --git a/lib/test_server/ssh/server.ex b/lib/test_server/ssh/server.ex new file mode 100644 index 0000000..a1622ba --- /dev/null +++ b/lib/test_server/ssh/server.ex @@ -0,0 +1,223 @@ +defmodule TestServer.SSH.Server do + @moduledoc false + + @doc false + @spec start(pid(), keyword()) :: {:ok, keyword()} | {:error, any()} + def start(instance, options) do + port = TestServer.open_port(options) + {host_keys, daemon_options} = daemon_options(instance, options) + + case :ssh.daemon(port, daemon_options) do + {:ok, daemon_ref} -> + suppress_ssh_strict_kex_ordering_log?(options) && install_strict_kex_filter() + + options = + options + |> Keyword.put(:host_keys, host_keys) + |> Keyword.put(:port, port) + |> Keyword.put(:daemon_ref, daemon_ref) + + {:ok, options} + + {:error, reason} -> + {:error, reason} + end + end + + defp daemon_options(instance, options) do + tmp_dir = to_charlist(System.tmp_dir!()) + + {host_keys_fn, host_keys} = + case host_keys(options) do + key_fun when is_function(key_fun, 2) -> {key_fun, :none} + host_keys when is_list(host_keys) -> {&default_host_key/2, to_host_key_maps(host_keys)} + end + + {auth_key_fn, auth_keys} = + case Keyword.get(options, :auth_keys, []) do + key_fun when is_function(key_fun, 3) -> {key_fun, []} + auth_keys when is_list(auth_keys) -> {&default_is_auth_key/3, auth_keys} + end + + key_cb = + { + TestServer.SSH.KeyAPI, + [ + host_key: host_keys_fn, + host_keys: host_keys, + is_auth_key: auth_key_fn, + auth_keys: auth_keys + ] + } + + ssh_cli = + { + TestServer.SSH.Channel, + options + |> Keyword.take([:listen]) + |> Keyword.put(:instance, instance) + } + + daemon_options = + options + |> Keyword.take([:auth_methods, :no_auth_needed, :user_passwords]) + |> normalize_user_passwords() + |> Keyword.put_new( + :no_auth_needed, + not Keyword.has_key?(options, :auth_keys) and + not Keyword.has_key?(options, :user_passwords) + ) + |> Keyword.merge( + key_cb: key_cb, + ssh_cli: ssh_cli, + system_dir: tmp_dir, + user_dir: tmp_dir, + parallel_login: true + ) + |> Keyword.merge(Keyword.get(options, :daemon, [])) + + daemon_options = + case Keyword.get(options, :ipfamily, :inet) do + :inet -> daemon_options + ipfamily -> [ipfamily | daemon_options] + end + + {host_keys, daemon_options} + end + + defp host_keys(options) do + Keyword.get_lazy(options, :host_keys, fn -> + [ + # `:"rsa-sha2-256"` | `:"rsa-sha2-512"` + :public_key.generate_key({:rsa, 2_048, 65_537}), + # `:`ecdsa-sha2-nistp256"` + :public_key.generate_key({:namedCurve, :secp256r1}), + # `:`ecdsa-sha2-nistp384"` + :public_key.generate_key({:namedCurve, :secp384r1}), + # `:`ecdsa-sha2-nistp521"` + :public_key.generate_key({:namedCurve, :secp521r1}), + # `:`ssh-ed25519"` + :public_key.generate_key({:namedCurve, :ed25519}), + # `:`ssh-ed448"` + :public_key.generate_key({:namedCurve, :ed448}) + ] + end) + end + + defp default_host_key(algorithm, daemon_options) do + host_keys = + daemon_options + |> Keyword.fetch!(:key_cb_private) + |> Keyword.fetch!(:host_keys) + + host_keys + |> Enum.find(&(algorithm in &1.algorithms)) + |> case do + nil -> {:error, :unsupported_algorithm} + %{key: host_key} -> {:ok, host_key} + end + end + + defp to_host_key_maps(host_keys) do + Enum.map(host_keys, fn + {:RSAPrivateKey, _, mod, exp, _, _, _, _, _, _, _} = host_key -> + %{ + key: host_key, + algorithms: [:"rsa-sha2-256", :"rsa-sha2-512"], + fingerprint: :ssh.hostkey_fingerprint({:RSAPublicKey, mod, exp}) + } + + {:ECPrivateKey, _, _, {:namedCurve, curve_oid}, public_key, _} = host_key -> + algorithm = algorithm_for_curve_oid(curve_oid) + fingerprint = :ssh.hostkey_fingerprint({{:ECPoint, public_key}, {:namedCurve, curve_oid}}) + + %{ + key: host_key, + algorithms: [algorithm], + fingerprint: fingerprint + } + + other -> + raise "Unsupported host key format: #{inspect(other)}" + end) + end + + defp algorithm_for_curve_oid({1, 3, 101, 112}), do: :"ssh-ed25519" + defp algorithm_for_curve_oid({1, 3, 101, 113}), do: :"ssh-ed448" + + defp algorithm_for_curve_oid(curve_oid) do + case :public_key.oid2ssh_curvename(curve_oid) do + "nistp256" -> :"ecdsa-sha2-nistp256" + "nistp384" -> :"ecdsa-sha2-nistp384" + "nistp521" -> :"ecdsa-sha2-nistp521" + end + end + + defp default_is_auth_key(public_key, user, daemon_options) do + user = to_string(user) + + daemon_options + |> Keyword.fetch!(:key_cb_private) + |> Keyword.fetch!(:auth_keys) + |> Enum.any?(fn + {nil, ^public_key} -> true + {^user, ^public_key} -> true + {_user, _public_key} -> false + end) + end + + defp normalize_user_passwords(options) do + case Keyword.has_key?(options, :user_passwords) do + true -> + Keyword.put( + options, + :user_passwords, + Enum.map(options[:user_passwords], fn {user, pass} -> + {to_charlist(user), to_charlist(pass)} + end) + ) + + false -> + options + end + end + + defp suppress_ssh_strict_kex_ordering_log?(options), + do: Keyword.get(options, :suppress_ssh_strict_kex_ordering_log, true) + + @ssh_strict_kex_filter :test_server_suppress_ssh_strict_kex_ordering + + # Erlang's ssh_transport module logs "server will use strict KEX ordering" at debug + # level using logger:debug/1 (bare string, no metadata). Because there's no module + # metadata, set_module_level/set_application_level can't target it. We use a primary + # filter that drops this specific debug message. + defp install_strict_kex_filter do + %{filters: filters} = :logger.get_primary_config() + + unless List.keyfind(filters, @ssh_strict_kex_filter, 0) do + :logger.add_primary_filter(@ssh_strict_kex_filter, { + fn + %{level: :debug, msg: {:string, ~c"server will use strict KEX ordering"}}, _extra -> + :stop + + _event, _extra -> + :ignore + end, + [] + }) + end + end + + @doc false + @spec stop(keyword()) :: :ok + def stop(options) do + daemon_ref = Keyword.fetch!(options, :daemon_ref) + + try do + :ssh.stop_daemon(daemon_ref) + after + suppress_ssh_strict_kex_ordering_log?(options) && + :logger.remove_primary_filter(@ssh_strict_kex_filter) + end + end +end diff --git a/mix.exs b/mix.exs index ba451b6..6b13c9d 100644 --- a/mix.exs +++ b/mix.exs @@ -28,7 +28,7 @@ defmodule TestServer.MixProject do def application do [ - extra_applications: [:logger, :crypto, :public_key, :inets], + extra_applications: [:logger, :crypto, :public_key, :ssh, :inets], mod: {TestServer.Application, []} ] end @@ -83,6 +83,9 @@ defmodule TestServer.MixProject do TestServer.HTTP.Server.Httpd, TestServer.HTTP.Server.Bandit, TestServer.HTTP.Server.Plug.Cowboy + ], + SSH: [ + TestServer.SSH ] ] ] diff --git a/test/test_server/ssh_test.exs b/test/test_server/ssh_test.exs new file mode 100644 index 0000000..96910de --- /dev/null +++ b/test/test_server/ssh_test.exs @@ -0,0 +1,1203 @@ +defmodule TestServer.SSHTest do + use ExUnit.Case + doctest TestServer.SSH + + import ExUnit.CaptureIO + import ExUnit.CaptureLog + + alias __MODULE__.SSHClient + + describe "start/1" do + test "with invalid port" do + assert_raise RuntimeError, ~r/Invalid port, got: :invalid/, fn -> + TestServer.SSH.start(port: :invalid) + end + + assert_raise RuntimeError, ~r/Invalid port, got: 65536/, fn -> + TestServer.SSH.start(port: 65_536) + end + + assert_raise RuntimeError, ~r/Could not listen to port 2222, because: :eaddrinuse/, fn -> + TestServer.SSH.start(port: 2222) + TestServer.SSH.start(port: 2222) + end + end + + test "starts with multiple ports" do + {:ok, instance_1} = TestServer.SSH.start() + {:ok, instance_2} = TestServer.SSH.start() + + refute instance_1 == instance_2 + + {_, port_1} = TestServer.SSH.address(instance_1) + {_, port_2} = TestServer.SSH.address(instance_2) + + refute port_1 == port_2 + end + + test "with `:host_keys` option list" do + host_key_1 = :public_key.generate_key({:rsa, 2048, 65_537}) + host_key_2 = :public_key.generate_key({:namedCurve, :secp521r1}) + + {:RSAPrivateKey, _, other_mod, other_exp, _, _, _, _, _, _, _} = + :public_key.generate_key({:rsa, 2048, 65_537}) + + other_hostkey_fingerprint = :ssh.hostkey_fingerprint({:RSAPublicKey, other_mod, other_exp}) + + {:ok, _instance} = TestServer.SSH.start(host_keys: [host_key_1, host_key_2]) + + [host_key_1_fingerprint, host_key_2_fingerprint] = + Enum.map(TestServer.SSH.host_keys(), & &1.fingerprint) + + assert capture_log(fn -> + assert SSHClient.connect( + TestServer.SSH.address(), + silently_accept_hosts: fn _peer, fingerprint -> + fingerprint == other_hostkey_fingerprint + end + ) == {:error, "Key exchange failed"} + end) =~ "Key exchange failed" + + assert {:ok, _conn} = + SSHClient.connect( + TestServer.SSH.address(), + silently_accept_hosts: fn _peer, fingerprint -> + fingerprint == host_key_2_fingerprint + end + ) + + assert {:ok, _conn} = + SSHClient.connect( + TestServer.SSH.address(), + silently_accept_hosts: fn _peer, fingerprint -> + fingerprint == host_key_1_fingerprint + end + ) + end + + test "with `:host_keys` option function" do + host_key = :public_key.generate_key({:rsa, 2_048, 65_537}) + {:RSAPrivateKey, _, mod, exp, _, _, _, _, _, _, _} = host_key + hostkey_fingerprint = :ssh.hostkey_fingerprint({:RSAPublicKey, mod, exp}) + + {:RSAPrivateKey, _, other_mod, other_exp, _, _, _, _, _, _, _} = + :public_key.generate_key({:rsa, 2_048, 65_537}) + + other_host_key_fingerprint = :ssh.hostkey_fingerprint({:RSAPublicKey, other_mod, other_exp}) + + {:ok, _instance} = + TestServer.SSH.start( + host_keys: fn _algorithm, _daemon_options -> + {:ok, host_key} + end + ) + + assert capture_log(fn -> + assert {:error, "Key exchange failed"} = + SSHClient.connect(TestServer.SSH.address(), + silently_accept_hosts: fn _peer, fingerprint -> + fingerprint == other_host_key_fingerprint + end + ) + end) =~ "ECDH reply failed. Verify host key: {error,fingerprint_check_failed}" + + assert capture_log(fn -> + assert {:error, "Key exchange failed"} = + SSHClient.connect(TestServer.SSH.address(), + preferred_algorithms: [public_key: [:"ecdsa-sha2-nistp521"]] + ) + end) =~ "No common key algorithm" + + assert {:ok, _conn} = + SSHClient.connect(TestServer.SSH.address(), + silently_accept_hosts: fn _peer, fingerprint -> + fingerprint == hostkey_fingerprint + end + ) + end + + for algorithm <- ~w( + rsa-sha2-256 + rsa-sha2-512 + ecdsa-sha2-nistp256 + ecdsa-sha2-nistp384 + ecdsa-sha2-nistp521 + ssh-ed25519 + ssh-ed448 + )a do + test "with default `:host_keys` option using #{algorithm} algorithm" do + {:ok, _instance} = TestServer.SSH.start() + + assert capture_log(fn -> + assert {:error, "Key exchange failed"} = + SSHClient.connect( + TestServer.SSH.address(), + preferred_algorithms: [public_key: [:"ssh-dss"]] + ) + end) =~ "Key exchange failed" + + {:ok, conn} = + SSHClient.connect( + TestServer.SSH.address(), + preferred_algorithms: [public_key: [unquote(algorithm)]] + ) + + {:algorithms, algs} = :ssh.connection_info(conn, :algorithms) + + assert Keyword.fetch!(algs, :hkey) == unquote(algorithm) + end + end + + test "with `:auth_keys` option list", context do + {:RSAPrivateKey, _, mod, exp, _, _, _, _, _, _, _} = + auth_key_1 = :public_key.generate_key({:rsa, 2048, 65_537}) + + auth_key_1_public_key = {:RSAPublicKey, mod, exp} + + {:ECPrivateKey, _, _, oid, public_key, _} = + auth_key_2 = :public_key.generate_key({:namedCurve, :secp256r1}) + + auth_key_2_public_key = {{:ECPoint, public_key}, oid} + user_dir = write_user_dir_pem!(context, auth_key_2) + client_options = [user_dir: user_dir, auth_methods: "publickey"] + + {:ok, _instance} = + TestServer.SSH.start( + auth_keys: [ + {nil, auth_key_1_public_key}, + {"user", auth_key_2_public_key} + ] + ) + + assert capture_log(fn -> + assert {:error, "Unable to connect using the available authentication methods"} = + SSHClient.connect( + TestServer.SSH.address(), + Keyword.put(client_options, :user, "other") + ) + end) =~ "User auth failed for: \"other\"" + + assert {:ok, _conn} = + SSHClient.connect( + TestServer.SSH.address(), + Keyword.put(client_options, :user, "user") + ) + + write_user_dir_pem!(context, auth_key_1) + + assert {:ok, _conn} = + SSHClient.connect( + TestServer.SSH.address(), + Keyword.put(client_options, :user, "other") + ) + end + + test "with `:auth_keys` option function", context do + {:RSAPrivateKey, _, mod, exp, _, _, _, _, _, _, _} = + auth_key = :public_key.generate_key({:rsa, 2048, 65_537}) + + auth_key_public = {:RSAPublicKey, mod, exp} + user_dir = write_user_dir_pem!(context, auth_key) + client_options = [user_dir: user_dir, auth_methods: "publickey"] + + {:ok, _instance} = + TestServer.SSH.start( + auth_keys: fn public_key, user, _daemon_options -> + user == ~c"user" and public_key == auth_key_public + end + ) + + assert capture_log(fn -> + assert {:error, "Unable to connect using the available authentication methods"} = + SSHClient.connect( + TestServer.SSH.address(), + Keyword.put(client_options, :user, "other") + ) + end) =~ "User auth failed for: \"other\"" + + assert {:ok, _conn} = + SSHClient.connect( + TestServer.SSH.address(), + Keyword.put(client_options, :user, "user") + ) + end + + test "with `:user_passwords` option" do + {:ok, _instance} = TestServer.SSH.start(user_passwords: [{"user", "pass"}]) + + assert capture_log(fn -> + assert SSHClient.connect( + TestServer.SSH.address(), + user: "user", + password: "invalid" + ) == + {:error, "Unable to connect using the available authentication methods"} + end) =~ "Unable to connect using the available authentication methods" + + assert capture_log(fn -> + assert SSHClient.connect( + TestServer.SSH.address(), + user: "other", + password: "pass" + ) == + {:error, "Unable to connect using the available authentication methods"} + end) =~ "Unable to connect using the available authentication methods" + + assert {:ok, _conn} = + SSHClient.connect( + TestServer.SSH.address(), + user: "user", + password: "pass", + auth_methods: "password" + ) + end + + test "with `:no_auth_needed` option" do + {:ok, _instance} = TestServer.SSH.start(no_auth_needed: true) + + assert {:ok, _conn} = SSHClient.connect(TestServer.SSH.address()) + end + + test "with `ipfamily: :inet6` option" do + {:ok, _instance} = TestServer.SSH.start(ipfamily: :inet6) + + {:ok, conn} = SSHClient.connect(TestServer.SSH.address(), inet6: true) + assert {ip, _port} = SSHClient.sockname(conn) + assert ip == {0, 0, 0, 0, 0, 0, 0, 1} + end + + test "with `:daemon` option" do + {:ok, _instance} = TestServer.SSH.start(daemon: [max_sessions: 2]) + + assert {:ok, _conn} = SSHClient.connect(TestServer.SSH.address()) + assert {:ok, _conn} = SSHClient.connect(TestServer.SSH.address()) + assert {:error, "Connection closed"} = SSHClient.connect(TestServer.SSH.address()) + end + + test "with `suppress_ssh_strict_kex_ordering_log: true` option" do + :logger.remove_primary_filter(:test_server_suppress_ssh_strict_kex_ordering) + {:ok, _instance} = TestServer.SSH.start(suppress_ssh_strict_kex_ordering_log: true) + + assert %{filters: filters} = :logger.get_primary_config() + assert List.keyfind(filters, :test_server_suppress_ssh_strict_kex_ordering, 0) + + assert capture_log(fn -> + assert {:ok, _conn} = SSHClient.connect(TestServer.SSH.address()) + end) == "" + + :ok = TestServer.SSH.stop() + + assert %{filters: filters} = :logger.get_primary_config() + refute List.keyfind(filters, :test_server_suppress_ssh_strict_kex_ordering, 0) + end + + test "with `suppress_ssh_strict_kex_ordering_log: false` option" do + :logger.remove_primary_filter(:test_server_suppress_ssh_strict_kex_ordering) + {:ok, _instance} = TestServer.SSH.start(suppress_ssh_strict_kex_ordering_log: false) + + assert %{filters: filters} = :logger.get_primary_config() + refute List.keyfind(filters, :test_server_suppress_ssh_strict_kex_ordering, 0) + + assert capture_log(fn -> + assert {:ok, _conn} = SSHClient.connect(TestServer.SSH.address()) + end) =~ "server will use strict KEX ordering" + end + end + + defp write_user_dir_pem!(context, key) do + test_name = + context.test + |> to_string() + |> String.replace(~r/[^a-zA-Z0-9_-]+/, "-") + + base_name = + case key do + {:RSAPrivateKey, _, _, _, _, _, _, _, _, _, _} -> "id_rsa" + {:ECPrivateKey, _, _, _, _, _} -> "id_ecdsa" + end + + path = Path.join([System.tmp_dir!(), to_string(context.module), test_name]) + + File.rm_rf!(path) + File.mkdir_p!(path) + + SSHClient.write_user_dir_pem!(key, base_name, path) + end + + describe "stop/1" do + test "when not running" do + assert_raise RuntimeError, "No current TestServer.SSH.Instance running", fn -> + TestServer.SSH.stop() + end + + assert_raise RuntimeError, + ~r/TestServer\.SSH\.Instance \#PID\<[0-9.]+\> is not running/, + fn -> + {:ok, instance} = TestServer.SSH.start() + + assert :ok = TestServer.SSH.stop() + + TestServer.SSH.stop(instance) + end + end + + test "stops" do + assert {:ok, pid} = TestServer.SSH.start() + address = TestServer.SSH.address() + + assert :ok = TestServer.SSH.stop() + refute Process.alive?(pid) + + assert SSHClient.connect(address) == {:error, :econnrefused} + end + + test "with multiple instances" do + {:ok, instance_1} = TestServer.SSH.start() + {:ok, _instance_2} = TestServer.SSH.start() + + assert_raise RuntimeError, + ~r/Multiple instances running, please pass instance to `TestServer\.SSH\.stop\/0`/, + fn -> + TestServer.SSH.stop() + end + + assert :ok = TestServer.SSH.stop(instance_1) + assert :ok = TestServer.SSH.stop() + end + end + + describe "address/2" do + test "when instance not running" do + assert_raise RuntimeError, "No current TestServer.SSH.Instance running", fn -> + TestServer.SSH.address() + end + + assert_raise RuntimeError, + ~r/TestServer\.SSH\.Instance \#PID\<[0-9.]+\> is not running/, + fn -> + {:ok, instance} = TestServer.SSH.start() + + assert :ok = TestServer.SSH.stop() + + TestServer.SSH.address(instance) + end + end + + test "with invalid `:host`" do + TestServer.SSH.start() + + assert_raise RuntimeError, ~r/Invalid host, got: :invalid/, fn -> + TestServer.SSH.address(host: :invalid) + end + end + + test "produces address" do + TestServer.SSH.start() + + assert {"localhost", port} = TestServer.SSH.address() + assert is_integer(port) + end + + test "with `:host`" do + {:ok, _instance} = TestServer.SSH.start() + + assert {"myserver.test", _port} = address = TestServer.SSH.address(host: "myserver.test") + + assert {:ok, _conn} = SSHClient.connect(address) + end + + test "with `:host` in IPv6-only mode" do + {:ok, _instance} = TestServer.SSH.start(ipfamily: :inet6) + + assert {:ok, _conn} = SSHClient.connect(TestServer.SSH.address(), inet6: true) + end + + test "with multiple instances" do + {:ok, instance_1} = TestServer.SSH.start() + {:ok, instance_2} = TestServer.SSH.start() + + assert_raise RuntimeError, + ~r/Multiple instances running, please pass instance to `TestServer\.SSH\.address\/1`/, + fn -> + TestServer.SSH.address() + end + + refute TestServer.SSH.address(instance_1) == TestServer.SSH.address(instance_2) + end + end + + describe "channel/2" do + test "when instance not running" do + {:ok, instance} = TestServer.SSH.start() + assert :ok = TestServer.SSH.stop() + + assert_raise RuntimeError, + ~r/TestServer\.SSH\.Instance \#PID\<[0-9.]+\> is not running/, + fn -> + TestServer.SSH.channel(instance, []) + end + end + + test "with invalid options" do + assert_raise ArgumentError, ~r/expected :all, got: :invalid/, fn -> + TestServer.SSH.channel(listen: :invalid) + end + + assert_raise ArgumentError, + ~r/expected list to only include \[:exec, :data, :env, :pty, :shell, :eof\], got: \[:invalid\]/, + fn -> + TestServer.SSH.channel(listen: [:invalid]) + end + end + + test "with multiple instances" do + {:ok, instance_1} = TestServer.SSH.start() + {:ok, _instance_2} = TestServer.SSH.start() + + assert_raise RuntimeError, + ~r/Multiple instances running, please pass instance to `TestServer\.SSH\.channel\/1`/, + fn -> + TestServer.SSH.channel() + end + + assert {:ok, _channel} = TestServer.SSH.channel(instance_1, []) + + TestServer.SSH.stop(instance_1) + end + + test "with no channel up message received" do + defmodule NoChannelUpMessageTest do + use ExUnit.Case + + test "fails" do + {:ok, _channel} = TestServer.SSH.channel() + + assert {:ok, conn} = SSHClient.connect(TestServer.SSH.address()) + assert {:ok, _channel_id} = SSHClient.session_channel(conn) + end + end + + assert capture_io(fn -> ExUnit.run() end) =~ + "has channels that were not used:" + end + + test "when receiving unexpected channel up message" do + defmodule TooManyChannelUpMessagesTest do + use ExUnit.Case + + test "fails" do + {:ok, _instance} = TestServer.SSH.start(suppress_warning: true) + + assert {:ok, conn} = SSHClient.connect(TestServer.SSH.address()) + assert {:ok, channel_id_1} = SSHClient.session_channel(conn) + assert {:error, :closed} = SSHClient.exec(conn, channel_id_1, "ping") + end + end + + assert io = capture_io(fn -> ExUnit.run() end) + assert io =~ "No available channels" + refute io =~ "The following channels have been used:" + end + + test "when receiving unexpected channel up message after used channels" do + defmodule TooManyChannelUpMessagesAfterUsedTest do + use ExUnit.Case + + test "fails" do + {:ok, _instance} = TestServer.SSH.start(suppress_warning: true) + {:ok, channel_1} = TestServer.SSH.channel() + :ok = TestServer.SSH.handle(channel_1) + + assert {:ok, conn} = SSHClient.connect(TestServer.SSH.address()) + assert {:ok, channel_id_1} = SSHClient.session_channel(conn) + assert {:ok, channel_id_2} = SSHClient.session_channel(conn) + assert :ok = SSHClient.exec(conn, channel_id_1, "ping") + assert {:ok, %{data: "ping"}} = SSHClient.receive_until_closed(conn, channel_id_1) + assert {:error, :closed} = SSHClient.exec(conn, channel_id_2, "ping") + end + end + + assert io = capture_io(fn -> ExUnit.run() end) + assert io =~ "No available channels" + assert io =~ "The following channels have been used:" + end + + test "with `listen: :all` option" do + {:ok, channel} = TestServer.SSH.channel(listen: :all) + + :ok = + TestServer.SSH.handle(channel, + to: fn {:env, _channel_id, _want_reply, var, value}, state -> + assert var == "FOO" + assert value == "bar" + + {:ok, state} + end + ) + + :ok = + TestServer.SSH.handle(channel, + to: fn {:shell, _channel_id, _want_reply}, state -> + {:ok, state} + end + ) + + :ok = + TestServer.SSH.handle(channel, + to: fn {:pty, _channel_id, _want_reply, + { + terminal, + _char_width, + _row_height, + _pixel_width, + _pixel_height, + _terminal_modes + }}, + state -> + assert terminal == ~c"xterm-256color" + + {:ok, state} + end + ) + + :ok = + TestServer.SSH.handle(channel, + to: fn {:eof, _channel_id}, state -> + {:ok, state} + end + ) + + :ok = + TestServer.SSH.handle(channel, + to: fn {:data, _channel_id, _want_reply, "ping"}, state -> + {:reply, "pong", state} + end + ) + + assert {:ok, conn} = SSHClient.connect(TestServer.SSH.address()) + assert {:ok, channel_id} = SSHClient.session_channel(conn) + assert :success = :ssh_connection.setenv(conn, channel_id, ~c"FOO", ~c"bar", 1_000) + assert {:ok, conn, channel_id} = SSHClient.open_shell(conn, channel_id) + assert :success = :ssh_connection.ptty_alloc(conn, channel_id, term: ~c"xterm-256color") + assert :ok = :ssh_connection.send_eof(conn, channel_id) + assert :ok = SSHClient.send(conn, channel_id, "ping") + assert {:ok, "pong"} = SSHClient.receive_data(conn, channel_id) + assert SSHClient.close(conn, channel_id) == :ok + end + + test "with `:listen` option filtering messages" do + {:ok, channel} = TestServer.SSH.channel(listen: []) + + TestServer.SSH.handle(channel, + to: fn msg, _state -> + flunk("Handler should not be called for ignored message: #{inspect(msg)}") + end + ) + + assert {:ok, conn, channel_id} = ssh_shell() + assert :ok = SSHClient.send(conn, channel_id, "ping") + assert :ok = SSHClient.exec(conn, channel_id, "ping") + assert {:ok, %{data: nil}} = SSHClient.receive_until_closed(conn, channel_id) + assert :ok = SSHClient.close(conn, channel_id) + + TestServer.SSH.stop() + end + + test "with multiple channels" do + {:ok, channel_1} = TestServer.SSH.channel() + {:ok, channel_2} = TestServer.SSH.channel() + TestServer.SSH.handle(channel_1, to: fn _msg, state -> {:reply, "channel1", state} end) + TestServer.SSH.handle(channel_2, to: fn _msg, state -> {:reply, "channel2", state} end) + + {:ok, conn} = SSHClient.connect(TestServer.SSH.address()) + assert {:ok, channel_id_1} = SSHClient.session_channel(conn) + assert {:ok, channel_id_2} = SSHClient.session_channel(conn) + assert :ok = SSHClient.exec(conn, channel_id_1, "ping") + assert :ok = SSHClient.exec(conn, channel_id_2, "ping") + assert {:ok, %{data: "channel1"}} = SSHClient.receive_until_closed(conn, channel_id_1) + assert {:ok, %{data: "channel2"}} = SSHClient.receive_until_closed(conn, channel_id_2) + end + end + + describe "handle/2" do + test "when instance not running" do + {:ok, _instance} = TestServer.SSH.start() + {:ok, channel} = TestServer.SSH.channel() + :ok = TestServer.SSH.stop() + + assert_raise RuntimeError, + ~r/TestServer\.SSH\.Instance \#PID\<[0-9.]+\> is not running/, + fn -> + TestServer.SSH.handle(channel) + end + end + + test "with invalid options" do + {:ok, _instance} = TestServer.SSH.start() + {:ok, channel} = TestServer.SSH.channel() + + assert_raise BadFunctionError, ~r/expected a function, got: :invalid/, fn -> + TestServer.SSH.handle(channel, to: :invalid) + end + + assert_raise BadFunctionError, ~r/expected a function, got: :invalid/, fn -> + TestServer.SSH.handle(channel, match: :invalid) + end + + TestServer.SSH.stop() + end + + test "with no message received" do + defmodule NoMessageReceivedTest do + use ExUnit.Case + + test "fails" do + {:ok, _instance} = TestServer.SSH.start() + {:ok, channel} = TestServer.SSH.channel() + :ok = TestServer.SSH.handle(channel) + + assert {:ok, conn} = SSHClient.connect(TestServer.SSH.address()) + assert {:ok, _channel_id} = SSHClient.session_channel(conn) + end + end + + assert capture_io(fn -> ExUnit.run() end) =~ + "did not receive a message for these handlers before the test ended" + end + + test "when receiving unexpected message" do + defmodule UnexpectedMessageTest do + use ExUnit.Case + + test "fails" do + {:ok, _instance} = TestServer.SSH.start(suppress_warning: true) + {:ok, _channel} = TestServer.SSH.channel() + + assert {:ok, %{data: data, exit_status: 1}} = unquote(__MODULE__).ssh_exec("ping") + assert data =~ "received an unexpected SSH message" + end + end + + assert io = capture_io(fn -> ExUnit.run() end) + assert io =~ "received an unexpected SSH message" + refute io =~ "The following handlers have been processed:" + end + + test "when receiving unexpected message after processed handlers" do + defmodule UnexpectedMessageAfterProcessedTest do + use ExUnit.Case + + test "fails" do + {:ok, _instance} = TestServer.SSH.start(suppress_warning: true) + {:ok, channel} = TestServer.SSH.channel() + :ok = TestServer.SSH.handle(channel) + + assert {:ok, conn, channel_id} = unquote(__MODULE__).ssh_shell() + assert :ok = SSHClient.send(conn, channel_id, "first") + assert :ok = SSHClient.send(conn, channel_id, "second") + assert {:ok, "first"} = SSHClient.receive_data(conn, channel_id) + assert {:ok, message} = SSHClient.receive_data(conn, channel_id) + assert message =~ "received an unexpected SSH message" + assert message =~ "The following handlers have been processed:" + end + end + + assert io = capture_io(fn -> ExUnit.run() end) + assert io =~ "received an unexpected SSH message" + assert io =~ "The following handlers have been processed:" + end + + test "with `:to` 3-arity function raising exception" do + defmodule HandleTo3ArityFunctionRaiseTest do + use ExUnit.Case + + test "fails" do + {:ok, _instance} = TestServer.SSH.start(suppress_warning: true) + {:ok, channel} = TestServer.SSH.channel() + + :ok = + TestServer.SSH.handle(channel, to: fn _msg, _connection, _state -> raise "boom" end) + + assert {:ok, %{data: data, exit_status: 1}} = unquote(__MODULE__).ssh_exec("ping") + assert data =~ "(RuntimeError) boom" + end + end + + assert io = capture_io(fn -> ExUnit.run() end) + assert io =~ "(RuntimeError) boom" + assert io =~ "anonymous fn/3 in TestServer.SSHTest.HandleTo3ArityFunctionRaiseTest" + end + + test "with `:to` 3-arity function with invalid response" do + defmodule HandleTo3ArityFunctionInvalidResponseTest do + use ExUnit.Case + + test "fails" do + {:ok, _instance} = TestServer.SSH.start(suppress_warning: true) + {:ok, channel} = TestServer.SSH.channel() + :ok = TestServer.SSH.handle(channel, to: fn _msg, _connection, _state -> :invalid end) + + assert {:ok, %{data: data, exit_status: 1}} = unquote(__MODULE__).ssh_exec("ping") + assert data =~ "(RuntimeError) Invalid callback response, got: :invalid." + end + end + + assert io = capture_io(fn -> ExUnit.run() end) + assert io =~ "(RuntimeError) Invalid callback response, got: :invalid." + end + + test "with `:to` 3-arity function returning `{:ok, state}` response" do + {:ok, _instance} = TestServer.SSH.start() + {:ok, channel} = TestServer.SSH.channel() + + :ok = + TestServer.SSH.handle(channel, + to: fn {:data, channel_id, _want_reply, _data}, connection, state -> + :ssh_connection.adjust_window(connection, channel_id, byte_size("first")) + :ssh_connection.send(connection, channel_id, "first") + :ssh_connection.adjust_window(connection, channel_id, byte_size("second")) + :ssh_connection.send(connection, channel_id, 2, "second") + + {:ok, state} + end + ) + + assert {:ok, conn, channel_id} = ssh_shell() + assert :ok = SSHClient.send(conn, channel_id, "ping") + assert {:ok, "first"} = SSHClient.receive_data(conn, channel_id, 0) + assert {:ok, "second"} = SSHClient.receive_data(conn, channel_id, 2) + end + + test "with `:to` 3-arity function returning `{:stop, channel_id, state}` response" do + {:ok, _instance} = TestServer.SSH.start() + {:ok, channel} = TestServer.SSH.channel() + + :ok = + TestServer.SSH.handle(channel, + to: fn {:exec, channel_id, want_reply, _command}, connection, state -> + :ssh_connection.reply_request(connection, want_reply, :success, channel_id) + :ssh_connection.send(connection, channel_id, "first") + :ssh_connection.send(connection, channel_id, 2, "second") + :ssh_connection.close(connection, channel_id) + + {:stop, channel_id, state} + end + ) + + assert {:ok, + %{ + data: "firstsecond", + messages: [ + {:data, 0, 0, "first"}, + {:data, 0, 2, "second"}, + {:closed, 0} + ] + }} = ssh_exec("ping") + end + + test "with `:to` 2-arity function raising exception" do + defmodule HandleTo2ArityFunctionRaiseTest do + use ExUnit.Case + + test "fails" do + {:ok, _instance} = TestServer.SSH.start(suppress_warning: true) + {:ok, channel} = TestServer.SSH.channel() + :ok = TestServer.SSH.handle(channel, to: fn _msg, _state -> raise "boom" end) + + assert {:ok, %{data: data, exit_status: 1}} = unquote(__MODULE__).ssh_exec("ping") + assert data =~ "(RuntimeError) boom" + end + end + + assert io = capture_io(fn -> ExUnit.run() end) + assert io =~ "(RuntimeError) boom" + assert io =~ "anonymous fn/2 in TestServer.SSHTest.HandleTo2ArityFunctionRaiseTest" + end + + test "with `:to` 2-arity function with invalid response" do + defmodule HandleTo2ArityFunctionInvalidResponseTest do + use ExUnit.Case + + test "fails" do + {:ok, _instance} = TestServer.SSH.start(suppress_warning: true) + {:ok, channel} = TestServer.SSH.channel() + :ok = TestServer.SSH.handle(channel, to: fn _msg, _state -> :invalid end) + + assert {:ok, %{data: data, exit_status: 1}} = unquote(__MODULE__).ssh_exec("ping") + assert data =~ "(RuntimeError) Invalid callback response, got: :invalid." + end + end + + assert io = capture_io(fn -> ExUnit.run() end) + assert io =~ "(RuntimeError) Invalid callback response, got: :invalid." + end + + test "with `:to` 2-arity function with `{:reply, data, state}` response" do + {:ok, _instance} = TestServer.SSH.start() + {:ok, channel} = TestServer.SSH.channel() + + assert :ok = + TestServer.SSH.handle(channel, + to: fn _msg, state -> + {:reply, "function called", state} + end + ) + + assert {:ok, result} = ssh_exec("ping") + assert result.data == "function called" + assert result.exit_status == 0 + + assert result.messages == [ + {:data, 0, 0, "function called"}, + {:exit_status, 0, 0}, + {:eof, 0}, + {:closed, 0} + ] + end + + test "with `:to` 2-arity function with `{:reply, {data, options}, state}` response" do + test_pid = self() + {:ok, _instance} = TestServer.SSH.start() + + # Shell handling + {:ok, channel_1} = TestServer.SSH.channel() + + :ok = + TestServer.SSH.handle(channel_1, + to: fn _msg, state -> + # This is to ensure we receive this message before we send close + send(test_pid, :continue) + + {:reply, {"pong", exit_status: 127, data_type_code: 1}, state} + end + ) + + assert {:ok, conn, channel_id} = ssh_shell() + assert :ok = SSHClient.send(conn, channel_id, "ping") + assert_receive :continue + assert :ok = SSHClient.close(conn, channel_id) + assert {:ok, result} = SSHClient.receive_until_closed(conn, channel_id) + assert result.data == "pong" + refute result.exit_status + assert result.messages == [{:data, 0, 1, "pong"}, {:closed, 0}] + + # Exec handling + {:ok, channel_2} = TestServer.SSH.channel() + + :ok = + TestServer.SSH.handle(channel_2, + to: fn _msg, state -> {:reply, {"pong", exit_status: 127, data_type_code: 1}, state} end + ) + + assert {:ok, result} = ssh_exec("ping") + assert result.data == "pong" + assert result.exit_status == 127 + + assert result.messages == [ + {:data, 0, 1, "pong"}, + {:exit_status, 0, 127}, + {:eof, 0}, + {:closed, 0} + ] + end + + test "with `:to` 2-arity function with `{:reply, {data, options}, state}` response with invalid options" do + defmodule HandleTo2ArityFunctionInvalidOptionsTest do + use ExUnit.Case + + test "fails" do + {:ok, _instance} = TestServer.SSH.start(suppress_warning: true) + {:ok, channel} = TestServer.SSH.channel() + + :ok = + TestServer.SSH.handle(channel, + to: fn _msg, state -> + {:reply, {"pong", invalid: 1}, state} + end + ) + + assert {:ok, result} = unquote(__MODULE__).ssh_exec("ping") + + assert result.data =~ + "(RuntimeError) Invalid options in callback response, got: [invalid: 1]." + + assert result.exit_status == 0 + end + end + + assert io = capture_io(fn -> ExUnit.run() end) + assert io =~ "(RuntimeError) Invalid options in callback response, got: [invalid: 1]." + end + + test "with `:to` 2-arity function with `{:ok, state}` response" do + {:ok, _instance} = TestServer.SSH.start() + {:ok, channel} = TestServer.SSH.channel() + + :ok = + TestServer.SSH.handle(channel, + to: fn _msg, state -> {:ok, Map.put(state, :key, "value")} end + ) + + :ok = + TestServer.SSH.handle(channel, + to: fn _msg, state -> {:reply, Map.fetch!(state, :key), state} end + ) + + assert {:ok, conn, channel_id} = ssh_shell() + assert :ok = SSHClient.send(conn, channel_id, "ping") + assert :ok = SSHClient.send(conn, channel_id, "ping") + assert {:ok, "value"} = SSHClient.receive_data(conn, channel_id) + end + + test "when `:match` function raises exception" do + defmodule MatchFunctionRaiseTest do + use ExUnit.Case + + test "fails" do + {:ok, _instance} = TestServer.SSH.start(suppress_warning: true) + {:ok, channel} = TestServer.SSH.channel() + + :ok = + TestServer.SSH.handle(channel, + match: fn {:exec, _channel_id, _want_reply, _command}, _state -> + raise "boom" + end + ) + + assert {:ok, %{data: data}} = unquote(__MODULE__).ssh_exec("ping") + assert data =~ "(RuntimeError) boom" + end + end + + assert io = capture_io(fn -> ExUnit.run() end) + assert io =~ "(RuntimeError) boom" + assert io =~ "anonymous fn/2 in TestServer.SSHTest.MatchFunctionRaiseTest" + end + + test "with `:match` function" do + {:ok, _instance} = TestServer.SSH.start() + {:ok, channel} = TestServer.SSH.channel() + + :ok = + TestServer.SSH.handle(channel, + match: fn {:exec, _channel_id, _want_reply, command}, _state -> + to_string(command) == "ping" + end + ) + + assert {:ok, %{data: "ping"}} = ssh_exec("ping") + end + + test "with `:match` function filtering multiple handlers" do + {:ok, _instance} = TestServer.SSH.start() + {:ok, channel} = TestServer.SSH.channel() + + :ok = + TestServer.SSH.handle(channel, + match: fn {:data, _channel_id, _type, data}, _state -> data == "first" end, + to: fn _frame, state -> {:reply, "pong", state} end + ) + + :ok = + TestServer.SSH.handle(channel, + match: fn {:data, _channel_id, _type, data}, _state -> data == "second" end + ) + + assert {:ok, conn, channel_id} = ssh_shell() + assert :ok = SSHClient.send(conn, channel_id, "second") + assert {:ok, "second"} = SSHClient.receive_data(conn, channel_id) + assert :ok = SSHClient.send(conn, channel_id, "first") + assert {:ok, "pong"} = SSHClient.receive_data(conn, channel_id) + end + end + + describe "host_keys/0" do + test "when instance not running" do + assert_raise RuntimeError, "No current TestServer.SSH.Instance running", fn -> + TestServer.SSH.host_keys() + end + + assert_raise RuntimeError, + ~r/TestServer\.SSH\.Instance \#PID\<[0-9.]+\> is not running/, + fn -> + {:ok, instance} = TestServer.SSH.start() + + assert :ok = TestServer.SSH.stop() + + TestServer.SSH.host_keys(instance) + end + end + + test "when instance running with `:host_keys` function" do + TestServer.SSH.start( + host_keys: fn _algorithm, _daemon_options -> + {:ok, :public_key.generate_key({:rsa, 2_048, 65_537})} + end + ) + + assert_raise RuntimeError, + ~r/TestServer\.SSH\.Instance \#PID\<[0-9.]+\> is running with `\[host_keys: function\]` option/, + fn -> + TestServer.SSH.host_keys() + end + end + end + + def ssh_exec(command) do + assert {:ok, conn} = SSHClient.connect(TestServer.SSH.address()) + assert {:ok, channel_id} = SSHClient.session_channel(conn) + assert :ok = SSHClient.exec(conn, channel_id, command) + + SSHClient.receive_until_closed(conn, channel_id) + end + + def ssh_shell do + assert {:ok, conn} = SSHClient.connect(TestServer.SSH.address()) + assert {:ok, channel_id} = SSHClient.session_channel(conn) + + SSHClient.open_shell(conn, channel_id) + end + + defmodule SSHClient do + def write_user_dir_pem!(key, base_name, path) do + type = elem(key, 0) + pem_entry = :public_key.pem_entry_encode(type, key) + pem = :public_key.pem_encode([pem_entry]) + File.write!(Path.join(path, "#{base_name}"), pem) + + path + end + + def sockname(conn) do + {:sockname, res} = :ssh.connection_info(conn, :sockname) + + res + end + + def connect({host, port}, options \\ []) do + options = + [ + silently_accept_hosts: true, + user_interaction: false + ] + |> Keyword.merge(options) + |> normalize_ssh_connect_options() + + :logger.add_primary_filter( + :suppress_ssh_log_messages, + {fn + %{level: :debug, msg: {:string, ~c"client will use strict KEX ordering"}}, _extra -> + :stop + + %{level: :notice, msg: {:report, %{report: msg}}}, _extra -> + (to_string(msg) =~ "Ssh login attempt to" && :stop) || :ignore + + _event, _extra -> + :ignore + end, []} + ) + + host + |> to_charlist() + |> :ssh.connect(port, options) + |> handle_resp() + |> case do + {:ok, conn} -> + on_exit(fn -> close(conn) end) + + {:ok, conn} + + {:error, reason} -> + {:error, reason} + end + end + + defp normalize_ssh_connect_options(options) do + Enum.reduce(~w(user password auth_methods user_dir)a, options, fn key, options -> + case Keyword.has_key?(options, key) do + true -> Keyword.update!(options, key, &String.to_charlist/1) + false -> options + end + end) + end + + defp handle_resp({:ok, conn}), do: {:ok, conn} + defp handle_resp({:error, reason}) when is_list(reason), do: {:error, to_string(reason)} + defp handle_resp({:error, reason}), do: {:error, reason} + + def open_shell(conn, channel_id) do + case :ssh_connection.shell(conn, channel_id) do + :ok -> {:ok, conn, channel_id} + :failure -> {:error, :failure} + {:error, :timeout} -> {:error, :timeout} + end + end + + def session_channel(conn, timeout \\ 500) do + conn + |> :ssh_connection.session_channel(timeout) + |> handle_resp() + end + + def send(conn, channel_id, data) do + case :ssh_connection.send(conn, channel_id, data) do + :ok -> :ok + {:error, reason} -> {:error, reason} + end + end + + def exec(conn, channel_id, command, timeout \\ 500) do + case :ssh_connection.exec(conn, channel_id, command, timeout) do + :success -> :ok + :failure -> {:error, :failure} + {:error, reason} -> {:error, reason} + end + end + + def receive_until_closed( + conn, + channel_id, + state \\ %{data: nil, exit_status: nil, messages: []} + ) do + receive do + {:ssh_cm, ^conn, {:exit_status, ^channel_id, status} = msg} -> + receive_until_closed(conn, channel_id, %{ + state + | exit_status: status, + messages: state.messages ++ [msg] + }) + + {:ssh_cm, ^conn, {:data, ^channel_id, _want_reply, data} = msg} -> + data = (state.data || "") <> data + + receive_until_closed(conn, channel_id, %{ + state + | data: data, + messages: state.messages ++ [msg] + }) + + {:ssh_cm, ^conn, {:eof, ^channel_id} = msg} -> + receive_until_closed(conn, channel_id, %{state | messages: state.messages ++ [msg]}) + + {:ssh_cm, ^conn, {:closed, ^channel_id} = msg} -> + {:ok, %{state | messages: state.messages ++ [msg]}} + after + 500 -> {:error, :timeout} + end + end + + def receive_data(conn, channel_id, type \\ 0) do + assert_receive {:ssh_cm, ^conn, {:data, ^channel_id, ^type, data}} + + {:ok, to_string(data)} + end + + def close(conn) do + :ssh.close(conn) + end + + def close(conn, channel_id) do + :ssh_connection.close(conn, channel_id) + end + end +end