diff --git a/apps/emqx_stomp/rebar.config b/apps/emqx_stomp/rebar.config index a9e3eebd6..7ac3b98c8 100644 --- a/apps/emqx_stomp/rebar.config +++ b/apps/emqx_stomp/rebar.config @@ -14,5 +14,3 @@ {cover_enabled, true}. {cover_opts, [verbose]}. {cover_export_enabled, true}. - -{plugins, [coveralls]}. diff --git a/apps/emqx_stomp/src/emqx_stomp_connection.erl b/apps/emqx_stomp/src/emqx_stomp_connection.erl index 3bb3a58fd..740320489 100644 --- a/apps/emqx_stomp/src/emqx_stomp_connection.erl +++ b/apps/emqx_stomp/src/emqx_stomp_connection.erl @@ -127,10 +127,6 @@ handle_info(timeout, State) -> handle_info({shutdown, Reason}, State) -> shutdown(Reason, State); -handle_info({transaction, {timeout, Id}}, State) -> - emqx_stomp_transaction:timeout(Id), - noreply(State); - handle_info({timeout, TRef, TMsg}, State) when TMsg =:= incoming; TMsg =:= outgoing -> @@ -145,6 +141,9 @@ handle_info({timeout, TRef, TMsg}, State) when TMsg =:= incoming; shutdown({sock_error, Reason}, State) end; +handle_info({timeout, TRef, TMsg}, State) -> + with_proto(timeout, [TRef, TMsg], State); + handle_info({'EXIT', HbProc, Error}, State = #state{heartbeat = HbProc}) -> stop(Error, State); diff --git a/apps/emqx_stomp/src/emqx_stomp_protocol.erl b/apps/emqx_stomp/src/emqx_stomp_protocol.erl index a366105b8..184b2cc90 100644 --- a/apps/emqx_stomp/src/emqx_stomp_protocol.erl +++ b/apps/emqx_stomp/src/emqx_stomp_protocol.erl @@ -38,6 +38,12 @@ , timeout/3 ]). +%% for trans callback +-export([ handle_recv_send_frame/2 + , handle_recv_ack_frame/2 + , handle_recv_nack_frame/2 + ]). + -record(pstate, { peername, heartfun, @@ -50,14 +56,18 @@ allow_anonymous, default_user, subscriptions = [], - timers :: #{atom() => disable | undefined | reference()} + timers :: #{atom() => disable | undefined | reference()}, + transaction :: #{binary() => list()} }). -define(TIMER_TABLE, #{ incoming_timer => incoming, - outgoing_timer => outgoing + outgoing_timer => outgoing, + clean_trans_timer => clean_trans }). +-define(TRANS_TIMEOUT, 60000). + -type(pstate() :: #pstate{}). %% @doc Init protocol @@ -70,6 +80,7 @@ init(#{peername := Peername, heartfun = HeartFun, sendfun = SendFun, timers = #{}, + transaction = #{}, allow_anonymous = AllowAnonymous, default_user = DefaultUser}. @@ -121,18 +132,10 @@ received(#stomp_frame{command = <<"CONNECT">>, headers = Headers}, received(#stomp_frame{command = <<"CONNECT">>}, State = #pstate{connected = true}) -> {error, unexpected_connect, State}; -received(#stomp_frame{command = <<"SEND">>, headers = Headers, body = Body}, State) -> - Topic = header(<<"destination">>, Headers), - Action = fun(State0) -> - maybe_send_receipt(receipt_id(Headers), State0), - emqx_broker:publish( - make_mqtt_message(Topic, Headers, iolist_to_binary(Body)) - ), - State0 - end, +received(Frame = #stomp_frame{command = <<"SEND">>, headers = Headers}, State) -> case header(<<"transaction">>, Headers) of - undefined -> {ok, Action(State)}; - TransactionId -> add_action(TransactionId, Action, receipt_id(Headers), State) + undefined -> {ok, handle_recv_send_frame(Frame, State)}; + TransactionId -> add_action(TransactionId, {fun ?MODULE:handle_recv_send_frame/2, [Frame]}, receipt_id(Headers), State) end; received(#stomp_frame{command = <<"SUBSCRIBE">>, headers = Headers}, @@ -167,15 +170,10 @@ received(#stomp_frame{command = <<"UNSUBSCRIBE">>, headers = Headers}, %% transaction:tx1 %% %% ^@ -received(#stomp_frame{command = <<"ACK">>, headers = Headers}, State) -> - Id = header(<<"id">>, Headers), - Action = fun(State0) -> - maybe_send_receipt(receipt_id(Headers), State0), - ack(Id, State0) - end, +received(Frame = #stomp_frame{command = <<"ACK">>, headers = Headers}, State) -> case header(<<"transaction">>, Headers) of - undefined -> {ok, Action(State)}; - TransactionId -> add_action(TransactionId, Action, receipt_id(Headers), State) + undefined -> {ok, handle_recv_ack_frame(Frame, State)}; + TransactionId -> add_action(TransactionId, {fun ?MODULE:handle_recv_ack_frame/2, [Frame]}, receipt_id(Headers), State) end; %% NACK @@ -183,29 +181,25 @@ received(#stomp_frame{command = <<"ACK">>, headers = Headers}, State) -> %% transaction:tx1 %% %% ^@ -received(#stomp_frame{command = <<"NACK">>, headers = Headers}, State) -> - Id = header(<<"id">>, Headers), - Action = fun(State0) -> - maybe_send_receipt(receipt_id(Headers), State0), - nack(Id, State0) - end, +received(Frame = #stomp_frame{command = <<"NACK">>, headers = Headers}, State) -> case header(<<"transaction">>, Headers) of - undefined -> {ok, Action(State)}; - TransactionId -> add_action(TransactionId, Action, receipt_id(Headers), State) + undefined -> {ok, handle_recv_nack_frame(Frame, State)}; + TransactionId -> add_action(TransactionId, {fun ?MODULE:handle_recv_nack_frame/2, [Frame]}, receipt_id(Headers), State) end; %% BEGIN %% transaction:tx1 %% %% ^@ -received(#stomp_frame{command = <<"BEGIN">>, headers = Headers}, State) -> - Id = header(<<"transaction">>, Headers), - %% self() ! TimeoutMsg - TimeoutMsg = {transaction, {timeout, Id}}, - case emqx_stomp_transaction:start(Id, TimeoutMsg) of - {ok, _Transaction} -> - maybe_send_receipt(receipt_id(Headers), State); - {error, already_started} -> +received(#stomp_frame{command = <<"BEGIN">>, headers = Headers}, + State = #pstate{transaction = Trans}) -> + Id = header(<<"transaction">>, Headers), + case maps:get(Id, Trans, undefined) of + undefined -> + Ts = erlang:system_time(millisecond), + NState = ensure_clean_trans_timer(State#pstate{transaction = Trans#{Id => {Ts, []}}}), + maybe_send_receipt(receipt_id(Headers), NState); + _ -> send(error_frame(receipt_id(Headers), ["Transaction ", Id, " already started"]), State) end; @@ -213,12 +207,16 @@ received(#stomp_frame{command = <<"BEGIN">>, headers = Headers}, State) -> %% transaction:tx1 %% %% ^@ -received(#stomp_frame{command = <<"COMMIT">>, headers = Headers}, State) -> +received(#stomp_frame{command = <<"COMMIT">>, headers = Headers}, + State = #pstate{transaction = Trans}) -> Id = header(<<"transaction">>, Headers), - case emqx_stomp_transaction:commit(Id, State) of - {ok, NState} -> + case maps:get(Id, Trans, undefined) of + {_, Actions} -> + NState = lists:foldr(fun({Func, Args}, S) -> + erlang:apply(Func, Args ++ [S]) + end, State#pstate{transaction = maps:remove(Id, Trans)}, Actions), maybe_send_receipt(receipt_id(Headers), NState); - {error, not_found} -> + _ -> send(error_frame(receipt_id(Headers), ["Transaction ", Id, " not found"]), State) end; @@ -226,12 +224,14 @@ received(#stomp_frame{command = <<"COMMIT">>, headers = Headers}, State) -> %% transaction:tx1 %% %% ^@ -received(#stomp_frame{command = <<"ABORT">>, headers = Headers}, State) -> +received(#stomp_frame{command = <<"ABORT">>, headers = Headers}, + State = #pstate{transaction = Trans}) -> Id = header(<<"transaction">>, Headers), - case emqx_stomp_transaction:abort(Id) of - ok -> - maybe_send_receipt(receipt_id(Headers), State); - {error, not_found} -> + case maps:get(Id, Trans, undefined) of + {_, _Actions} -> + NState = State#pstate{transaction = maps:remove(Id, Trans)}, + maybe_send_receipt(receipt_id(Headers), NState); + _ -> send(error_frame(receipt_id(Headers), ["Transaction ", Id, " not found"]), State) end; @@ -247,8 +247,8 @@ send(Msg = #message{topic = Topic, headers = Headers, payload = Payload}, {<<"message-id">>, next_msgid()}, {<<"destination">>, Topic}, {<<"content-type">>, <<"text/plain">>}], - Headers1 = case Ack of - _ when Ack =:= <<"client">> orelse Ack =:= <<"client-individual">> -> + Headers1 = case Ack of + _ when Ack =:= <<"client">> orelse Ack =:= <<"client-individual">> -> Headers0 ++ [{<<"ack">>, next_ackid()}]; _ -> Headers0 @@ -290,7 +290,12 @@ timeout(_TRef, {outgoing, NewVal}, {ok, State}; {ok, NHrtBt} -> {ok, reset_timer(outgoing_timer, State#pstate{heart_beats = NHrtBt})} - end. + end; + +timeout(_TRef, clean_trans, State = #pstate{transaction = Trans}) -> + Now = erlang:system_time(millisecond), + NTrans = maps:filter(fun(_, {Ts, _}) -> Ts + ?TRANS_TIMEOUT < Now end, Trans), + {ok, ensure_clean_trans_timer(State#pstate{transaction = NTrans})}. negotiate_version(undefined) -> {ok, <<"1.0">>}; @@ -318,11 +323,12 @@ check_login(Login, Passcode, _, DefaultUser) -> {_, _ } -> false end. -add_action(Id, Action, ReceiptId, State) -> - case emqx_stomp_transaction:add(Id, Action) of - {ok, _} -> - {ok, State}; - {error, not_found} -> +add_action(Id, Action, ReceiptId, State = #pstate{transaction = Trans}) -> + case maps:get(Id, Trans, undefined) of + {Ts, Actions} -> + NTrans = Trans#{Id => {Ts, [Action|Actions]}}, + {ok, State#pstate{transaction = NTrans}}; + _ -> send(error_frame(ReceiptId, ["Transaction ", Id, " not found"]), State) end. @@ -331,7 +337,7 @@ maybe_send_receipt(undefined, State) -> maybe_send_receipt(ReceiptId, State) -> send(receipt_frame(ReceiptId), State). -ack(_Id, State) -> +ack(_Id, State) -> State. nack(_Id, State) -> State. @@ -360,7 +366,7 @@ next_msgid() -> undefined -> 1; I -> I end, - put(msgid, MsgId + 1), + put(msgid, MsgId + 1), MsgId. next_ackid() -> @@ -368,16 +374,16 @@ next_ackid() -> undefined -> 1; I -> I end, - put(ackid, AckId + 1), + put(ackid, AckId + 1), AckId. make_mqtt_message(Topic, Headers, Body) -> Msg = emqx_message:make(stomp, Topic, Body), Headers1 = lists:foldl(fun(Key, Headers0) -> proplists:delete(Key, Headers0) - end, Headers, [<<"destination">>, - <<"content-length">>, - <<"content-type">>, + end, Headers, [<<"destination">>, + <<"content-length">>, + <<"content-type">>, <<"transaction">>, <<"receipt">>]), emqx_message:set_headers(#{stomp_headers => Headers1}, Msg). @@ -385,6 +391,33 @@ make_mqtt_message(Topic, Headers, Body) -> receipt_id(Headers) -> header(<<"receipt">>, Headers). +%%-------------------------------------------------------------------- +%% Transaction Handle + +handle_recv_send_frame(#stomp_frame{command = <<"SEND">>, headers = Headers, body = Body}, State) -> + Topic = header(<<"destination">>, Headers), + maybe_send_receipt(receipt_id(Headers), State), + emqx_broker:publish( + make_mqtt_message(Topic, Headers, iolist_to_binary(Body)) + ), + State. + +handle_recv_ack_frame(#stomp_frame{command = <<"ACK">>, headers = Headers}, State) -> + Id = header(<<"id">>, Headers), + maybe_send_receipt(receipt_id(Headers), State), + ack(Id, State). + +handle_recv_nack_frame(#stomp_frame{command = <<"NACK">>, headers = Headers}, State) -> + Id = header(<<"id">>, Headers), + maybe_send_receipt(receipt_id(Headers), State), + nack(Id, State). + +ensure_clean_trans_timer(State = #pstate{transaction = Trans}) -> + case maps:size(Trans) of + 0 -> State; + _ -> ensure_timer(clean_trans_timer, State) + end. + %%-------------------------------------------------------------------- %% Heartbeat @@ -433,4 +466,7 @@ clean_timer(Name, State = #pstate{timers = Timers}) -> interval(incoming_timer, #pstate{heart_beats = HrtBt}) -> emqx_stomp_heartbeat:interval(incoming, HrtBt); interval(outgoing_timer, #pstate{heart_beats = HrtBt}) -> - emqx_stomp_heartbeat:interval(outgoing, HrtBt). + emqx_stomp_heartbeat:interval(outgoing, HrtBt); +interval(clean_trans_timer, _) -> + ?TRANS_TIMEOUT. + diff --git a/apps/emqx_stomp/src/emqx_stomp_transaction.erl b/apps/emqx_stomp/src/emqx_stomp_transaction.erl deleted file mode 100644 index 6f15e8aa5..000000000 --- a/apps/emqx_stomp/src/emqx_stomp_transaction.erl +++ /dev/null @@ -1,77 +0,0 @@ -%%-------------------------------------------------------------------- -%% Copyright (c) 2020 EMQ Technologies Co., Ltd. All Rights Reserved. -%% -%% Licensed under the Apache License, Version 2.0 (the "License"); -%% you may not use this file except in compliance with the License. -%% You may obtain a copy of the License at -%% -%% http://www.apache.org/licenses/LICENSE-2.0 -%% -%% Unless required by applicable law or agreed to in writing, software -%% distributed under the License is distributed on an "AS IS" BASIS, -%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -%% See the License for the specific language governing permissions and -%% limitations under the License. -%%-------------------------------------------------------------------- - -%% @doc Stomp Transaction - --module(emqx_stomp_transaction). - --include("emqx_stomp.hrl"). - --export([ start/2 - , add/2 - , commit/2 - , abort/1 - , timeout/1 - ]). - --record(transaction, {id, actions, tref}). - --define(TIMEOUT, 60000). - -start(Id, TimeoutMsg) -> - case get({transaction, Id}) of - undefined -> - TRef = erlang:send_after(?TIMEOUT, self(), TimeoutMsg), - Transaction = #transaction{id = Id, actions = [], tref = TRef}, - put({transaction, Id}, Transaction), - {ok, Transaction}; - _Transaction -> - {error, already_started} - end. - -add(Id, Action) -> - Fun = fun(Transaction = #transaction{actions = Actions}) -> - Transaction1 = Transaction#transaction{actions = [Action | Actions]}, - put({transaction, Id}, Transaction1), - {ok, Transaction1} - end, - with_transaction(Id, Fun). - -commit(Id, InitState) -> - Fun = fun(Transaction = #transaction{actions = Actions}) -> - done(Transaction), - {ok, lists:foldr(fun(Action, State) -> Action(State) end, - InitState, Actions)} - end, - with_transaction(Id, Fun). - -abort(Id) -> - with_transaction(Id, fun done/1). - -timeout(Id) -> - erase({transaction, Id}). - -done(#transaction{id = Id, tref = TRef}) -> - erase({transaction, Id}), - catch erlang:cancel_timer(TRef), - ok. - -with_transaction(Id, Fun) -> - case get({transaction, Id}) of - undefined -> {error, not_found}; - Transaction -> Fun(Transaction) - end. -