supprot qos0, qos1, qos2

This commit is contained in:
Ery Lee 2015-01-15 23:50:37 +08:00
parent 45b63a6b13
commit 999c2b5ebd
8 changed files with 300 additions and 194 deletions

View File

@ -68,10 +68,10 @@
%% MQTT Message
%%------------------------------------------------------------------------------
-record(mqtt_message, {
msgid :: integer() | undefined,
qos = ?QOS_0 :: mqtt_qos(),
retain = false :: boolean(),
dup = false :: boolean(),
msgid :: integer(),
topic :: binary(),
payload :: binary()
}).

View File

@ -90,11 +90,12 @@ handle_cast(Msg, State) ->
handle_info(timeout, State) ->
stop({shutdown, timeout}, State);
handle_info({stop, duplicate_id, NewPid}, State=#state{conn_name=ConnName}) ->
handle_info({stop, duplicate_id, _NewPid}, State=#state{ proto_state = ProtoState, conn_name=ConnName}) ->
%% TODO: to...
%% need transfer data???
%% emqtt_client:transfer(NewPid, Data),
%% lager:error("Shutdown for duplicate clientid:~s, conn:~s", [ClientId, ConnName]),
lager:error("Shutdown for duplicate clientid: ~s, conn:~s",
[emqtt_protocol:client_id(ProtoState), ConnName]),
stop({shutdown, duplicate_id}, State);
%%TODO: ok??
@ -105,8 +106,8 @@ handle_info({dispatch, {From, Message}}, #state{proto_state = ProtoState} = Stat
handle_info({inet_reply, _Ref, ok}, State) ->
{noreply, State, hibernate};
handle_info({inet_async, Sock, _Ref, {ok, Data}}, #state{ peer_name = PeerName, socket = Sock } = State) ->
lager:debug("RECV from ~s: ~p", [State#state.peer_name, Data]),
handle_info({inet_async, Sock, _Ref, {ok, Data}}, State = #state{ peer_name = PeerName, socket = Sock }) ->
lager:debug("RECV from ~s: ~p", [PeerName, Data]),
process_received_bytes(
Data, control_throttle(State #state{ await_recv = false }));
@ -192,7 +193,7 @@ process_received_bytes(Bytes,
end.
%%----------------------------------------------------------------------------
network_error(Reason, State = #state{ peer_name = PeerName, conn_name = ConnStr }) ->
network_error(Reason, State = #state{ peer_name = PeerName }) ->
lager:error("Client ~s: MQTT detected network error '~p'", [PeerName, Reason]),
stop({shutdown, conn_closed}, State).

View File

@ -0,0 +1,89 @@
%%-----------------------------------------------------------------------------
%% Copyright (c) 2012-2015, Feng Lee <feng@emqtt.io>
%%
%% Permission is hereby granted, free of charge, to any person obtaining a copy
%% of this software and associated documentation files (the "Software"), to deal
%% in the Software without restriction, including without limitation the rights
%% to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
%% copies of the Software, and to permit persons to whom the Software is
%% furnished to do so, subject to the following conditions:
%%
%% The above copyright notice and this permission notice shall be included in all
%% copies or substantial portions of the Software.
%%
%% THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
%% IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
%% FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
%% AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
%% LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
%% OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
%% SOFTWARE.
%%------------------------------------------------------------------------------
-module(emqtt_message).
-include("emqtt.hrl").
-include("emqtt_packet.hrl").
-export([to_packet/1]).
%%----------------------------------------------------------------------------
-ifdef(use_specs).
-spec( from_packet( mqtt_packet() ) -> mqtt_message() | undefined ).
-spec( to_packet( mqtt_message() ) -> mqtt_packet() ).
-endif
%%----------------------------------------------------------------------------
from_packet(#mqtt_packet{ header = #mqtt_packet_header{ type = ?PUBLISH,
qos = Qos,
retain = Retain,
dup = Dup },
variable = #mqtt_packet_publish{ topic_name = Topic,
packet_id = PacketId },
payload = Payload }) ->
#mqtt_message{ msgid = PacketId,
qos = Qos,
retain = Retain,
dup = Dup,
topic = Topic,
payload = Payload };
from_packet(#mqtt_packet_connect{ will_flag = false }) ->
undefined;
from_packet(#mqtt_packet_connect{ will_retain = Retain,
will_qos = Qos,
will_topic = Topic,
will_msg = Msg }) ->
#mqtt_message{ retain = Retain,
qos = Qos,
topic = Topic,
dup = false,
payload = Msg }.
to_packet(#mqtt_message{ msgid = MsgId,
qos = Qos,
retain = Retain,
dup = Dup,
topic = Topic,
payload = Payload }) ->
PacketId = if
Qos =:= ?QOS_0 -> undefined;
true -> MsgId
end,
#mqtt_packet{ header = #mqtt_packet_header { type = ?PUBLISH,
qos = Qos,
retain = Retain,
dup = Dup },
variable = #mqtt_packet_publish { topic_name = Topic,
packet_id = PacketId },
payload = Payload }.

View File

@ -30,29 +30,42 @@
%% API Function Exports
%% ------------------------------------------------------------------
-export([initial_state/2]).
-export([initial_state/2, client_id/1]).
-export([handle_packet/2, send_message/2, send_packet/2, shutdown/2]).
-export([info/1]).
%% ------------------------------------------------------------------
%% Protocol State
%% ------------------------------------------------------------------
-record(proto_state, {
socket,
-record(proto_state, {
socket,
peer_name,
connected = false, %received CONNECT action?
proto_vsn,
proto_vsn,
proto_name,
packet_id,
%packet_id,
client_id,
clean_sess,
session, %% session state or session pid
will_msg
}).
-type proto_state() :: #proto_state{}.
%%----------------------------------------------------------------------------
-ifdef(use_specs).
-type(proto_state() :: #proto_state{}).
-spec(send_message({pid() | tuple(), mqtt_message()}, proto_state()) -> {ok, proto_state()}).
-spec(handle_packet(mqtt_packet(), proto_state()) -> {ok, proto_state()} | {error, any()}).
-endif.
%%----------------------------------------------------------------------------
-define(PACKET_TYPE(Packet, Type),
Packet = #mqtt_packet { header = #mqtt_packet_header { type = Type }}).
@ -62,28 +75,23 @@
initial_state(Socket, Peername) ->
#proto_state{
socket = Socket,
peer_name = Peername,
packet_id = 1
peer_name = Peername
}.
client_id(#proto_state { client_id = ClientId }) -> ClientId.
%%SHOULD be registered in emqtt_cm
info(#proto_state{ proto_vsn = ProtoVsn,
proto_name = ProtoName,
packet_id = PacketId,
client_id = ClientId,
clean_sess = CleanSess,
will_msg = WillMsg }) ->
[ {packet_id, PacketId},
{proto_vsn, ProtoVsn},
[ {proto_vsn, ProtoVsn},
{proto_name, ProtoName},
{client_id, ClientId},
{clean_sess, CleanSess},
{will_msg, WillMsg} ].
-spec handle_packet(Packet, State) -> {ok, NewState} | {error, any()} when
Packet :: mqtt_packet(),
State :: proto_state(),
NewState :: proto_state().
%%CONNECT Client requests a connection to a Server
@ -125,7 +133,7 @@ handle_packet(?CONNECT, Packet = #mqtt_packet {
ClientId1 = clientid(ClientId, State),
start_keepalive(KeepAlive),
emqtt_cm:register(ClientId1, self()),
{?CONNACK_ACCEPT, State#proto_state{ will_msg = make_willmsg(Var),
{?CONNACK_ACCEPT, State#proto_state{ will_msg = willmsg(Var),
clean_sess = CleanSess,
client_id = ClientId1 }};
false ->
@ -145,21 +153,21 @@ handle_packet(?CONNECT, Packet = #mqtt_packet {
handle_packet(?PUBLISH, Packet = #mqtt_packet {
header = #mqtt_packet_header {qos = ?QOS_0}},
State = #proto_state{session = Session}) ->
emqtt_session:publish(Session, {?QOS_0, make_message(Packet)}),
emqtt_session:publish(Session, {?QOS_0, emqtt_messsage:from_packet(Packet)}),
{ok, State};
handle_packet(?PUBLISH, Packet = #mqtt_packet {
header = #mqtt_packet_header { qos = ?QOS_1 },
variable = #mqtt_packet_publish{packet_id = PacketId }},
State = #proto_state { session = Session }) ->
emqtt_session:publish(Session, {?QOS_1, make_message(Packet)}),
emqtt_session:publish(Session, {?QOS_1, emqtt_messsage:from_packet(Packet)}),
send_packet( make_packet(?PUBACK, PacketId), State);
handle_packet(?PUBLISH, Packet = #mqtt_packet {
header = #mqtt_packet_header { qos = ?QOS_2 },
variable = #mqtt_packet_publish { packet_id = PacketId } },
State = #proto_state { session = Session }) ->
NewSession = emqtt_session:publish(Session, {?QOS_2, make_message(Packet)}),
NewSession = emqtt_session:publish(Session, {?QOS_2, emqtt_message:from_packet(Packet)}),
send_packet( make_packet(?PUBREC, PacketId), State#proto_state {session = NewSession} );
handle_packet(Puback, #mqtt_packet{variable = ?PUBACK_PACKET(PacketId) },
@ -188,9 +196,9 @@ handle_packet(?SUBSCRIBE, #mqtt_packet {
Topics = [{Name, Qos} || #mqtt_topic{name=Name, qos=Qos} <- TopicTable],
{ok, NewSession, GrantedQos} = emqtt_session:subscribe(Session, Topics),
send_packet(#mqtt_packet { header = #mqtt_packet_header { type = ?SUBACK },
variable = #mqtt_packet_suback{
packet_id = PacketId,
qos_table = GrantedQos }}, State);
variable = #mqtt_packet_suback{ packet_id = PacketId,
qos_table = GrantedQos }},
State#proto_state{ session = NewSession });
handle_packet(?UNSUBSCRIBE, #mqtt_packet {
variable = #mqtt_packet_subscribe{
@ -223,41 +231,19 @@ puback_qos(?PUBREC) -> ?QOS_0;
puback_qos(?PUBREL) -> ?QOS_1;
puback_qos(?PUBCOMP) -> ?QOS_0.
-spec send_message({From, Message}, State) -> {ok, NewState} when
From :: pid(),
Message :: mqtt_message(),
State :: proto_state(),
NewState :: proto_state().
%% qos0 message
send_message({_From, Message = #mqtt_message{ qos = ?QOS_0 }}, State) ->
send_packet(emqtt_message:to_packet(Message), State);
send_message({From, Message = #mqtt_message{
retain = Retain,
qos = Qos,
topic = Topic,
dup = Dup,
payload = Payload}},
State = #proto_state{packet_id = PacketId}) ->
%% message from session
send_message({_From = SessPid, Message}, State = #proto_state{session = SessPid}) when is_pid(SessPid) ->
send_packet(emqtt_message:to_packet(Message), State);
Packet = #mqtt_packet {
header = #mqtt_packet_header {
type = ?PUBLISH,
qos = Qos,
retain = Retain,
dup = Dup },
variable = #mqtt_packet_publish {
topic_name = Topic,
packet_id = if
Qos == ?QOS_0 -> undefined;
true -> PacketId
end },
payload = Payload},
send_packet(Packet, State),
if
Qos == ?QOS_0 ->
{ok, State};
true ->
{ok, next_packet_id(State)}
end.
%% message(qos1, qos2) not from session
send_message({_From, Message = #mqtt_message{ qos = Qos }}, State = #proto_state{ session = Session })
when (Qos =:= ?QOS_1) orelse (Qos =:= ?QOS_2) ->
{Message1, NewSession} = emqtt_session:store(Session, Message),
send_packet(emqtt_message:to_packet(Message1), State#proto_state{session = NewSession}).
send_packet(Packet, State = #proto_state{socket = Sock, peer_name = PeerName, client_id = ClientId}) ->
lager:info("SENT to ~s@~s: ~s", [ClientId, PeerName, emqtt_packet:dump(Packet)]),
@ -267,57 +253,20 @@ send_packet(Packet, State = #proto_state{socket = Sock, peer_name = PeerName, cl
erlang:port_command(Sock, Data),
{ok, State}.
shutdown(Error, State = #proto_state{peer_name = PeerName, client_id = ClientId, will_msg = WillMsg}) ->
shutdown(Error, #proto_state{peer_name = PeerName, client_id = ClientId, will_msg = WillMsg}) ->
send_willmsg(WillMsg),
try_unregister(ClientId, self()),
lager:info("Protocol ~s@~s Shutdown: ~p", [ClientId, PeerName, Error]),
ok.
make_message(#mqtt_packet {
header = #mqtt_packet_header{
qos = Qos,
retain = Retain,
dup = Dup },
variable = #mqtt_packet_publish{
topic_name = Topic,
packet_id = PacketId },
payload = Payload }) ->
#mqtt_message{ retain = Retain,
qos = Qos,
topic = Topic,
dup = Dup,
msgid = PacketId,
payload = Payload}.
make_willmsg(#mqtt_packet_connect{ will_flag = false }) ->
undefined;
make_willmsg(#mqtt_packet_connect{ will_retain = Retain,
will_qos = Qos,
will_topic = Topic,
will_msg = Msg }) ->
#mqtt_message{ retain = Retain,
qos = Qos,
topic = Topic,
dup = false,
payload = Msg }.
next_packet_id(State = #proto_state{ packet_id = 16#ffff }) ->
State #proto_state{ packet_id = 1 };
next_packet_id(State = #proto_state{ packet_id = PacketId }) ->
State #proto_state{ packet_id = PacketId + 1 }.
willmsg(Packet) when is_record(Packet, mqtt_packet_connect) ->
emqtt_packet:from_packet(Packet).
clientid(<<>>, #proto_state{peer_name = PeerName}) ->
<<"eMQTT/", (base64:encode(PeerName))/binary>>;
clientid(ClientId, _State) -> ClientId.
maybe_clean_sess(false, _Conn, _ClientId) ->
% todo: establish subscription to deliver old unacknowledged messages
ok.
%%----------------------------------------------------------------------------
send_willmsg(undefined) -> ignore;
@ -328,7 +277,6 @@ start_keepalive(0) -> ignore;
start_keepalive(Sec) when Sec > 0 ->
self() ! {keepalive, start, round(Sec * 1.5)}.
%%----------------------------------------------------------------------------
%% Validate Packets
%%----------------------------------------------------------------------------
@ -365,7 +313,7 @@ validate_packet(#mqtt_packet { header = #mqtt_packet_header { type = ?PUBLISH }
variable = #mqtt_packet_publish{ topic_name = Topic }}) ->
case emqtt_topic:validate({publish, Topic}) of
true -> ok;
false -> lager:error("Error Publish Topic: ~p", [Topic]), {error, badtopic}
false -> lager:warning("Error publish topic: ~p", [Topic]), {error, badtopic}
end;
validate_packet(#mqtt_packet { header = #mqtt_packet_header { type = ?SUBSCRIBE },

View File

@ -127,8 +127,14 @@ publish(Topic, Msg) when is_binary(Topic) ->
end, match(Topic)).
%dispatch locally, should only be called by publish
dispatch(Topic, Msg) when is_binary(Topic) ->
[SubPid ! {dispatch, {self(), Msg}} || #topic_subscriber{subpid=SubPid} <- ets:lookup(topic_subscriber, Topic)].
dispatch(Topic, Msg = #mqtt_message{qos = Qos}) when is_binary(Topic) ->
lists:foreach(fun(#topic_subscriber{qos = SubQos, subpid=SubPid}) ->
Msg1 = if
Qos > SubQos -> Msg#mqtt_message{qos = SubQos};
true -> Msg
end,
SubPid ! {dispatch, {self(), Msg1}}
end, ets:lookup(topic_subscriber, Topic)).
-spec match(Topic :: binary()) -> [topic()].
match(Topic) when is_binary(Topic) ->

View File

@ -21,53 +21,54 @@
%%------------------------------------------------------------------------------
-module(emqtt_queue).
-behaviour(gen_server).
-include("emqtt.hrl").
-define(SERVER, ?MODULE).
-export([new/1, new/2, in/3, all/1, clear/1]).
%% ------------------------------------------------------------------
%% API Function Exports
%% ------------------------------------------------------------------
%%----------------------------------------------------------------------------
-export([start_link/0]).
-ifdef(use_specs).
%% ------------------------------------------------------------------
%% gen_server Function Exports
%% ------------------------------------------------------------------
-type(mqtt_queue() :: #mqtt_queue_wrapper{}).
-export([init/1, handle_call/3, handle_cast/2, handle_info/2,
terminate/2, code_change/3]).
-spec(new(non_neg_intger()) -> mqtt_queue()).
%% ------------------------------------------------------------------
%% API Function Definitions
%% ------------------------------------------------------------------
-spec(in(binary(), mqtt_message(), mqtt_queue()) -> mqtt_queue()).
start_link() ->
gen_server:start_link(?MODULE, [], []).
-spec(all(mqtt_queue()) -> list()).
%% ------------------------------------------------------------------
%% gen_server Function Definitions
%% ------------------------------------------------------------------
-spec(clear(mqtt_queue()) -> mqtt_queue()).
init(Args) ->
{ok, Args}.
-endif.
handle_call(_Request, _From, State) ->
{reply, ok, State}.
%%----------------------------------------------------------------------------
handle_cast(_Msg, State) ->
{noreply, State}.
-define(DEFAULT_MAX_LEN, 1000).
handle_info(_Info, State) ->
{noreply, State}.
-record(mqtt_queue_wrapper, { queue = queue:new(), max_len = ?DEFAULT_MAX_LEN, store_qos0 = false }).
terminate(_Reason, _State) ->
ok.
new(MaxLen) -> #mqtt_queue_wrapper{ max_len = MaxLen }.
code_change(_OldVsn, State, _Extra) ->
{ok, State}.
new(MaxLen, StoreQos0) -> #mqtt_queue_wrapper{ max_len = MaxLen, store_qos0 = StoreQos0 }.
%% ------------------------------------------------------------------
%% Internal Function Definitions
%% ------------------------------------------------------------------
in(ClientId, Message = #mqtt_message{qos = Qos},
Wrapper = #mqtt_queue_wrapper{ queue = Queue, max_len = MaxLen}) ->
case queue:len(Queue) < MaxLen of
true ->
Wrapper#mqtt_queue_wrapper{ queue = queue:in(Message, Queue) };
false -> % full
if
Qos =:= ?QOS_0 ->
lager:warning("Queue ~s drop qos0 message: ~p", [ClientId, Message]),
Wrapper;
true ->
{{value, Msg}, Queue1} = queue:drop(Queue),
lager:warning("Queue ~s drop message: ~p", [ClientId, Msg]),
Wrapper#mqtt_queue_wrapper{ queue = Queue1 }
end
end.
all(#mqtt_queue_wrapper { queue = Queue }) -> queue:to_list(Queue).
clear(Queue) -> Queue#mqtt_queue_wrapper{ queue = queue:new() }.

View File

@ -47,6 +47,16 @@
-export([init/1, handle_call/3, handle_cast/2, handle_info/2,
terminate/2, code_change/3]).
%%----------------------------------------------------------------------------
-ifdef(use_specs).
-spec(start_link/1 :: () -> {ok, pid()}).
-spec route(mqtt_message()) -> ok.
-endif.
%% ------------------------------------------------------------------
%% API Function Definitions
%% ------------------------------------------------------------------
@ -54,9 +64,8 @@
start_link() ->
gen_server:start_link({local, ?SERVER}, ?MODULE, [], []).
-spec route(Msg :: mqtt_message()) -> any().
route(Msg) ->
emqtt_pubsub:publish(retained(Msg)).
route(Message) ->
emqtt_pubsub:publish(retained(Message)).
%% ------------------------------------------------------------------
%% gen_server Function Definitions

View File

@ -31,6 +31,8 @@
%% ------------------------------------------------------------------
-export([start/1, resume/3, publish/2, puback/2, subscribe/2, unsubscribe/2, destroy/2]).
-export([store/2]).
%%start gen_server
-export([start_link/3]).
@ -44,14 +46,14 @@
-record(session_state, {
client_id :: binary(),
client_pid :: pid(),
packet_id = 1,
message_id = 1,
submap :: map(),
messages = [], %% do not receive rel
msg_queue, %% do not receive rel
awaiting_ack :: map(),
awaiting_rel :: map(),
awaiting_comp :: map(),
expires,
expire_timer,
max_queue }).
expire_timer }).
%% ------------------------------------------------------------------
%% Start Session
@ -74,51 +76,69 @@ resume(SessPid, ClientId, ClientPid) when is_pid(SessPid) ->
gen_server:cast(SessPid, {resume, ClientId, ClientPid}),
SessPid.
publish(_, {?QOS_0, Message}) ->
emqtt_router:route(Message);
%%TODO:
publish(_, {?QOS_1, Message}) ->
emqtt_router:route(Message);
%%TODO:
publish(SessState = #session_state{awaiting_rel = Awaiting},
publish(Session, {?QOS_0, Message}) ->
emqtt_router:route(Message), Session;
publish(Session, {?QOS_1, Message}) ->
emqtt_router:route(Message), Session;
publish(SessState = #session_state{awaiting_rel = AwaitingRel},
{?QOS_2, Message = #mqtt_message{ msgid = MsgId }}) ->
%% store in awaiting map
%%TODO: TIMEOUT
Awaiting1 = maps:put(MsgId, Message, Awaiting),
SessState#session_state{awaiting_rel = Awaiting1};
%% store in awaiting_rel
SessState#session_state{awaiting_rel = maps:put(MsgId, Message, AwaitingRel)};
publish(SessPid, {?QOS_2, Message}) when is_pid(SessPid) ->
gen_server:cast(SessPid, {publish, ?QOS_2, Message}),
SessPid.
%% PUBACK
puback(SessState = #session_state{client_id = ClientId, awaiting_ack = Awaiting}, {?PUBACK, PacketId}) ->
Awaiting1 =
case maps:is_key(PacketId, Awaiting) of
true -> maps:remove(PacketId, Awaiting);
false -> lager:warning("~s puback packetid '~p' not exist", [ClientId, PacketId])
true -> ok;
false -> lager:warning("Session ~s: PUBACK PacketId '~p' not found!", [ClientId, PacketId])
end,
SessState#session_state{awaiting_ack= Awaiting1};
SessState#session_state{awaiting_ack = maps:remove(PacketId, Awaiting)};
puback(SessPid, {?PUBACK, PacketId}) when is_pid(SessPid) ->
gen_server:cast(SessPid, {puback, PacketId}), SessPid;
puback(SessState = #session_state{}, {?PUBREC, PacketId}) ->
%%TODO'
SessState;
%% PUBREC
puback(SessState = #session_state{ client_id = ClientId,
awaiting_ack = AwaitingAck,
awaiting_comp = AwaitingComp }, {?PUBREC, PacketId}) ->
case maps:is_key(PacketId, AwaitingAck) of
true -> ok;
false -> lager:warning("Session ~s: PUBREC PacketId '~p' not found!", [ClientId, PacketId])
end,
SessState#session_state{ awaiting_ack = maps:remove(PacketId, AwaitingAck),
awaiting_comp = maps:put(PacketId, true, AwaitingComp) };
puback(SessPid, {?PUBREC, PacketId}) when is_pid(SessPid) ->
gen_server:cast(SessPid, {pubrec, PacketId}), SessPid;
puback(SessState = #session_state{}, {?PUBREL, PacketId}) ->
%FIXME Later: should release the message here
%%emqtt_router:route(Message).
'TODO', erase({msg, PacketId}), SessState;
%% PUBREL
puback(SessState = #session_state{client_id = ClientId, awaiting_rel = Awaiting}, {?PUBREL, PacketId}) ->
case maps:find(PacketId, Awaiting) of
{ok, Msg} -> emqtt_router:route(Msg);
error -> lager:warning("Session ~s: PUBREL PacketId '~p' not found!", [ClientId, PacketId])
end,
SessState#session_state{awaiting_rel = maps:remove(PacketId, Awaiting)};
puback(SessPid, {?PUBREL, PacketId}) when is_pid(SessPid) ->
gen_server:cast(SessPid, {pubrel, PacketId}), SessPid;
puback(SessState = #session_state{}, {?PUBCOMP, PacketId}) ->
'TODO', SessState;
%% PUBCOMP
puback(SessState = #session_state{ client_id = ClientId,
awaiting_comp = AwaitingComp}, {?PUBCOMP, PacketId}) ->
case maps:is_key(PacketId, AwaitingComp) of
true -> ok;
false -> lager:warning("Session ~s: PUBREC PacketId '~p' not exist", [ClientId, PacketId])
end,
SessState#session_state{ awaiting_comp = maps:remove(PacketId, AwaitingComp) };
puback(SessPid, {?PUBCOMP, PacketId}) when is_pid(SessPid) ->
gen_server:cast(SessPid, {pubcomp, PacketId}), SessPid.
%% SUBSCRIBE
subscribe(SessState = #session_state{client_id = ClientId, submap = SubMap}, Topics) ->
Resubs = [Topic || {Name, _Qos} = Topic <- Topics, maps:is_key(Name, SubMap)],
case Resubs of
@ -127,14 +147,15 @@ subscribe(SessState = #session_state{client_id = ClientId, submap = SubMap}, Top
end,
SubMap1 = lists:foldl(fun({Name, Qos}, Acc) -> maps:put(Name, Qos, Acc) end, SubMap, Topics),
{ok, GrantedQos} = emqtt_pubsub:subscribe(Topics, self()),
%[ok = emqtt_pubsub:subscribe({Topic, Qos}, self()) || {Topic, Qos} <- Topics],
%GrantedQos = [Qos || {_Name, Qos} <- Topics],
{ok, SessState#session_state{submap = SubMap1}, GrantedQos};
subscribe(SessPid, Topics) when is_pid(SessPid) ->
{ok, GrantedQos} = gen_server:call(SessPid, {subscribe, Topics}),
{ok, SessPid, GrantedQos}.
%%
%% @doc UNSUBSCRIBE
%%
unsubscribe(SessState = #session_state{client_id = ClientId, submap = SubMap}, Topics) ->
%%TODO: refactor later.
case Topics -- maps:keys(SubMap) of
@ -153,12 +174,25 @@ unsubscribe(SessPid, Topics) when is_pid(SessPid) ->
destroy(SessPid, ClientId) when is_pid(SessPid) ->
gen_server:cast(SessPid, {destroy, ClientId}).
%store message(qos1) that sent to client
store(SessState = #session_state{ message_id = MsgId, awaiting_ack = Awaiting},
Message = #mqtt_message{ qos = Qos }) when (Qos =:= ?QOS_1) orelse (Qos =:= ?QOS_2) ->
%%assign msgid before send
Message1 = Message#mqtt_message{ msgid = MsgId },
Message2 =
if
Qos =:= ?QOS_2 -> Message1#mqtt_message{dup = false};
true -> Message1
end,
Awaiting1 = maps:put(MsgId, Message2, Awaiting),
{Message1, next_msg_id(SessState#session_state{ awaiting_ack = Awaiting1 })}.
initial_state(ClientId) ->
#session_state { client_id = ClientId,
packet_id = 1,
submap = #{},
awaiting_ack = #{},
awaiting_rel = #{} }.
awaiting_rel = #{},
awaiting_comp = #{} }.
initial_state(ClientId, ClientPid) ->
State = initial_state(ClientId),
@ -173,12 +207,14 @@ start_link(SessOpts, ClientId, ClientPid) ->
init([SessOpts, ClientId, ClientPid]) ->
process_flag(trap_exit, true),
%%TODO: OK?
%%TODO: Is this OK?
true = link(ClientPid),
State = initial_state(ClientId, ClientPid),
{ok, State#session_state{
expires = proplists:get_value(expires, SessOpts, 24) * 3600,
max_queue = proplists:get_value(max_queue, SessOpts, 1000) } }.
Expires = proplists:get_value(expires, SessOpts, 1) * 3600,
MsgQueue = emqtt_queue:new( proplists:get_value(max_queue, SessOpts, 1000),
proplists:get_value(store_qos0, SessOpts, false) ),
{ok, State#session_state{ expires = Expires,
msg_queue = MsgQueue }, hibernate}.
handle_call({subscribe, Topics}, _From, State) ->
{ok, NewState, GrantedQos} = subscribe(State, Topics),
@ -194,13 +230,13 @@ handle_call(Req, _From, State) ->
handle_cast({resume, ClientId, ClientPid}, State = #session_state {
client_id = ClientId,
client_pid = undefined,
messages = Messages,
msg_queue = Queue,
expire_timer = ETimer}) ->
lager:info("Session: client ~s resumed by ~p", [ClientId, ClientPid]),
erlang:cancel_timer(ETimer),
[ClientPid ! {dispatch, {self(), Message}} || Message <- lists:reverse(Messages)],
NewState = State#session_state{ client_pid = ClientPid, messages = [], expire_timer = undefined},
{noreply, NewState};
[ClientPid ! {dispatch, {self(), Message}} || Message <- emqtt_queue:all(Queue)],
NewState = State#session_state{ client_pid = ClientPid, msg_queue = emqtt_queue:clear(Queue), expire_timer = undefined},
{noreply, NewState, hibernate};
handle_cast({publish, ?QOS_2, Message}, State) ->
NewState = publish(State, {?QOS_2, Message}),
@ -223,22 +259,14 @@ handle_cast({pubcomp, PacketId}, State) ->
{noreply, NewState};
handle_cast({destroy, ClientId}, State = #session_state{client_id = ClientId}) ->
lager:warning("Session: ~s destroyed", [ClientId]),
lager:warning("Session ~s destroyed", [ClientId]),
{stop, normal, State};
handle_cast(Msg, State) ->
{stop, {badmsg, Msg}, State}.
handle_info({dispatch, {_From, Message}}, State = #session_state{
client_pid = undefined, messages = Messages}) ->
%%TODO: queue len
NewState = State#session_state{messages = [Message | Messages]},
{noreply, NewState};
handle_info({dispatch, {_From, Message}}, State = #session_state{client_pid = ClientPid}) ->
%%TODO: replace From with self(), ok?
ClientPid ! {dispatch, {self(), Message}},
{noreply, State};
handle_info({dispatch, {_From, Message}}, State) ->
{noreply, dispatch(Message, State)};
handle_info({'EXIT', ClientPid, Reason}, State = #session_state{
client_id = ClientId, client_pid = ClientPid, expires = Expires}) ->
@ -247,7 +275,7 @@ handle_info({'EXIT', ClientPid, Reason}, State = #session_state{
{noreply, State#session_state{ client_pid = undefined, expire_timer = Timer}};
handle_info(session_expired, State = #session_state{client_id = ClientId}) ->
lager:warning("Session: ~s session expired!", [ClientId]),
lager:warning("Session ~s expired!", [ClientId]),
{stop, {shutdown, expired}, State};
handle_info(Info, State) ->
@ -263,4 +291,28 @@ code_change(_OldVsn, State, _Extra) ->
%% Internal Function Definitions
%% ------------------------------------------------------------------
dispatch(Message, State = #session_state{ client_id = ClientId,
client_pid = undefined }) ->
queue(ClientId, Message, State);
dispatch(Message = #mqtt_message{ qos = ?QOS_0 }, State = #session_state{
client_pid = ClientPid }) ->
ClientPid ! {dispatch, {self(), Message}},
State;
dispatch(Message = #mqtt_message{ qos = Qos }, State = #session_state{ client_pid = ClientPid })
when (Qos =:= ?QOS_1) orelse (Qos =:= ?QOS_2) ->
{Message1, NewState} = store(State, Message),
ClientPid ! {dispatch, {self(), Message1}},
NewState.
queue(ClientId, Message, State = #session_state{msg_queue = Queue}) ->
State#session_state{msg_queue = emqtt_queue:in(ClientId, Message, Queue)}.
next_msg_id(State = #session_state{ message_id = 16#ffff }) ->
State#session_state{ message_id = 1 };
next_msg_id(State = #session_state{ message_id = MsgId }) ->
State#session_state{ message_id = MsgId + 1 }.