diff --git a/lib/off_broadway/emqqt/broker.ex b/lib/off_broadway/emqqt/broker.ex index 50be09e..733a44f 100644 --- a/lib/off_broadway/emqqt/broker.ex +++ b/lib/off_broadway/emqqt/broker.ex @@ -3,10 +3,14 @@ defmodule OffBroadway.EMQTT.Broker do require Logger def start_link(opts) do - client_id = get_in(opts, [:config, :clientid]) - GenServer.start_link(__MODULE__, opts, name: :"#{__MODULE__}-#{client_id}") + name = get_in(opts, [:config, :name]) + GenServer.start_link(__MODULE__, opts, name: :"#{__MODULE__}-#{name}") end + def stop_emqtt(pid), do: GenServer.cast(pid, :stop_emqtt) + def pause_emqtt(pid), do: GenServer.cast(pid, :pause_emqtt) + def resume_emqtt(pid), do: GenServer.cast(pid, :resume_emqtt) + @impl true def init(args) do with {:ok, config} <- Keyword.fetch(args, :config), @@ -27,6 +31,7 @@ defmodule OffBroadway.EMQTT.Broker do ets_table: String.to_existing_atom(client_id), emqtt: emqtt, emqtt_ref: Process.monitor(emqtt), + emqtt_config: config, topics: topics, topic_subscriptions: [] }, {:continue, :create_ets_table}} @@ -81,6 +86,32 @@ defmodule OffBroadway.EMQTT.Broker do {:noreply, state} end + def handle_info({:DOWN, ref, :process, _, :normal}, state) when ref == state.emqtt_ref, do: {:noreply, state} + + def handle_info({:DOWN, ref, :process, _, _reason}, state) when ref == state.emqtt_ref do + {:ok, pid} = :emqtt.start_link(state.emqtt_config) + {:ok, _props} = :emqtt.connect(pid) + {:noreply, %{state | emqtt: pid, emqtt_ref: Process.monitor(pid)}, {:continue, :subscribe_to_topics}} + end + + def handle_info({:EXIT, _, _reason}, state), do: {:noreply, state} + + @impl true + def handle_cast(:stop_emqtt, state) do + if Process.alive?(state.emqtt), do: :ok = :emqtt.stop(state.emqtt) + {:noreply, state} + end + + def handle_cast(:pause_emqtt, state) do + :ok = :emqtt.pause(state.emqtt) + {:noreply, state} + end + + def handle_cast(:resume_emqtt, state) do + :ok = :emqtt.resume(state.emqtt) + {:noreply, state} + end + @impl true def terminate(_reason, state) do Process.demonitor(state.emqtt_ref) diff --git a/lib/off_broadway/emqqt/producer.ex b/lib/off_broadway/emqqt/producer.ex index afff155..96cdbc8 100644 --- a/lib/off_broadway/emqqt/producer.ex +++ b/lib/off_broadway/emqqt/producer.ex @@ -95,7 +95,7 @@ defmodule OffBroadway.EMQTT.Producer do end end - @impl true + @impl Producer def prepare_for_start(_module, broadway_opts) do {producer_module, client_opts} = broadway_opts[:producer][:module] @@ -132,6 +132,13 @@ defmodule OffBroadway.EMQTT.Producer do end end + @impl Producer + def prepare_for_draining(%{receive_timer: timer} = state) do + timer && Process.cancel_timer(timer) + Broker.stop_emqtt(state.emqtt) + {:noreply, [], %{state | drain: true, receive_timer: nil}} + end + @spec emqtt_process_name(String.t()) :: atom() def emqtt_process_name(client_id), do: String.to_atom(client_id) diff --git a/mix.exs b/mix.exs index 644f87b..c3642f7 100644 --- a/mix.exs +++ b/mix.exs @@ -58,7 +58,7 @@ defmodule OffBroadway.EMQTT.MixProject do {:cowlib, "~> 2.13", override: true}, {:ex_doc, "~> 0.34.2", only: [:dev, :test], runtime: false}, {:credo, "~> 1.7", only: :dev}, - {:dialyxir, "~> 1.4", only: :dev}, + {:dialyxir, "~> 1.4", only: :dev, runtime: false}, {:excoveralls, "~> 0.18", only: :test} ] end diff --git a/test/off_broadway/emqtt/producer_test.exs b/test/off_broadway/emqtt/producer_test.exs index 8d1f58b..987ce76 100644 --- a/test/off_broadway/emqtt/producer_test.exs +++ b/test/off_broadway/emqtt/producer_test.exs @@ -156,5 +156,14 @@ defmodule OffBroadway.EMQTT.ProducerTest do :telemetry.detach("telemetry-events") stop_process(pid) end + + test "stops the emqtt server when draining" do + {:ok, pid} = start_broadway(nil, unique_name(), @broadway_opts ++ [topics: [{"#", :at_least_once}]]) + Broadway.stop(pid, :normal) + + # Make sure to not kill the producer before it can respond + Process.sleep(10) + stop_process(pid) + end end end