diff --git a/apps/emqtt/include/emqtt.hrl b/apps/emqtt/include/emqtt.hrl index 2d172945d..c73fc8f71 100644 --- a/apps/emqtt/include/emqtt.hrl +++ b/apps/emqtt/include/emqtt.hrl @@ -68,10 +68,10 @@ %% MQTT Message %%------------------------------------------------------------------------------ -record(mqtt_message, { + msgid :: integer() | undefined, qos = ?QOS_0 :: mqtt_qos(), retain = false :: boolean(), dup = false :: boolean(), - msgid :: integer(), topic :: binary(), payload :: binary() }). diff --git a/apps/emqtt/src/emqtt_client.erl b/apps/emqtt/src/emqtt_client.erl index f438e16ef..d0b4774f2 100644 --- a/apps/emqtt/src/emqtt_client.erl +++ b/apps/emqtt/src/emqtt_client.erl @@ -90,11 +90,12 @@ handle_cast(Msg, State) -> handle_info(timeout, State) -> stop({shutdown, timeout}, State); -handle_info({stop, duplicate_id, NewPid}, State=#state{conn_name=ConnName}) -> +handle_info({stop, duplicate_id, _NewPid}, State=#state{ proto_state = ProtoState, conn_name=ConnName}) -> %% TODO: to... %% need transfer data??? %% emqtt_client:transfer(NewPid, Data), - %% lager:error("Shutdown for duplicate clientid:~s, conn:~s", [ClientId, ConnName]), + lager:error("Shutdown for duplicate clientid: ~s, conn:~s", + [emqtt_protocol:client_id(ProtoState), ConnName]), stop({shutdown, duplicate_id}, State); %%TODO: ok?? @@ -105,8 +106,8 @@ handle_info({dispatch, {From, Message}}, #state{proto_state = ProtoState} = Stat handle_info({inet_reply, _Ref, ok}, State) -> {noreply, State, hibernate}; -handle_info({inet_async, Sock, _Ref, {ok, Data}}, #state{ peer_name = PeerName, socket = Sock } = State) -> - lager:debug("RECV from ~s: ~p", [State#state.peer_name, Data]), +handle_info({inet_async, Sock, _Ref, {ok, Data}}, State = #state{ peer_name = PeerName, socket = Sock }) -> + lager:debug("RECV from ~s: ~p", [PeerName, Data]), process_received_bytes( Data, control_throttle(State #state{ await_recv = false })); @@ -192,7 +193,7 @@ process_received_bytes(Bytes, end. %%---------------------------------------------------------------------------- -network_error(Reason, State = #state{ peer_name = PeerName, conn_name = ConnStr }) -> +network_error(Reason, State = #state{ peer_name = PeerName }) -> lager:error("Client ~s: MQTT detected network error '~p'", [PeerName, Reason]), stop({shutdown, conn_closed}, State). diff --git a/apps/emqtt/src/emqtt_message.erl b/apps/emqtt/src/emqtt_message.erl new file mode 100644 index 000000000..5ea06405d --- /dev/null +++ b/apps/emqtt/src/emqtt_message.erl @@ -0,0 +1,89 @@ +%%----------------------------------------------------------------------------- +%% Copyright (c) 2012-2015, Feng Lee +%% +%% Permission is hereby granted, free of charge, to any person obtaining a copy +%% of this software and associated documentation files (the "Software"), to deal +%% in the Software without restriction, including without limitation the rights +%% to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +%% copies of the Software, and to permit persons to whom the Software is +%% furnished to do so, subject to the following conditions: +%% +%% The above copyright notice and this permission notice shall be included in all +%% copies or substantial portions of the Software. +%% +%% THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +%% IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +%% FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +%% AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +%% LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +%% OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +%% SOFTWARE. +%%------------------------------------------------------------------------------ + +-module(emqtt_message). + +-include("emqtt.hrl"). + +-include("emqtt_packet.hrl"). + +-export([to_packet/1]). + +%%---------------------------------------------------------------------------- + +-ifdef(use_specs). + +-spec( from_packet( mqtt_packet() ) -> mqtt_message() | undefined ). + +-spec( to_packet( mqtt_message() ) -> mqtt_packet() ). + +-endif + +%%---------------------------------------------------------------------------- + +from_packet(#mqtt_packet{ header = #mqtt_packet_header{ type = ?PUBLISH, + qos = Qos, + retain = Retain, + dup = Dup }, + variable = #mqtt_packet_publish{ topic_name = Topic, + packet_id = PacketId }, + payload = Payload }) -> + #mqtt_message{ msgid = PacketId, + qos = Qos, + retain = Retain, + dup = Dup, + topic = Topic, + payload = Payload }; + +from_packet(#mqtt_packet_connect{ will_flag = false }) -> + undefined; + +from_packet(#mqtt_packet_connect{ will_retain = Retain, + will_qos = Qos, + will_topic = Topic, + will_msg = Msg }) -> + #mqtt_message{ retain = Retain, + qos = Qos, + topic = Topic, + dup = false, + payload = Msg }. + +to_packet(#mqtt_message{ msgid = MsgId, + qos = Qos, + retain = Retain, + dup = Dup, + topic = Topic, + payload = Payload }) -> + + PacketId = if + Qos =:= ?QOS_0 -> undefined; + true -> MsgId + end, + + #mqtt_packet{ header = #mqtt_packet_header { type = ?PUBLISH, + qos = Qos, + retain = Retain, + dup = Dup }, + variable = #mqtt_packet_publish { topic_name = Topic, + packet_id = PacketId }, + payload = Payload }. + diff --git a/apps/emqtt/src/emqtt_protocol.erl b/apps/emqtt/src/emqtt_protocol.erl index 72e9813d4..364d60398 100644 --- a/apps/emqtt/src/emqtt_protocol.erl +++ b/apps/emqtt/src/emqtt_protocol.erl @@ -30,29 +30,42 @@ %% API Function Exports %% ------------------------------------------------------------------ --export([initial_state/2]). +-export([initial_state/2, client_id/1]). -export([handle_packet/2, send_message/2, send_packet/2, shutdown/2]). -export([info/1]). + %% ------------------------------------------------------------------ %% Protocol State %% ------------------------------------------------------------------ --record(proto_state, { - socket, +-record(proto_state, { + socket, peer_name, connected = false, %received CONNECT action? - proto_vsn, + proto_vsn, proto_name, - packet_id, + %packet_id, client_id, clean_sess, session, %% session state or session pid will_msg }). --type proto_state() :: #proto_state{}. +%%---------------------------------------------------------------------------- + +-ifdef(use_specs). + +-type(proto_state() :: #proto_state{}). + +-spec(send_message({pid() | tuple(), mqtt_message()}, proto_state()) -> {ok, proto_state()}). + +-spec(handle_packet(mqtt_packet(), proto_state()) -> {ok, proto_state()} | {error, any()}). + +-endif. + +%%---------------------------------------------------------------------------- -define(PACKET_TYPE(Packet, Type), Packet = #mqtt_packet { header = #mqtt_packet_header { type = Type }}). @@ -62,28 +75,23 @@ initial_state(Socket, Peername) -> #proto_state{ socket = Socket, - peer_name = Peername, - packet_id = 1 + peer_name = Peername }. +client_id(#proto_state { client_id = ClientId }) -> ClientId. + %%SHOULD be registered in emqtt_cm info(#proto_state{ proto_vsn = ProtoVsn, proto_name = ProtoName, - packet_id = PacketId, client_id = ClientId, clean_sess = CleanSess, will_msg = WillMsg }) -> - [ {packet_id, PacketId}, - {proto_vsn, ProtoVsn}, + [ {proto_vsn, ProtoVsn}, {proto_name, ProtoName}, {client_id, ClientId}, {clean_sess, CleanSess}, {will_msg, WillMsg} ]. --spec handle_packet(Packet, State) -> {ok, NewState} | {error, any()} when - Packet :: mqtt_packet(), - State :: proto_state(), - NewState :: proto_state(). %%CONNECT – Client requests a connection to a Server @@ -125,7 +133,7 @@ handle_packet(?CONNECT, Packet = #mqtt_packet { ClientId1 = clientid(ClientId, State), start_keepalive(KeepAlive), emqtt_cm:register(ClientId1, self()), - {?CONNACK_ACCEPT, State#proto_state{ will_msg = make_willmsg(Var), + {?CONNACK_ACCEPT, State#proto_state{ will_msg = willmsg(Var), clean_sess = CleanSess, client_id = ClientId1 }}; false -> @@ -145,21 +153,21 @@ handle_packet(?CONNECT, Packet = #mqtt_packet { handle_packet(?PUBLISH, Packet = #mqtt_packet { header = #mqtt_packet_header {qos = ?QOS_0}}, State = #proto_state{session = Session}) -> - emqtt_session:publish(Session, {?QOS_0, make_message(Packet)}), + emqtt_session:publish(Session, {?QOS_0, emqtt_messsage:from_packet(Packet)}), {ok, State}; handle_packet(?PUBLISH, Packet = #mqtt_packet { header = #mqtt_packet_header { qos = ?QOS_1 }, variable = #mqtt_packet_publish{packet_id = PacketId }}, State = #proto_state { session = Session }) -> - emqtt_session:publish(Session, {?QOS_1, make_message(Packet)}), + emqtt_session:publish(Session, {?QOS_1, emqtt_messsage:from_packet(Packet)}), send_packet( make_packet(?PUBACK, PacketId), State); handle_packet(?PUBLISH, Packet = #mqtt_packet { header = #mqtt_packet_header { qos = ?QOS_2 }, variable = #mqtt_packet_publish { packet_id = PacketId } }, State = #proto_state { session = Session }) -> - NewSession = emqtt_session:publish(Session, {?QOS_2, make_message(Packet)}), + NewSession = emqtt_session:publish(Session, {?QOS_2, emqtt_message:from_packet(Packet)}), send_packet( make_packet(?PUBREC, PacketId), State#proto_state {session = NewSession} ); handle_packet(Puback, #mqtt_packet{variable = ?PUBACK_PACKET(PacketId) }, @@ -188,9 +196,9 @@ handle_packet(?SUBSCRIBE, #mqtt_packet { Topics = [{Name, Qos} || #mqtt_topic{name=Name, qos=Qos} <- TopicTable], {ok, NewSession, GrantedQos} = emqtt_session:subscribe(Session, Topics), send_packet(#mqtt_packet { header = #mqtt_packet_header { type = ?SUBACK }, - variable = #mqtt_packet_suback{ - packet_id = PacketId, - qos_table = GrantedQos }}, State); + variable = #mqtt_packet_suback{ packet_id = PacketId, + qos_table = GrantedQos }}, + State#proto_state{ session = NewSession }); handle_packet(?UNSUBSCRIBE, #mqtt_packet { variable = #mqtt_packet_subscribe{ @@ -223,41 +231,19 @@ puback_qos(?PUBREC) -> ?QOS_0; puback_qos(?PUBREL) -> ?QOS_1; puback_qos(?PUBCOMP) -> ?QOS_0. --spec send_message({From, Message}, State) -> {ok, NewState} when - From :: pid(), - Message :: mqtt_message(), - State :: proto_state(), - NewState :: proto_state(). +%% qos0 message +send_message({_From, Message = #mqtt_message{ qos = ?QOS_0 }}, State) -> + send_packet(emqtt_message:to_packet(Message), State); -send_message({From, Message = #mqtt_message{ - retain = Retain, - qos = Qos, - topic = Topic, - dup = Dup, - payload = Payload}}, - State = #proto_state{packet_id = PacketId}) -> +%% message from session +send_message({_From = SessPid, Message}, State = #proto_state{session = SessPid}) when is_pid(SessPid) -> + send_packet(emqtt_message:to_packet(Message), State); - Packet = #mqtt_packet { - header = #mqtt_packet_header { - type = ?PUBLISH, - qos = Qos, - retain = Retain, - dup = Dup }, - variable = #mqtt_packet_publish { - topic_name = Topic, - packet_id = if - Qos == ?QOS_0 -> undefined; - true -> PacketId - end }, - payload = Payload}, - - send_packet(Packet, State), - if - Qos == ?QOS_0 -> - {ok, State}; - true -> - {ok, next_packet_id(State)} - end. +%% message(qos1, qos2) not from session +send_message({_From, Message = #mqtt_message{ qos = Qos }}, State = #proto_state{ session = Session }) + when (Qos =:= ?QOS_1) orelse (Qos =:= ?QOS_2) -> + {Message1, NewSession} = emqtt_session:store(Session, Message), + send_packet(emqtt_message:to_packet(Message1), State#proto_state{session = NewSession}). send_packet(Packet, State = #proto_state{socket = Sock, peer_name = PeerName, client_id = ClientId}) -> lager:info("SENT to ~s@~s: ~s", [ClientId, PeerName, emqtt_packet:dump(Packet)]), @@ -267,57 +253,20 @@ send_packet(Packet, State = #proto_state{socket = Sock, peer_name = PeerName, cl erlang:port_command(Sock, Data), {ok, State}. -shutdown(Error, State = #proto_state{peer_name = PeerName, client_id = ClientId, will_msg = WillMsg}) -> +shutdown(Error, #proto_state{peer_name = PeerName, client_id = ClientId, will_msg = WillMsg}) -> send_willmsg(WillMsg), try_unregister(ClientId, self()), lager:info("Protocol ~s@~s Shutdown: ~p", [ClientId, PeerName, Error]), ok. -make_message(#mqtt_packet { - header = #mqtt_packet_header{ - qos = Qos, - retain = Retain, - dup = Dup }, - variable = #mqtt_packet_publish{ - topic_name = Topic, - packet_id = PacketId }, - payload = Payload }) -> - - #mqtt_message{ retain = Retain, - qos = Qos, - topic = Topic, - dup = Dup, - msgid = PacketId, - payload = Payload}. - -make_willmsg(#mqtt_packet_connect{ will_flag = false }) -> - undefined; - -make_willmsg(#mqtt_packet_connect{ will_retain = Retain, - will_qos = Qos, - will_topic = Topic, - will_msg = Msg }) -> - #mqtt_message{ retain = Retain, - qos = Qos, - topic = Topic, - dup = false, - payload = Msg }. - -next_packet_id(State = #proto_state{ packet_id = 16#ffff }) -> - State #proto_state{ packet_id = 1 }; -next_packet_id(State = #proto_state{ packet_id = PacketId }) -> - State #proto_state{ packet_id = PacketId + 1 }. - +willmsg(Packet) when is_record(Packet, mqtt_packet_connect) -> + emqtt_packet:from_packet(Packet). clientid(<<>>, #proto_state{peer_name = PeerName}) -> <<"eMQTT/", (base64:encode(PeerName))/binary>>; clientid(ClientId, _State) -> ClientId. -maybe_clean_sess(false, _Conn, _ClientId) -> - % todo: establish subscription to deliver old unacknowledged messages - ok. - %%---------------------------------------------------------------------------- send_willmsg(undefined) -> ignore; @@ -328,7 +277,6 @@ start_keepalive(0) -> ignore; start_keepalive(Sec) when Sec > 0 -> self() ! {keepalive, start, round(Sec * 1.5)}. - %%---------------------------------------------------------------------------- %% Validate Packets %%---------------------------------------------------------------------------- @@ -365,7 +313,7 @@ validate_packet(#mqtt_packet { header = #mqtt_packet_header { type = ?PUBLISH } variable = #mqtt_packet_publish{ topic_name = Topic }}) -> case emqtt_topic:validate({publish, Topic}) of true -> ok; - false -> lager:error("Error Publish Topic: ~p", [Topic]), {error, badtopic} + false -> lager:warning("Error publish topic: ~p", [Topic]), {error, badtopic} end; validate_packet(#mqtt_packet { header = #mqtt_packet_header { type = ?SUBSCRIBE }, diff --git a/apps/emqtt/src/emqtt_pubsub.erl b/apps/emqtt/src/emqtt_pubsub.erl index 51274128d..c05e344f1 100644 --- a/apps/emqtt/src/emqtt_pubsub.erl +++ b/apps/emqtt/src/emqtt_pubsub.erl @@ -127,8 +127,14 @@ publish(Topic, Msg) when is_binary(Topic) -> end, match(Topic)). %dispatch locally, should only be called by publish -dispatch(Topic, Msg) when is_binary(Topic) -> - [SubPid ! {dispatch, {self(), Msg}} || #topic_subscriber{subpid=SubPid} <- ets:lookup(topic_subscriber, Topic)]. +dispatch(Topic, Msg = #mqtt_message{qos = Qos}) when is_binary(Topic) -> + lists:foreach(fun(#topic_subscriber{qos = SubQos, subpid=SubPid}) -> + Msg1 = if + Qos > SubQos -> Msg#mqtt_message{qos = SubQos}; + true -> Msg + end, + SubPid ! {dispatch, {self(), Msg1}} + end, ets:lookup(topic_subscriber, Topic)). -spec match(Topic :: binary()) -> [topic()]. match(Topic) when is_binary(Topic) -> diff --git a/apps/emqtt/src/emqtt_queue.erl b/apps/emqtt/src/emqtt_queue.erl index bb6e3cb77..c025be04f 100644 --- a/apps/emqtt/src/emqtt_queue.erl +++ b/apps/emqtt/src/emqtt_queue.erl @@ -21,53 +21,54 @@ %%------------------------------------------------------------------------------ -module(emqtt_queue). --behaviour(gen_server). +-include("emqtt.hrl"). --define(SERVER, ?MODULE). +-export([new/1, new/2, in/3, all/1, clear/1]). -%% ------------------------------------------------------------------ -%% API Function Exports -%% ------------------------------------------------------------------ +%%---------------------------------------------------------------------------- --export([start_link/0]). +-ifdef(use_specs). -%% ------------------------------------------------------------------ -%% gen_server Function Exports -%% ------------------------------------------------------------------ +-type(mqtt_queue() :: #mqtt_queue_wrapper{}). --export([init/1, handle_call/3, handle_cast/2, handle_info/2, - terminate/2, code_change/3]). +-spec(new(non_neg_intger()) -> mqtt_queue()). -%% ------------------------------------------------------------------ -%% API Function Definitions -%% ------------------------------------------------------------------ +-spec(in(binary(), mqtt_message(), mqtt_queue()) -> mqtt_queue()). -start_link() -> - gen_server:start_link(?MODULE, [], []). +-spec(all(mqtt_queue()) -> list()). -%% ------------------------------------------------------------------ -%% gen_server Function Definitions -%% ------------------------------------------------------------------ +-spec(clear(mqtt_queue()) -> mqtt_queue()). -init(Args) -> - {ok, Args}. +-endif. -handle_call(_Request, _From, State) -> - {reply, ok, State}. +%%---------------------------------------------------------------------------- -handle_cast(_Msg, State) -> - {noreply, State}. +-define(DEFAULT_MAX_LEN, 1000). -handle_info(_Info, State) -> - {noreply, State}. +-record(mqtt_queue_wrapper, { queue = queue:new(), max_len = ?DEFAULT_MAX_LEN, store_qos0 = false }). -terminate(_Reason, _State) -> - ok. +new(MaxLen) -> #mqtt_queue_wrapper{ max_len = MaxLen }. -code_change(_OldVsn, State, _Extra) -> - {ok, State}. +new(MaxLen, StoreQos0) -> #mqtt_queue_wrapper{ max_len = MaxLen, store_qos0 = StoreQos0 }. -%% ------------------------------------------------------------------ -%% Internal Function Definitions -%% ------------------------------------------------------------------ +in(ClientId, Message = #mqtt_message{qos = Qos}, + Wrapper = #mqtt_queue_wrapper{ queue = Queue, max_len = MaxLen}) -> + case queue:len(Queue) < MaxLen of + true -> + Wrapper#mqtt_queue_wrapper{ queue = queue:in(Message, Queue) }; + false -> % full + if + Qos =:= ?QOS_0 -> + lager:warning("Queue ~s drop qos0 message: ~p", [ClientId, Message]), + Wrapper; + true -> + {{value, Msg}, Queue1} = queue:drop(Queue), + lager:warning("Queue ~s drop message: ~p", [ClientId, Msg]), + Wrapper#mqtt_queue_wrapper{ queue = Queue1 } + end + end. + +all(#mqtt_queue_wrapper { queue = Queue }) -> queue:to_list(Queue). + +clear(Queue) -> Queue#mqtt_queue_wrapper{ queue = queue:new() }. diff --git a/apps/emqtt/src/emqtt_router.erl b/apps/emqtt/src/emqtt_router.erl index 9b552e8d3..e7e6bf033 100644 --- a/apps/emqtt/src/emqtt_router.erl +++ b/apps/emqtt/src/emqtt_router.erl @@ -47,6 +47,16 @@ -export([init/1, handle_call/3, handle_cast/2, handle_info/2, terminate/2, code_change/3]). +%%---------------------------------------------------------------------------- + +-ifdef(use_specs). + +-spec(start_link/1 :: () -> {ok, pid()}). + +-spec route(mqtt_message()) -> ok. + +-endif. + %% ------------------------------------------------------------------ %% API Function Definitions %% ------------------------------------------------------------------ @@ -54,9 +64,8 @@ start_link() -> gen_server:start_link({local, ?SERVER}, ?MODULE, [], []). --spec route(Msg :: mqtt_message()) -> any(). -route(Msg) -> - emqtt_pubsub:publish(retained(Msg)). +route(Message) -> + emqtt_pubsub:publish(retained(Message)). %% ------------------------------------------------------------------ %% gen_server Function Definitions diff --git a/apps/emqtt/src/emqtt_session.erl b/apps/emqtt/src/emqtt_session.erl index 42e47e27e..c4abd8e3f 100644 --- a/apps/emqtt/src/emqtt_session.erl +++ b/apps/emqtt/src/emqtt_session.erl @@ -31,6 +31,8 @@ %% ------------------------------------------------------------------ -export([start/1, resume/3, publish/2, puback/2, subscribe/2, unsubscribe/2, destroy/2]). +-export([store/2]). + %%start gen_server -export([start_link/3]). @@ -44,14 +46,14 @@ -record(session_state, { client_id :: binary(), client_pid :: pid(), - packet_id = 1, + message_id = 1, submap :: map(), - messages = [], %% do not receive rel + msg_queue, %% do not receive rel awaiting_ack :: map(), awaiting_rel :: map(), + awaiting_comp :: map(), expires, - expire_timer, - max_queue }). + expire_timer }). %% ------------------------------------------------------------------ %% Start Session @@ -74,51 +76,69 @@ resume(SessPid, ClientId, ClientPid) when is_pid(SessPid) -> gen_server:cast(SessPid, {resume, ClientId, ClientPid}), SessPid. -publish(_, {?QOS_0, Message}) -> - emqtt_router:route(Message); -%%TODO: -publish(_, {?QOS_1, Message}) -> - emqtt_router:route(Message); -%%TODO: -publish(SessState = #session_state{awaiting_rel = Awaiting}, +publish(Session, {?QOS_0, Message}) -> + emqtt_router:route(Message), Session; + +publish(Session, {?QOS_1, Message}) -> + emqtt_router:route(Message), Session; + +publish(SessState = #session_state{awaiting_rel = AwaitingRel}, {?QOS_2, Message = #mqtt_message{ msgid = MsgId }}) -> - %% store in awaiting map - %%TODO: TIMEOUT - Awaiting1 = maps:put(MsgId, Message, Awaiting), - SessState#session_state{awaiting_rel = Awaiting1}; + %% store in awaiting_rel + SessState#session_state{awaiting_rel = maps:put(MsgId, Message, AwaitingRel)}; publish(SessPid, {?QOS_2, Message}) when is_pid(SessPid) -> gen_server:cast(SessPid, {publish, ?QOS_2, Message}), SessPid. +%% PUBACK puback(SessState = #session_state{client_id = ClientId, awaiting_ack = Awaiting}, {?PUBACK, PacketId}) -> - Awaiting1 = case maps:is_key(PacketId, Awaiting) of - true -> maps:remove(PacketId, Awaiting); - false -> lager:warning("~s puback packetid '~p' not exist", [ClientId, PacketId]) + true -> ok; + false -> lager:warning("Session ~s: PUBACK PacketId '~p' not found!", [ClientId, PacketId]) end, - SessState#session_state{awaiting_ack= Awaiting1}; + SessState#session_state{awaiting_ack = maps:remove(PacketId, Awaiting)}; puback(SessPid, {?PUBACK, PacketId}) when is_pid(SessPid) -> gen_server:cast(SessPid, {puback, PacketId}), SessPid; -puback(SessState = #session_state{}, {?PUBREC, PacketId}) -> - %%TODO' - SessState; +%% PUBREC +puback(SessState = #session_state{ client_id = ClientId, + awaiting_ack = AwaitingAck, + awaiting_comp = AwaitingComp }, {?PUBREC, PacketId}) -> + case maps:is_key(PacketId, AwaitingAck) of + true -> ok; + false -> lager:warning("Session ~s: PUBREC PacketId '~p' not found!", [ClientId, PacketId]) + end, + SessState#session_state{ awaiting_ack = maps:remove(PacketId, AwaitingAck), + awaiting_comp = maps:put(PacketId, true, AwaitingComp) }; + puback(SessPid, {?PUBREC, PacketId}) when is_pid(SessPid) -> gen_server:cast(SessPid, {pubrec, PacketId}), SessPid; -puback(SessState = #session_state{}, {?PUBREL, PacketId}) -> - %FIXME Later: should release the message here - %%emqtt_router:route(Message). - 'TODO', erase({msg, PacketId}), SessState; +%% PUBREL +puback(SessState = #session_state{client_id = ClientId, awaiting_rel = Awaiting}, {?PUBREL, PacketId}) -> + case maps:find(PacketId, Awaiting) of + {ok, Msg} -> emqtt_router:route(Msg); + error -> lager:warning("Session ~s: PUBREL PacketId '~p' not found!", [ClientId, PacketId]) + end, + SessState#session_state{awaiting_rel = maps:remove(PacketId, Awaiting)}; + puback(SessPid, {?PUBREL, PacketId}) when is_pid(SessPid) -> gen_server:cast(SessPid, {pubrel, PacketId}), SessPid; -puback(SessState = #session_state{}, {?PUBCOMP, PacketId}) -> - 'TODO', SessState; +%% PUBCOMP +puback(SessState = #session_state{ client_id = ClientId, + awaiting_comp = AwaitingComp}, {?PUBCOMP, PacketId}) -> + case maps:is_key(PacketId, AwaitingComp) of + true -> ok; + false -> lager:warning("Session ~s: PUBREC PacketId '~p' not exist", [ClientId, PacketId]) + end, + SessState#session_state{ awaiting_comp = maps:remove(PacketId, AwaitingComp) }; + puback(SessPid, {?PUBCOMP, PacketId}) when is_pid(SessPid) -> gen_server:cast(SessPid, {pubcomp, PacketId}), SessPid. +%% SUBSCRIBE subscribe(SessState = #session_state{client_id = ClientId, submap = SubMap}, Topics) -> Resubs = [Topic || {Name, _Qos} = Topic <- Topics, maps:is_key(Name, SubMap)], case Resubs of @@ -127,14 +147,15 @@ subscribe(SessState = #session_state{client_id = ClientId, submap = SubMap}, Top end, SubMap1 = lists:foldl(fun({Name, Qos}, Acc) -> maps:put(Name, Qos, Acc) end, SubMap, Topics), {ok, GrantedQos} = emqtt_pubsub:subscribe(Topics, self()), - %[ok = emqtt_pubsub:subscribe({Topic, Qos}, self()) || {Topic, Qos} <- Topics], - %GrantedQos = [Qos || {_Name, Qos} <- Topics], {ok, SessState#session_state{submap = SubMap1}, GrantedQos}; subscribe(SessPid, Topics) when is_pid(SessPid) -> {ok, GrantedQos} = gen_server:call(SessPid, {subscribe, Topics}), {ok, SessPid, GrantedQos}. +%% +%% @doc UNSUBSCRIBE +%% unsubscribe(SessState = #session_state{client_id = ClientId, submap = SubMap}, Topics) -> %%TODO: refactor later. case Topics -- maps:keys(SubMap) of @@ -153,12 +174,25 @@ unsubscribe(SessPid, Topics) when is_pid(SessPid) -> destroy(SessPid, ClientId) when is_pid(SessPid) -> gen_server:cast(SessPid, {destroy, ClientId}). +%store message(qos1) that sent to client +store(SessState = #session_state{ message_id = MsgId, awaiting_ack = Awaiting}, + Message = #mqtt_message{ qos = Qos }) when (Qos =:= ?QOS_1) orelse (Qos =:= ?QOS_2) -> + %%assign msgid before send + Message1 = Message#mqtt_message{ msgid = MsgId }, + Message2 = + if + Qos =:= ?QOS_2 -> Message1#mqtt_message{dup = false}; + true -> Message1 + end, + Awaiting1 = maps:put(MsgId, Message2, Awaiting), + {Message1, next_msg_id(SessState#session_state{ awaiting_ack = Awaiting1 })}. + initial_state(ClientId) -> #session_state { client_id = ClientId, - packet_id = 1, submap = #{}, awaiting_ack = #{}, - awaiting_rel = #{} }. + awaiting_rel = #{}, + awaiting_comp = #{} }. initial_state(ClientId, ClientPid) -> State = initial_state(ClientId), @@ -173,12 +207,14 @@ start_link(SessOpts, ClientId, ClientPid) -> init([SessOpts, ClientId, ClientPid]) -> process_flag(trap_exit, true), - %%TODO: OK? + %%TODO: Is this OK? true = link(ClientPid), State = initial_state(ClientId, ClientPid), - {ok, State#session_state{ - expires = proplists:get_value(expires, SessOpts, 24) * 3600, - max_queue = proplists:get_value(max_queue, SessOpts, 1000) } }. + Expires = proplists:get_value(expires, SessOpts, 1) * 3600, + MsgQueue = emqtt_queue:new( proplists:get_value(max_queue, SessOpts, 1000), + proplists:get_value(store_qos0, SessOpts, false) ), + {ok, State#session_state{ expires = Expires, + msg_queue = MsgQueue }, hibernate}. handle_call({subscribe, Topics}, _From, State) -> {ok, NewState, GrantedQos} = subscribe(State, Topics), @@ -194,13 +230,13 @@ handle_call(Req, _From, State) -> handle_cast({resume, ClientId, ClientPid}, State = #session_state { client_id = ClientId, client_pid = undefined, - messages = Messages, + msg_queue = Queue, expire_timer = ETimer}) -> lager:info("Session: client ~s resumed by ~p", [ClientId, ClientPid]), erlang:cancel_timer(ETimer), - [ClientPid ! {dispatch, {self(), Message}} || Message <- lists:reverse(Messages)], - NewState = State#session_state{ client_pid = ClientPid, messages = [], expire_timer = undefined}, - {noreply, NewState}; + [ClientPid ! {dispatch, {self(), Message}} || Message <- emqtt_queue:all(Queue)], + NewState = State#session_state{ client_pid = ClientPid, msg_queue = emqtt_queue:clear(Queue), expire_timer = undefined}, + {noreply, NewState, hibernate}; handle_cast({publish, ?QOS_2, Message}, State) -> NewState = publish(State, {?QOS_2, Message}), @@ -223,22 +259,14 @@ handle_cast({pubcomp, PacketId}, State) -> {noreply, NewState}; handle_cast({destroy, ClientId}, State = #session_state{client_id = ClientId}) -> - lager:warning("Session: ~s destroyed", [ClientId]), + lager:warning("Session ~s destroyed", [ClientId]), {stop, normal, State}; handle_cast(Msg, State) -> {stop, {badmsg, Msg}, State}. -handle_info({dispatch, {_From, Message}}, State = #session_state{ - client_pid = undefined, messages = Messages}) -> - %%TODO: queue len - NewState = State#session_state{messages = [Message | Messages]}, - {noreply, NewState}; - -handle_info({dispatch, {_From, Message}}, State = #session_state{client_pid = ClientPid}) -> - %%TODO: replace From with self(), ok? - ClientPid ! {dispatch, {self(), Message}}, - {noreply, State}; +handle_info({dispatch, {_From, Message}}, State) -> + {noreply, dispatch(Message, State)}; handle_info({'EXIT', ClientPid, Reason}, State = #session_state{ client_id = ClientId, client_pid = ClientPid, expires = Expires}) -> @@ -247,7 +275,7 @@ handle_info({'EXIT', ClientPid, Reason}, State = #session_state{ {noreply, State#session_state{ client_pid = undefined, expire_timer = Timer}}; handle_info(session_expired, State = #session_state{client_id = ClientId}) -> - lager:warning("Session: ~s session expired!", [ClientId]), + lager:warning("Session ~s expired!", [ClientId]), {stop, {shutdown, expired}, State}; handle_info(Info, State) -> @@ -263,4 +291,28 @@ code_change(_OldVsn, State, _Extra) -> %% Internal Function Definitions %% ------------------------------------------------------------------ +dispatch(Message, State = #session_state{ client_id = ClientId, + client_pid = undefined }) -> + queue(ClientId, Message, State); +dispatch(Message = #mqtt_message{ qos = ?QOS_0 }, State = #session_state{ + client_pid = ClientPid }) -> + ClientPid ! {dispatch, {self(), Message}}, + State; + +dispatch(Message = #mqtt_message{ qos = Qos }, State = #session_state{ client_pid = ClientPid }) + when (Qos =:= ?QOS_1) orelse (Qos =:= ?QOS_2) -> + {Message1, NewState} = store(State, Message), + ClientPid ! {dispatch, {self(), Message1}}, + NewState. + +queue(ClientId, Message, State = #session_state{msg_queue = Queue}) -> + State#session_state{msg_queue = emqtt_queue:in(ClientId, Message, Queue)}. + +next_msg_id(State = #session_state{ message_id = 16#ffff }) -> + State#session_state{ message_id = 1 }; + +next_msg_id(State = #session_state{ message_id = MsgId }) -> + State#session_state{ message_id = MsgId + 1 }. + +