diff --git a/apps/emqttd/include/emqttd_protocol.hrl b/apps/emqttd/include/emqttd_protocol.hrl index 5266fae54..35637e3d5 100644 --- a/apps/emqttd/include/emqttd_protocol.hrl +++ b/apps/emqttd/include/emqttd_protocol.hrl @@ -196,6 +196,11 @@ packet_id = PacketId}, payload = Payload}). +-define(PUBLISH(Qos, PacketId), + #mqtt_packet{header = #mqtt_packet_header{type = ?PUBLISH, + qos = Qos}, + variable = #mqtt_packet_publish{packet_id = PacketId}}). + -define(PUBACK_PACKET(Type, PacketId), #mqtt_packet{header = #mqtt_packet_header{type = Type}, variable = #mqtt_packet_puback{packet_id = PacketId}}). diff --git a/apps/emqttd/src/emqttd_protocol.erl b/apps/emqttd/src/emqttd_protocol.erl index 2c4f9568c..7bdc1d67c 100644 --- a/apps/emqttd/src/emqttd_protocol.erl +++ b/apps/emqttd/src/emqttd_protocol.erl @@ -170,37 +170,18 @@ handle(Packet = ?CONNECT_PACKET(Var), State0 = #proto_state{peername = Peername} %% Send connack send(?CONNACK_PACKET(ReturnCode1), State3); -handle(Packet = ?PUBLISH_PACKET(?QOS_0, Topic, _PacketId, _Payload), - State = #proto_state{clientid = ClientId, session = Session}) -> +handle(Packet = ?PUBLISH_PACKET(_Qos, Topic, _PacketId, _Payload), + State = #proto_state{clientid = ClientId}) -> + case check_acl(publish, Topic, State) of - allow -> - do_publish(Session, ClientId, Packet); + allow -> + publish(Packet, State); deny -> lager:error("ACL Deny: ~s cannot publish to ~s", [ClientId, Topic]) end, {ok, State}; -handle(Packet = ?PUBLISH_PACKET(?QOS_1, Topic, PacketId, _Payload), - State = #proto_state{clientid = ClientId, session = Session}) -> - case check_acl(publish, Topic, State) of - allow -> - do_publish(Session, ClientId, Packet), - send(?PUBACK_PACKET(?PUBACK, PacketId), State); - deny -> - lager:error("ACL Deny: ~s cannot publish to ~s", [ClientId, Topic]), - {ok, State} - end; -handle(Packet = ?PUBLISH_PACKET(?QOS_2, Topic, PacketId, _Payload), - State = #proto_state{clientid = ClientId, session = Session}) -> - case check_acl(publish, Topic, State) of - allow -> - do_publish(Session, ClientId, Packet), - send(?PUBACK_PACKET(?PUBREC, PacketId), State); - deny -> - lager:error("ACL Deny: ~s cannot publish to ~s", [ClientId, Topic]), - {ok, State} - end; handle(?PUBACK_PACKET(?PUBACK, PacketId), State = #proto_state{session = Session}) -> emqttd_session:puback(Session, PacketId), @@ -256,10 +237,26 @@ handle(?PACKET(?DISCONNECT), State) -> % clean willmsg {stop, normal, State#proto_state{will_msg = undefined}}. -do_publish(Session, ClientId, Packet) -> - Msg = emqttd_message:from_packet(ClientId, Packet), - Msg1 = emqttd_broker:foldl_hooks(client_publish, [], Msg), - emqttd_session:publish(Session, Msg1). +publish(Packet = ?PUBLISH(?QOS_0, _PacketId), #proto_state{clientid = ClientId, session = Session}) -> + emqttd_session:publish(Session, emqttd_message:from_packet(ClientId, Packet)); + +publish(Packet = ?PUBLISH(?QOS_1, PacketId), State = #proto_state{clientid = ClientId, session = Session}) -> + case emqttd_session:publish(Session, emqttd_message:from_packet(ClientId, Packet)) of + ok -> + send(?PUBACK_PACKET(?PUBACK, PacketId), State); + {error, Error} -> + %%TODO: log format... + lager:error("Client ~s: publish qos1 error ~p", [ClientId, Error]) + end; + +publish(Packet = ?PUBLISH(?QOS_2, PacketId), State = #proto_state{clientid = ClientId, session = Session}) -> + case emqttd_session:publish(Session, emqttd_message:from_packet(ClientId, Packet)) of + ok -> + send(?PUBACK_PACKET(?PUBREC, PacketId), State); + {error, Error} -> + %%TODO: log format... + lager:error("Client ~s: publish qos2 error ~p", [ClientId, Error]) + end. -spec send(mqtt_message() | mqtt_packet(), proto_state()) -> {ok, proto_state()}. send(Msg, State) when is_record(Msg, mqtt_message) -> @@ -323,7 +320,7 @@ send_willmsg(_ClientId, undefined) -> ignore; send_willmsg(ClientId, WillMsg) -> lager:info("Client ~s send willmsg: ~p", [ClientId, WillMsg]), - emqttd_pubsub:publish(ClientId, WillMsg). + emqttd_pubsub:publish(WillMsg#mqtt_message{from = ClientId}). start_keepalive(0) -> ignore; diff --git a/apps/emqttd/src/emqttd_pubsub.erl b/apps/emqttd/src/emqttd_pubsub.erl index 1e1304d1a..3699bfc9f 100644 --- a/apps/emqttd/src/emqttd_pubsub.erl +++ b/apps/emqttd/src/emqttd_pubsub.erl @@ -47,7 +47,7 @@ -export([create/1, subscribe/1, unsubscribe/1, - publish/2, + publish/1, %local node dispatch/2, match/1]). @@ -81,7 +81,7 @@ mnesia(boot) -> {ram_copies, [node()]}, {record_name, mqtt_subscriber}, {attributes, record_info(fields, mqtt_subscriber)}, - {index, [pid]}, + {index, [subpid]}, {local_content, true}]); mnesia(copy) -> @@ -156,19 +156,23 @@ cast(Msg) -> %% @doc Publish to cluster nodes %% @end %%------------------------------------------------------------------------------ --spec publish(From :: mqtt_clientid() | atom(), Msg :: mqtt_message()) -> ok. -publish(From, #mqtt_message{topic=Topic} = Msg) -> +-spec publish(Msg :: mqtt_message()) -> ok. +publish(#mqtt_message{topic=Topic, from = From} = Msg) -> trace(publish, From, Msg), + + %%TODO:call hooks here... + %%Msg1 = emqttd_broker:foldl_hooks(client_publish, [], Msg), + %% Retain message first. Don't create retained topic. case emqttd_msg_store:retain(Msg) of ok -> %TODO: why unset 'retain' flag? - publish(From, Topic, emqttd_message:unset_flag(Msg)); + publish(Topic, emqttd_message:unset_flag(Msg)); ignore -> - publish(From, Topic, Msg) + publish(Topic, Msg) end. -publish(From, <<"$Q/", _/binary>> = Queue, #mqtt_message{qos = Qos} = Msg) -> +publish(<<"$Q/", _/binary>> = Queue, #mqtt_message{qos = Qos} = Msg) -> lists:foreach( fun(#mqtt_queue{subpid = SubPid, qos = SubQos}) -> Msg1 = if @@ -178,7 +182,7 @@ publish(From, <<"$Q/", _/binary>> = Queue, #mqtt_message{qos = Qos} = Msg) -> SubPid ! {dispatch, Msg1} end, mnesia:dirty_read(queue, Queue)); -publish(_From, Topic, Msg) when is_binary(Topic) -> +publish(Topic, Msg) when is_binary(Topic) -> lists:foreach(fun(#mqtt_topic{topic=Name, node=Node}) -> case Node =:= node() of true -> dispatch(Name, Msg); diff --git a/apps/emqttd/src/emqttd_session.erl b/apps/emqttd/src/emqttd_session.erl index d4e8f081d..384e49ae1 100644 --- a/apps/emqttd/src/emqttd_session.erl +++ b/apps/emqttd/src/emqttd_session.erl @@ -21,13 +21,13 @@ %%%----------------------------------------------------------------------------- %%% @doc %%% -%%% emqttd session for persistent client. +%%% Session for persistent MQTT client. %%% %%% Session State in the broker consists of: %%% %%% 1. The Client’s subscriptions. %%% -%%% 2. inflight qos1, qos2 messages sent to the client but unacked, QoS 1 and QoS 2 +%%% 2. inflight qos1/2 messages sent to the client but unacked, QoS 1 and QoS 2 %%% messages which have been sent to the Client, but have not been completely %%% acknowledged. %%% @@ -59,6 +59,8 @@ puback/2, pubrec/2, pubrel/2, pubcomp/2, subscribe/2, unsubscribe/2]). +-behaviour(gen_server). + %% gen_server Function Exports -export([init/1, handle_call/3, handle_cast/2, handle_info/2, terminate/2, code_change/3]). @@ -116,7 +118,7 @@ max_awaiting_rel = 100, %% session expired after 48 hours - expired_after = 172800, + expired_after = 48, expired_timer, @@ -126,6 +128,7 @@ %% @doc Start a session. %% @end %%------------------------------------------------------------------------------ +-spec start_link(boolean(), binary(), pid()) -> {ok, pid()} | {error, any()}. start_link(CleanSess, ClientId, ClientPid) -> gen_server:start_link(?MODULE, [CleanSess, ClientId, ClientPid], []). @@ -133,7 +136,8 @@ start_link(CleanSess, ClientId, ClientPid) -> %% @doc Resume a session. %% @end %%------------------------------------------------------------------------------ -resume(Session, ClientId, ClientPid) when is_pid(Session) -> +-spec resume(pid(), binary(), pid()) -> ok. +resume(Session, ClientId, ClientPid) -> gen_server:cast(Session, {resume, ClientId, ClientPid}). %%------------------------------------------------------------------------------ @@ -141,73 +145,69 @@ resume(Session, ClientId, ClientPid) when is_pid(Session) -> %% @end %%------------------------------------------------------------------------------ -spec destroy(Session:: pid(), ClientId :: binary()) -> ok. -destroy(Session, ClientId) when is_pid(Session) -> +destroy(Session, ClientId) -> gen_server:call(Session, {destroy, ClientId}). +%%------------------------------------------------------------------------------ +%% @doc Subscribe Topics +%% @end +%%------------------------------------------------------------------------------ +-spec subscribe(pid(), [{binary(), mqtt_qos()}]) -> {ok, [mqtt_qos()]}. +subscribe(Session, TopicTable) -> + gen_server:call(Session, {subscribe, TopicTable}). + %%------------------------------------------------------------------------------ %% @doc Publish message %% @end %%------------------------------------------------------------------------------ -spec publish(Session :: pid(), {mqtt_qos(), mqtt_message()}) -> ok. -publish(Session, Msg = #mqtt_message{qos = ?QOS_0}) when is_pid(Session) -> +publish(_Session, Msg = #mqtt_message{qos = ?QOS_0}) -> %% publish qos0 directly emqttd_pubsub:publish(Msg); -publish(Session, Msg = #mqtt_message{qos = ?QOS_1}) when is_pid(Session) -> - %% publish qos1 directly, and client will puback +publish(_Session, Msg = #mqtt_message{qos = ?QOS_1}) -> + %% publish qos1 directly, and client will puback automatically emqttd_pubsub:publish(Msg); -publish(Session, Msg = #mqtt_message{qos = ?QOS_2}) when is_pid(Session) -> +publish(Session, Msg = #mqtt_message{qos = ?QOS_2}) -> %% publish qos2 by session - gen_server:cast(Session, {publish, Msg}). + gen_server:call(Session, {publish, Msg}). %%------------------------------------------------------------------------------ %% @doc PubAck message %% @end %%------------------------------------------------------------------------------ --spec puback(Session :: pid(), MsgId :: mqtt_packet_id()) -> ok. -puback(Session, MsgId) when is_pid(Session) -> +-spec puback(pid(), mqtt_msgid()) -> ok. +puback(Session, MsgId) -> gen_server:cast(Session, {puback, MsgId}). --spec pubrec(Session :: pid(), MsgId :: mqtt_packet_id()) -> ok. -pubrec(Session, MsgId) when is_pid(Session) -> +-spec pubrec(pid(), mqtt_msgid()) -> ok. +pubrec(Session, MsgId) -> gen_server:cast(Session, {pubrec, MsgId}). --spec pubrel(Session :: pid(), MsgId :: mqtt_packet_id()) -> ok. -pubrel(Session, MsgId) when is_pid(Session) -> +-spec pubrel(pid(), mqtt_msgid()) -> ok. +pubrel(Session, MsgId) -> gen_server:cast(Session, {pubrel, MsgId}). --spec pubcomp(Session :: pid(), MsgId :: mqtt_packet_id()) -> ok. -pubcomp(Session, MsgId) when is_pid(Session) -> +-spec pubcomp(pid(), mqtt_msgid()) -> ok. +pubcomp(Session, MsgId) -> gen_server:cast(Session, {pubcomp, MsgId}). -%%------------------------------------------------------------------------------ -%% @doc Subscribe Topics -%% @end -%%------------------------------------------------------------------------------ --spec subscribe(Session :: pid(), [{binary(), mqtt_qos()}]) -> {ok, [mqtt_qos()]}. -subscribe(Session, Topics) when is_pid(Session) -> - gen_server:call(Session, {subscribe, Topics}). - %%------------------------------------------------------------------------------ %% @doc Unsubscribe Topics %% @end %%------------------------------------------------------------------------------ --spec unsubscribe(Session :: pid(), [Topic :: binary()]) -> ok. -unsubscribe(Session, Topics) when is_pid(Session) -> +-spec unsubscribe(pid(), [binary()]) -> ok. +unsubscribe(Session, Topics) -> gen_server:call(Session, {unsubscribe, Topics}). %%%============================================================================= %%% gen_server callbacks %%%============================================================================= + init([CleanSess, ClientId, ClientPid]) -> - if - CleanSess =:= false -> - process_flag(trap_exit, true), - true = link(ClientPid); - CleanSess =:= true -> - ok - end, + process_flag(trap_exit, true), + true = link(ClientPid), QEnv = emqttd:env(mqtt, queue), SessEnv = emqttd:env(mqtt, session), PendingQ = emqttd_mqueue:new(ClientId, QEnv), @@ -227,25 +227,25 @@ init([CleanSess, ClientId, ClientPid]) -> await_rel_timeout = emqttd_opts:g(await_rel_timeout, SessEnv), max_awaiting_rel = emqttd_opts:g(max_awaiting_rel, SessEnv), expired_after = emqttd_opts:g(expired_after, SessEnv) * 3600, - timestamp = os:timestamp() - }, + timestamp = os:timestamp()}, {ok, Session, hibernate}. -handle_call({subscribe, Topics}, _From, Session = #session{clientid = ClientId, subscriptions = Subscriptions}) -> +handle_call({subscribe, Topics}, _From, Session = #session{clientid = ClientId, + subscriptions = Subscriptions}) -> %% subscribe first and don't care if the subscriptions have been existed {ok, GrantedQos} = emqttd_pubsub:subscribe(Topics), - lager:info([{client, ClientId}], "Session ~s subscribe ~p. Granted QoS: ~p", - [ClientId, Topics, GrantedQos]), + lager:info([{client, ClientId}], "Session ~s subscribe ~p, Granted QoS: ~p", + [ClientId, Topics, GrantedQos]), Subscriptions1 = lists:foldl(fun({Topic, Qos}, Acc) -> case lists:keyfind(Topic, 1, Acc) of {Topic, Qos} -> - lager:warning([{client, ClientId}], "~s resubscribe ~p: qos = ~p", [ClientId, Topic, Qos]), Acc; + lager:warning([{client, ClientId}], "Session ~s resubscribe ~p: qos = ~p", [ClientId, Topic, Qos]), Acc; {Topic, Old} -> - lager:warning([{client, ClientId}], "~s resubscribe ~p: old qos=~p, new qos=~p", + lager:warning([{client, ClientId}], "Session ~s resubscribe ~p: old qos=~p, new qos=~p", [ClientId, Topic, Old, Qos]), lists:keyreplace(Topic, 1, Acc, {Topic, Qos}); false -> @@ -263,7 +263,7 @@ handle_call({unsubscribe, Topics}, _From, Session = #session{clientid = ClientId %%unsubscribe from topic tree ok = emqttd_pubsub:unsubscribe(Topics), - lager:info([{client, ClientId}], "Client ~s unsubscribe ~p.", [ClientId, Topics]), + lager:info([{client, ClientId}], "Session ~s unsubscribe ~p.", [ClientId, Topics]), Subscriptions1 = lists:foldl(fun(Topic, Acc) -> @@ -277,12 +277,24 @@ handle_call({unsubscribe, Topics}, _From, Session = #session{clientid = ClientId {reply, ok, Session#session{subscriptions = Subscriptions1}}; +handle_call({publish, Message = #mqtt_message{qos = ?QOS_2, msgid = MsgId}}, _From, + Session = #session{clientid = ClientId, awaiting_rel = AwaitingRel, await_rel_timeout = Timeout}) -> + case check_awaiting_rel(Session) of + true -> + TRef = timer(Timeout, {timeout, awaiting_rel, MsgId}), + {reply, ok, Session#session{awaiting_rel = maps:put(MsgId, {Message, TRef}, AwaitingRel)}}; + false -> + lager:error([{clientid, ClientId}], "Session ~s " + " dropped Qos2 message for too many awaiting_rel: ~p", [ClientId, Message]), + {reply, {error, dropped}, Session} + end; + handle_call({destroy, ClientId}, _From, Session = #session{clientid = ClientId}) -> lager:warning("Session ~s destroyed", [ClientId]), {stop, {shutdown, destroy}, ok, Session}; handle_call(Req, _From, State) -> - lager:error("Unexpected Request: ~p", [Req]), + lager:critical("Unexpected Request: ~p", [Req]), {reply, {error, badreq}, State}. handle_cast({resume, ClientId, ClientPid}, State = #session{ @@ -331,9 +343,6 @@ handle_cast({resume, ClientId, ClientPid}, State = #session{ pending_queue = emqttd_queue:clear(Queue), expired_timer = undefined}, hibernate}; -handle_cast({publish, Message = #mqtt_message{qos = ?QOS_2}}, Session) -> - {noreply, publish_qos2(Message, Session)}; - handle_cast({puback, MsgId}, Session = #session{clientid = ClientId, inflight_queue = Q, awaiting_ack = Awaiting}) -> case maps:find(MsgId, Awaiting) of @@ -362,14 +371,14 @@ handle_cast({pubrec, MsgId}, Session = #session{clientid = ClientId, {noreply, Session} end; -handle_cast({pubrel, MsgId}, Session = #session{clientid = ClientId, awaiting_rel = Awaiting}) -> - case maps:find(MsgId, Awaiting) of +handle_cast({pubrel, MsgId}, Session = #session{clientid = ClientId, awaiting_rel = AwaitingRel}) -> + case maps:find(MsgId, AwaitingRel) of {ok, {Msg, TRef}} -> catch erlang:cancel_timer(TRef), emqttd_pubsub:publish(Msg), - {noreply, Session#session{awaiting_rel = maps:remove(MsgId, Awaiting)}}; + {noreply, Session#session{awaiting_rel = maps:remove(MsgId, AwaitingRel)}}; error -> - lager:error("Session ~s cannot find PUBREL'~p'!", [ClientId, MsgId]), + lager:error("Session ~s cannot find PUBREL '~p'!", [ClientId, MsgId]), {noreply, Session} end; @@ -408,6 +417,10 @@ handle_info({'EXIT', ClientPid, Reason}, Session = #session{clean_sess = false, TRef = timer(Expires * 1000, session_expired), {noreply, Session#session{expired_timer = TRef}}; +handle_info({'EXIT', ClientPid, _Reason}, Session = #session{clean_sess = true, client_pid = ClientPid}) -> + %%TODO: reason... + {stop, normal, Session}; + handle_info({'EXIT', ClientPid0, _Reason}, State = #session{client_pid = ClientPid}) -> lager:critical("Unexpected Client EXIT: pid=~p, pid(state): ~p", [ClientPid0, ClientPid]), {noreply, State}; @@ -456,19 +469,6 @@ code_change(_OldVsn, Session, _Extra) -> %% @end %%------------------------------------------------------------------------------ -publish_qos2(Message = #mqtt_message{qos = ?QOS_2,msgid = MsgId}, Session = #session{clientid = ClientId, - awaiting_rel = AwaitingRel, - await_rel_timeout = Timeout}) -> - - case check_awaiting_rel(Session) of - true -> - TRef = timer(Timeout, {timeout, awaiting_rel, MsgId}), - Session#session{awaiting_rel = maps:put(MsgId, {Message, TRef}, AwaitingRel)}; - false -> - lager:error([{clientid, ClientId}], "Session ~s " - " dropped Qos2 message for too many awaiting_rel: ~p", [ClientId, Message]), - Session - end. check_awaiting_rel(#session{max_awaiting_rel = 0}) -> true;