From 9e47d31f7961d168ec75711532ac0e5eb7e9b426 Mon Sep 17 00:00:00 2001 From: JianBo He Date: Wed, 9 Dec 2020 19:17:51 +0800 Subject: [PATCH 1/4] refactor(stomp): avoid anonymous functions --- apps/emqx_stomp/src/emqx_stomp_connection.erl | 85 +++++---- apps/emqx_stomp/src/emqx_stomp_frame.erl | 64 ++++--- apps/emqx_stomp/src/emqx_stomp_heartbeat.erl | 126 ++++++------- apps/emqx_stomp/src/emqx_stomp_protocol.erl | 165 +++++++++++++----- apps/emqx_stomp/test/emqx_stomp_SUITE.erl | 6 +- .../test/emqx_stomp_heartbeat_SUITE.erl | 53 ++++++ 6 files changed, 316 insertions(+), 183 deletions(-) create mode 100644 apps/emqx_stomp/test/emqx_stomp_heartbeat_SUITE.erl diff --git a/apps/emqx_stomp/src/emqx_stomp_connection.erl b/apps/emqx_stomp/src/emqx_stomp_connection.erl index b6c05ac68..dc1977944 100644 --- a/apps/emqx_stomp/src/emqx_stomp_connection.erl +++ b/apps/emqx_stomp/src/emqx_stomp_connection.erl @@ -33,8 +33,11 @@ , terminate/2 ]). +%% for protocol +-export([send/4, heartbeat/2]). + -record(stomp_client, {transport, socket, peername, conn_name, conn_state, - await_recv, rate_limit, parse_fun, proto_state, + await_recv, rate_limit, parser, proto_state, proto_env, heartbeat}). -define(INFO_KEYS, [peername, await_recv, conn_state]). @@ -55,9 +58,12 @@ init([Transport, Sock, ProtoEnv]) -> {ok, NewSock} -> {ok, Peername} = Transport:ensure_ok_or_exit(peername, [NewSock]), ConnName = esockd:format(Peername), - SendFun = send_fun(Transport, Sock), - ParseFun = emqx_stomp_frame:parser(ProtoEnv), - ProtoState = emqx_stomp_protocol:init(Peername, SendFun, ProtoEnv), + SendFun = {fun ?MODULE:send/4, [Transport, Sock, self()]}, + HrtBtFun = {fun ?MODULE:heartbeat/2, [Transport, Sock]}, + Parser = emqx_stomp_frame:init_parer_state(ProtoEnv), + ProtoState = emqx_stomp_protocol:init(#{peername => Peername, + sendfun => SendFun, + heartfun => HrtBtFun}, ProtoEnv), RateLimit = init_rate_limit(proplists:get_value(rate_limit, ProtoEnv)), State = run_socket(#stomp_client{transport = Transport, socket = NewSock, @@ -66,7 +72,7 @@ init([Transport, Sock, ProtoEnv]) -> conn_state = running, await_recv = false, rate_limit = RateLimit, - parse_fun = ParseFun, + parser = Parser, proto_env = ProtoEnv, proto_state = ProtoState}), gen_server:enter_loop(?MODULE, [{hibernate_after, 5000}], State, 20000); @@ -79,17 +85,17 @@ init_rate_limit(undefined) -> init_rate_limit({Rate, Burst}) -> esockd_rate_limit:new(Rate, Burst). -send_fun(Transport, Sock) -> - Self = self(), - fun(Data) -> - try Transport:async_send(Sock, Data) of - ok -> ok; - {error, Reason} -> Self ! {shutdown, Reason} - catch - error:Error -> Self ! {shutdown, Error} - end +send(Data, Transport, Sock, ConnPid) -> + try Transport:async_send(Sock, Data) of + ok -> ok; + {error, Reason} -> ConnPid ! {shutdown, Reason} + catch + error:Error -> ConnPid ! {shutdown, Error} end. +heartbeat(Transport, Sock) -> + Transport:send(Sock, <<$\n>>). + handle_call(info, _From, State = #stomp_client{transport = Transport, socket = Sock, peername = Peername, @@ -124,16 +130,19 @@ handle_info({transaction, {timeout, Id}}, State) -> emqx_stomp_transaction:timeout(Id), noreply(State); -handle_info({heartbeat, start, {Cx, Cy}}, - State = #stomp_client{transport = Transport, socket = Sock}) -> - Self = self(), - Incomming = {Cx, statfun(recv_oct, State), fun() -> Self ! {heartbeat, timeout} end}, - Outgoing = {Cy, statfun(send_oct, State), fun() -> Transport:send(Sock, <<$\n>>) end}, - {ok, HbProc} = emqx_stomp_heartbeat:start_link(Incomming, Outgoing), - noreply(State#stomp_client{heartbeat = HbProc}); +handle_info({timeout, TRef, TMsg}, State) when TMsg =:= incoming; + TMsg =:= outgoing -> -handle_info({heartbeat, timeout}, State) -> - stop({shutdown, heartbeat_timeout}, State); + Stat = case TMsg of + incoming -> recv_oct; + _ -> send_oct + end, + case getstat(Stat, State) of + {ok, Val} -> + with_proto(timeout, [TRef, {TMsg, Val}], State); + {error, Reason} -> + shutdown({sock_error, Reason}, State) + end; handle_info({'EXIT', HbProc, Error}, State = #stomp_client{heartbeat = HbProc}) -> stop(Error, State); @@ -186,14 +195,24 @@ code_change(_OldVsn, State, _Extra) -> %% Receive and Parse data %%-------------------------------------------------------------------- +with_proto(Fun, Args, State = #stomp_client{proto_state = ProtoState}) -> + case erlang:apply(emqx_stomp_protocol, Fun, Args ++ [ProtoState]) of + {ok, NProtoState} -> + noreply(State#stomp_client{proto_state = NProtoState}); + {F, Reason, NProtoState} when F == stop; + F == error; + F == shutdown -> + shutdown(Reason, State#stomp_client{proto_state = NProtoState}) + end. + received(<<>>, State) -> noreply(State); -received(Bytes, State = #stomp_client{parse_fun = ParseFun, +received(Bytes, State = #stomp_client{parser = Parser, proto_state = ProtoState}) -> - try ParseFun(Bytes) of - {more, NewParseFun} -> - noreply(State#stomp_client{parse_fun = NewParseFun}); + try emqx_stomp_frame:parse(Bytes, Parser) of + {more, NewParser} -> + noreply(State#stomp_client{parser = NewParser}); {ok, Frame, Rest} -> ?LOG(info, "RECV Frame: ~s", [emqx_stomp_frame:format(Frame)], State), case emqx_stomp_protocol:received(Frame, ProtoState) of @@ -216,7 +235,7 @@ received(Bytes, State = #stomp_client{parse_fun = ParseFun, end. reset_parser(State = #stomp_client{proto_env = ProtoEnv}) -> - State#stomp_client{parse_fun = emqx_stomp_frame:parser(ProtoEnv)}. + State#stomp_client{parser = emqx_stomp_frame:init_parer_state(ProtoEnv)}. rate_limit(_Size, State = #stomp_client{rate_limit = undefined}) -> run_socket(State); @@ -238,12 +257,10 @@ run_socket(State = #stomp_client{transport = Transport, socket = Sock}) -> Transport:async_recv(Sock, 0, infinity), State#stomp_client{await_recv = true}. -statfun(Stat, #stomp_client{transport = Transport, socket = Sock}) -> - fun() -> - case Transport:getstat(Sock, [Stat]) of - {ok, [{Stat, Val}]} -> {ok, Val}; - {error, Error} -> {error, Error} - end +getstat(Stat, #stomp_client{transport = Transport, socket = Sock}) -> + case Transport:getstat(Sock, [Stat]) of + {ok, [{Stat, Val}]} -> {ok, Val}; + {error, Error} -> {error, Error} end. noreply(State) -> diff --git a/apps/emqx_stomp/src/emqx_stomp_frame.erl b/apps/emqx_stomp/src/emqx_stomp_frame.erl index e5795335a..5d5ddffe1 100644 --- a/apps/emqx_stomp/src/emqx_stomp_frame.erl +++ b/apps/emqx_stomp/src/emqx_stomp_frame.erl @@ -70,7 +70,8 @@ -include("emqx_stomp.hrl"). --export([ parser/1 +-export([ init_parer_state/1 + , parse/2 , serialize/1 ]). @@ -91,14 +92,18 @@ -record(frame_limit, {max_header_num, max_header_length, max_body_length}). --type(parser() :: fun((binary()) -> {ok, stomp_frame(), binary()} - | {more, parser()} - | {error, any()})). +-type(result() :: {ok, stomp_frame(), binary()} + | {more, parser()} + | {error, any()}). + +-type(parser() :: #{phase := none | command | headers | hdname | hdvalue | body, + pre => binary(), + state := #parser_state{}}). %% @doc Initialize a parser --spec parser([proplists:property()]) -> parser(). -parser(Opts) -> - fun(Bin) -> parse(none, Bin, #parser_state{limit = limit(Opts)}) end. +-spec init_parer_state([proplists:property()]) -> parser(). +init_parer_state(Opts) -> + #{phase => none, state => #parser_state{limit = limit(Opts)}}. limit(Opts) -> #frame_limit{max_header_num = g(max_header_num, Opts, ?MAX_HEADER_NUM), @@ -108,29 +113,31 @@ limit(Opts) -> g(Key, Opts, Val) -> proplists:get_value(Key, Opts, Val). -%% @doc Parse frame --spec(parse(Phase :: atom(), binary(), #parser_state{}) -> - {ok, stomp_frame(), binary()} | {more, parser()} | {error, any()}). -parse(none, <<>>, State) -> - {more, fun(Bin) -> parse(none, Bin, State) end}; -parse(none, <>, State) -> - parse(none, Bin, State); -parse(none, Bin, State) -> - parse(command, Bin, State); +-spec parse(binary(), parser()) -> result(). +parse(<<>>, Parser) -> + {more, Parser}; -parse(Phase, <<>>, State) -> - {more, fun(Bin) -> parse(Phase, Bin, State) end}; -parse(Phase, <>, State) -> - {more, fun(Bin) -> parse(Phase, <>, State) end}; -parse(Phase, <>, State) -> +parse(Bytes, #{phase := body, len := Len, state := State}) -> + parse(body, Bytes, State, Len); + +parse(Bytes, Parser = #{pre := Pre}) -> + parse(<
>, maps:without([pre], Parser));
+parse(<>, #{phase := Phase, state := State}) ->
     parse(Phase, <>, State);
-parse(_Phase, <>, _State) ->
+parse(<>, Parser) ->
+    {more, Parser#{pre => <>}};
+parse(<>, _Parser) ->
     {error, linefeed_expected};
-parse(Phase, <>, State) when Phase =:= hdname; Phase =:= hdvalue ->
-    {more, fun(Bin) -> parse(Phase, <>, State) end};
-parse(Phase, <>, State) when Phase =:= hdname; Phase =:= hdvalue ->
+
+parse(<>, Parser = #{phase := Phase}) when Phase =:= hdname; Phase =:= hdvalue ->
+    {more, Parser#{pre => <>}};
+parse(<>, #{phase := Phase, state := State}) when Phase =:= hdname; Phase =:= hdvalue ->
     parse(Phase, Rest, acc(unescape(Ch), State));
 
+parse(Bytes, #{phase := none, state := State}) ->
+    parse(command, Bytes, State).
+
+%% @private
 parse(command, <>, State = #parser_state{acc = Acc}) ->
     parse(headers, Rest, State#parser_state{cmd = Acc, acc = <<>>});
 parse(command, <>, State) ->
@@ -153,20 +160,21 @@ parse(hdvalue, <>, State = #parser_state{headers = Headers, hd
 parse(hdvalue, <>, State) ->
     parse(hdvalue, Rest, acc(Ch, State)).
 
+%% @private
 parse(body, <<>>, State, Length) ->
-    {more, fun(Bin) -> parse(body, Bin, State, Length) end};
+    {more, #{phase => body, length => Length, state => State}};
 parse(body, Bin, State, none) ->
     case binary:split(Bin, <>) of
         [Chunk, Rest] ->
             {ok, new_frame(acc(Chunk, State)), Rest};
         [Chunk] ->
-            {more, fun(More) -> parse(body, More, acc(Chunk, State), none) end}
+            {more, #{phase => body, length => none, state => acc(Chunk, State)}}
     end;
 parse(body, Bin, State, Len) when byte_size(Bin) >= (Len+1) ->
     <> = Bin,
     {ok, new_frame(acc(Chunk, State)), Rest};
 parse(body, Bin, State, Len) ->
-    {more, fun(More) -> parse(body, More, acc(Bin, State), Len - byte_size(Bin)) end}.
+    {more, #{phase => body, length => Len - byte_size(Bin), state => acc(Bin, State)}}.
 
 add_header(Name, Value, Headers) ->
     case lists:keyfind(Name, 1, Headers) of
diff --git a/apps/emqx_stomp/src/emqx_stomp_heartbeat.erl b/apps/emqx_stomp/src/emqx_stomp_heartbeat.erl
index 22e1c3eb3..79cc8f435 100644
--- a/apps/emqx_stomp/src/emqx_stomp_heartbeat.erl
+++ b/apps/emqx_stomp/src/emqx_stomp_heartbeat.erl
@@ -19,88 +19,74 @@
 
 -include("emqx_stomp.hrl").
 
--export([ start_link/2
-        , stop/1
-        ]).
-
-%% callback
 -export([ init/1
-        , loop/3
+        , check/3
+        , info/1
+        , interval/2
         ]).
 
--define(MAX_REPEATS, 1).
+-record(heartbeater, {interval, statval, repeat}).
 
--record(heartbeater, {name, cycle, tref, val, statfun, action, repeat = 0}).
+-type name() :: incoming | outgoing.
 
-start_link({0, _, _}, {0, _, _}) ->
-    {ok, none};
+-type heartbeat() :: #{incoming => #heartbeater{},
+                       outgoing => #heartbeater{}
+                      }.
 
-start_link(Incoming, Outgoing) ->
-    Params = [self(), Incoming, Outgoing],
-    {ok, spawn_link(?MODULE, init, [Params])}.
 
-stop(Pid) ->
-    Pid ! stop.
+%%--------------------------------------------------------------------
+%% APIs
+%%--------------------------------------------------------------------
 
-init([Parent, Incoming, Outgoing]) ->
-    loop(Parent, heartbeater(incomming, Incoming), heartbeater(outgoing,  Outgoing)).
+-spec init({non_neg_integer(), non_neg_integer()}) -> heartbeat().
+init({0, 0}) ->
+    #{};
+init({Cx, Cy}) ->
+    maps:filter(fun(_, V) -> V /= undefined end,
+      #{incoming => heartbeater(Cx),
+        outgoing => heartbeater(Cy)
+       }).
 
-heartbeater(_, {0, _, _}) ->
+heartbeater(0) ->
     undefined;
+heartbeater(I) ->
+    #heartbeater{
+       interval = I,
+       statval = 0,
+       repeat = 0
+      }.
 
-heartbeater(InOut, {Cycle, StatFun, ActionFun}) ->
-    {ok, Val} = StatFun(),
-    #heartbeater{name = InOut, cycle = Cycle,
-                 tref = timer(InOut, Cycle),
-                 val = Val, statfun = StatFun,
-                 action = ActionFun}.
-
-loop(Parent, Incomming, Outgoing) ->
-    receive
-        {heartbeat, incomming} ->
-            #heartbeater{val = LastVal, statfun = StatFun,
-                         action = Action, repeat = Repeat} = Incomming,
-            case StatFun() of
-                {ok, Val} ->
-                    if Val =/= LastVal ->
-                           hibernate([Parent, resume(Incomming, Val), Outgoing]);
-                       Repeat < ?MAX_REPEATS ->
-                           hibernate([Parent, resume(Incomming, Val, Repeat+1), Outgoing]);
-                       true ->
-                           Action()
-                    end;
-                {error, Error} -> %% einval
-                    exit({shutdown, Error})
-            end;
-        {heartbeat, outgoing}  ->
-            #heartbeater{val = LastVal, statfun = StatFun, action = Action} = Outgoing,
-            case StatFun() of
-                {ok, Val} ->
-                    if Val =:= LastVal ->
-                           Action(), {ok, NewVal} = StatFun(),
-                           hibernate([Parent, Incomming, resume(Outgoing, NewVal)]);
-                       true ->
-                           hibernate([Parent, Incomming, resume(Outgoing, Val)])
-                    end;
-                {error, Error} -> %% einval
-                    exit({shutdown, Error})
-            end;
-        stop ->
-            ok;
-        _Other ->
-            loop(Parent, Incomming, Outgoing)
+-spec check(name(), pos_integer(), heartbeat())
+    -> {ok, heartbeat()}
+     | {error, timeout}.
+check(Name, NewVal, HrtBt) ->
+    HrtBter = maps:get(Name, HrtBt),
+    case check(NewVal, HrtBter) of
+        {error, _} = R -> R;
+        {ok, NHrtBter} ->
+            {ok, HrtBt#{Name => NHrtBter}}
     end.
 
-resume(Hb, NewVal) ->
-    resume(Hb, NewVal, 0).
-resume(Hb = #heartbeater{name = InOut, cycle = Cycle}, NewVal, Repeat) ->
-    Hb#heartbeater{tref = timer(InOut, Cycle), val = NewVal, repeat = Repeat}.
+check(NewVal, HrtBter = #heartbeater{statval = OldVal,
+                                     repeat = Repeat}) ->
+    if
+        NewVal =/= OldVal ->
+            {ok, HrtBter#heartbeater{statval = NewVal, repeat = 0}};
+        Repeat < 1 ->
+            {ok, HrtBter#heartbeater{repeat = Repeat + 1}};
+        true -> {error, timeout}
+    end.
 
-timer(_InOut, 0) ->
-    undefined;
-timer(InOut, Cycle) ->
-    erlang:send_after(Cycle, self(), {heartbeat, InOut}).
-
-hibernate(Args) ->
-    erlang:hibernate(?MODULE, loop, Args).
+-spec info(heartbeat()) -> map().
+info(HrtBt) ->
+    maps:map(fun(_, #heartbeater{interval = Intv,
+                                 statval = Val,
+                                 repeat = Repeat}) ->
+            #{interval => Intv, statval => Val, repeat => Repeat}
+             end, HrtBt).
 
+interval(Type, HrtBt) ->
+    case maps:get(Type, HrtBt, undefined) of
+        undefined -> undefined;
+        #heartbeater{interval = Intv} -> Intv
+    end.
diff --git a/apps/emqx_stomp/src/emqx_stomp_protocol.erl b/apps/emqx_stomp/src/emqx_stomp_protocol.erl
index fa75f08e3..4834955a2 100644
--- a/apps/emqx_stomp/src/emqx_stomp_protocol.erl
+++ b/apps/emqx_stomp/src/emqx_stomp_protocol.erl
@@ -18,50 +18,58 @@
 -module(emqx_stomp_protocol).
 
 -include("emqx_stomp.hrl").
+
 -include_lib("emqx/include/emqx.hrl").
+-include_lib("emqx/include/logger.hrl").
 -include_lib("emqx/include/emqx_mqtt.hrl").
 
+-logger_header("[Stomp-Proto]").
+
 -import(proplists, [get_value/2, get_value/3]).
 
 %% API
--export([ init/3
+-export([ init/2
         , info/1
         ]).
 
 -export([ received/2
         , send/2
         , shutdown/2
+        , timeout/3
         ]).
 
--record(stomp_proto, {peername,
-                      sendfun,
-                      connected = false,
-                      proto_ver,
-                      proto_name,
-                      heart_beats,
-                      login,
-                      allow_anonymous,
-                      default_user,
-                      subscriptions = []}).
+-record(stomp_proto, {
+          peername,
+          heartfun,
+          sendfun,
+          connected = false,
+          proto_ver,
+          proto_name,
+          heart_beats,
+          login,
+          allow_anonymous,
+          default_user,
+          subscriptions = [],
+          timers :: #{atom() => disable | undefined | reference()}
+         }).
+
+-define(TIMER_TABLE, #{
+          incoming_timer => incoming,
+          outgoing_timer => outgoing
+        }).
 
 -type(stomp_proto() :: #stomp_proto{}).
 
--define(LOG(Level, Format, Args, State),
-        emqx_logger:Level("Stomp(~s): " ++ Format, [esockd:format(State#stomp_proto.peername) | Args])).
-
--define(record_to_proplist(Def, Rec),
-        lists:zip(record_info(fields, Def), tl(tuple_to_list(Rec)))).
-
--define(record_to_proplist(Def, Rec, Fields),
-    [{K, V} || {K, V} <- ?record_to_proplist(Def, Rec),
-                         lists:member(K, Fields)]).
-
 %% @doc Init protocol
-init(Peername, SendFun, Env) ->
+init(#{peername := Peername,
+       sendfun := SendFun,
+       heartfun := HeartFun}, Env) ->
     AllowAnonymous = get_value(allow_anonymous, Env, false),
     DefaultUser = get_value(default_user, Env),
 	#stomp_proto{peername = Peername,
+                 heartfun = HeartFun,
                  sendfun = SendFun,
+                 timers = #{},
                  allow_anonymous = AllowAnonymous,
                  default_user = DefaultUser}.
 
@@ -78,9 +86,10 @@ info(#stomp_proto{connected     = Connected,
      {login, Login},
      {subscriptions, Subscriptions}].
 
--spec(received(stomp_frame(), stomp_proto()) -> {ok, stomp_proto()}
-                                              | {error, any(), stomp_proto()}
-                                              | {stop, any(), stomp_proto()}).
+-spec(received(stomp_frame(), stomp_proto())
+    -> {ok, stomp_proto()}
+     | {error, any(), stomp_proto()}
+     | {stop, any(), stomp_proto()}).
 received(Frame = #stomp_frame{command = <<"STOMP">>}, State) ->
     received(Frame#stomp_frame{command = <<"CONNECT">>}, State);
 
@@ -92,12 +101,11 @@ received(#stomp_frame{command = <<"CONNECT">>, headers = Headers},
             Passc = header(<<"passcode">>, Headers),
             case check_login(Login, Passc, AllowAnonymous, DefaultUser) of
                 true ->
-                    Heartbeats = header(<<"heart-beat">>, Headers, <<"0,0">>),
-                    self() ! {heartbeat, start, parse_heartbeats(Heartbeats)},
-                    NewState = State#stomp_proto{connected = true, proto_ver = Version,
-                                                 heart_beats = Heartbeats, login = Login},
+                    Heartbeats = parse_heartbeats(header(<<"heart-beat">>, Headers, <<"0,0">>)),
+                    NState = start_heartbeart_timer(Heartbeats, State#stomp_proto{connected = true,
+                                                                                  proto_ver = Version, login = Login}),
                     send(connected_frame([{<<"version">>, Version},
-                                          {<<"heart-beat">>, reverse_heartbeats(Heartbeats)}]), NewState);
+                                          {<<"heart-beat">>, reverse_heartbeats(Heartbeats)}]), NState);
                 false ->
                     send(error_frame(undefined, <<"Login or passcode error!">>), State),
                     {error, login_or_passcode_error, State}
@@ -206,8 +214,8 @@ received(#stomp_frame{command = <<"BEGIN">>, headers = Headers}, State) ->
 received(#stomp_frame{command = <<"COMMIT">>, headers = Headers}, State) ->
     Id = header(<<"transaction">>, Headers),
     case emqx_stomp_transaction:commit(Id, State) of
-        {ok, NewState} ->
-            maybe_send_receipt(receipt_id(Headers), NewState);
+        {ok, NState} ->
+            maybe_send_receipt(receipt_id(Headers), NState);
         {error, not_found} ->
             send(error_frame(receipt_id(Headers), ["Transaction ", Id, " not found"]), State)
     end;
@@ -248,17 +256,40 @@ send(Msg = #message{topic = Topic, headers = Headers, payload = Payload},
                                  body = Payload},
             send(Frame, State);
         false ->
-            ?LOG(error, "Stomp dropped: ~p", [Msg], State),
+            ?LOG(error, "Stomp dropped: ~p", [Msg]),
             {error, dropped, State}
     end;
 
-send(Frame, State = #stomp_proto{sendfun = SendFun}) ->
-    ?LOG(info, "SEND Frame: ~s", [emqx_stomp_frame:format(Frame)], State),
+send(Frame, State = #stomp_proto{sendfun = {Fun, Args}}) ->
+    ?LOG(info, "SEND Frame: ~s", [emqx_stomp_frame:format(Frame)]),
     Data = emqx_stomp_frame:serialize(Frame),
-    ?LOG(debug, "SEND ~p", [Data], State),
-    SendFun(Data),
+    ?LOG(debug, "SEND ~p", [Data]),
+    erlang:apply(Fun, [Data] ++ Args),
     {ok, State}.
 
+shutdown(_Reason, _State) ->
+    ok.
+
+timeout(_TRef, {incoming, NewVal},
+        State = #stomp_proto{heart_beats = HrtBt}) ->
+    case emqx_stomp_heartbeat:check(incoming, NewVal, HrtBt) of
+        {error, timeout} ->
+            {shutdown, heartbeat_timeout, State};
+        {ok, NHrtBt} ->
+            {ok, reset_timer(incoming_timer, State#stomp_proto{heart_beats = NHrtBt})}
+    end;
+
+timeout(_TRef, {outgoing, NewVal},
+        State = #stomp_proto{heart_beats = HrtBt,
+                             heartfun = {Fun, Args}}) ->
+    case emqx_stomp_heartbeat:check(outgoing, NewVal, HrtBt) of
+        {error, timeout} ->
+            _ = erlang:apply(Fun, Args),
+            {ok, State};
+        {ok, NHrtBt} ->
+            {ok, reset_timer(outgoing_timer, State#stomp_proto{heart_beats = NHrtBt})}
+    end.
+
 negotiate_version(undefined) ->
     {ok, <<"1.0">>};
 negotiate_version(Accepts) ->
@@ -322,17 +353,6 @@ error_frame(Headers, undefined, Msg) ->
 error_frame(Headers, ReceiptId, Msg) ->
     emqx_stomp_frame:make(<<"ERROR">>, [{<<"receipt-id">>, ReceiptId} | Headers], Msg).
 
-parse_heartbeats(Heartbeats) ->
-    CxCy = re:split(Heartbeats, <<",">>, [{return, list}]),
-    list_to_tuple([list_to_integer(S) || S <- CxCy]).
-
-reverse_heartbeats(Heartbeats) ->
-    CxCy = re:split(Heartbeats, <<",">>, [{return, list}]),
-    list_to_binary(string:join(lists:reverse(CxCy), ",")).
-
-shutdown(_Reason, _State) ->
-    ok.
-
 next_msgid() ->
     MsgId = case get(msgid) of
                 undefined -> 1;
@@ -363,3 +383,52 @@ make_mqtt_message(Topic, Headers, Body) ->
 receipt_id(Headers) ->
     header(<<"receipt">>, Headers).
 
+%%--------------------------------------------------------------------
+%% Heartbeat
+
+parse_heartbeats(Heartbeats) ->
+    CxCy = re:split(Heartbeats, <<",">>, [{return, list}]),
+    list_to_tuple([list_to_integer(S) || S <- CxCy]).
+
+reverse_heartbeats({Cx, Cy}) ->
+    iolist_to_binary(io_lib:format("~w,~w", [Cy, Cx])).
+
+start_heartbeart_timer(Heartbeats, State) ->
+    ensure_timer(
+      [incoming_timer, outgoing_timer],
+      State#stomp_proto{heart_beats = emqx_stomp_heartbeat:init(Heartbeats)}).
+
+%%--------------------------------------------------------------------
+%% Timer
+
+ensure_timer([Name], State) ->
+    ensure_timer(Name, State);
+ensure_timer([Name | Rest], State) ->
+    ensure_timer(Rest, ensure_timer(Name, State));
+
+ensure_timer(Name, State = #stomp_proto{timers = Timers}) ->
+    TRef = maps:get(Name, Timers, undefined),
+    Time = interval(Name, State),
+    case TRef == undefined andalso is_integer(Time) andalso Time > 0 of
+        true  -> ensure_timer(Name, Time, State);
+        false -> State %% Timer disabled or exists
+    end.
+
+ensure_timer(Name, Time, State = #stomp_proto{timers = Timers}) ->
+    Msg = maps:get(Name, ?TIMER_TABLE),
+    TRef = emqx_misc:start_timer(Time, Msg),
+    State#stomp_proto{timers = Timers#{Name => TRef}}.
+
+reset_timer(Name, State) ->
+    ensure_timer(Name, clean_timer(Name, State)).
+
+reset_timer(Name, Time, State) ->
+    ensure_timer(Name, Time, clean_timer(Name, State)).
+
+clean_timer(Name, State = #stomp_proto{timers = Timers}) ->
+    State#stomp_proto{timers = maps:remove(Name, Timers)}.
+
+interval(incoming_timer, #stomp_proto{heart_beats = HrtBt}) ->
+    emqx_stomp_heartbeat:interval(incoming, HrtBt);
+interval(outgoing_timer, #stomp_proto{heart_beats = HrtBt}) ->
+    emqx_stomp_heartbeat:interval(outgoing, HrtBt).
diff --git a/apps/emqx_stomp/test/emqx_stomp_SUITE.erl b/apps/emqx_stomp/test/emqx_stomp_SUITE.erl
index d8b5cc5b2..ca46762ed 100644
--- a/apps/emqx_stomp/test/emqx_stomp_SUITE.erl
+++ b/apps/emqx_stomp/test/emqx_stomp_SUITE.erl
@@ -100,7 +100,7 @@ t_heartbeat(_) ->
                                                      {<<"host">>, <<"127.0.0.1:61613">>},
                                                      {<<"login">>, <<"guest">>},
                                                      {<<"passcode">>, <<"guest">>},
-                                                     {<<"heart-beat">>, <<"500,800">>}])),
+                                                     {<<"heart-beat">>, <<"1000,800">>}])),
                         {ok, Data} = gen_tcp:recv(Sock, 0),
                         {ok, #stomp_frame{command = <<"CONNECTED">>,
                                           headers = _,
@@ -345,5 +345,5 @@ parse(Data) ->
     ProtoEnv = [{max_headers, 10},
                 {max_header_length, 1024},
                 {max_body_length, 8192}],
-    ParseFun = emqx_stomp_frame:parser(ProtoEnv),
-    ParseFun(Data).
+    Parser = emqx_stomp_frame:init_parer_state(ProtoEnv),
+    emqx_stomp_frame:parse(Data, Parser).
diff --git a/apps/emqx_stomp/test/emqx_stomp_heartbeat_SUITE.erl b/apps/emqx_stomp/test/emqx_stomp_heartbeat_SUITE.erl
new file mode 100644
index 000000000..0d01bfcd4
--- /dev/null
+++ b/apps/emqx_stomp/test/emqx_stomp_heartbeat_SUITE.erl
@@ -0,0 +1,53 @@
+%%--------------------------------------------------------------------
+%% Copyright (c) 2020 EMQ Technologies Co., Ltd. All Rights Reserved.
+%%
+%% Licensed under the Apache License, Version 2.0 (the "License");
+%% you may not use this file except in compliance with the License.
+%% You may obtain a copy of the License at
+%%
+%%     http://www.apache.org/licenses/LICENSE-2.0
+%%
+%% Unless required by applicable law or agreed to in writing, software
+%% distributed under the License is distributed on an "AS IS" BASIS,
+%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+%% See the License for the specific language governing permissions and
+%% limitations under the License.
+%%--------------------------------------------------------------------
+
+-module(emqx_stomp_heartbeat_SUITE).
+
+-compile(export_all).
+-compile(nowarn_export_all).
+
+all() -> emqx_ct:all(?MODULE).
+
+%%--------------------------------------------------------------------
+%% Test Cases
+%%--------------------------------------------------------------------
+
+t_init(_) ->
+    #{} = emqx_stomp_heartbeat:init({0, 0}),
+    #{incoming := _} = emqx_stomp_heartbeat:init({1, 0}),
+    #{outgoing := _} = emqx_stomp_heartbeat:init({0, 1}).
+
+t_check_1(_) ->
+    HrtBt = emqx_stomp_heartbeat:init({1, 1}),
+    {ok, HrtBt1} = emqx_stomp_heartbeat:check(incoming, 0, HrtBt),
+    {error, timeout} = emqx_stomp_heartbeat:check(incoming, 0, HrtBt1),
+
+    {ok, HrtBt2} = emqx_stomp_heartbeat:check(outgoing, 0, HrtBt1),
+    {error, timeout} = emqx_stomp_heartbeat:check(outgoing, 0, HrtBt2),
+    ok.
+
+t_check_2(_) ->
+    HrtBt = emqx_stomp_heartbeat:init({1, 0}),
+    #{incoming := _} = lists:foldl(fun(I, Acc) ->
+                            {ok, NAcc} = emqx_stomp_heartbeat:check(incoming, I, Acc),
+                            NAcc
+                       end, HrtBt, lists:seq(1,1000)),
+    ok.
+
+t_info(_) ->
+    HrtBt = emqx_stomp_heartbeat:init({100, 100}),
+    #{incoming := _,
+      outgoing := _} = emqx_stomp_heartbeat:info(HrtBt).

From 1263a05bbcdc6f9f26b3a3d02622adc0f9ccede8 Mon Sep 17 00:00:00 2001
From: JianBo He 
Date: Thu, 10 Dec 2020 09:38:46 +0800
Subject: [PATCH 2/4] refactor(stomp): improve code form naming

---
 apps/emqx_stomp/src/emqx_stomp_connection.erl | 165 +++++++++---------
 apps/emqx_stomp/src/emqx_stomp_protocol.erl   |  60 ++++---
 2 files changed, 114 insertions(+), 111 deletions(-)

diff --git a/apps/emqx_stomp/src/emqx_stomp_connection.erl b/apps/emqx_stomp/src/emqx_stomp_connection.erl
index dc1977944..3bb3a58fd 100644
--- a/apps/emqx_stomp/src/emqx_stomp_connection.erl
+++ b/apps/emqx_stomp/src/emqx_stomp_connection.erl
@@ -19,6 +19,9 @@
 -behaviour(gen_server).
 
 -include("emqx_stomp.hrl").
+-include_lib("emqx/include/logger.hrl").
+
+-logger_header("[Stomp-Conn]").
 
 -export([ start_link/3
         , info/1
@@ -36,16 +39,13 @@
 %% for protocol
 -export([send/4, heartbeat/2]).
 
--record(stomp_client, {transport, socket, peername, conn_name, conn_state,
-                       await_recv, rate_limit, parser, proto_state,
-                       proto_env, heartbeat}).
+-record(state, {transport, socket, peername, conn_name, conn_state,
+                await_recv, rate_limit, parser, pstate,
+                proto_env, heartbeat}).
 
 -define(INFO_KEYS, [peername, await_recv, conn_state]).
 -define(SOCK_STATS, [recv_oct, recv_cnt, send_oct, send_cnt]).
 
--define(LOG(Level, Format, Args, State),
-        emqx_logger:Level("Stomp(~s): " ++ Format, [State#stomp_client.conn_name | Args])).
-
 start_link(Transport, Sock, ProtoEnv) ->
     {ok, proc_lib:spawn_link(?MODULE, init, [[Transport, Sock, ProtoEnv]])}.
 
@@ -61,20 +61,21 @@ init([Transport, Sock, ProtoEnv]) ->
             SendFun = {fun ?MODULE:send/4, [Transport, Sock, self()]},
             HrtBtFun = {fun ?MODULE:heartbeat/2, [Transport, Sock]},
             Parser = emqx_stomp_frame:init_parer_state(ProtoEnv),
-            ProtoState = emqx_stomp_protocol:init(#{peername => Peername,
-                                                    sendfun => SendFun,
-                                                    heartfun => HrtBtFun}, ProtoEnv),
+            PState = emqx_stomp_protocol:init(#{peername => Peername,
+                                                sendfun => SendFun,
+                                                heartfun => HrtBtFun}, ProtoEnv),
             RateLimit = init_rate_limit(proplists:get_value(rate_limit, ProtoEnv)),
-            State = run_socket(#stomp_client{transport   = Transport,
-                                             socket      = NewSock,
-                                             peername    = Peername,
-                                             conn_name   = ConnName,
-                                             conn_state  = running,
-                                             await_recv  = false,
-                                             rate_limit  = RateLimit,
-                                             parser      = Parser,
-                                             proto_env   = ProtoEnv,
-                                             proto_state = ProtoState}),
+            State = run_socket(#state{transport   = Transport,
+                                      socket      = NewSock,
+                                      peername    = Peername,
+                                      conn_name   = ConnName,
+                                      conn_state  = running,
+                                      await_recv  = false,
+                                      rate_limit  = RateLimit,
+                                      parser      = Parser,
+                                      proto_env   = ProtoEnv,
+                                      pstate      = PState}),
+            emqx_logger:set_metadata_peername(esockd:format(Peername)),
             gen_server:enter_loop(?MODULE, [{hibernate_after, 5000}], State, 20000);
         {error, Reason} ->
             {stop, Reason}
@@ -96,15 +97,15 @@ send(Data, Transport, Sock, ConnPid) ->
 heartbeat(Transport, Sock) ->
     Transport:send(Sock, <<$\n>>).
 
-handle_call(info, _From, State = #stomp_client{transport   = Transport,
-                                               socket      = Sock,
-                                               peername    = Peername,
-                                               await_recv  = AwaitRecv,
-                                               conn_state  = ConnState,
-                                               proto_state = ProtoState}) ->
+handle_call(info, _From, State = #state{transport   = Transport,
+                                        socket      = Sock,
+                                        peername    = Peername,
+                                        await_recv  = AwaitRecv,
+                                        conn_state  = ConnState,
+                                        pstate      = PState}) ->
     ClientInfo = [{peername,  Peername}, {await_recv, AwaitRecv},
                   {conn_state, ConnState}],
-    ProtoInfo  = emqx_stomp_protocol:info(ProtoState),
+    ProtoInfo  = emqx_stomp_protocol:info(PState),
     case Transport:getstat(Sock, ?SOCK_STATS) of
         {ok, SockStats} ->
             {reply, lists:append([ClientInfo, ProtoInfo, SockStats]), State};
@@ -113,11 +114,11 @@ handle_call(info, _From, State = #stomp_client{transport   = Transport,
     end;
 
 handle_call(Req, _From, State) ->
-    ?LOG(error, "unexpected request: ~p", [Req], State),
+    ?LOG(error, "unexpected request: ~p", [Req]),
     {reply, ignored, State}.
 
 handle_cast(Msg, State) ->
-    ?LOG(error, "unexpected msg: ~p", [Msg], State),
+    ?LOG(error, "unexpected msg: ~p", [Msg]),
     noreply(State).
 
 handle_info(timeout, State) ->
@@ -144,15 +145,15 @@ handle_info({timeout, TRef, TMsg}, State) when TMsg =:= incoming;
             shutdown({sock_error, Reason}, State)
     end;
 
-handle_info({'EXIT', HbProc, Error}, State = #stomp_client{heartbeat = HbProc}) ->
+handle_info({'EXIT', HbProc, Error}, State = #state{heartbeat = HbProc}) ->
     stop(Error, State);
 
 handle_info(activate_sock, State) ->
-    noreply(run_socket(State#stomp_client{conn_state = running}));
+    noreply(run_socket(State#state{conn_state = running}));
 
 handle_info({inet_async, _Sock, _Ref, {ok, Bytes}}, State) ->
-    ?LOG(debug, "RECV ~p", [Bytes], State),
-    received(Bytes, rate_limit(size(Bytes), State#stomp_client{await_recv = false}));
+    ?LOG(debug, "RECV ~p", [Bytes]),
+    received(Bytes, rate_limit(size(Bytes), State#state{await_recv = false}));
 
 handle_info({inet_async, _Sock, _Ref, {error, Reason}}, State) ->
     shutdown(Reason, State);
@@ -163,29 +164,29 @@ handle_info({inet_reply, _Ref, ok}, State) ->
 handle_info({inet_reply, _Sock, {error, Reason}}, State) ->
     shutdown(Reason, State);
 
-handle_info({deliver, _Topic, Msg}, State = #stomp_client{proto_state = ProtoState}) ->
-    noreply(State#stomp_client{proto_state = case emqx_stomp_protocol:send(Msg, ProtoState) of 
-                                                 {ok, ProtoState1} ->
-                                                     ProtoState1;
-                                                 {error, dropped, ProtoState1} ->
-                                                     ProtoState1
-                                             end});
+handle_info({deliver, _Topic, Msg}, State = #state{pstate = PState}) ->
+    noreply(State#state{pstate = case emqx_stomp_protocol:send(Msg, PState) of
+                                     {ok, PState1} ->
+                                         PState1;
+                                     {error, dropped, PState1} ->
+                                         PState1
+                                 end});
 
 handle_info(Info, State) ->
-    ?LOG(error, "Unexpected info: ~p", [Info], State),
+    ?LOG(error, "Unexpected info: ~p", [Info]),
     noreply(State).
 
-terminate(Reason, State = #stomp_client{transport   = Transport,
-                                        socket      = Sock,
-                                        proto_state = ProtoState}) ->
-    ?LOG(info, "terminated for ~p", [Reason], State),
+terminate(Reason, #state{transport = Transport,
+                         socket    = Sock,
+                         pstate    = PState}) ->
+    ?LOG(info, "terminated for ~p", [Reason]),
     Transport:fast_close(Sock),
-    case {ProtoState, Reason} of
+    case {PState, Reason} of
         {undefined, _} -> ok;
         {_, {shutdown, Error}} ->
-            emqx_stomp_protocol:shutdown(Error, ProtoState);
+            emqx_stomp_protocol:shutdown(Error, PState);
         {_,  Reason} ->
-            emqx_stomp_protocol:shutdown(Reason, ProtoState)
+            emqx_stomp_protocol:shutdown(Reason, PState)
     end.
 
 code_change(_OldVsn, State, _Extra) ->
@@ -195,69 +196,69 @@ code_change(_OldVsn, State, _Extra) ->
 %% Receive and Parse data
 %%--------------------------------------------------------------------
 
-with_proto(Fun, Args, State = #stomp_client{proto_state = ProtoState}) ->
-    case erlang:apply(emqx_stomp_protocol, Fun, Args ++ [ProtoState]) of
-        {ok, NProtoState} ->
-            noreply(State#stomp_client{proto_state = NProtoState});
-        {F, Reason, NProtoState} when F == stop;
-                                      F == error;
-                                      F == shutdown ->
-            shutdown(Reason, State#stomp_client{proto_state = NProtoState})
+with_proto(Fun, Args, State = #state{pstate = PState}) ->
+    case erlang:apply(emqx_stomp_protocol, Fun, Args ++ [PState]) of
+        {ok, NPState} ->
+            noreply(State#state{pstate = NPState});
+        {F, Reason, NPState} when F == stop;
+                                  F == error;
+                                  F == shutdown ->
+            shutdown(Reason, State#state{pstate = NPState})
     end.
 
 received(<<>>, State) ->
     noreply(State);
 
-received(Bytes, State = #stomp_client{parser   = Parser,
-                                      proto_state = ProtoState}) ->
+received(Bytes, State = #state{parser   = Parser,
+                               pstate = PState}) ->
     try emqx_stomp_frame:parse(Bytes, Parser) of
         {more, NewParser} ->
-            noreply(State#stomp_client{parser = NewParser});
+            noreply(State#state{parser = NewParser});
         {ok, Frame, Rest} ->
-            ?LOG(info, "RECV Frame: ~s", [emqx_stomp_frame:format(Frame)], State),
-            case emqx_stomp_protocol:received(Frame, ProtoState) of
-                {ok, ProtoState1}           ->
-                    received(Rest, reset_parser(State#stomp_client{proto_state = ProtoState1}));
-                {error, Error, ProtoState1} ->
-                    shutdown(Error, State#stomp_client{proto_state = ProtoState1});
-                {stop, Reason, ProtoState1} ->
-                    stop(Reason, State#stomp_client{proto_state = ProtoState1})
+            ?LOG(info, "RECV Frame: ~s", [emqx_stomp_frame:format(Frame)]),
+            case emqx_stomp_protocol:received(Frame, PState) of
+                {ok, PState1}           ->
+                    received(Rest, reset_parser(State#state{pstate = PState1}));
+                {error, Error, PState1} ->
+                    shutdown(Error, State#state{pstate = PState1});
+                {stop, Reason, PState1} ->
+                    stop(Reason, State#state{pstate = PState1})
             end;
         {error, Error} ->
-            ?LOG(error, "Framing error - ~s", [Error], State),
-            ?LOG(error, "Bytes: ~p", [Bytes], State),
+            ?LOG(error, "Framing error - ~s", [Error]),
+            ?LOG(error, "Bytes: ~p", [Bytes]),
             shutdown(frame_error, State)
     catch
         _Error:Reason ->
-            ?LOG(error, "Parser failed for ~p", [Reason], State),
-            ?LOG(error, "Error data: ~p", [Bytes], State),
+            ?LOG(error, "Parser failed for ~p", [Reason]),
+            ?LOG(error, "Error data: ~p", [Bytes]),
             shutdown(parse_error, State)
     end.
 
-reset_parser(State = #stomp_client{proto_env = ProtoEnv}) ->
-    State#stomp_client{parser = emqx_stomp_frame:init_parer_state(ProtoEnv)}.
+reset_parser(State = #state{proto_env = ProtoEnv}) ->
+    State#state{parser = emqx_stomp_frame:init_parer_state(ProtoEnv)}.
 
-rate_limit(_Size, State = #stomp_client{rate_limit = undefined}) ->
+rate_limit(_Size, State = #state{rate_limit = undefined}) ->
     run_socket(State);
-rate_limit(Size, State = #stomp_client{rate_limit = Rl}) ->
+rate_limit(Size, State = #state{rate_limit = Rl}) ->
     case esockd_rate_limit:check(Size, Rl) of
         {0, Rl1} ->
-            run_socket(State#stomp_client{conn_state = running, rate_limit = Rl1});
+            run_socket(State#state{conn_state = running, rate_limit = Rl1});
         {Pause, Rl1} ->
-            ?LOG(error, "Rate limiter pause for ~p", [Pause], State),
+            ?LOG(error, "Rate limiter pause for ~p", [Pause]),
             erlang:send_after(Pause, self(), activate_sock),
-            State#stomp_client{conn_state = blocked, rate_limit = Rl1}
+            State#state{conn_state = blocked, rate_limit = Rl1}
     end.
 
-run_socket(State = #stomp_client{conn_state = blocked}) ->
+run_socket(State = #state{conn_state = blocked}) ->
     State;
-run_socket(State = #stomp_client{await_recv = true}) ->
+run_socket(State = #state{await_recv = true}) ->
     State;
-run_socket(State = #stomp_client{transport = Transport, socket = Sock}) ->
+run_socket(State = #state{transport = Transport, socket = Sock}) ->
     Transport:async_recv(Sock, 0, infinity),
-    State#stomp_client{await_recv = true}.
+    State#state{await_recv = true}.
 
-getstat(Stat, #stomp_client{transport = Transport, socket = Sock}) ->
+getstat(Stat, #state{transport = Transport, socket = Sock}) ->
     case Transport:getstat(Sock, [Stat]) of
         {ok, [{Stat, Val}]} -> {ok, Val};
         {error, Error}      -> {error, Error}
diff --git a/apps/emqx_stomp/src/emqx_stomp_protocol.erl b/apps/emqx_stomp/src/emqx_stomp_protocol.erl
index 4834955a2..a366105b8 100644
--- a/apps/emqx_stomp/src/emqx_stomp_protocol.erl
+++ b/apps/emqx_stomp/src/emqx_stomp_protocol.erl
@@ -38,7 +38,7 @@
         , timeout/3
         ]).
 
--record(stomp_proto, {
+-record(pstate, {
           peername,
           heartfun,
           sendfun,
@@ -58,7 +58,7 @@
           outgoing_timer => outgoing
         }).
 
--type(stomp_proto() :: #stomp_proto{}).
+-type(pstate() :: #pstate{}).
 
 %% @doc Init protocol
 init(#{peername := Peername,
@@ -66,14 +66,14 @@ init(#{peername := Peername,
        heartfun := HeartFun}, Env) ->
     AllowAnonymous = get_value(allow_anonymous, Env, false),
     DefaultUser = get_value(default_user, Env),
-	#stomp_proto{peername = Peername,
+	#pstate{peername = Peername,
                  heartfun = HeartFun,
                  sendfun = SendFun,
                  timers = #{},
                  allow_anonymous = AllowAnonymous,
                  default_user = DefaultUser}.
 
-info(#stomp_proto{connected     = Connected,
+info(#pstate{connected     = Connected,
                   proto_ver     = ProtoVer,
                   proto_name    = ProtoName,
                   heart_beats   = Heartbeats,
@@ -86,23 +86,25 @@ info(#stomp_proto{connected     = Connected,
      {login, Login},
      {subscriptions, Subscriptions}].
 
--spec(received(stomp_frame(), stomp_proto())
-    -> {ok, stomp_proto()}
-     | {error, any(), stomp_proto()}
-     | {stop, any(), stomp_proto()}).
+-spec(received(stomp_frame(), pstate())
+    -> {ok, pstate()}
+     | {error, any(), pstate()}
+     | {stop, any(), pstate()}).
 received(Frame = #stomp_frame{command = <<"STOMP">>}, State) ->
     received(Frame#stomp_frame{command = <<"CONNECT">>}, State);
 
 received(#stomp_frame{command = <<"CONNECT">>, headers = Headers},
-         State = #stomp_proto{connected = false, allow_anonymous = AllowAnonymous, default_user = DefaultUser}) ->
+         State = #pstate{connected = false, allow_anonymous = AllowAnonymous, default_user = DefaultUser}) ->
     case negotiate_version(header(<<"accept-version">>, Headers)) of
         {ok, Version} ->
             Login = header(<<"login">>, Headers),
             Passc = header(<<"passcode">>, Headers),
             case check_login(Login, Passc, AllowAnonymous, DefaultUser) of
                 true ->
+                    emqx_logger:set_metadata_clientid(Login),
+
                     Heartbeats = parse_heartbeats(header(<<"heart-beat">>, Headers, <<"0,0">>)),
-                    NState = start_heartbeart_timer(Heartbeats, State#stomp_proto{connected = true,
+                    NState = start_heartbeart_timer(Heartbeats, State#pstate{connected = true,
                                                                                   proto_ver = Version, login = Login}),
                     send(connected_frame([{<<"version">>, Version},
                                           {<<"heart-beat">>, reverse_heartbeats(Heartbeats)}]), NState);
@@ -116,7 +118,7 @@ received(#stomp_frame{command = <<"CONNECT">>, headers = Headers},
             {error, unsupported_version, State}
     end;
 
-received(#stomp_frame{command = <<"CONNECT">>}, State = #stomp_proto{connected = true}) ->
+received(#stomp_frame{command = <<"CONNECT">>}, State = #pstate{connected = true}) ->
     {error, unexpected_connect, State};
 
 received(#stomp_frame{command = <<"SEND">>, headers = Headers, body = Body}, State) ->
@@ -134,7 +136,7 @@ received(#stomp_frame{command = <<"SEND">>, headers = Headers, body = Body}, Sta
     end;
 
 received(#stomp_frame{command = <<"SUBSCRIBE">>, headers = Headers},
-            State = #stomp_proto{subscriptions = Subscriptions}) ->
+            State = #pstate{subscriptions = Subscriptions}) ->
     Id    = header(<<"id">>, Headers),
     Topic = header(<<"destination">>, Headers),
     Ack   = header(<<"ack">>, Headers, <<"auto">>),
@@ -143,18 +145,18 @@ received(#stomp_frame{command = <<"SUBSCRIBE">>, headers = Headers},
                            {ok, State};
                        false ->
                            emqx_broker:subscribe(Topic),
-                           {ok, State#stomp_proto{subscriptions = [{Id, Topic, Ack}|Subscriptions]}}
+                           {ok, State#pstate{subscriptions = [{Id, Topic, Ack}|Subscriptions]}}
                    end,
     maybe_send_receipt(receipt_id(Headers), State1);
 
 received(#stomp_frame{command = <<"UNSUBSCRIBE">>, headers = Headers},
-            State = #stomp_proto{subscriptions = Subscriptions}) ->
+            State = #pstate{subscriptions = Subscriptions}) ->
     Id = header(<<"id">>, Headers),
 
     {ok, State1} = case lists:keyfind(Id, 1, Subscriptions) of
                        {Id, Topic, _Ack} ->
                            ok = emqx_broker:unsubscribe(Topic),
-                           {ok, State#stomp_proto{subscriptions = lists:keydelete(Id, 1, Subscriptions)}};
+                           {ok, State#pstate{subscriptions = lists:keydelete(Id, 1, Subscriptions)}};
                        false ->
                            {ok, State}
                    end,
@@ -238,7 +240,7 @@ received(#stomp_frame{command = <<"DISCONNECT">>, headers = Headers}, State) ->
     {stop, normal, State}.
 
 send(Msg = #message{topic = Topic, headers = Headers, payload = Payload},
-     State = #stomp_proto{subscriptions = Subscriptions}) ->
+     State = #pstate{subscriptions = Subscriptions}) ->
     case lists:keyfind(Topic, 2, Subscriptions) of
         {Id, Topic, Ack} ->
             Headers0 = [{<<"subscription">>, Id},
@@ -260,7 +262,7 @@ send(Msg = #message{topic = Topic, headers = Headers, payload = Payload},
             {error, dropped, State}
     end;
 
-send(Frame, State = #stomp_proto{sendfun = {Fun, Args}}) ->
+send(Frame, State = #pstate{sendfun = {Fun, Args}}) ->
     ?LOG(info, "SEND Frame: ~s", [emqx_stomp_frame:format(Frame)]),
     Data = emqx_stomp_frame:serialize(Frame),
     ?LOG(debug, "SEND ~p", [Data]),
@@ -271,23 +273,23 @@ shutdown(_Reason, _State) ->
     ok.
 
 timeout(_TRef, {incoming, NewVal},
-        State = #stomp_proto{heart_beats = HrtBt}) ->
+        State = #pstate{heart_beats = HrtBt}) ->
     case emqx_stomp_heartbeat:check(incoming, NewVal, HrtBt) of
         {error, timeout} ->
             {shutdown, heartbeat_timeout, State};
         {ok, NHrtBt} ->
-            {ok, reset_timer(incoming_timer, State#stomp_proto{heart_beats = NHrtBt})}
+            {ok, reset_timer(incoming_timer, State#pstate{heart_beats = NHrtBt})}
     end;
 
 timeout(_TRef, {outgoing, NewVal},
-        State = #stomp_proto{heart_beats = HrtBt,
+        State = #pstate{heart_beats = HrtBt,
                              heartfun = {Fun, Args}}) ->
     case emqx_stomp_heartbeat:check(outgoing, NewVal, HrtBt) of
         {error, timeout} ->
             _ = erlang:apply(Fun, Args),
             {ok, State};
         {ok, NHrtBt} ->
-            {ok, reset_timer(outgoing_timer, State#stomp_proto{heart_beats = NHrtBt})}
+            {ok, reset_timer(outgoing_timer, State#pstate{heart_beats = NHrtBt})}
     end.
 
 negotiate_version(undefined) ->
@@ -396,7 +398,7 @@ reverse_heartbeats({Cx, Cy}) ->
 start_heartbeart_timer(Heartbeats, State) ->
     ensure_timer(
       [incoming_timer, outgoing_timer],
-      State#stomp_proto{heart_beats = emqx_stomp_heartbeat:init(Heartbeats)}).
+      State#pstate{heart_beats = emqx_stomp_heartbeat:init(Heartbeats)}).
 
 %%--------------------------------------------------------------------
 %% Timer
@@ -406,7 +408,7 @@ ensure_timer([Name], State) ->
 ensure_timer([Name | Rest], State) ->
     ensure_timer(Rest, ensure_timer(Name, State));
 
-ensure_timer(Name, State = #stomp_proto{timers = Timers}) ->
+ensure_timer(Name, State = #pstate{timers = Timers}) ->
     TRef = maps:get(Name, Timers, undefined),
     Time = interval(Name, State),
     case TRef == undefined andalso is_integer(Time) andalso Time > 0 of
@@ -414,10 +416,10 @@ ensure_timer(Name, State = #stomp_proto{timers = Timers}) ->
         false -> State %% Timer disabled or exists
     end.
 
-ensure_timer(Name, Time, State = #stomp_proto{timers = Timers}) ->
+ensure_timer(Name, Time, State = #pstate{timers = Timers}) ->
     Msg = maps:get(Name, ?TIMER_TABLE),
     TRef = emqx_misc:start_timer(Time, Msg),
-    State#stomp_proto{timers = Timers#{Name => TRef}}.
+    State#pstate{timers = Timers#{Name => TRef}}.
 
 reset_timer(Name, State) ->
     ensure_timer(Name, clean_timer(Name, State)).
@@ -425,10 +427,10 @@ reset_timer(Name, State) ->
 reset_timer(Name, Time, State) ->
     ensure_timer(Name, Time, clean_timer(Name, State)).
 
-clean_timer(Name, State = #stomp_proto{timers = Timers}) ->
-    State#stomp_proto{timers = maps:remove(Name, Timers)}.
+clean_timer(Name, State = #pstate{timers = Timers}) ->
+    State#pstate{timers = maps:remove(Name, Timers)}.
 
-interval(incoming_timer, #stomp_proto{heart_beats = HrtBt}) ->
+interval(incoming_timer, #pstate{heart_beats = HrtBt}) ->
     emqx_stomp_heartbeat:interval(incoming, HrtBt);
-interval(outgoing_timer, #stomp_proto{heart_beats = HrtBt}) ->
+interval(outgoing_timer, #pstate{heart_beats = HrtBt}) ->
     emqx_stomp_heartbeat:interval(outgoing, HrtBt).

From 713b4c780477d4974004fd732dac58a0dbd160da Mon Sep 17 00:00:00 2001
From: JianBo He 
Date: Thu, 10 Dec 2020 12:32:18 +0800
Subject: [PATCH 3/4] refactor(stomp): remove transaction module

---
 apps/emqx_stomp/rebar.config                  |   2 -
 apps/emqx_stomp/src/emqx_stomp_connection.erl |   7 +-
 apps/emqx_stomp/src/emqx_stomp_protocol.erl   | 158 +++++++++++-------
 .../emqx_stomp/src/emqx_stomp_transaction.erl |  77 ---------
 4 files changed, 100 insertions(+), 144 deletions(-)
 delete mode 100644 apps/emqx_stomp/src/emqx_stomp_transaction.erl

diff --git a/apps/emqx_stomp/rebar.config b/apps/emqx_stomp/rebar.config
index a9e3eebd6..7ac3b98c8 100644
--- a/apps/emqx_stomp/rebar.config
+++ b/apps/emqx_stomp/rebar.config
@@ -14,5 +14,3 @@
 {cover_enabled, true}.
 {cover_opts, [verbose]}.
 {cover_export_enabled, true}.
-
-{plugins, [coveralls]}.
diff --git a/apps/emqx_stomp/src/emqx_stomp_connection.erl b/apps/emqx_stomp/src/emqx_stomp_connection.erl
index 3bb3a58fd..740320489 100644
--- a/apps/emqx_stomp/src/emqx_stomp_connection.erl
+++ b/apps/emqx_stomp/src/emqx_stomp_connection.erl
@@ -127,10 +127,6 @@ handle_info(timeout, State) ->
 handle_info({shutdown, Reason}, State) ->
     shutdown(Reason, State);
 
-handle_info({transaction, {timeout, Id}}, State) ->
-    emqx_stomp_transaction:timeout(Id),
-    noreply(State);
-
 handle_info({timeout, TRef, TMsg}, State) when TMsg =:= incoming;
                                                TMsg =:= outgoing ->
 
@@ -145,6 +141,9 @@ handle_info({timeout, TRef, TMsg}, State) when TMsg =:= incoming;
             shutdown({sock_error, Reason}, State)
     end;
 
+handle_info({timeout, TRef, TMsg}, State) ->
+    with_proto(timeout, [TRef, TMsg], State);
+
 handle_info({'EXIT', HbProc, Error}, State = #state{heartbeat = HbProc}) ->
     stop(Error, State);
 
diff --git a/apps/emqx_stomp/src/emqx_stomp_protocol.erl b/apps/emqx_stomp/src/emqx_stomp_protocol.erl
index a366105b8..184b2cc90 100644
--- a/apps/emqx_stomp/src/emqx_stomp_protocol.erl
+++ b/apps/emqx_stomp/src/emqx_stomp_protocol.erl
@@ -38,6 +38,12 @@
         , timeout/3
         ]).
 
+%% for trans callback
+-export([ handle_recv_send_frame/2
+        , handle_recv_ack_frame/2
+        , handle_recv_nack_frame/2
+        ]).
+
 -record(pstate, {
           peername,
           heartfun,
@@ -50,14 +56,18 @@
           allow_anonymous,
           default_user,
           subscriptions = [],
-          timers :: #{atom() => disable | undefined | reference()}
+          timers :: #{atom() => disable | undefined | reference()},
+          transaction :: #{binary() => list()}
          }).
 
 -define(TIMER_TABLE, #{
           incoming_timer => incoming,
-          outgoing_timer => outgoing
+          outgoing_timer => outgoing,
+          clean_trans_timer => clean_trans
         }).
 
+-define(TRANS_TIMEOUT, 60000).
+
 -type(pstate() :: #pstate{}).
 
 %% @doc Init protocol
@@ -70,6 +80,7 @@ init(#{peername := Peername,
                  heartfun = HeartFun,
                  sendfun = SendFun,
                  timers = #{},
+                 transaction = #{},
                  allow_anonymous = AllowAnonymous,
                  default_user = DefaultUser}.
 
@@ -121,18 +132,10 @@ received(#stomp_frame{command = <<"CONNECT">>, headers = Headers},
 received(#stomp_frame{command = <<"CONNECT">>}, State = #pstate{connected = true}) ->
     {error, unexpected_connect, State};
 
-received(#stomp_frame{command = <<"SEND">>, headers = Headers, body = Body}, State) ->
-    Topic = header(<<"destination">>, Headers),
-    Action = fun(State0) ->
-                 maybe_send_receipt(receipt_id(Headers), State0),
-                 emqx_broker:publish(
-                     make_mqtt_message(Topic, Headers, iolist_to_binary(Body))
-                 ),
-                 State0
-             end,
+received(Frame = #stomp_frame{command = <<"SEND">>, headers = Headers}, State) ->
     case header(<<"transaction">>, Headers) of
-        undefined     -> {ok, Action(State)};
-        TransactionId -> add_action(TransactionId, Action, receipt_id(Headers), State)
+        undefined     -> {ok, handle_recv_send_frame(Frame, State)};
+        TransactionId -> add_action(TransactionId, {fun ?MODULE:handle_recv_send_frame/2, [Frame]}, receipt_id(Headers), State)
     end;
 
 received(#stomp_frame{command = <<"SUBSCRIBE">>, headers = Headers},
@@ -167,15 +170,10 @@ received(#stomp_frame{command = <<"UNSUBSCRIBE">>, headers = Headers},
 %% transaction:tx1
 %%
 %% ^@
-received(#stomp_frame{command = <<"ACK">>, headers = Headers}, State) ->
-    Id = header(<<"id">>, Headers),
-    Action = fun(State0) -> 
-                 maybe_send_receipt(receipt_id(Headers), State0),
-                 ack(Id, State0) 
-             end,
+received(Frame = #stomp_frame{command = <<"ACK">>, headers = Headers}, State) ->
     case header(<<"transaction">>, Headers) of
-        undefined     -> {ok, Action(State)};
-        TransactionId -> add_action(TransactionId, Action, receipt_id(Headers), State)
+        undefined     -> {ok, handle_recv_ack_frame(Frame, State)};
+        TransactionId -> add_action(TransactionId, {fun ?MODULE:handle_recv_ack_frame/2, [Frame]}, receipt_id(Headers), State)
     end;
 
 %% NACK
@@ -183,29 +181,25 @@ received(#stomp_frame{command = <<"ACK">>, headers = Headers}, State) ->
 %% transaction:tx1
 %%
 %% ^@
-received(#stomp_frame{command = <<"NACK">>, headers = Headers}, State) ->
-    Id = header(<<"id">>, Headers),
-    Action = fun(State0) -> 
-                 maybe_send_receipt(receipt_id(Headers), State0),
-                 nack(Id, State0) 
-             end,
+received(Frame = #stomp_frame{command = <<"NACK">>, headers = Headers}, State) ->
     case header(<<"transaction">>, Headers) of
-        undefined     -> {ok, Action(State)};
-        TransactionId -> add_action(TransactionId, Action, receipt_id(Headers), State)
+        undefined     -> {ok, handle_recv_nack_frame(Frame, State)};
+        TransactionId -> add_action(TransactionId, {fun ?MODULE:handle_recv_nack_frame/2, [Frame]}, receipt_id(Headers), State)
     end;
 
 %% BEGIN
 %% transaction:tx1
 %%
 %% ^@
-received(#stomp_frame{command = <<"BEGIN">>, headers = Headers}, State) ->
-    Id        = header(<<"transaction">>, Headers),
-    %% self() ! TimeoutMsg
-    TimeoutMsg = {transaction, {timeout, Id}},
-    case emqx_stomp_transaction:start(Id, TimeoutMsg) of
-        {ok, _Transaction} ->
-            maybe_send_receipt(receipt_id(Headers), State);
-        {error, already_started} ->
+received(#stomp_frame{command = <<"BEGIN">>, headers = Headers},
+         State = #pstate{transaction = Trans}) ->
+    Id = header(<<"transaction">>, Headers),
+    case maps:get(Id, Trans, undefined) of
+        undefined ->
+            Ts = erlang:system_time(millisecond),
+            NState = ensure_clean_trans_timer(State#pstate{transaction = Trans#{Id => {Ts, []}}}),
+            maybe_send_receipt(receipt_id(Headers), NState);
+        _ ->
             send(error_frame(receipt_id(Headers), ["Transaction ", Id, " already started"]), State)
     end;
 
@@ -213,12 +207,16 @@ received(#stomp_frame{command = <<"BEGIN">>, headers = Headers}, State) ->
 %% transaction:tx1
 %%
 %% ^@
-received(#stomp_frame{command = <<"COMMIT">>, headers = Headers}, State) ->
+received(#stomp_frame{command = <<"COMMIT">>, headers = Headers},
+         State = #pstate{transaction = Trans}) ->
     Id = header(<<"transaction">>, Headers),
-    case emqx_stomp_transaction:commit(Id, State) of
-        {ok, NState} ->
+    case maps:get(Id, Trans, undefined) of
+        {_, Actions} ->
+            NState = lists:foldr(fun({Func, Args}, S) ->
+                erlang:apply(Func, Args ++ [S])
+            end, State#pstate{transaction = maps:remove(Id, Trans)}, Actions),
             maybe_send_receipt(receipt_id(Headers), NState);
-        {error, not_found} ->
+        _ ->
             send(error_frame(receipt_id(Headers), ["Transaction ", Id, " not found"]), State)
     end;
 
@@ -226,12 +224,14 @@ received(#stomp_frame{command = <<"COMMIT">>, headers = Headers}, State) ->
 %% transaction:tx1
 %%
 %% ^@
-received(#stomp_frame{command = <<"ABORT">>, headers = Headers}, State) ->
+received(#stomp_frame{command = <<"ABORT">>, headers = Headers},
+         State = #pstate{transaction = Trans}) ->
     Id = header(<<"transaction">>, Headers),
-    case emqx_stomp_transaction:abort(Id) of
-        ok ->
-            maybe_send_receipt(receipt_id(Headers), State);
-        {error, not_found} ->
+    case maps:get(Id, Trans, undefined) of
+        {_, _Actions} ->
+            NState = State#pstate{transaction = maps:remove(Id, Trans)},
+            maybe_send_receipt(receipt_id(Headers), NState);
+        _ ->
             send(error_frame(receipt_id(Headers), ["Transaction ", Id, " not found"]), State)
     end;
 
@@ -247,8 +247,8 @@ send(Msg = #message{topic = Topic, headers = Headers, payload = Payload},
                         {<<"message-id">>, next_msgid()},
                         {<<"destination">>, Topic},
                         {<<"content-type">>, <<"text/plain">>}],
-            Headers1 = case Ack of 
-                           _ when Ack =:= <<"client">> orelse Ack =:= <<"client-individual">> -> 
+            Headers1 = case Ack of
+                           _ when Ack =:= <<"client">> orelse Ack =:= <<"client-individual">> ->
                                Headers0 ++ [{<<"ack">>, next_ackid()}];
                            _ ->
                                Headers0
@@ -290,7 +290,12 @@ timeout(_TRef, {outgoing, NewVal},
             {ok, State};
         {ok, NHrtBt} ->
             {ok, reset_timer(outgoing_timer, State#pstate{heart_beats = NHrtBt})}
-    end.
+    end;
+
+timeout(_TRef, clean_trans, State = #pstate{transaction = Trans}) ->
+    Now = erlang:system_time(millisecond),
+    NTrans = maps:filter(fun(_, {Ts, _}) -> Ts + ?TRANS_TIMEOUT < Now end, Trans),
+    {ok, ensure_clean_trans_timer(State#pstate{transaction = NTrans})}.
 
 negotiate_version(undefined) ->
     {ok, <<"1.0">>};
@@ -318,11 +323,12 @@ check_login(Login, Passcode, _, DefaultUser) ->
         {_,     _       } -> false
     end.
 
-add_action(Id, Action, ReceiptId, State) ->
-    case emqx_stomp_transaction:add(Id, Action) of
-        {ok, _}           ->
-            {ok, State};
-        {error, not_found} ->
+add_action(Id, Action, ReceiptId, State = #pstate{transaction = Trans}) ->
+    case maps:get(Id, Trans, undefined) of
+        {Ts, Actions} ->
+            NTrans = Trans#{Id => {Ts, [Action|Actions]}},
+            {ok, State#pstate{transaction = NTrans}};
+        _ ->
             send(error_frame(ReceiptId, ["Transaction ", Id, " not found"]), State)
     end.
 
@@ -331,7 +337,7 @@ maybe_send_receipt(undefined, State) ->
 maybe_send_receipt(ReceiptId, State) ->
     send(receipt_frame(ReceiptId), State).
 
-ack(_Id, State) -> 
+ack(_Id, State) ->
     State.
 
 nack(_Id, State) -> State.
@@ -360,7 +366,7 @@ next_msgid() ->
                 undefined -> 1;
                 I         -> I
             end,
-    put(msgid, MsgId + 1), 
+    put(msgid, MsgId + 1),
     MsgId.
 
 next_ackid() ->
@@ -368,16 +374,16 @@ next_ackid() ->
                 undefined -> 1;
                 I         -> I
             end,
-    put(ackid, AckId + 1), 
+    put(ackid, AckId + 1),
     AckId.
 
 make_mqtt_message(Topic, Headers, Body) ->
     Msg = emqx_message:make(stomp, Topic, Body),
     Headers1 = lists:foldl(fun(Key, Headers0) ->
                                proplists:delete(Key, Headers0)
-                           end, Headers, [<<"destination">>, 
-                                          <<"content-length">>, 
-                                          <<"content-type">>, 
+                           end, Headers, [<<"destination">>,
+                                          <<"content-length">>,
+                                          <<"content-type">>,
                                           <<"transaction">>,
                                           <<"receipt">>]),
     emqx_message:set_headers(#{stomp_headers => Headers1}, Msg).
@@ -385,6 +391,33 @@ make_mqtt_message(Topic, Headers, Body) ->
 receipt_id(Headers) ->
     header(<<"receipt">>, Headers).
 
+%%--------------------------------------------------------------------
+%% Transaction Handle
+
+handle_recv_send_frame(#stomp_frame{command = <<"SEND">>, headers = Headers, body = Body}, State) ->
+    Topic = header(<<"destination">>, Headers),
+    maybe_send_receipt(receipt_id(Headers), State),
+    emqx_broker:publish(
+        make_mqtt_message(Topic, Headers, iolist_to_binary(Body))
+    ),
+    State.
+
+handle_recv_ack_frame(#stomp_frame{command = <<"ACK">>, headers = Headers}, State) ->
+    Id = header(<<"id">>, Headers),
+    maybe_send_receipt(receipt_id(Headers), State),
+    ack(Id, State).
+
+handle_recv_nack_frame(#stomp_frame{command = <<"NACK">>, headers = Headers}, State) ->
+    Id = header(<<"id">>, Headers),
+     maybe_send_receipt(receipt_id(Headers), State),
+     nack(Id, State).
+
+ensure_clean_trans_timer(State = #pstate{transaction = Trans}) ->
+    case maps:size(Trans) of
+        0 -> State;
+        _ -> ensure_timer(clean_trans_timer, State)
+    end.
+
 %%--------------------------------------------------------------------
 %% Heartbeat
 
@@ -433,4 +466,7 @@ clean_timer(Name, State = #pstate{timers = Timers}) ->
 interval(incoming_timer, #pstate{heart_beats = HrtBt}) ->
     emqx_stomp_heartbeat:interval(incoming, HrtBt);
 interval(outgoing_timer, #pstate{heart_beats = HrtBt}) ->
-    emqx_stomp_heartbeat:interval(outgoing, HrtBt).
+    emqx_stomp_heartbeat:interval(outgoing, HrtBt);
+interval(clean_trans_timer, _) ->
+    ?TRANS_TIMEOUT.
+
diff --git a/apps/emqx_stomp/src/emqx_stomp_transaction.erl b/apps/emqx_stomp/src/emqx_stomp_transaction.erl
deleted file mode 100644
index 6f15e8aa5..000000000
--- a/apps/emqx_stomp/src/emqx_stomp_transaction.erl
+++ /dev/null
@@ -1,77 +0,0 @@
-%%--------------------------------------------------------------------
-%% Copyright (c) 2020 EMQ Technologies Co., Ltd. All Rights Reserved.
-%%
-%% Licensed under the Apache License, Version 2.0 (the "License");
-%% you may not use this file except in compliance with the License.
-%% You may obtain a copy of the License at
-%%
-%%     http://www.apache.org/licenses/LICENSE-2.0
-%%
-%% Unless required by applicable law or agreed to in writing, software
-%% distributed under the License is distributed on an "AS IS" BASIS,
-%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-%% See the License for the specific language governing permissions and
-%% limitations under the License.
-%%--------------------------------------------------------------------
-
-%% @doc Stomp Transaction
-
--module(emqx_stomp_transaction).
-
--include("emqx_stomp.hrl").
-
--export([ start/2
-        , add/2
-        , commit/2
-        , abort/1
-        , timeout/1
-        ]).
-
--record(transaction, {id, actions, tref}).
-
--define(TIMEOUT, 60000).
-
-start(Id, TimeoutMsg) ->
-    case get({transaction, Id}) of
-        undefined    ->
-            TRef = erlang:send_after(?TIMEOUT, self(), TimeoutMsg),
-            Transaction = #transaction{id = Id, actions = [], tref = TRef},
-            put({transaction, Id}, Transaction),
-            {ok, Transaction};
-        _Transaction ->
-            {error, already_started}
-    end.
-
-add(Id, Action) ->
-    Fun = fun(Transaction = #transaction{actions = Actions}) ->
-            Transaction1 = Transaction#transaction{actions = [Action | Actions]},
-            put({transaction, Id}, Transaction1),
-            {ok, Transaction1}
-          end,
-    with_transaction(Id, Fun).
-
-commit(Id, InitState) ->
-    Fun = fun(Transaction = #transaction{actions = Actions}) ->
-            done(Transaction),
-            {ok, lists:foldr(fun(Action, State) -> Action(State) end,
-                             InitState, Actions)}
-          end,
-    with_transaction(Id, Fun).
-
-abort(Id) ->
-    with_transaction(Id, fun done/1).
-
-timeout(Id) ->
-    erase({transaction, Id}).
-
-done(#transaction{id = Id, tref = TRef}) ->
-    erase({transaction, Id}),
-    catch erlang:cancel_timer(TRef),
-    ok.
-
-with_transaction(Id, Fun) ->
-    case get({transaction, Id}) of
-        undefined   -> {error, not_found};
-        Transaction -> Fun(Transaction)
-    end.
-

From 2b1429fe03efd7367c6d0d28f9ab9688ee166876 Mon Sep 17 00:00:00 2001
From: JianBo He 
Date: Thu, 10 Dec 2020 13:08:03 +0800
Subject: [PATCH 4/4] test(stomp): cover the emqx_stom_heartbeat:interval/2

---
 apps/emqx_stomp/test/emqx_stomp_heartbeat_SUITE.erl | 6 ++++++
 1 file changed, 6 insertions(+)

diff --git a/apps/emqx_stomp/test/emqx_stomp_heartbeat_SUITE.erl b/apps/emqx_stomp/test/emqx_stomp_heartbeat_SUITE.erl
index 0d01bfcd4..b3ea25aa1 100644
--- a/apps/emqx_stomp/test/emqx_stomp_heartbeat_SUITE.erl
+++ b/apps/emqx_stomp/test/emqx_stomp_heartbeat_SUITE.erl
@@ -51,3 +51,9 @@ t_info(_) ->
     HrtBt = emqx_stomp_heartbeat:init({100, 100}),
     #{incoming := _,
       outgoing := _} = emqx_stomp_heartbeat:info(HrtBt).
+
+t_interval(_) ->
+    HrtBt = emqx_stomp_heartbeat:init({1, 0}),
+    1 = emqx_stomp_heartbeat:interval(incoming, HrtBt),
+    undefined = emqx_stomp_heartbeat:interval(outgoing, HrtBt).
+