diff --git a/apps/emqtt/src/emqtt_app.erl b/apps/emqtt/src/emqtt_app.erl index abf59a323..f742a0167 100644 --- a/apps/emqtt/src/emqtt_app.erl +++ b/apps/emqtt/src/emqtt_app.erl @@ -59,6 +59,7 @@ print_vsn() -> ?PRINT("~s ~s is running now~n", [Desc, Vsn]). start_servers(Sup) -> + {ok, SessOpts} = application:get_env(session), lists:foreach( fun({Name, F}) when is_function(F) -> ?PRINT("~s is starting...", [Name]), @@ -75,7 +76,7 @@ start_servers(Sup) -> end, [{"emqtt config", emqtt_config}, {"emqtt client manager", emqtt_cm}, - {"emqtt session manager", emqtt_sm}, + {"emqtt session manager", emqtt_sm, SessOpts}, {"emqtt auth", emqtt_auth}, {"emqtt retained", emqtt_retained}, {"emqtt pubsub", emqtt_pubsub}, diff --git a/apps/emqtt/src/emqtt_packet.erl b/apps/emqtt/src/emqtt_packet.erl index 80e0f3bfc..13bc32d61 100644 --- a/apps/emqtt/src/emqtt_packet.erl +++ b/apps/emqtt/src/emqtt_packet.erl @@ -31,7 +31,7 @@ -export([parse/2, serialise/1]). --export([validate/2, dump/1]). +-export([dump/1]). -define(MAX_LEN, 16#fffffff). -define(HIGHBIT, 2#10000000). @@ -259,21 +259,6 @@ opt(X) when is_integer(X) -> X. protocol_name_approved(Ver, Name) -> lists:member({Ver, Name}, ?PROTOCOL_NAMES). -validate(protocol, {Ver, Name}) -> - protocol_name_approved(Ver, Name); - -validate(clientid, {_, ClientId}) when ( size(ClientId) >= 1 ) - andalso ( size(ClientId) >= ?MAX_CLIENTID_LEN ) -> - true; - -%% MQTT3.1.1 allow null clientId. -validate(clientid, {?MQTT_PROTO_V311, ClientId}) - when size(ClientId) =:= 0 -> - true; - -validate(clientid, {_, _}) -> - false. - dump(#mqtt_packet{header = Header, variable = Variable, payload = Payload}) when Payload =:= undefined orelse Payload =:= <<>> -> dump_header(Header, dump_variable(Variable)); diff --git a/apps/emqtt/src/emqtt_protocol.erl b/apps/emqtt/src/emqtt_protocol.erl index c0a6d29c8..bd58c8f7e 100644 --- a/apps/emqtt/src/emqtt_protocol.erl +++ b/apps/emqtt/src/emqtt_protocol.erl @@ -26,6 +26,19 @@ -include("emqtt_packet.hrl"). +%% ------------------------------------------------------------------ +%% API Function Exports +%% ------------------------------------------------------------------ + +-export([initial_state/2]). + +-export([handle_packet/2, send_message/2, send_packet/2, connection_lost/1]). + +-export([info/1]). + +%% ------------------------------------------------------------------ +%% Protocol State +%% ------------------------------------------------------------------ -record(proto_state, { socket, peer_name, @@ -35,20 +48,12 @@ packet_id, client_id, clean_sess, - will_msg, - subscriptions, - awaiting_ack, - awaiting_rel + session, %% session state or session pid + will_msg }). -type proto_state() :: #proto_state{}. --export([initial_state/2]). - --export([handle_packet/2, send_message/2, send_packet/2, connection_lost/1]). - --export([info/1]). - -define(PACKET_TYPE(Packet, Type), Packet = #mqtt_packet { header = #mqtt_packet_header { type = Type }}). @@ -56,26 +61,22 @@ initial_state(Socket, Peername) -> #proto_state{ socket = Socket, peer_name = Peername, - packet_id = 1, - subscriptions = [], - awaiting_ack = gb_trees:empty(), - awaiting_rel = gb_trees:empty() + packet_id = 1 }. +%%SHOULD be registered in emqtt_cm info(#proto_state{ proto_vsn = ProtoVsn, proto_name = ProtoName, packet_id = PacketId, client_id = ClientId, clean_sess = CleanSess, - will_msg = WillMsg, - subscriptions= Subs }) -> + will_msg = WillMsg }) -> [ {packet_id, PacketId}, {proto_vsn, ProtoVsn}, {proto_name, ProtoName}, {client_id, ClientId}, {clean_sess, CleanSess}, - {will_msg, WillMsg}, - {subscriptions, Subs} ]. + {will_msg, WillMsg} ]. -spec handle_packet(Packet, State) -> {ok, NewState} | {error, any()} when Packet :: mqtt_packet(), @@ -84,7 +85,7 @@ info(#proto_state{ proto_vsn = ProtoVsn, %%CONNECT – Client requests a connection to a Server -%%A Client can only send the CONNECT Packet once over a Network Connection. 369 +%%A Client can only send the CONNECT Packet once over a Network Connection. handle_packet(?PACKET_TYPE(Packet, ?CONNECT), State = #proto_state{connected = false}) -> handle_packet(?CONNECT, Packet, State#proto_state{connected = true}); @@ -98,7 +99,7 @@ handle_packet(_Packet, State = #proto_state{connected = false}) -> handle_packet(?PACKET_TYPE(Packet, Type), State = #proto_state { peer_name = PeerName, client_id = ClientId }) -> lager:info("RECV from ~s@~s: ~s", [ClientId, PeerName, emqtt_packet:dump(Packet)]), - case validate_packet(Type, Packet) of + case validate_packet(Packet) of ok -> handle_packet(Type, Packet, State); {error, Reason} -> @@ -122,8 +123,9 @@ handle_packet(?CONNECT, Packet = #mqtt_packet { ClientId1 = clientid(ClientId, State), start_keepalive(KeepAlive), emqtt_cm:register(ClientId1, self()), - {?CONNACK_ACCEPT, - State#proto_state{ will_msg = make_will_msg(Var), client_id = ClientId1 }}; + {?CONNACK_ACCEPT, State#proto_state{ will_msg = make_will_msg(Var), + clean_sess = CleanSess, + client_id = ClientId1 }}; false -> lager:error("~s@~s: username '~s' login failed - no credentials", [ClientId, PeerName, Username]), {?CONNACK_CREDENTIALS, State#proto_state{client_id = ClientId}} @@ -134,30 +136,31 @@ handle_packet(?CONNECT, Packet = #mqtt_packet { send_packet( #mqtt_packet { header = #mqtt_packet_header { type = ?CONNACK }, variable = #mqtt_packet_connack{ return_code = ReturnCode1 }}, State1 ), - {ok, State1}; + %% + {ok, Session} = emqtt_session:start({CleanSess, ClientId, self()}), + emqtt_session:resume(Session), + %%TODO: Resume session + {ok, State1#proto_state { session = Session }}; handle_packet(?PUBLISH, Packet = #mqtt_packet { header = #mqtt_packet_header {qos = ?QOS_0}}, State) -> - emqtt_router:route(make_message(Packet)), + emqtt_session:publish(Session, ?QOS_0, make_message(Packet)), {ok, State}; handle_packet(?PUBLISH, Packet = #mqtt_packet { - header = #mqtt_packet_header { qos = ?QOS_1 }, + header = #mqtt_packet_header { qos = ?QOS_1 }, variable = #mqtt_packet_publish{packet_id = PacketId}}, - State) -> - emqtt_router:route(make_message(Packet)), - send_packet( make_packet(?PUBACK, PacketId), State ), - {ok, State}; + State = #proto_state { session = Session }) -> + emqtt_session:publish(Session, {?QOS_1, make_message(Packet)}), + send_packet( make_packet(?PUBACK, PacketId), State); handle_packet(?PUBLISH, Packet = #mqtt_packet { - header = #mqtt_packet_header { qos = ?QOS_2 }, - variable = #mqtt_packet_publish{packet_id = PacketId}}, + header = #mqtt_packet_header { qos = ?QOS_2 }, + variable = #mqtt_packet_publish { packet_id = PacketId } }, State) -> %%FIXME: this is not right...should store it first... - emqtt_router:route(make_message(Packet)), - put({msg, PacketId}, pubrec), - send_packet( make_packet(?PUBREC, PacketId), State ), - {ok, State}; + NewSession = emqtt_session:publish(Session, {?QOS_2, make_message(Packet)}), + send_packet( make_packet(?PUBREC, PacketId), State#proto_state {session = NewSession} ); handle_packet(?PUBACK, #mqtt_packet {}, State) -> %FIXME Later @@ -167,14 +170,12 @@ handle_packet(?PUBREC, #mqtt_packet { variable = #mqtt_packet_puback { packet_id = PacketId }}, State) -> %FIXME Later: should release the message here - send_packet( make_packet(?PUBREL, PacketId), State ), - {ok, State}; + send_packet( make_packet(?PUBREL, PacketId), State ); handle_packet(?PUBREL, #mqtt_packet { variable = #mqtt_packet_puback { packet_id = PacketId}}, State) -> %%FIXME: not right... erase({msg, PacketId}), - send_packet( make_packet(?PUBCOMP, PacketId), State ), - {ok, State}; + send_packet( make_packet(?PUBCOMP, PacketId), State ); handle_packet(?PUBCOMP, #mqtt_packet { variable = #mqtt_packet_puback{packet_id = _PacketId}}, State) -> @@ -197,9 +198,8 @@ handle_packet(?SUBSCRIBE, #mqtt_packet { send_packet(#mqtt_packet { header = #mqtt_packet_header { type = ?SUBACK }, variable = #mqtt_packet_suback{ packet_id = PacketId, - qos_table = GrantedQos }}, State), + qos_table = GrantedQos }}, State); - {ok, State}; handle_packet(?UNSUBSCRIBE, #mqtt_packet { variable = #mqtt_packet_subscribe{ @@ -211,13 +211,10 @@ handle_packet(?UNSUBSCRIBE, #mqtt_packet { [emqtt_pubsub:unsubscribe(Name, self()) || #mqtt_topic{name=Name} <- Topics], send_packet(#mqtt_packet { header = #mqtt_packet_header {type = ?UNSUBACK }, - variable = #mqtt_packet_suback{packet_id = PacketId }}, State), - - {ok, State}; + variable = #mqtt_packet_suback{packet_id = PacketId }}, State); handle_packet(?PINGREQ, #mqtt_packet{}, State) -> - send_packet(make_packet(?PINGRESP), State), - {ok, State}; + send_packet(make_packet(?PINGRESP), State); handle_packet(?DISCONNECT, #mqtt_packet{}, State=#proto_state{peer_name = PeerName, client_id = ClientId}) -> {stop, State}. @@ -275,6 +272,7 @@ send_packet(Packet, #proto_state{socket = Sock, peer_name = PeerName, client_id lager:debug("SENT to ~s: ~p", [PeerName, Data]), %%FIXME Later... erlang:port_command(Sock, Data). + {ok, State}; %%TODO: fix me later... connection_lost(#proto_state{client_id = ClientId} = State) -> @@ -317,50 +315,7 @@ next_packet_id(State = #proto_state{ packet_id = PacketId }) -> State #proto_state{ packet_id = PacketId + 1 }. -validate_connect( #mqtt_packet_connect { - proto_ver = Ver, - proto_name = Name, - client_id = ClientId } ) -> - case emqtt_packet:validate(protocol, {Ver, Name}) of - true -> - case emqtt_packet:validate(clientid, {Ver, ClientId}) of - true -> - ?CONNACK_ACCEPT; - false -> - ?CONNACK_INVALID_ID - end; - false -> - ?CONNACK_PROTO_VER - end. -validate_packet(?PUBLISH, #mqtt_packet { - variable = #mqtt_packet_publish{ - topic_name = Topic }}) -> - case emqtt_topic:validate({publish, Topic}) of - true -> ok; - false -> {error, badtopic} - end; - -validate_packet(?UNSUBSCRIBE, #mqtt_packet { - variable = #mqtt_packet_subscribe{ - topic_table = Topics }}) -> - ErrTopics = [Topic || #mqtt_topic{name=Topic, qos=Qos} <- Topics, - not emqtt_topic:validate({subscribe, Topic})], - case ErrTopics of - [] -> ok; - _ -> lager:error("error topics: ~p", [ErrTopics]), {error, badtopic} - end; - -validate_packet(?SUBSCRIBE, #mqtt_packet{variable = #mqtt_packet_subscribe{topic_table = Topics}}) -> - ErrTopics = [Topic || #mqtt_topic{name=Topic, qos=Qos} <- Topics, - not (emqtt_topic:validate({subscribe, Topic}) and (Qos < 3))], - case ErrTopics of - [] -> ok; - _ -> lager:error("error topics: ~p", [ErrTopics]), {error, badtopic} - end; - -validate_packet(_Type, _Frame) -> - ok. clientid(<<>>, #proto_state{peer_name = PeerName}) -> <<"eMQTT/", (base64:encode(PeerName))/binary>>; @@ -382,3 +337,68 @@ start_keepalive(0) -> ignore; start_keepalive(Sec) when Sec > 0 -> self() ! {keepalive, start, round(Sec * 1.5)}. + +%%---------------------------------------------------------------------------- +%% Validate Packets +%%---------------------------------------------------------------------------- +validate_connect( Connect = #mqtt_packet_connect{} ) -> + case validate_protocol(Connect) of + true -> + case validate_clientid(Connect) of + true -> + ?CONNACK_ACCEPT; + false -> + ?CONNACK_INVALID_ID + end; + false -> + ?CONNACK_PROTO_VER + end. + +validate_protocol(#mqtt_packet_connect { proto_ver = Ver, proto_name = Name }) -> + lists:member({Ver, Name}, ?PROTOCOL_NAMES). + +validate_clientid(#mqtt_packet_connect { client_id = ClientId }) + when ( size(ClientId) >= 1 ) andalso ( size(ClientId) >= ?MAX_CLIENTID_LEN ) -> + true; + +%% MQTT3.1.1 allow null clientId. +validate_clientid(#mqtt_packet_connect { proto_ver =?MQTT_PROTO_V311, client_id = ClientId }) + when size(ClientId) =:= 0 -> + true; + +validate_clientid(#mqtt_packet_connect { proto_ver = Ver, clean_sess = CleanSess, client_id = ClientId }) -> + lager:warning("Invalid ClientId: ~s, ProtoVer: ~p, CleanSess: ~s", [ClientId, Ver, CleanSess]), + false. + +validate_packet(#mqtt_packet { header = #mqtt_packet_header { type = ?PUBLISH }, + variable = #mqtt_packet_publish{ topic_name = Topic }}) -> + case emqtt_topic:validate({publish, Topic}) of + true -> ok; + false -> lager:error("Error Publish Topic: ~p", [Topic]), {error, badtopic} + end; + +validate_packet(#mqtt_packet { header = #mqtt_packet_header { type = ?SUBSCRIBE }, + variable = #mqtt_packet_subscribe{topic_table = Topics }}) -> + + validate_topics(subscribe, Topics); + +validate_packet(#mqtt_packet{ header = #mqtt_packet_header { type = ?UNSUBSCRIBE }, + variable = #mqtt_packet_subscribe{ topic_table = Topics }}) -> + + validate_topics(unsubscribe, Topics); + +validate_packet(_Packet) -> + ok. + +validate_topics(Type, Topics) when Type =:= subscribe orelse Type =:= unsubscribe -> + ErrTopics = [Topic || #mqtt_topic{name=Topic, qos=Qos} <- Topics, + not (emqtt_topic:validate({Type, Topic}) and validate_qos(Qos))], + case ErrTopics of + [] -> ok; + _ -> lager:error("Error Topics: ~p", [ErrTopics]), {error, badtopic} + end. + +validate_qos(undefined) -> true; +validate_qos(Qos) when Qos =< ?QOS_2 -> true; +validate_qos(_) -> false. + diff --git a/apps/emqtt/src/emqtt_session.erl b/apps/emqtt/src/emqtt_session.erl index 145d7fb18..da52036aa 100644 --- a/apps/emqtt/src/emqtt_session.erl +++ b/apps/emqtt/src/emqtt_session.erl @@ -22,3 +22,97 @@ -module(emqtt_session). +-record(session_state, { + client_id, + client_pid, + packet_id = 1, + subscriptions = [], + messages = [], %% do not receive rel + awaiting_ack, + awaiting_rel }). + +%% ------------------------------------------------------------------ +%% API Function Exports +%% ------------------------------------------------------------------ +-export([start/1, resume/1, publish/2]). + +%% ------------------------------------------------------------------ +%% gen_server Function Exports +%% ------------------------------------------------------------------ + +-export([init/1, handle_call/3, handle_cast/2, handle_info/2, + terminate/2, code_change/3]). + +%% ------------------------------------------------------------------ +%% API Function Definitions +%% ------------------------------------------------------------------ +start({true = CleanSess, ClientId, ClientPid}) -> + %%destroy old session + %%TODO: emqtt_sm:destory_session(ClientId), + {ok, initial_state(ClientId)}; + +start({false = CleanSess, ClientId, ClientPid}) -> + %%TODO: emqtt_sm:start_session({ClientId, ClientPid}) + gen_server:start_link(?MODULE, [ClientId, ClientPid], []). + +resume(#session_state {}) -> 'TODO'; +resume(SessPid) when is_pid(SessPid) -> 'TODO'. + +publish(_, {?QOS_0, Message}) -> + emqtt_router:route(Message); + +%%TODO: +publish(_, {?QOS_1, Message}) -> + emqtt_router:route(Message), + +%%TODO: +publish(Session = #session_state{awaiting_rel = Awaiting}, {?QOS_2, Message}) -> + %% store gb_tree: + Session#session_state{awaiting_rel = Awaiting}; + +publish(_, {?QOS_2, Message}) -> + %TODO: + put({msg, PacketId}, pubrec), + emqtt_router:route(Message), + +initial_state(ClientId) -> + #session_state { client_id = ClientId, + packet_id = 1, + subscriptions = [], + awaiting_ack = gb_trees:empty(), + awaiting_rel = gb_trees:empty() }. + +initial_state(ClientId, ClientPid) -> + State = initial_state(ClientId), + State#session_state{client_pid = ClientPid}. + +%% ------------------------------------------------------------------ +%% gen_server Function Definitions +%% ------------------------------------------------------------------ + +init([ClientId, ClientPid]) -> + process_flag(trap_exit, true), + State = initial_state(ClientId, ClientPid), + {ok, State}. + +handle_call(_Request, _From, State) -> + {reply, ok, State}. + +handle_cast(_Msg, State) -> + {noreply, State}. + +handle_info(_Info, State) -> + {noreply, State}. + +terminate(_Reason, _State) -> + ok. + +code_change(_OldVsn, State, _Extra) -> + {ok, State}. + +%% ------------------------------------------------------------------ +%% Internal Function Definitions +%% ------------------------------------------------------------------ + + + diff --git a/apps/emqtt/src/emqtt_sm.erl b/apps/emqtt/src/emqtt_sm.erl index bb711d581..791d60ad7 100644 --- a/apps/emqtt/src/emqtt_sm.erl +++ b/apps/emqtt/src/emqtt_sm.erl @@ -51,9 +51,9 @@ %% API Function Exports %% ------------------------------------------------------------------ --export([start_link/0]). +-export([start_link/1]). --export([lookup/1, create/2, resume/2, destroy/1]). +-export([lookup/1, register/2, resume/2, destroy/1]). %% ------------------------------------------------------------------ %% gen_server Function Exports @@ -75,7 +75,7 @@ start_link(SessOpts) -> lookup(ClientId) -> ok. -create(ClientId, Pid) -> ok. +register(ClientId, Pid) -> ok. resume(ClientId, Pid) -> ok. @@ -86,7 +86,6 @@ destroy(ClientId) -> ok. %% ------------------------------------------------------------------ init(SessOpts) -> - {ok, SessOpts} = application:get_env(session), State = #state{ expires = proplists:get_value(expires, SessOpts, 24) * 3600, max_queue = proplists:get_value(max_queue, SessOpts, 1000) }, {ok, State}.