diff --git a/CHANGELOG.md b/CHANGELOG.md index 04ed16428..a7bd8d018 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,16 @@ emqttd ChangeLog ================== +0.13.0-alpha (2015-11-02) +------------------------- + +eSockd 3.0 + +MochiWeb 4.0 + +...... + + 0.12.3-beta (2015-10-22) ------------------------- diff --git a/include/emqttd.hrl b/include/emqttd.hrl index 51bce2f6d..ab4b6ee72 100644 --- a/include/emqttd.hrl +++ b/include/emqttd.hrl @@ -91,6 +91,7 @@ -record(mqtt_client, { client_id :: binary() | undefined, client_pid :: pid(), + client_mon :: reference(), username :: binary() | undefined, peername :: {inet:ip_address(), integer()}, clean_sess :: boolean(), diff --git a/src/emqttd_throttle.erl b/include/emqttd_internal.hrl similarity index 82% rename from src/emqttd_throttle.erl rename to include/emqttd_internal.hrl index 256ae27d4..4f2eb378c 100644 --- a/src/emqttd_throttle.erl +++ b/include/emqttd_internal.hrl @@ -20,13 +20,16 @@ %%% SOFTWARE. %%%----------------------------------------------------------------------------- %%% @doc -%%% emqttd client throttle. +%%% MQTT Internal Header. %%% %%% @end %%%----------------------------------------------------------------------------- --module(emqttd_throttle). --author("Feng Lee "). +-define(record_to_proplist(Def, Rec), + lists:zip(record_info(fields, Def), + tl(tuple_to_list(Rec)))). -%% TODO:... 0.11.0... +-define(record_to_proplist(Def, Rec, Fields), + [{K, V} || {K, V} <- ?record_to_proplist(Def, Rec), + lists:member(K, Fields)]). diff --git a/plugins/emqttd_sockjs b/plugins/emqttd_sockjs deleted file mode 160000 index 6d5ba0dfe..000000000 --- a/plugins/emqttd_sockjs +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 6d5ba0dfe62d375da09f1d53823b8aa54046aa11 diff --git a/rebar.config b/rebar.config index 9dcb12ee6..84968128f 100644 --- a/rebar.config +++ b/rebar.config @@ -29,8 +29,8 @@ {deps, [ {gproc, ".*", {git, "git://github.com/uwiger/gproc.git", {branch, "master"}}}, {lager, ".*", {git, "git://github.com/basho/lager.git", {branch, "master"}}}, - {esockd, "2.*", {git, "git://github.com/emqtt/esockd.git", {branch, "master"}}}, - {mochiweb, ".*", {git, "git://github.com/emqtt/mochiweb.git", {branch, "master"}}} + {esockd, "3.*", {git, "git://github.com/emqtt/esockd.git", {branch, "master"}}}, + {mochiweb, "4.*", {git, "git://github.com/emqtt/mochiweb.git", {branch, "master"}}} ]}. {recursive_cmds, [ct, eunit, clean]}. diff --git a/rel/files/acl.config b/rel/files/acl.config index 3359b2b04..9b1d512a6 100644 --- a/rel/files/acl.config +++ b/rel/files/acl.config @@ -1,21 +1,21 @@ %%%----------------------------------------------------------------------------- -%% -%% [ACL](https://github.com/emqtt/emqttd/wiki/ACL) -%% -%% -type who() :: all | binary() | -%% {ipaddr, esockd_access:cidr()} | -%% {client, binary()} | -%% {user, binary()}. -%% -%% -type access() :: subscribe | publish | pubsub. -%% -%% -type topic() :: binary(). -%% -%% -type rule() :: {allow, all} | -%% {allow, who(), access(), list(topic())} | -%% {deny, all} | -%% {deny, who(), access(), list(topic())}. -%% +%%% +%%% [ACL](https://github.com/emqtt/emqttd/wiki/ACL) +%%% +%%% -type who() :: all | binary() | +%%% {ipaddr, esockd_access:cidr()} | +%%% {client, binary()} | +%%% {user, binary()}. +%%% +%%% -type access() :: subscribe | publish | pubsub. +%%% +%%% -type topic() :: binary(). +%%% +%%% -type rule() :: {allow, all} | +%%% {allow, who(), access(), list(topic())} | +%%% {deny, all} | +%%% {deny, who(), access(), list(topic())}. +%%% %%%----------------------------------------------------------------------------- {allow, {user, "dashboard"}, subscribe, ["$SYS/#"]}. diff --git a/rel/files/emqttd.config.development b/rel/files/emqttd.config.development index 4989110f9..16c475523 100644 --- a/rel/files/emqttd.config.development +++ b/rel/files/emqttd.config.development @@ -176,34 +176,50 @@ %% File to store loaded plugin names. {loaded_file, "./data/loaded_plugins"} ]}, + %% Listeners {listeners, [ {mqtt, 1883, [ %% Size of acceptor pool {acceptors, 16}, + %% Maximum number of concurrent clients {max_clients, 512}, + %% Socket Access Control {access, [{allow, all}]}, + + %% Connection Options + {connopts, [ + %% Rate Limit. Format is 'burst, rate', Unit is KB/Sec + %% {rate_limit, "100,10"} %% 100K burst, 10K rate + ]}, + %% Socket Options {sockopts, [ - {backlog, 512} %Set buffer if hight thoughtput %{recbuf, 4096}, - %{sndbuf, 4096} + %{sndbuf, 4096}, %{buffer, 4096}, + %{nodelay, true}, + {backlog, 512} ]} ]}, + {mqtts, 8883, [ %% Size of acceptor pool {acceptors, 4}, + %% Maximum number of concurrent clients {max_clients, 512}, + %% Socket Access Control {access, [{allow, all}]}, + %% SSL certificate and key files {ssl, [{certfile, "etc/ssl/ssl.crt"}, {keyfile, "etc/ssl/ssl.key"}]}, + %% Socket Options {sockopts, [ {backlog, 1024} @@ -227,6 +243,7 @@ %% {backlog, 1024} %% ]} %%]}, + %% HTTP and WebSocket Listener {http, 8083, [ %% Size of acceptor pool diff --git a/rel/files/emqttd.config.production b/rel/files/emqttd.config.production index 53b1ea3a0..d96cb1373 100644 --- a/rel/files/emqttd.config.production +++ b/rel/files/emqttd.config.production @@ -175,15 +175,18 @@ {acceptors, 16}, %% Maximum number of concurrent clients {max_clients, 8192}, + %% Rate Limit. Format is 'burst, rate', Unit is KB/Sec. + %% {rate_limit, "10,1"}, %% 10K burst, 1K rate %% Socket Access Control {access, [{allow, all}]}, %% Socket Options {sockopts, [ - {backlog, 512} %Set buffer if hight thoughtput %{recbuf, 4096}, - %{sndbuf, 4096} + %{sndbuf, 4096}, %{buffer, 4096}, + %{nodelay, true}, + {backlog, 1024} ]} ]}, {mqtts, 8883, [ @@ -243,7 +246,7 @@ {long_gc, false}, %% Long Schedule(ms) - {long_schedule, 50}, + {long_schedule, 100}, %% 8M words. 32MB on 32-bit VM, 64MB on 64-bit VM. %% 8 * 1024 * 1024 diff --git a/rel/files/vm.args b/rel/files/vm.args index ef3307fed..9cc798802 100644 --- a/rel/files/vm.args +++ b/rel/files/vm.args @@ -17,7 +17,9 @@ ## Enable kernel poll and a few async threads +K true -+A 16 + +## 12 threads/core. ++A 48 ## max process numbers +P 8192 @@ -28,6 +30,10 @@ ## max atom number ## +t +## Set the distribution buffer busy limit (dist_buf_busy_limit) in kilobytes. +## Valid range is 1-2097151. Default is 1024. +## +zdbbl 8192 + ##------------------------------------------------------------------------- ## Env ##------------------------------------------------------------------------- diff --git a/src/emqttd.app.src b/src/emqttd.app.src index c35f7465d..8ebfbe721 100644 --- a/src/emqttd.app.src +++ b/src/emqttd.app.src @@ -1,7 +1,7 @@ {application, emqttd, [ {id, "emqttd"}, - {vsn, "0.12.3"}, + {vsn, "0.13.0"}, {description, "Erlang MQTT Broker"}, {modules, []}, {registered, []}, diff --git a/src/emqttd_access_control.erl b/src/emqttd_access_control.erl index 320b6f75f..7d15d5904 100644 --- a/src/emqttd_access_control.erl +++ b/src/emqttd_access_control.erl @@ -24,7 +24,6 @@ %%% %%% @end %%%----------------------------------------------------------------------------- - -module(emqttd_access_control). -author("Feng Lee "). @@ -36,14 +35,13 @@ -define(SERVER, ?MODULE). %% API Function Exports --export([start_link/0, - start_link/1, +-export([start_link/0, start_link/1, auth/2, % authentication check_acl/3, % acl check reload_acl/0, % reload acl - register_mod/3, - unregister_mod/2, lookup_mods/1, + register_mod/3, register_mod/4, + unregister_mod/2, stop/0]). %% gen_server callbacks @@ -77,7 +75,7 @@ auth(Client, Password) when is_record(Client, mqtt_client) -> auth(Client, Password, lookup_mods(auth)). auth(_Client, _Password, []) -> {error, "No auth module to check!"}; -auth(Client, Password, [{Mod, State} | Mods]) -> +auth(Client, Password, [{Mod, State, _Seq} | Mods]) -> case Mod:check(Client, Password, State) of ok -> ok; {error, Reason} -> {error, Reason}; @@ -100,7 +98,7 @@ check_acl(Client, PubSub, Topic) when ?IS_PUBSUB(PubSub) -> check_acl(#mqtt_client{client_id = ClientId}, PubSub, Topic, []) -> lager:error("ACL: nomatch when ~s ~s ~s", [ClientId, PubSub, Topic]), allow; -check_acl(Client, PubSub, Topic, [{M, State}|AclMods]) -> +check_acl(Client, PubSub, Topic, [{M, State, _Seq}|AclMods]) -> case M:check_acl({Client, PubSub, Topic}, State) of allow -> allow; deny -> deny; @@ -113,7 +111,7 @@ check_acl(Client, PubSub, Topic, [{M, State}|AclMods]) -> %%------------------------------------------------------------------------------ -spec reload_acl() -> list() | {error, any()}. reload_acl() -> - [M:reload_acl(State) || {M, State} <- lookup_mods(acl)]. + [M:reload_acl(State) || {M, State, _Seq} <- lookup_mods(acl)]. %%------------------------------------------------------------------------------ %% @doc Register authentication or ACL module @@ -121,7 +119,11 @@ reload_acl() -> %%------------------------------------------------------------------------------ -spec register_mod(Type :: auth | acl, Mod :: atom(), Opts :: list()) -> ok | {error, any()}. register_mod(Type, Mod, Opts) when Type =:= auth; Type =:= acl-> - gen_server:call(?SERVER, {register_mod, Type, Mod, Opts}). + register_mod(Type, Mod, Opts, 0). + +-spec register_mod(auth | acl, atom(), list(), pos_integer()) -> ok | {error, any()}. +register_mod(Type, Mod, Opts, Seq) when Type =:= auth; Type =:= acl-> + gen_server:call(?SERVER, {register_mod, Type, Mod, Opts, Seq}). %%------------------------------------------------------------------------------ %% @doc Unregister authentication or ACL module @@ -172,22 +174,26 @@ init_mods(acl, AclMods) -> init_mod(Fun, Name, Opts) -> Module = Fun(Name), {ok, State} = Module:init(Opts), - {Module, State}. + {Module, State, 0}. -handle_call({register_mod, Type, Mod, Opts}, _From, State) -> +handle_call({register_mod, Type, Mod, Opts, Seq}, _From, State) -> Mods = lookup_mods(Type), Reply = case lists:keyfind(Mod, 1, Mods) of - false -> + false -> case catch Mod:init(Opts) of - {ok, ModState} -> - ets:insert(?ACCESS_CONTROL_TAB, {tab_key(Type), [{Mod, ModState}|Mods]}), + {ok, ModState} -> + NewMods = + lists:sort(fun({_, _, Seq1}, {_, _, Seq2}) -> + Seq1 >= Seq2 + end, [{Mod, ModState, Seq} | Mods]), + ets:insert(?ACCESS_CONTROL_TAB, {tab_key(Type), NewMods}), ok; {'EXIT', Error} -> lager:error("Access Control: register ~s error - ~p", [Mod, Error]), {error, Error} end; - _ -> + _ -> {error, existed} end, {reply, Reply, State}; diff --git a/src/emqttd_access_rule.erl b/src/emqttd_access_rule.erl index fab9461b9..30ed8f87b 100644 --- a/src/emqttd_access_rule.erl +++ b/src/emqttd_access_rule.erl @@ -24,7 +24,6 @@ %%% %%% @end %%%----------------------------------------------------------------------------- - -module(emqttd_access_rule). -author("Feng Lee "). @@ -49,17 +48,22 @@ -export([compile/1, match/3]). +-define(ALLOW_DENY(A), ((A =:= allow) orelse (A =:= deny))). + %%------------------------------------------------------------------------------ %% @doc Compile access rule %% @end %%------------------------------------------------------------------------------ -compile({A, all}) when (A =:= allow) orelse (A =:= deny) -> +compile({A, all}) when ?ALLOW_DENY(A) -> {A, all}; -compile({A, Who, Access, TopicFilters}) when (A =:= allow) orelse (A =:= deny) -> +compile({A, Who, Access, Topic}) when ?ALLOW_DENY(A) andalso is_binary(Topic) -> + {A, compile(who, Who), Access, [compile(topic, Topic)]}; + +compile({A, Who, Access, TopicFilters}) when ?ALLOW_DENY(A) -> {A, compile(who, Who), Access, [compile(topic, Topic) || Topic <- TopicFilters]}. -compile(who, all) -> +compile(who, all) -> all; compile(who, {ipaddr, CIDR}) -> {Start, End} = esockd_access:range(CIDR), @@ -72,6 +76,10 @@ compile(who, {user, all}) -> {user, all}; compile(who, {user, Username}) -> {user, bin(Username)}; +compile(who, {'and', Conds}) when is_list(Conds) -> + {'and', [compile(who, Cond) || Cond <- Conds]}; +compile(who, {'or', Conds}) when is_list(Conds) -> + {'or', [compile(who, Cond) || Cond <- Conds]}; compile(topic, {eq, Topic}) -> {eq, emqttd_topic:words(bin(Topic))}; @@ -120,6 +128,14 @@ match_who(#mqtt_client{peername = undefined}, {ipaddr, _Tup}) -> match_who(#mqtt_client{peername = {IP, _}}, {ipaddr, {_CDIR, Start, End}}) -> I = esockd_access:atoi(IP), I >= Start andalso I =< End; +match_who(Client, {'and', Conds}) when is_list(Conds) -> + lists:foldl(fun(Who, Allow) -> + match_who(Client, Who) andalso Allow + end, true, Conds); +match_who(Client, {'or', Conds}) when is_list(Conds) -> + lists:foldl(fun(Who, Allow) -> + match_who(Client, Who) orelse Allow + end, false, Conds); match_who(_Client, _Who) -> false. diff --git a/src/emqttd_app.erl b/src/emqttd_app.erl index fa7799904..4ee9cabd1 100644 --- a/src/emqttd_app.erl +++ b/src/emqttd_app.erl @@ -24,7 +24,6 @@ %%% %%% @end %%%----------------------------------------------------------------------------- - -module(emqttd_app). -author("Feng Lee "). @@ -73,17 +72,17 @@ start_listeners() -> start_servers(Sup) -> Servers = [{"emqttd ctl", emqttd_ctl}, {"emqttd trace", emqttd_trace}, + {"emqttd pubsub", {supervisor, emqttd_pubsub_sup}}, + {"emqttd stats", emqttd_stats}, + {"emqttd metrics", emqttd_metrics}, {"emqttd retained", emqttd_retained}, {"emqttd pooler", {supervisor, emqttd_pooler_sup}}, {"emqttd client manager", {supervisor, emqttd_cm_sup}}, {"emqttd session manager", {supervisor, emqttd_sm_sup}}, {"emqttd session supervisor", {supervisor, emqttd_session_sup}}, - {"emqttd pubsub", {supervisor, emqttd_pubsub_sup}}, - {"emqttd stats", emqttd_stats}, - {"emqttd metrics", emqttd_metrics}, {"emqttd broker", emqttd_broker}, {"emqttd alarm", emqttd_alarm}, - {"emqttd mode supervisor", emqttd_mod_sup}, + {"emqttd mod supervisor", emqttd_mod_sup}, {"emqttd bridge supervisor", {supervisor, emqttd_bridge_sup}}, {"emqttd access control", emqttd_access_control}, {"emqttd system monitor", emqttd_sysmon, emqttd:env(sysmon)}], diff --git a/src/emqttd_auth_username.erl b/src/emqttd_auth_username.erl index 392d32961..855a3d6e6 100644 --- a/src/emqttd_auth_username.erl +++ b/src/emqttd_auth_username.erl @@ -105,7 +105,7 @@ init(Opts) -> mnesia:create_table(?AUTH_USERNAME_TAB, [ {disc_copies, [node()]}, {attributes, record_info(fields, ?AUTH_USERNAME_TAB)}]), - mnesia:add_table_copy(?AUTH_USERNAME_TAB, node(), ram_copies), + mnesia:add_table_copy(?AUTH_USERNAME_TAB, node(), disc_copies), emqttd_ctl:register_cmd(users, {?MODULE, cli}, []), {ok, Opts}. diff --git a/src/emqttd_cli.erl b/src/emqttd_cli.erl index fd70ce4f6..59441d516 100644 --- a/src/emqttd_cli.erl +++ b/src/emqttd_cli.erl @@ -72,8 +72,8 @@ status([]) -> case lists:keysearch(emqttd, 1, application:which_applications()) of false -> ?PRINT_MSG("emqttd is not running~n"); - {value,_Version} -> - ?PRINT_MSG("emqttd is running~n") + {value, {emqttd, _Desc, Vsn}} -> + ?PRINT("emqttd ~s is running~n", [Vsn]) end; status(_) -> ?PRINT_CMD("status", "query broker status"). diff --git a/src/emqttd_client.erl b/src/emqttd_client.erl index de7c10766..a684737a5 100644 --- a/src/emqttd_client.erl +++ b/src/emqttd_client.erl @@ -20,11 +20,10 @@ %%% SOFTWARE. %%%----------------------------------------------------------------------------- %%% @doc -%%% MQTT Client +%%% MQTT Client Connection. %%% %%% @end %%%----------------------------------------------------------------------------- - -module(emqttd_client). -author("Feng Lee "). @@ -33,40 +32,34 @@ -include("emqttd_protocol.hrl"). +-include("emqttd_internal.hrl"). + +-behaviour(gen_server). + %% API Function Exports -export([start_link/2, session/1, info/1, kick/1]). -%% SUB/UNSUB Asynchronously +%% SUB/UNSUB Asynchronously, called by plugins. -export([subscribe/2, unsubscribe/2]). --behaviour(gen_server). - %% gen_server Function Exports -export([init/1, handle_call/3, handle_cast/2, handle_info/2, code_change/3, terminate/2]). -%% Client State... --record(state, {transport, - socket, - peername, - conn_name, - await_recv, - conn_state, - conserve, - parser, - proto_state, - packet_opts, - keepalive}). +%% Client State +-record(client_state, {connection, connname, peername, peerhost, peerport, + await_recv, conn_state, rate_limit, parser_fun, + proto_state, packet_opts, keepalive}). --define(DEBUG(Format, Args, State), - lager:debug("Client(~s): " ++ Format, - [emqttd_net:format(State#state.peername) | Args])). --define(ERROR(Format, Args, State), - lager:error("Client(~s): " ++ Format, - [emqttd_net:format(State#state.peername) | Args])). +-define(INFO_KEYS, [peername, peerhost, peerport, await_recv, conn_state]). -start_link(SockArgs, MqttEnv) -> - {ok, proc_lib:spawn_link(?MODULE, init, [[SockArgs, MqttEnv]])}. +-define(SOCK_STATS, [recv_oct, recv_cnt, send_oct, send_cnt]). + +-define(LOG(Level, Format, Args, State), + lager:Level("Client(~s): " ++ Format, [State#client_state.connname | Args])). + +start_link(Connection, MqttEnv) -> + {ok, proc_lib:spawn_link(?MODULE, init, [[Connection, MqttEnv]])}. session(CPid) -> gen_server:call(CPid, session, infinity). @@ -83,125 +76,158 @@ subscribe(CPid, TopicTable) -> unsubscribe(CPid, Topics) -> gen_server:cast(CPid, {unsubscribe, Topics}). -init([SockArgs = {Transport, Sock, _SockFun}, MqttEnv]) -> - % Transform if ssl. - {ok, NewSock} = esockd_connection:accept(SockArgs), - {ok, Peername} = emqttd_net:peername(Sock), - {ok, ConnStr} = emqttd_net:connection_string(Sock, inbound), - SendFun = fun(Data) -> Transport:send(NewSock, Data) end, +init([Connection0, MqttEnv]) -> + {ok, Connection} = Connection0:wait(), + {PeerHost, PeerPort, PeerName} = + case Connection:peername() of + {ok, Peer = {Host, Port}} -> + {Host, Port, Peer}; + {error, enotconn} -> + Connection:fast_close(), + exit(normal); + {error, Reason} -> + Connection:fast_close(), + exit({shutdown, Reason}) + end, + ConnName = esockd_net:format(PeerName), + SendFun = fun(Data) -> + try Connection:async_send(Data) of + true -> ok + catch + error:Error -> exit({shutdown, Error}) + end + end, PktOpts = proplists:get_value(packet, MqttEnv), - ProtoState = emqttd_protocol:init(Peername, SendFun, PktOpts), - State = control_throttle(#state{transport = Transport, - socket = NewSock, - peername = Peername, - conn_name = ConnStr, - await_recv = false, - conn_state = running, - conserve = false, - packet_opts = PktOpts, - parser = emqttd_parser:new(PktOpts), - proto_state = ProtoState}), + ParserFun = emqttd_parser:new(PktOpts), + ProtoState = emqttd_protocol:init(PeerName, SendFun, PktOpts), + RateLimit = proplists:get_value(rate_limit, Connection:opts()), + State = run_socket(#client_state{connection = Connection, + connname = ConnName, + peername = PeerName, + peerhost = PeerHost, + peerport = PeerPort, + await_recv = false, + conn_state = running, + rate_limit = RateLimit, + parser_fun = ParserFun, + proto_state = ProtoState, + packet_opts = PktOpts}), ClientOpts = proplists:get_value(client, MqttEnv), IdleTimout = proplists:get_value(idle_timeout, ClientOpts, 10), gen_server:enter_loop(?MODULE, [], State, timer:seconds(IdleTimout)). -handle_call(session, _From, State = #state{proto_state = ProtoState}) -> +handle_call(session, _From, State = #client_state{proto_state = ProtoState}) -> {reply, emqttd_protocol:session(ProtoState), State}; -handle_call(info, _From, State = #state{conn_name = ConnName, - proto_state = ProtoState}) -> - {reply, [{conn_name, ConnName} | emqttd_protocol:info(ProtoState)], State}; +handle_call(info, _From, State = #client_state{connection = Connection, + proto_state = ProtoState}) -> + ClientInfo = ?record_to_proplist(client_state, State, ?INFO_KEYS), + ProtoInfo = emqttd_protocol:info(ProtoState), + {ok, SockStats} = Connection:getstat(?SOCK_STATS), + {reply, lists:append([ClientInfo, [{proto_info, ProtoInfo}, + {sock_stats, SockStats}]]), State}; handle_call(kick, _From, State) -> {stop, {shutdown, kick}, ok, State}; handle_call(Req, _From, State) -> - ?ERROR("Unexpected request: ~p", [Req], State), - {reply, {error, unsupported_request}, State}. + ?LOG(critical, "Unexpected request: ~p", [Req], State), + {reply, {error, unsupported_request}, State}. handle_cast({subscribe, TopicTable}, State) -> - with_session(fun(SessPid) -> emqttd_session:subscribe(SessPid, TopicTable) end, State); + with_session(fun(SessPid) -> + emqttd_session:subscribe(SessPid, TopicTable) + end, State); handle_cast({unsubscribe, Topics}, State) -> - with_session(fun(SessPid) -> emqttd_session:unsubscribe(SessPid, Topics) end, State); + with_session(fun(SessPid) -> + emqttd_session:unsubscribe(SessPid, Topics) + end, State); handle_cast(Msg, State) -> - ?ERROR("Unexpected msg: ~p",[Msg], State), - {noreply, State}. + ?LOG(critical, "Unexpected msg: ~p", [Msg], State), + noreply(State). handle_info(timeout, State) -> - stop({shutdown, timeout}, State); - -handle_info({stop, duplicate_id, _NewPid}, State=#state{proto_state = ProtoState, - conn_name = ConnName}) -> - lager:warning("Shutdown for duplicate clientid: ~s, conn:~s", - [emqttd_protocol:clientid(ProtoState), ConnName]), - stop({shutdown, duplicate_id}, State); + shutdown(idle_timeout, State); -handle_info({deliver, Message}, State = #state{proto_state = ProtoState}) -> - {ok, ProtoState1} = emqttd_protocol:send(Message, ProtoState), - noreply(State#state{proto_state = ProtoState1}); +%% Asynchronous SUBACK +handle_info({suback, PacketId, GrantedQos}, State) -> + with_proto_state(fun(ProtoState) -> + Packet = ?SUBACK_PACKET(PacketId, GrantedQos), + emqttd_protocol:send(Packet, ProtoState) + end, State); -handle_info({redeliver, {?PUBREL, PacketId}}, State = #state{proto_state = ProtoState}) -> - {ok, ProtoState1} = emqttd_protocol:redeliver({?PUBREL, PacketId}, ProtoState), - noreply(State#state{proto_state = ProtoState1}); +handle_info({deliver, Message}, State) -> + with_proto_state(fun(ProtoState) -> + emqttd_protocol:send(Message, ProtoState) + end, State); -handle_info({inet_reply, _Ref, ok}, State) -> - noreply(State); +handle_info({redeliver, {?PUBREL, PacketId}}, State) -> + with_proto_state(fun(ProtoState) -> + emqttd_protocol:redeliver({?PUBREL, PacketId}, ProtoState) + end, State); -handle_info({inet_async, Sock, _Ref, {ok, Data}}, State = #state{peername = Peername, socket = Sock}) -> - lager:debug("RECV from ~s: ~p", [emqttd_net:format(Peername), Data]), - emqttd_metrics:inc('bytes/received', size(Data)), - received(Data, control_throttle(State #state{await_recv = false})); +handle_info({shutdown, conflict, {ClientId, NewPid}}, State) -> + ?LOG(warning, "clientid '~s' conflict with ~p", [ClientId, NewPid], State), + shutdown(conflict, State); + +handle_info(activate_sock, State) -> + noreply(run_socket(State#client_state{conn_state = running})); + +handle_info({inet_async, _Sock, _Ref, {ok, Data}}, State) -> + Size = size(Data), + ?LOG(debug, "RECV ~p", [Data], State), + emqttd_metrics:inc('bytes/received', Size), + received(Data, rate_limit(Size, State#client_state{await_recv = false})); handle_info({inet_async, _Sock, _Ref, {error, Reason}}, State) -> - network_error(Reason, State); + shutdown(Reason, State); + +handle_info({inet_reply, _Sock, ok}, State) -> + noreply(State); handle_info({inet_reply, _Sock, {error, Reason}}, State) -> - ?ERROR("Unexpected inet_reply - ~p", [Reason], State), - {noreply, State}; + shutdown(Reason, State); -handle_info({keepalive, start, TimeoutSec}, State = #state{transport = Transport, socket = Socket}) -> - ?DEBUG("Start KeepAlive with ~p seconds", [TimeoutSec], State), +handle_info({keepalive, start, Interval}, State = #client_state{connection = Connection}) -> + ?LOG(debug, "Keepalive at the interval of ~p", [Interval], State), StatFun = fun() -> - case Transport:getstat(Socket, [recv_oct]) of - {ok, [{recv_oct, RecvOct}]} -> {ok, RecvOct}; - {error, Error} -> {error, Error} - end - end, - KeepAlive = emqttd_keepalive:start(StatFun, TimeoutSec, {keepalive, check}), - noreply(State#state{keepalive = KeepAlive}); + case Connection:getstat([recv_oct]) of + {ok, [{recv_oct, RecvOct}]} -> {ok, RecvOct}; + {error, Error} -> {error, Error} + end + end, + KeepAlive = emqttd_keepalive:start(StatFun, Interval, {keepalive, check}), + noreply(State#client_state{keepalive = KeepAlive}); -handle_info({keepalive, check}, State = #state{keepalive = KeepAlive}) -> +handle_info({keepalive, check}, State = #client_state{keepalive = KeepAlive}) -> case emqttd_keepalive:check(KeepAlive) of - {ok, KeepAlive1} -> - noreply(State#state{keepalive = KeepAlive1}); - {error, timeout} -> - ?DEBUG("Keepalive Timeout!", [], State), - stop({shutdown, keepalive_timeout}, State#state{keepalive = undefined}); - {error, Error} -> - ?DEBUG("Keepalive Error - ~p", [Error], State), - stop({shutdown, keepalive_error}, State#state{keepalive = undefined}) + {ok, KeepAlive1} -> + noreply(State#client_state{keepalive = KeepAlive1}); + {error, timeout} -> + ?LOG(debug, "Keepalive timeout", [], State), + shutdown(keepalive_timeout, State); + {error, Error} -> + ?LOG(warning, "Keepalive error - ~p", [Error], State), + shutdown(Error, State) end; handle_info(Info, State) -> - ?ERROR("Unexpected info: ~p", [Info], State), - {noreply, State}. + ?LOG(critical, "Unexpected info: ~p", [Info], State), + noreply(State). -terminate(Reason, #state{transport = Transport, - socket = Socket, - keepalive = KeepAlive, - proto_state = ProtoState}) -> +terminate(Reason, #client_state{connection = Connection, + keepalive = KeepAlive, + proto_state = ProtoState}) -> + Connection:fast_close(), emqttd_keepalive:cancel(KeepAlive), - if - Reason == {shutdown, conn_closed} -> ok; - true -> Transport:fast_close(Socket) - end, case {ProtoState, Reason} of - {undefined, _} -> ok; + {undefined, _} -> + ok; {_, {shutdown, Error}} -> emqttd_protocol:shutdown(Error, ProtoState); - {_, Reason} -> + {_, Reason} -> emqttd_protocol:shutdown(Reason, ProtoState) end. @@ -212,65 +238,73 @@ code_change(_OldVsn, State, _Extra) -> %%% Internal functions %%%============================================================================= +with_proto_state(Fun, State = #client_state{proto_state = ProtoState}) -> + {ok, ProtoState1} = Fun(ProtoState), + noreply(State#client_state{proto_state = ProtoState1}). + +with_session(Fun, State = #client_state{proto_state = ProtoState}) -> + Fun(emqttd_protocol:session(ProtoState)), + noreply(State). + +%% receive and parse tcp data +received(<<>>, State) -> + noreply(State); + +received(Bytes, State = #client_state{parser_fun = ParserFun, + packet_opts = PacketOpts, + proto_state = ProtoState}) -> + case catch ParserFun(Bytes) of + {more, NewParser} -> + noreply(run_socket(State#client_state{parser_fun = NewParser})); + {ok, Packet, Rest} -> + emqttd_metrics:received(Packet), + case emqttd_protocol:received(Packet, ProtoState) of + {ok, ProtoState1} -> + received(Rest, State#client_state{parser_fun = emqttd_parser:new(PacketOpts), + proto_state = ProtoState1}); + {error, Error} -> + ?LOG(error, "Protocol error - ~p", [Error], State), + shutdown(Error, State); + {error, Error, ProtoState1} -> + shutdown(Error, State#client_state{proto_state = ProtoState1}); + {stop, Reason, ProtoState1} -> + stop(Reason, State#client_state{proto_state = ProtoState1}) + end; + {error, Error} -> + ?LOG(error, "Framing error - ~p", [Error], State), + shutdown(Error, State); + {'EXIT', Reason} -> + ?LOG(error, "Parser failed for ~p", [Reason], State), + ?LOG(error, "Error data: ~p", [Bytes], State), + shutdown(parser_error, State) + end. + +rate_limit(_Size, State = #client_state{rate_limit = undefined}) -> + run_socket(State); +rate_limit(Size, State = #client_state{rate_limit = Rl}) -> + case Rl:check(Size) of + {0, Rl1} -> + run_socket(State#client_state{conn_state = running, rate_limit = Rl1}); + {Pause, Rl1} -> + ?LOG(error, "Rate limiter pause for ~p", [Size, Pause], State), + erlang:send_after(Pause, self(), activate_sock), + State#client_state{conn_state = blocked, rate_limit = Rl1} + end. + +run_socket(State = #client_state{conn_state = blocked}) -> + State; +run_socket(State = #client_state{await_recv = true}) -> + State; +run_socket(State = #client_state{connection = Connection}) -> + Connection:async_recv(0, infinity), + State#client_state{await_recv = true}. + noreply(State) -> {noreply, State, hibernate}. +shutdown(Reason, State) -> + stop({shutdown, Reason}, State). + stop(Reason, State) -> {stop, Reason, State}. -with_session(Fun, State = #state{proto_state = ProtoState}) -> - Fun(emqttd_protocol:session(ProtoState)), noreply(State). - -%% receive and parse tcp data -received(<<>>, State) -> - {noreply, State, hibernate}; - -received(Bytes, State = #state{packet_opts = PacketOpts, - parser = Parser, - proto_state = ProtoState}) -> - case catch Parser(Bytes) of - {more, NewParser} -> - noreply(control_throttle(State#state{parser = NewParser})); - {ok, Packet, Rest} -> - emqttd_metrics:received(Packet), - case emqttd_protocol:received(Packet, ProtoState) of - {ok, ProtoState1} -> - received(Rest, State#state{parser = emqttd_parser:new(PacketOpts), - proto_state = ProtoState1}); - {error, Error} -> - ?ERROR("Protocol error - ~p", [Error], State), - stop({shutdown, Error}, State); - {error, Error, ProtoState1} -> - stop({shutdown, Error}, State#state{proto_state = ProtoState1}); - {stop, Reason, ProtoState1} -> - stop(Reason, State#state{proto_state = ProtoState1}) - end; - {error, Error} -> - ?ERROR("Framing error - ~p", [Error], State), - stop({shutdown, Error}, State); - {'EXIT', Reason} -> - ?ERROR("Parser failed for ~p~nError Frame: ~p", [Reason, Bytes], State), - {stop, {shutdown, frame_error}, State} - end. - -network_error(Reason, State = #state{peername = Peername}) -> - lager:warning("Client(~s): network error - ~p", - [emqttd_net:format(Peername), Reason]), - stop({shutdown, conn_closed}, State). - -run_socket(State = #state{conn_state = blocked}) -> - State; -run_socket(State = #state{await_recv = true}) -> - State; -run_socket(State = #state{transport = Transport, socket = Sock}) -> - Transport:async_recv(Sock, 0, infinity), - State#state{await_recv = true}. - -control_throttle(State = #state{conn_state = Flow, - conserve = Conserve}) -> - case {Flow, Conserve} of - {running, true} -> State #state{conn_state = blocked}; - {blocked, false} -> run_socket(State #state{conn_state = running}); - {_, _} -> run_socket(State) - end. - diff --git a/src/emqttd_cm.erl b/src/emqttd_cm.erl index db198c176..1347e1815 100644 --- a/src/emqttd_cm.erl +++ b/src/emqttd_cm.erl @@ -37,8 +37,6 @@ -behaviour(gen_server2). --define(SERVER, ?MODULE). - %% gen_server Function Exports -export([init/1, handle_call/3, handle_cast/2, handle_info/2, terminate/2, code_change/3]). @@ -47,6 +45,9 @@ -define(CM_POOL, ?MODULE). +-define(LOG(Level, Format, Args, Client), + lager:Level("CM(~s): " ++ Format, [Client#mqtt_client.client_id|Args])). + %%%============================================================================= %%% API %%%============================================================================= @@ -102,15 +103,16 @@ init([Id, StatsFun]) -> handle_call(Req, _From, State) -> lager:error("unexpected request: ~p", [Req]), - {reply, {error, badreq}, State}. + {reply, {error, unsupported_req}, State}. -handle_cast({register, Client = #mqtt_client{client_id = ClientId, client_pid = Pid}}, State) -> +handle_cast({register, Client = #mqtt_client{client_id = ClientId, + client_pid = Pid}}, State) -> case ets:lookup(mqtt_client, ClientId) of [#mqtt_client{client_pid = Pid}] -> - lager:error("ClientId '~s' has been registered with ~p", [ClientId, Pid]), ignore; [#mqtt_client{client_pid = OldPid}] -> - lager:warning("ClientId '~s' is duplicated: pid=~p, oldpid=~p", [ClientId, Pid, OldPid]); + %% TODO: should cancel monitor + ?LOG(warning, "client ~p conflict with ~p", [Pid, OldPid], Client); [] -> ok end, @@ -121,10 +123,10 @@ handle_cast({unregister, ClientId, Pid}, State) -> case ets:lookup(mqtt_client, ClientId) of [#mqtt_client{client_pid = Pid}] -> ets:delete(mqtt_client, ClientId); - [_] -> + [_] -> ignore; [] -> - lager:error("Cannot find clientId '~s' with ~p", [ClientId, Pid]) + lager:warning("CM(~s): Cannot find registered pid ~p", [ClientId, Pid]) end, {noreply, setstats(State)}; @@ -137,7 +139,8 @@ handle_info(Info, State) -> {noreply, State}. terminate(_Reason, #state{id = Id}) -> - gproc_pool:disconnect_worker(?CM_POOL, {?MODULE, Id}), ok. + gproc_pool:disconnect_worker(?CM_POOL, {?MODULE, Id}), + ok. code_change(_OldVsn, State, _Extra) -> {ok, State}. diff --git a/src/emqttd_message.erl b/src/emqttd_message.erl index d8e55e202..9d340b2b3 100644 --- a/src/emqttd_message.erl +++ b/src/emqttd_message.erl @@ -24,7 +24,6 @@ %%% %%% @end %%%----------------------------------------------------------------------------- - -module(emqttd_message). -author("Feng Lee "). @@ -170,14 +169,12 @@ unset_flag(Flag, Msg) when Flag =:= dup orelse Flag =:= retain -> Msg. %% @doc Format MQTT Message %% @end %%------------------------------------------------------------------------------ -format(#mqtt_message{msgid=MsgId, - pktid = PktId, - from = From, - qos=Qos, - retain=Retain, - dup=Dup, - topic=Topic}) -> - io_lib:format("Message(MsgId=~p, PktId=~p, from=~s, " - "Qos=~p, Retain=~s, Dup=~s, Topic=~s)", - [MsgId, PktId, From, Qos, Retain, Dup, Topic]). +format(#mqtt_message{msgid = MsgId, pktid = PktId, from = From, + qos = Qos, retain = Retain, dup = Dup, topic =Topic}) -> + io_lib:format("Message(Q~p, R~p, D~p, MsgId=~p, PktId=~p, From=~s, Topic=~s)", + [i(Qos), i(Retain), i(Dup), MsgId, PktId, From, Topic]). + +i(true) -> 1; +i(false) -> 0; +i(I) when is_integer(I) -> I. diff --git a/src/emqttd_packet.erl b/src/emqttd_packet.erl index 383f8df7d..0bc5190a8 100644 --- a/src/emqttd_packet.erl +++ b/src/emqttd_packet.erl @@ -74,13 +74,15 @@ connack_name(?CONNACK_AUTH) -> 'CONNACK_AUTH'. format(#mqtt_packet{header = Header, variable = Variable, payload = Payload}) -> format_header(Header, format_variable(Variable, Payload)). -format_header(#mqtt_packet_header{type = Type, dup = Dup, qos = QoS, retain = Retain}, S) -> - S1 = - if - S == undefined -> <<>>; - true -> [", ", S] - end, - io_lib:format("~s(Qos=~p, Retain=~s, Dup=~s~s)", [type_name(Type), QoS, Retain, Dup, S1]). +format_header(#mqtt_packet_header{type = Type, + dup = Dup, + qos = QoS, + retain = Retain}, S) -> + S1 = if + S == undefined -> <<>>; + true -> [", ", S] + end, + io_lib:format("~s(Q~p, R~p, D~p~s)", [type_name(Type), QoS, i(Retain), i(Dup), S1]). format_variable(undefined, _) -> undefined; @@ -105,8 +107,8 @@ format_variable(#mqtt_packet_connect{ Format = "ClientId=~s, ProtoName=~s, ProtoVsn=~p, CleanSess=~s, KeepAlive=~p, Username=~s, Password=~s", Args = [ClientId, ProtoName, ProtoVer, CleanSess, KeepAlive, Username, format_password(Password)], {Format1, Args1} = if - WillFlag -> { Format ++ ", Will(Qos=~p, Retain=~s, Topic=~s, Msg=~s)", - Args ++ [ WillQoS, WillRetain, WillTopic, WillMsg ] }; + WillFlag -> { Format ++ ", Will(Q~p, R~p, Topic=~s, Msg=~s)", + Args ++ [WillQoS, i(WillRetain), WillTopic, WillMsg] }; true -> {Format, Args} end, io_lib:format(Format1, Args1); @@ -145,3 +147,6 @@ format_variable(undefined) -> undefined. format_password(undefined) -> undefined; format_password(_Password) -> '******'. +i(true) -> 1; +i(false) -> 0; +i(I) when is_integer(I) -> I. diff --git a/src/emqttd_protocol.erl b/src/emqttd_protocol.erl index ce85ddab3..33571e213 100644 --- a/src/emqttd_protocol.erl +++ b/src/emqttd_protocol.erl @@ -24,7 +24,6 @@ %%% %%% @end %%%----------------------------------------------------------------------------- - -module(emqttd_protocol). -author("Feng Lee "). @@ -33,6 +32,8 @@ -include("emqttd_protocol.hrl"). +-include("emqttd_internal.hrl"). + %% API -export([init/3, info/1, clientid/1, client/1, session/1]). @@ -41,29 +42,26 @@ -export([process/2]). %% Protocol State --record(proto_state, {peername, - sendfun, - connected = false, %received CONNECT action? - proto_ver, - proto_name, - username, - client_id, - clean_sess, - session, - will_msg, - keepalive, - max_clientid_len = ?MAX_CLIENTID_LEN, - client_pid, - ws_initial_headers, %% Headers from first HTTP request for websocket client +-record(proto_state, {peername, sendfun, connected = false, + client_id, client_pid, clean_sess, + proto_ver, proto_name, username, + will_msg, keepalive, max_clientid_len = ?MAX_CLIENTID_LEN, + session, ws_initial_headers, %% Headers from first HTTP request for websocket client connected_at}). -type proto_state() :: #proto_state{}. +-define(INFO_KEYS, [client_id, username, clean_sess, proto_ver, proto_name, + keepalive, will_msg, ws_initial_headers, connected_at]). + +-define(LOG(Level, Format, Args, State), + lager:Level([{client, State#proto_state.client_id}], "Client(~s@~s): " ++ Format, + [State#proto_state.client_id, esockd_net:format(State#proto_state.peername) | Args])). + %%------------------------------------------------------------------------------ %% @doc Init protocol %% @end %%------------------------------------------------------------------------------ - init(Peername, SendFun, Opts) -> MaxLen = emqttd_opts:g(max_clientid_len, Opts, ?MAX_CLIENTID_LEN), WsInitialHeaders = emqttd_opts:g(ws_initial_headers, Opts), @@ -73,38 +71,20 @@ init(Peername, SendFun, Opts) -> client_pid = self(), ws_initial_headers = WsInitialHeaders}. -info(#proto_state{client_id = ClientId, - username = Username, - peername = Peername, - proto_ver = ProtoVer, - proto_name = ProtoName, - keepalive = KeepAlive, - clean_sess = CleanSess, - ws_initial_headers = WsInitialHeaders, - will_msg = WillMsg, - connected_at = ConnectedAt}) -> - [{client_id, ClientId}, - {username, Username}, - {peername, Peername}, - {proto_ver, ProtoVer}, - {proto_name, ProtoName}, - {keepalive, KeepAlive}, - {clean_sess, CleanSess}, - {ws_initial_headers, WsInitialHeaders}, - {will_msg, WillMsg}, - {connected_at, ConnectedAt}]. +info(ProtoState) -> + ?record_to_proplist(proto_state, ProtoState, ?INFO_KEYS). clientid(#proto_state{client_id = ClientId}) -> ClientId. client(#proto_state{client_id = ClientId, + client_pid = ClientPid, peername = Peername, username = Username, clean_sess = CleanSess, proto_ver = ProtoVer, keepalive = Keepalive, will_msg = WillMsg, - client_pid = Pid, ws_initial_headers = WsInitialHeaders, connected_at = Time}) -> WillTopic = if @@ -112,7 +92,7 @@ client(#proto_state{client_id = ClientId, true -> WillMsg#mqtt_message.topic end, #mqtt_client{client_id = ClientId, - client_pid = Pid, + client_pid = ClientPid, username = Username, peername = Peername, clean_sess = CleanSess, @@ -127,7 +107,7 @@ session(#proto_state{session = Session}) -> %% CONNECT – Client requests a connection to a Server -%%A Client can only send the CONNECT Packet once over a Network Connection. +%% A Client can only send the CONNECT Packet once over a Network Connection. -spec received(mqtt_packet(), proto_state()) -> {ok, proto_state()} | {error, any()}. received(Packet = ?PACKET(?CONNECT), State = #proto_state{connected = false}) -> process(Packet, State#proto_state{connected = true}); @@ -135,20 +115,20 @@ received(Packet = ?PACKET(?CONNECT), State = #proto_state{connected = false}) -> received(?PACKET(?CONNECT), State = #proto_state{connected = true}) -> {error, protocol_bad_connect, State}; -%%Received other packets when CONNECT not arrived. +%% Received other packets when CONNECT not arrived. received(_Packet, State = #proto_state{connected = false}) -> {error, protocol_not_connected, State}; received(Packet = ?PACKET(_Type), State) -> trace(recv, Packet, State), case validate_packet(Packet) of - ok -> - process(Packet, State); - {error, Reason} -> - {error, Reason, State} + ok -> + process(Packet, State); + {error, Reason} -> + {error, Reason, State} end. -process(Packet = ?CONNECT_PACKET(Var), State0 = #proto_state{peername = Peername}) -> +process(Packet = ?CONNECT_PACKET(Var), State0) -> #mqtt_packet_connect{proto_ver = ProtoVer, proto_name = ProtoName, @@ -190,10 +170,8 @@ process(Packet = ?CONNECT_PACKET(Var), State0 = #proto_state{peername = Peername exit({shutdown, Error}) end; {error, Reason}-> - lager:error("~s@~s: username '~s' login failed for ~s", - [ClientId, emqttd_net:format(Peername), Username, Reason]), + ?LOG(error, "Username '~s' login failed for ~s", [Username, Reason], State1), {?CONNACK_CREDENTIALS, State1} - end; ReturnCode -> {ReturnCode, State1} @@ -203,19 +181,18 @@ process(Packet = ?CONNECT_PACKET(Var), State0 = #proto_state{peername = Peername %% Send connack send(?CONNACK_PACKET(ReturnCode1), State3); -process(Packet = ?PUBLISH_PACKET(_Qos, Topic, _PacketId, _Payload), - State = #proto_state{client_id = ClientId}) -> - - case check_acl(publish, Topic, State) of +process(Packet = ?PUBLISH_PACKET(_Qos, Topic, _PacketId, _Payload), State) -> + case check_acl(publish, Topic, client(State)) of allow -> publish(Packet, State); - deny -> - lager:error("ACL Deny: ~s cannot publish to ~s", [ClientId, Topic]) + deny -> + ?LOG(error, "Cannot publish to ~s for ACL Deny", [Topic], State) end, {ok, State}; process(?PUBACK_PACKET(?PUBACK, PacketId), State = #proto_state{session = Session}) -> - emqttd_session:puback(Session, PacketId), {ok, State}; + emqttd_session:puback(Session, PacketId), + {ok, State}; process(?PUBACK_PACKET(?PUBREC, PacketId), State = #proto_state{session = Session}) -> emqttd_session:pubrec(Session, PacketId), @@ -228,25 +205,22 @@ process(?PUBACK_PACKET(?PUBREL, PacketId), State = #proto_state{session = Sessio process(?PUBACK_PACKET(?PUBCOMP, PacketId), State = #proto_state{session = Session})-> emqttd_session:pubcomp(Session, PacketId), {ok, State}; -%% protect from empty topic list +%% Protect from empty topic table process(?SUBSCRIBE_PACKET(PacketId, []), State) -> send(?SUBACK_PACKET(PacketId, []), State); -process(?SUBSCRIBE_PACKET(PacketId, TopicTable), - State = #proto_state{client_id = ClientId, session = Session}) -> - AllowDenies = [check_acl(subscribe, Topic, State) || {Topic, _Qos} <- TopicTable], +process(?SUBSCRIBE_PACKET(PacketId, TopicTable), State = #proto_state{session = Session}) -> + Client = client(State), + AllowDenies = [check_acl(subscribe, Topic, Client) || {Topic, _Qos} <- TopicTable], case lists:member(deny, AllowDenies) of true -> - lager:error("SUBSCRIBE from '~s' Denied: ~p", [ClientId, TopicTable]), + ?LOG(error, "Cannot SUBSCRIBE ~p for ACL Deny", [TopicTable], State), send(?SUBACK_PACKET(PacketId, [16#80 || _ <- TopicTable]), State); false -> - AckFun = fun(GrantedQos) -> - send(?SUBACK_PACKET(PacketId, GrantedQos), State) - end, - emqttd_session:subscribe(Session, TopicTable, AckFun), {ok, State} + emqttd_session:subscribe(Session, PacketId, TopicTable), {ok, State} end; -%% protect from empty topic list +%% Protect from empty topic list process(?UNSUBSCRIBE_PACKET(PacketId, []), State) -> send(?UNSUBACK_PACKET(PacketId), State); @@ -258,72 +232,61 @@ process(?PACKET(?PINGREQ), State) -> send(?PACKET(?PINGRESP), State); process(?PACKET(?DISCONNECT), State) -> - % clean willmsg + % Clean willmsg {stop, normal, State#proto_state{will_msg = undefined}}. publish(Packet = ?PUBLISH_PACKET(?QOS_0, _PacketId), #proto_state{client_id = ClientId, session = Session}) -> - Msg = emqttd_message:from_packet(ClientId, Packet), - emqttd_session:publish(Session, Msg); + emqttd_session:publish(Session, emqttd_message:from_packet(ClientId, Packet)); -publish(Packet = ?PUBLISH_PACKET(?QOS_1, PacketId), - State = #proto_state{client_id = ClientId, session = Session}) -> +publish(Packet = ?PUBLISH_PACKET(?QOS_1, _PacketId), State) -> + with_puback(?PUBACK, Packet, State); + +publish(Packet = ?PUBLISH_PACKET(?QOS_2, _PacketId), State) -> + with_puback(?PUBREC, Packet, State). + +with_puback(Type, Packet = ?PUBLISH_PACKET(_Qos, PacketId), + State = #proto_state{client_id = ClientId, session = Session}) -> Msg = emqttd_message:from_packet(ClientId, Packet), case emqttd_session:publish(Session, Msg) of ok -> - send(?PUBACK_PACKET(?PUBACK, PacketId), State); + send(?PUBACK_PACKET(Type, PacketId), State); {error, Error} -> - lager:error("Client(~s): publish qos1 error - ~p", [ClientId, Error]) - end; - -publish(Packet = ?PUBLISH_PACKET(?QOS_2, PacketId), - State = #proto_state{client_id = ClientId, session = Session}) -> - Msg = emqttd_message:from_packet(ClientId, Packet), - case emqttd_session:publish(Session, Msg) of - ok -> - send(?PUBACK_PACKET(?PUBREC, PacketId), State); - {error, Error} -> - lager:error("Client(~s): publish qos2 error - ~p", [ClientId, Error]) + ?LOG(error, "PUBLISH ~p error: ~p", [PacketId, Error], State) end. -spec send(mqtt_message() | mqtt_packet(), proto_state()) -> {ok, proto_state()}. send(Msg, State) when is_record(Msg, mqtt_message) -> send(emqttd_message:to_packet(Msg), State); -send(Packet, State = #proto_state{sendfun = SendFun, peername = Peername}) +send(Packet, State = #proto_state{sendfun = SendFun}) when is_record(Packet, mqtt_packet) -> trace(send, Packet, State), emqttd_metrics:sent(Packet), Data = emqttd_serialiser:serialise(Packet), - lager:debug("SENT to ~s: ~p", [emqttd_net:format(Peername), Data]), + ?LOG(debug, "SEND ~p", [Data], State), emqttd_metrics:inc('bytes/sent', size(Data)), SendFun(Data), {ok, State}. -trace(recv, Packet, #proto_state{peername = Peername, client_id = ClientId}) -> - lager:info([{client, ClientId}], "RECV from ~s@~s: ~s", - [ClientId, emqttd_net:format(Peername), emqttd_packet:format(Packet)]); +trace(recv, Packet, ProtoState) -> + ?LOG(info, "RECV ~s", [emqttd_packet:format(Packet)], ProtoState); -trace(send, Packet, #proto_state{peername = Peername, client_id = ClientId}) -> - lager:info([{client, ClientId}], "SEND to ~s@~s: ~s", - [ClientId, emqttd_net:format(Peername), emqttd_packet:format(Packet)]). +trace(send, Packet, ProtoState) -> + ?LOG(info, "SEND ~s", [emqttd_packet:format(Packet)], ProtoState). %% @doc redeliver PUBREL PacketId redeliver({?PUBREL, PacketId}, State) -> send(?PUBREL_PACKET(PacketId), State). -shutdown(Error, #proto_state{client_id = undefined}) -> - lager:info("Protocol shutdown ~p", [Error]), +shutdown(_Error, #proto_state{client_id = undefined}) -> ignore; -shutdown(duplicate_id, #proto_state{client_id = ClientId}) -> - %% unregister the device +shutdown(conflict, #proto_state{client_id = ClientId}) -> emqttd_cm:unregister(ClientId); -%% TODO: ClientId?? -shutdown(Error, #proto_state{peername = Peername, client_id = ClientId, will_msg = WillMsg}) -> - lager:info([{client, ClientId}], "Client ~s@~s: shutdown ~p", - [ClientId, emqttd_net:format(Peername), Error]), +shutdown(Error, State = #proto_state{client_id = ClientId, will_msg = WillMsg}) -> + ?LOG(info, "Shutdown for ~p", [Error], State), send_willmsg(ClientId, WillMsg), emqttd_broker:foreach_hooks('client.disconnected', [Error, ClientId]), emqttd_cm:unregister(ClientId). @@ -344,7 +307,6 @@ maybe_set_clientid(State) -> send_willmsg(_ClientId, undefined) -> ignore; send_willmsg(ClientId, WillMsg) -> - lager:info("Client ~s send willmsg: ~p", [ClientId, WillMsg]), emqttd_pubsub:publish(WillMsg#mqtt_message{from = ClientId}). start_keepalive(0) -> ignore; @@ -371,52 +333,55 @@ validate_connect(Connect = #mqtt_packet_connect{}, ProtoState) -> validate_protocol(#mqtt_packet_connect{proto_ver = Ver, proto_name = Name}) -> lists:member({Ver, Name}, ?PROTOCOL_NAMES). -validate_clientid(#mqtt_packet_connect{client_id = ClientId}, #proto_state{max_clientid_len = MaxLen}) - when ( size(ClientId) >= 1 ) andalso ( size(ClientId) =< MaxLen ) -> +validate_clientid(#mqtt_packet_connect{client_id = ClientId}, + #proto_state{max_clientid_len = MaxLen}) + when (size(ClientId) >= 1) andalso (size(ClientId) =< MaxLen) -> true; %% MQTT3.1.1 allow null clientId. validate_clientid(#mqtt_packet_connect{proto_ver =?MQTT_PROTO_V311, - client_id = ClientId}, _ProtoState) + client_id = ClientId}, _ProtoState) when size(ClientId) =:= 0 -> true; -validate_clientid(#mqtt_packet_connect{proto_ver = Ver, - clean_sess = CleanSess, - client_id = ClientId}, _ProtoState) -> - lager:warning("Invalid ClientId: ~s, ProtoVer: ~p, CleanSess: ~s", [ClientId, Ver, CleanSess]), +validate_clientid(#mqtt_packet_connect{proto_ver = ProtoVer, + clean_sess = CleanSess}, ProtoState) -> + ?LOG(warning, "Invalid clientId. ProtoVer: ~p, CleanSess: ~s", + [ProtoVer, CleanSess], ProtoState), false. -validate_packet(#mqtt_packet{header = #mqtt_packet_header{type = ?PUBLISH}, - variable = #mqtt_packet_publish{topic_name = Topic}}) -> +validate_packet(?PUBLISH_PACKET(_Qos, Topic, _PacketId, _Payload)) -> case emqttd_topic:validate({name, Topic}) of true -> ok; - false -> lager:warning("Error publish topic: ~p", [Topic]), {error, badtopic} + false -> {error, badtopic} end; -validate_packet(#mqtt_packet{header = #mqtt_packet_header{type = ?SUBSCRIBE}, - variable = #mqtt_packet_subscribe{topic_table = Topics}}) -> - - validate_topics(filter, Topics); - -validate_packet(#mqtt_packet{header = #mqtt_packet_header{type = ?UNSUBSCRIBE}, - variable = #mqtt_packet_subscribe{topic_table = Topics}}) -> +validate_packet(?SUBSCRIBE_PACKET(_PacketId, TopicTable)) -> + validate_topics(filter, TopicTable); +validate_packet(?UNSUBSCRIBE_PACKET(_PacketId, Topics)) -> validate_topics(filter, Topics); validate_packet(_Packet) -> ok. -validate_topics(Type, []) when Type =:= name orelse Type =:= filter -> - lager:error("Empty Topics!"), +validate_topics(_Type, []) -> {error, empty_topics}; -validate_topics(Type, Topics) when Type =:= name orelse Type =:= filter -> - ErrTopics = [Topic || {Topic, Qos} <- Topics, - not (emqttd_topic:validate({Type, Topic}) and validate_qos(Qos))], - case ErrTopics of +validate_topics(Type, TopicTable = [{_Topic, _Qos}|_]) + when Type =:= name orelse Type =:= filter -> + Valid = fun(Topic, Qos) -> + emqttd_topic:validate({Type, Topic}) and validate_qos(Qos) + end, + case [Topic || {Topic, Qos} <- TopicTable, not Valid(Topic, Qos)] of [] -> ok; - _ -> lager:error("Error Topics: ~p", [ErrTopics]), {error, badtopic} + _ -> {error, badtopic} + end; + +validate_topics(Type, Topics = [Topic0|_]) when is_binary(Topic0) -> + case [Topic || Topic <- Topics, not emqttd_topic:validate({Type, Topic})] of + [] -> ok; + _ -> {error, badtopic} end. validate_qos(undefined) -> @@ -426,17 +391,17 @@ validate_qos(Qos) when ?IS_QOS(Qos) -> validate_qos(_) -> false. -%% publish ACL is cached in process dictionary. -check_acl(publish, Topic, State) -> +%% PUBLISH ACL is cached in process dictionary. +check_acl(publish, Topic, Client) -> case get({acl, publish, Topic}) of undefined -> - AllowDeny = emqttd_access_control:check_acl(client(State), publish, Topic), + AllowDeny = emqttd_access_control:check_acl(Client, publish, Topic), put({acl, publish, Topic}, AllowDeny), AllowDeny; AllowDeny -> AllowDeny end; -check_acl(subscribe, Topic, State) -> - emqttd_access_control:check_acl(client(State), subscribe, Topic). +check_acl(subscribe, Topic, Client) -> + emqttd_access_control:check_acl(Client, subscribe, Topic). diff --git a/src/emqttd_pubsub.erl b/src/emqttd_pubsub.erl index 478db721d..decb23d1f 100644 --- a/src/emqttd_pubsub.erl +++ b/src/emqttd_pubsub.erl @@ -24,7 +24,6 @@ %%% %%% @end %%%----------------------------------------------------------------------------- - -module(emqttd_pubsub). -author("Feng Lee "). @@ -48,8 +47,7 @@ publish/1]). %% Local node --export([dispatch/2, - match/1]). +-export([dispatch/2, match/1]). -behaviour(gen_server2). diff --git a/src/emqttd_retained.erl b/src/emqttd_retained.erl index 8e797255b..685006200 100644 --- a/src/emqttd_retained.erl +++ b/src/emqttd_retained.erl @@ -98,7 +98,7 @@ retain(Msg = #mqtt_message{topic = Topic, retain = true, payload = Payload}) -> case {TabSize < limit(table), size(Payload) < limit(payload)} of {true, true} -> Retained = #mqtt_retained{topic = Topic, message = Msg}, - lager:debug("Retained ~s", [emqttd_message:format(Msg)]), + lager:debug("RETAIN ~s", [emqttd_message:format(Msg)]), mnesia:async_dirty(fun mnesia:write/3, [retained, Retained, write]), emqttd_metrics:set('messages/retained', mnesia:table_info(retained, size)); {false, _}-> diff --git a/src/emqttd_session.erl b/src/emqttd_session.erl index d2163d3f4..a1b4ab47a 100644 --- a/src/emqttd_session.erl +++ b/src/emqttd_session.erl @@ -44,7 +44,6 @@ %%% %%% @end %%%----------------------------------------------------------------------------- - -module(emqttd_session). -author("Feng Lee "). @@ -53,16 +52,15 @@ -include("emqttd_protocol.hrl"). +-behaviour(gen_server2). + %% Session API --export([start_link/3, resume/3, destroy/2]). +-export([start_link/3, resume/3, info/1, destroy/2]). %% PubSub APIs --export([publish/2, - puback/2, pubrec/2, pubrel/2, pubcomp/2, +-export([publish/2, puback/2, pubrec/2, pubrel/2, pubcomp/2, subscribe/2, subscribe/3, unsubscribe/2]). --behaviour(gen_server2). - %% gen_server Function Exports -export([init/1, handle_call/3, handle_cast/2, handle_info/2, terminate/2, code_change/3]). @@ -81,9 +79,6 @@ %% Client Pid bind with session client_pid :: pid(), - %% Client Monitor - client_mon :: reference(), - %% Last packet id of the session packet_id = 1, @@ -138,6 +133,10 @@ -define(PUBSUB_TIMEOUT, 60000). +-define(LOG(Level, Format, Args, State), + lager:Level([{client, State#session.client_id}], + "Session(~s): " ++ Format, [State#session.client_id | Args])). + %%------------------------------------------------------------------------------ %% @doc Start a session. %% @end @@ -154,6 +153,13 @@ start_link(CleanSess, ClientId, ClientPid) -> resume(SessPid, ClientId, ClientPid) -> gen_server2:cast(SessPid, {resume, ClientId, ClientPid}). +%%------------------------------------------------------------------------------ +%% @doc Session Info. +%% @end +%%------------------------------------------------------------------------------ +info(SessPid) -> + gen_server2:call(SessPid, info). + %%------------------------------------------------------------------------------ %% @doc Destroy a session. %% @end @@ -170,8 +176,12 @@ destroy(SessPid, ClientId) -> subscribe(SessPid, TopicTable) -> subscribe(SessPid, TopicTable, fun(_) -> ok end). --spec subscribe(pid(), [{binary(), mqtt_qos()}], AckFun :: fun()) -> ok. -subscribe(SessPid, TopicTable, AckFun) -> +-spec subscribe(pid(), mqtt_packet_id(), [{binary(), mqtt_qos()}]) -> ok. +subscribe(SessPid, PacketId, TopicTable) -> + From = self(), + AckFun = fun(GrantedQos) -> + From ! {suback, PacketId, GrantedQos} + end, gen_server2:cast(SessPid, {subscribe, TopicTable, AckFun}). %%------------------------------------------------------------------------------ @@ -224,7 +234,8 @@ unsubscribe(SessPid, Topics) -> %%%============================================================================= init([CleanSess, ClientId, ClientPid]) -> - %% process_flag(trap_exit, true), + process_flag(trap_exit, true), + true = link(ClientPid), QEnv = emqttd:env(mqtt, queue), SessEnv = emqttd:env(mqtt, session), Session = #session{ @@ -244,14 +255,15 @@ init([CleanSess, ClientId, ClientPid]) -> expired_after = emqttd_opts:g(expired_after, SessEnv) * 3600, collect_interval = emqttd_opts:g(collect_interval, SessEnv, 0), timestamp = os:timestamp()}, - emqttd_sm:register_session(CleanSess, ClientId, info(Session)), - %% monitor client - MRef = erlang:monitor(process, ClientPid), + emqttd_sm:register_session(CleanSess, ClientId, sess_info(Session)), %% start statistics - {ok, start_collector(Session#session{client_mon = MRef}), hibernate}. + {ok, start_collector(Session), hibernate}. prioritise_call(Msg, _From, _Len, _State) -> - case Msg of _ -> 0 end. + case Msg of + info -> 10; + _ -> 0 + end. prioritise_cast(Msg, _Len, _State) -> case Msg of @@ -268,15 +280,17 @@ prioritise_cast(Msg, _Len, _State) -> prioritise_info(Msg, _Len, _State) -> case Msg of - {'DOWN', _, _, _, _} -> 10; {'EXIT', _, _} -> 10; - session_expired -> 10; + expired -> 10; {timeout, _, _} -> 5; collect_info -> 2; {dispatch, _} -> 1; _ -> 0 end. +handle_call(info, _From, State) -> + {reply, sess_info(State), State}; + handle_call({publish, Msg = #mqtt_message{qos = ?QOS_2, pktid = PktId}}, _From, Session = #session{client_id = ClientId, awaiting_rel = AwaitingRel, @@ -293,38 +307,36 @@ handle_call({publish, Msg = #mqtt_message{qos = ?QOS_2, pktid = PktId}}, _From, end; handle_call(Req, _From, State) -> - lager:error("Unexpected Request: ~p", [Req]), - {reply, ok, State}. + ?LOG(critical, "Unexpected Request: ~p", [Req], State), + {reply, {error, unsupported_req}, State}. -handle_cast({subscribe, TopicTable0, AckFun}, Session = #session{ - client_id = ClientId, subscriptions = Subscriptions}) -> +handle_cast({subscribe, TopicTable0, AckFun}, Session = #session{client_id = ClientId, + subscriptions = Subscriptions}) -> TopicTable = emqttd_broker:foldl_hooks('client.subscribe', [ClientId], TopicTable0), case TopicTable -- Subscriptions of [] -> - catch AckFun([Qos || {_, Qos} <- TopicTable]), + AckFun([Qos || {_, Qos} <- TopicTable]), noreply(Session); _ -> %% subscribe first and don't care if the subscriptions have been existed {ok, GrantedQos} = emqttd_pubsub:subscribe(TopicTable), - catch AckFun(GrantedQos), + AckFun(GrantedQos), emqttd_broker:foreach_hooks('client.subscribe.after', [ClientId, TopicTable]), - lager:info([{client, ClientId}], "Session(~s): subscribe ~p, Granted QoS: ~p", - [ClientId, TopicTable, GrantedQos]), + ?LOG(info, "Subscribe ~p, Granted QoS: ~p", [TopicTable, GrantedQos], Session), Subscriptions1 = lists:foldl(fun({Topic, Qos}, Acc) -> case lists:keyfind(Topic, 1, Acc) of {Topic, Qos} -> - lager:warning([{client, ClientId}], "Session(~s): " - "resubscribe ~s, qos = ~w", [ClientId, Topic, Qos]), Acc; + ?LOG(warning, "resubscribe ~s, qos = ~w", [Topic, Qos], Session), + Acc; {Topic, OldQos} -> - lager:warning([{client, ClientId}], "Session(~s): " - "resubscribe ~s, old qos=~w, new qos=~w", [ClientId, Topic, OldQos, Qos]), + ?LOG(warning, "resubscribe ~s, old qos=~w, new qos=~w", [Topic, OldQos, Qos], Session), lists:keyreplace(Topic, 1, Acc, {Topic, Qos}); false -> %%TODO: the design is ugly, rewrite later...:( @@ -354,44 +366,35 @@ handle_cast({unsubscribe, Topics0}, Session = #session{client_id = ClientId, {Topic, _Qos} -> lists:keydelete(Topic, 1, Acc); false -> - lager:warning([{client, ClientId}], "Session(~s) not subscribe ~s", [ClientId, Topic]), Acc + Acc end end, Subscriptions, Topics), noreply(Session#session{subscriptions = Subscriptions1}); handle_cast({destroy, ClientId}, Session = #session{client_id = ClientId}) -> - lager:warning([{client, ClientId}], "Session(~s) destroyed", [ClientId]), + ?LOG(warning, "destroyed", [], Session), {stop, {shutdown, destroy}, Session}; -handle_cast({resume, ClientId, ClientPid}, Session) -> +handle_cast({resume, ClientId, ClientPid}, Session = #session{client_id = ClientId, + client_pid = OldClientPid, + inflight_queue = InflightQ, + awaiting_ack = AwaitingAck, + awaiting_comp = AwaitingComp, + expired_timer = ETimer} = Session) -> - #session{client_id = ClientId, - client_pid = OldClientPid, - client_mon = MRef, - inflight_queue = InflightQ, - awaiting_ack = AwaitingAck, - awaiting_comp = AwaitingComp, - expired_timer = ETimer} = Session, + ?LOG(info, "resumed by ~p", [ClientPid], Session), - lager:info([{client, ClientId}], "Session(~s) resumed by ~p", [ClientId, ClientPid]), - - %% cancel expired timer + %% Cancel expired timer cancel_timer(ETimer), - %% Kickout old client - if - OldClientPid == undefined -> - ok; - OldClientPid == ClientPid -> - ok; %% ?? - true -> - lager:error([{client, ClientId}], "Session(~s): ~p kickout ~p", - [ClientId, ClientPid, OldClientPid]), - OldClientPid ! {stop, duplicate_id, ClientPid}, - erlang:demonitor(MRef, [flush]) + case kick(ClientId, OldClientPid, ClientPid) of + ok -> ?LOG(warning, "~p kickout ~p", [ClientPid, OldClientPid], Session); + ignore -> ok end, + true = link(ClientPid), + %% Redeliver PUBREL [ClientPid ! {redeliver, {?PUBREL, PktId}} || PktId <- maps:keys(AwaitingComp)], @@ -402,7 +405,6 @@ handle_cast({resume, ClientId, ClientPid}, Session) -> [cancel_timer(TRef) || TRef <- maps:values(AwaitingComp)], Session1 = Session#session{client_pid = ClientPid, - client_mon = erlang:monitor(process, ClientPid), awaiting_ack = #{}, awaiting_comp = #{}, expired_timer = undefined}, @@ -417,19 +419,18 @@ handle_cast({resume, ClientId, ClientPid}, Session) -> noreply(dequeue(Session2)); %% PUBACK -handle_cast({puback, PktId}, Session = #session{client_id = ClientId, awaiting_ack = AwaitingAck}) -> +handle_cast({puback, PktId}, Session = #session{awaiting_ack = AwaitingAck}) -> case maps:find(PktId, AwaitingAck) of {ok, TRef} -> cancel_timer(TRef), noreply(dequeue(acked(PktId, Session))); error -> - lager:error([{client, ClientId}], "Session(~s) cannot find PUBACK ~w", [ClientId, PktId]), + ?LOG(error, "Cannot find PUBACK: ~p", [PktId], Session), noreply(Session) end; %% PUBREC -handle_cast({pubrec, PktId}, Session = #session{client_id = ClientId, - awaiting_ack = AwaitingAck, +handle_cast({pubrec, PktId}, Session = #session{awaiting_ack = AwaitingAck, awaiting_comp = AwaitingComp, await_rel_timeout = Timeout}) -> case maps:find(PktId, AwaitingAck) of @@ -440,37 +441,36 @@ handle_cast({pubrec, PktId}, Session = #session{client_id = ClientId, Session1 = acked(PktId, Session#session{awaiting_comp = AwaitingComp1}), noreply(dequeue(Session1)); error -> - lager:error([{client, ClientId}], "Session(~s) cannot find PUBREC ~w", [ClientId, PktId]), + ?LOG(error, "Cannot find PUBREC: ~p", [PktId], Session), noreply(Session) end; %% PUBREL -handle_cast({pubrel, PktId}, Session = #session{client_id = ClientId, - awaiting_rel = AwaitingRel}) -> +handle_cast({pubrel, PktId}, Session = #session{awaiting_rel = AwaitingRel}) -> case maps:find(PktId, AwaitingRel) of {ok, {Msg, TRef}} -> cancel_timer(TRef), emqttd_pubsub:publish(Msg), noreply(Session#session{awaiting_rel = maps:remove(PktId, AwaitingRel)}); error -> - lager:error([{client, ClientId}], "Session(~s) cannot find PUBREL ~w", [ClientId, PktId]), + ?LOG(error, "Cannot find PUBREL: ~p", [PktId], Session), noreply(Session) end; %% PUBCOMP -handle_cast({pubcomp, PktId}, Session = #session{client_id = ClientId, awaiting_comp = AwaitingComp}) -> +handle_cast({pubcomp, PktId}, Session = #session{awaiting_comp = AwaitingComp}) -> case maps:find(PktId, AwaitingComp) of {ok, TRef} -> cancel_timer(TRef), noreply(Session#session{awaiting_comp = maps:remove(PktId, AwaitingComp)}); error -> - lager:error("Session(~s) cannot find PUBCOMP ~w", [ClientId, PktId]), + ?LOG(error, "Cannot find PUBCOMP: ~p", [PktId], Session), noreply(Session) end; handle_cast(Msg, State) -> - lager:error("Unexpected Msg: ~p, State: ~p", [Msg, State]), - {noreply, State}. + ?LOG(critical, "Unexpected Msg: ~p", [Msg], State), + noreply(State). %% Queue messages when client is offline handle_info({dispatch, Msg}, Session = #session{client_pid = undefined, @@ -484,14 +484,15 @@ handle_info({dispatch, Msg = #mqtt_message{qos = ?QOS_0}}, ClientPid ! {deliver, Msg}, noreply(Session); -handle_info({dispatch, Msg = #mqtt_message{qos = QoS}}, Session = #session{message_queue = MsgQ}) +handle_info({dispatch, Msg = #mqtt_message{qos = QoS}}, + Session = #session{message_queue = MsgQ}) when QoS =:= ?QOS_1 orelse QoS =:= ?QOS_2 -> case check_inflight(Session) of - true -> - {noreply, deliver(Msg, Session)}; + true -> + noreply(deliver(Msg, Session)); false -> - {noreply, Session#session{message_queue = emqttd_mqueue:in(Msg, MsgQ)}} + noreply(Session#session{message_queue = emqttd_mqueue:in(Msg, MsgQ)}) end; handle_info({timeout, awaiting_ack, PktId}, Session = #session{client_pid = undefined, @@ -499,78 +500,70 @@ handle_info({timeout, awaiting_ack, PktId}, Session = #session{client_pid = unde %% just remove awaiting noreply(Session#session{awaiting_ack = maps:remove(PktId, AwaitingAck)}); -handle_info({timeout, awaiting_ack, PktId}, Session = #session{client_id = ClientId, - inflight_queue = InflightQ, +handle_info({timeout, awaiting_ack, PktId}, Session = #session{inflight_queue = InflightQ, awaiting_ack = AwaitingAck}) -> - lager:info("Awaiting Ack Timeout: ~p:", [PktId]), case maps:find(PktId, AwaitingAck) of {ok, _TRef} -> case lists:keyfind(PktId, 1, InflightQ) of {_, Msg} -> noreply(redeliver(Msg, Session)); false -> - lager:error([{client, ClientId}], "Session(~s):" - "Awaiting timeout but Cannot find PktId :~p", [ClientId, PktId]), + ?LOG(error, "AwaitingAck timeout but Cannot find PktId: ~p", [PktId], Session), noreply(dequeue(Session)) end; error -> - lager:error([{client, ClientId}], "Session(~s):" - "Cannot find Awaiting Ack:~p", [ClientId, PktId]), + ?LOG(error, "Cannot find AwaitingAck: ~p", [PktId], Session), noreply(Session) end; -handle_info({timeout, awaiting_rel, PktId}, Session = #session{client_id = ClientId, - awaiting_rel = AwaitingRel}) -> +handle_info({timeout, awaiting_rel, PktId}, Session = #session{awaiting_rel = AwaitingRel}) -> case maps:find(PktId, AwaitingRel) of - {ok, {Msg, _TRef}} -> - lager:error([{client, ClientId}], "Session(~s) AwaitingRel Timout!~n" - "Drop Message:~p", [ClientId, Msg]), + {ok, {_Msg, _TRef}} -> + ?LOG(error, "AwaitingRel Timout: ~p, Drop Message!", [PktId], Session), noreply(Session#session{awaiting_rel = maps:remove(PktId, AwaitingRel)}); error -> - lager:error([{client, ClientId}], "Session(~s) cannot find AwaitingRel ~w", [ClientId, PktId]), - {noreply, Session, hibernate} + ?LOG(error, "Cannot find AwaitingRel: ~p", [PktId], Session), + noreply(Session) end; -handle_info({timeout, awaiting_comp, PktId}, Session = #session{client_id = ClientId, - awaiting_comp = Awaiting}) -> +handle_info({timeout, awaiting_comp, PktId}, Session = #session{awaiting_comp = Awaiting}) -> case maps:find(PktId, Awaiting) of {ok, _TRef} -> - lager:error([{client, ClientId}], "Session(~s) " - "Awaiting PUBCOMP Timout: PktId=~p!", [ClientId, PktId]), + ?LOG(error, "Awaiting PUBCOMP Timout: ~p", [PktId], Session), noreply(Session#session{awaiting_comp = maps:remove(PktId, Awaiting)}); error -> - lager:error([{client, ClientId}], "Session(~s) " - "Cannot find Awaiting PUBCOMP: PktId=~p", [ClientId, PktId]), + ?LOG(error, "Cannot find Awaiting PUBCOMP: ~p", [PktId], Session), noreply(Session) end; handle_info(collect_info, Session = #session{clean_sess = CleanSess, client_id = ClientId}) -> - emqttd_sm:register_session(CleanSess, ClientId, info(Session)), - {noreply, start_collector(Session), hibernate}; + emqttd_sm:register_session(CleanSess, ClientId, sess_info(Session)), + noreply(start_collector(Session)); -handle_info({'DOWN', _MRef, process, ClientPid, _}, Session = #session{clean_sess = true, - client_pid = ClientPid}) -> +handle_info({'EXIT', ClientPid, _Reason}, Session = #session{clean_sess = true, + client_pid = ClientPid}) -> {stop, normal, Session}; -handle_info({'DOWN', _MRef, process, ClientPid, _}, Session = #session{clean_sess = false, - client_pid = ClientPid, - expired_after = Expires}) -> - TRef = timer(Expires, session_expired), - noreply(Session#session{client_pid = undefined, client_mon = undefined, expired_timer = TRef}); +handle_info({'EXIT', ClientPid, Reason}, Session = #session{clean_sess = false, + client_pid = ClientPid, + expired_after = Expires}) -> + ?LOG(info, "Client ~p EXIT for ~p", [ClientPid, Reason], Session), + TRef = timer(Expires, expired), + erlang:garbage_collect(), %%TODO: ??? + noreply(Session#session{client_pid = undefined, expired_timer = TRef}); -handle_info({'DOWN', _MRef, process, Pid, Reason}, Session = #session{client_id = ClientId, - client_pid = ClientPid}) -> - lager:error([{client, ClientId}], "Session(~s): unexpected DOWN: " - "client_pid=~p, down_pid=~p, reason=~p", - [ClientId, ClientPid, Pid, Reason]), +handle_info({'EXIT', Pid, Reason}, Session = #session{client_pid = ClientPid}) -> + + ?LOG(error, "Unexpected EXIT: client_pid=~p, exit_pid=~p, reason=~p", + [ClientPid, Pid, Reason], Session), noreply(Session); -handle_info(session_expired, Session = #session{client_id = ClientId}) -> - lager:error("Session(~s) expired, shutdown now.", [ClientId]), +handle_info(expired, Session) -> + ?LOG(info, "expired, shutdown now.", [], Session), {stop, {shutdown, expired}, Session}; -handle_info(Info, Session = #session{client_id = ClientId}) -> - lager:error("Session(~s) unexpected info: ~p", [ClientId, Info]), +handle_info(Info, Session) -> + ?LOG(critical, "Unexpected info: ~p", [Info], Session), {noreply, Session}. terminate(_Reason, #session{clean_sess = CleanSess, client_id = ClientId}) -> @@ -583,6 +576,17 @@ code_change(_OldVsn, Session, _Extra) -> %%% Internal functions %%%============================================================================= +%%------------------------------------------------------------------------------ +%% Kick old client out +%%------------------------------------------------------------------------------ +kick(_ClientId, undefined, _Pid) -> + ignore; +kick(_ClientId, Pid, Pid) -> + ignore; +kick(ClientId, OldPid, Pid) -> + unlink(OldPid), + OldPid ! {shutdown, conflict, {ClientId, Pid}}. + %%------------------------------------------------------------------------------ %% Check inflight and awaiting_rel %%------------------------------------------------------------------------------ @@ -656,7 +660,7 @@ acked(PktId, Session = #session{client_id = ClientId, {_, Msg} -> emqttd_broker:foreach_hooks('message.acked', [ClientId, Msg]); false -> - lager:error("Session(~s): Cannot find acked message: ~p", [PktId]) + ?LOG(error, "Cannot find acked pktid: ~p", [PktId], Session) end, Session#session{awaiting_ack = maps:remove(PktId, Awaiting), inflight_queue = lists:keydelete(PktId, 1, InflightQ)}. @@ -685,15 +689,15 @@ start_collector(Session = #session{collect_interval = Interval}) -> TRef = erlang:send_after(timer:seconds(Interval), self(), collect_info), Session#session{collect_timer = TRef}. -info(#session{clean_sess = CleanSess, - subscriptions = Subscriptions, - inflight_queue = InflightQueue, - max_inflight = MaxInflight, - message_queue = MessageQueue, - awaiting_rel = AwaitingRel, - awaiting_ack = AwaitingAck, - awaiting_comp = AwaitingComp, - timestamp = CreatedAt}) -> +sess_info(#session{clean_sess = CleanSess, + subscriptions = Subscriptions, + inflight_queue = InflightQueue, + max_inflight = MaxInflight, + message_queue = MessageQueue, + awaiting_rel = AwaitingRel, + awaiting_ack = AwaitingAck, + awaiting_comp = AwaitingComp, + timestamp = CreatedAt}) -> Stats = emqttd_mqueue:stats(MessageQueue), [{clean_sess, CleanSess}, {subscriptions, Subscriptions}, diff --git a/src/emqttd_sm.erl b/src/emqttd_sm.erl index 719ac0ca1..d62e85c86 100644 --- a/src/emqttd_sm.erl +++ b/src/emqttd_sm.erl @@ -57,7 +57,10 @@ -define(SM_POOL, ?MODULE). --define(SESSION_TIMEOUT, 60000). +-define(CALL_TIMEOUT, 60000). + +-define(LOG(Level, Format, Args, Session), + lager:Level("SM(~s): " ++ Format, [Session#mqtt_session.client_id | Args])). %%%============================================================================= %%% Mnesia callbacks @@ -113,7 +116,7 @@ start_session(CleanSess, ClientId) -> lookup_session(ClientId) -> case mnesia:dirty_read(session, ClientId) of [Session] -> Session; - [] -> undefined + [] -> undefined end. %%------------------------------------------------------------------------------ @@ -137,13 +140,11 @@ register_session(CleanSess, ClientId, Info) -> unregister_session(CleanSess, ClientId) -> ets:delete(sesstab(CleanSess), {ClientId, self()}). -sesstab(true) -> - mqtt_transient_session; -sesstab(false) -> - mqtt_persistent_session. +sesstab(true) -> mqtt_transient_session; +sesstab(false) -> mqtt_persistent_session. call(SM, Req) -> - gen_server2:call(SM, Req, ?SESSION_TIMEOUT). %%infinity). + gen_server2:call(SM, Req, ?CALL_TIMEOUT). %%infinity). %%%============================================================================= %%% gen_server callbacks @@ -223,8 +224,8 @@ create_session(CleanSess, ClientId, ClientPid) -> case insert_session(Session) of {aborted, {conflict, ConflictPid}} -> %% Conflict with othe node? - lager:error("Session(~s): Conflict with ~p!", [ClientId, ConflictPid]), - {error, conflict}; + lager:error("SM(~s): Conflict with ~p", [ClientId, ConflictPid]), + {error, mnesia_conflict}; {atomic, ok} -> erlang:monitor(process, SessPid), {ok, SessPid} @@ -245,8 +246,8 @@ insert_session(Session = #mqtt_session{client_id = ClientId}) -> end). %% Local node -resume_session(#mqtt_session{client_id = ClientId, - sess_pid = SessPid}, ClientPid) +resume_session(Session = #mqtt_session{client_id = ClientId, + sess_pid = SessPid}, ClientPid) when node(SessPid) =:= node() -> case is_process_alive(SessPid) of @@ -254,7 +255,7 @@ resume_session(#mqtt_session{client_id = ClientId, emqttd_session:resume(SessPid, ClientId, ClientPid), {ok, SessPid}; false -> - lager:error("Session(~s): Cannot resume ~p, it seems already dead!", [ClientId, SessPid]), + ?LOG(error, "Cannot resume ~p which seems already dead!", [SessPid], Session), {error, session_died} end; @@ -265,12 +266,11 @@ resume_session(Session = #mqtt_session{client_id = ClientId, sess_pid = SessPid} ok -> {ok, SessPid}; {badrpc, nodedown} -> - lager:error("Session(~s): Died for node ~s down!", [ClientId, Node]), + ?LOG(error, "Session died for node '~s' down", [Node], Session), remove_session(Session), {error, session_nodedown}; {badrpc, Reason} -> - lager:error("Session(~s): Failed to resume from node ~s for ~p", - [ClientId, Node, Reason]), + ?LOG(error, "Failed to resume from node ~s for ~p", [Node, Reason], Session), {error, Reason} end. @@ -288,11 +288,11 @@ destroy_session(Session = #mqtt_session{client_id = ClientId, ok -> remove_session(Session); {badrpc, nodedown} -> - lager:error("Session(~s): Died for node ~s down!", [ClientId, Node]), + ?LOG(error, "Node '~s' down", [Node], Session), remove_session(Session); {badrpc, Reason} -> - lager:error("Session(~s): Failed to destory ~p on remote node ~p for ~s", - [ClientId, SessPid, Node, Reason]), + ?LOG(error, "Failed to destory ~p on remote node ~p for ~s", + [SessPid, Node, Reason], Session), {error, Reason} end. diff --git a/src/emqttd_ws_client.erl b/src/emqttd_ws_client.erl index 3d06e8432..cb9a49e74 100644 --- a/src/emqttd_ws_client.erl +++ b/src/emqttd_ws_client.erl @@ -24,7 +24,6 @@ %%% %%% @end %%%----------------------------------------------------------------------------- - -module(emqttd_ws_client). -author("Feng Lee "). @@ -46,10 +45,13 @@ terminate/2, code_change/3]). %% WebSocket Loop State --record(wsocket_state, {request, client_pid, packet_opts, parser}). +-record(wsocket_state, {request, client_pid, packet_opts, parser_fun}). -%% Client State --record(client_state, {ws_pid, request, proto_state, keepalive}). +%% WebSocket Client State +-record(wsclient_state, {ws_pid, request, proto_state, keepalive}). + +-define(WSLOG(Level, Format, Args, Req), + lager:Level("WsClient(~s): " ++ Format, [Req:get(peer) | Args])). %%------------------------------------------------------------------------------ %% @doc Start WebSocket client. @@ -57,12 +59,14 @@ %%------------------------------------------------------------------------------ start_link(Req) -> PktOpts = emqttd:env(mqtt, packet), + ParserFun = emqttd_parser:new(PktOpts), {ReentryWs, ReplyChannel} = upgrade(Req), - {ok, ClientPid} = gen_server:start_link(?MODULE, [self(), Req, ReplyChannel, PktOpts], []), - ReentryWs(#wsocket_state{request = Req, - client_pid = ClientPid, - packet_opts = PktOpts, - parser = emqttd_parser:new(PktOpts)}). + Params = [self(), Req, ReplyChannel, PktOpts], + {ok, ClientPid} = gen_server:start_link(?MODULE, Params, []), + ReentryWs(#wsocket_state{request = Req, + client_pid = ClientPid, + packet_opts = PktOpts, + parser_fun = ParserFun}). session(CPid) -> gen_server:call(CPid, session, infinity). @@ -97,25 +101,28 @@ ws_loop([<<>>], State, _ReplyChannel) -> State; ws_loop(Data, State = #wsocket_state{request = Req, client_pid = ClientPid, - parser = Parser}, ReplyChannel) -> - Peer = Req:get(peer), - lager:debug("RECV from ~s(WebSocket): ~p", [Peer, Data]), - case Parser(iolist_to_binary(Data)) of + parser_fun = ParserFun}, ReplyChannel) -> + ?WSLOG(debug, "RECV ~p", [Data], Req), + case catch ParserFun(iolist_to_binary(Data)) of {more, NewParser} -> - State#wsocket_state{parser = NewParser}; + State#wsocket_state{parser_fun = NewParser}; {ok, Packet, Rest} -> gen_server:cast(ClientPid, {received, Packet}), ws_loop(Rest, reset_parser(State), ReplyChannel); {error, Error} -> - lager:error("MQTT(WebSocket) frame error ~p for connection ~s", [Error, Peer]), - exit({shutdown, Error}) + ?WSLOG(error, "Frame error: ~p", [Error], Req), + exit({shutdown, Error}); + {'EXIT', Reason} -> + ?WSLOG(error, "Frame error: ~p", [Reason], Req), + ?WSLOG(error, "Error data: ~p", [Data], Req), + exit({shutdown, parser_error}) end. reset_parser(State = #wsocket_state{packet_opts = PktOpts}) -> - State#wsocket_state{parser = emqttd_parser:new(PktOpts)}. + State#wsocket_state{parser_fun = emqttd_parser:new(PktOpts)}. %%%============================================================================= -%%% gen_fsm callbacks +%%% gen_server callbacks %%%============================================================================= init([WsPid, Req, ReplyChannel, PktOpts]) -> @@ -125,93 +132,105 @@ init([WsPid, Req, ReplyChannel, PktOpts]) -> Headers = mochiweb_request:get(headers, Req), HeadersList = mochiweb_headers:to_list(Headers), ProtoState = emqttd_protocol:init(Peername, SendFun, - [{ws_initial_headers, HeadersList}|PktOpts]), - {ok, #client_state{ws_pid = WsPid, request = Req, proto_state = ProtoState}}. + [{ws_initial_headers, HeadersList} | PktOpts]), + {ok, #wsclient_state{ws_pid = WsPid, request = Req, proto_state = ProtoState}}. -handle_call(session, _From, State = #client_state{proto_state = ProtoState}) -> +handle_call(session, _From, State = #wsclient_state{proto_state = ProtoState}) -> {reply, emqttd_protocol:session(ProtoState), State}; -handle_call(info, _From, State = #client_state{request = Req, - proto_state = ProtoState}) -> - {reply, [{websocket, true}, {peer, Req:get(peer)} - | emqttd_protocol:info(ProtoState)], State}; +handle_call(info, _From, State = #wsclient_state{request = Req, + proto_state = ProtoState}) -> + ProtoInfo = emqttd_protocol:info(ProtoState), + {reply, [{websocket, true}, {peer, Req:get(peer)}| ProtoInfo], State}; handle_call(kick, _From, State) -> {stop, {shutdown, kick}, ok, State}; -handle_call(_Req, _From, State) -> - {reply, error, State}. +handle_call(Req, _From, State = #wsclient_state{request = HttpReq}) -> + ?WSLOG(critical, "Unexpected request: ~p", [Req], HttpReq), + {reply, {error, unsupported_request}, State}. handle_cast({subscribe, TopicTable}, State) -> - with_session(fun(SessPid) -> emqttd_session:subscribe(SessPid, TopicTable) end, State); + with_session(fun(SessPid) -> + emqttd_session:subscribe(SessPid, TopicTable) + end, State); handle_cast({unsubscribe, Topics}, State) -> - with_session(fun(SessPid) -> emqttd_session:unsubscribe(SessPid, Topics) end, State); + with_session(fun(SessPid) -> + emqttd_session:unsubscribe(SessPid, Topics) + end, State); -handle_cast({received, Packet}, State = #client_state{proto_state = ProtoState}) -> +handle_cast({received, Packet}, State = #wsclient_state{request = Req, + proto_state = ProtoState}) -> case emqttd_protocol:received(Packet, ProtoState) of {ok, ProtoState1} -> - noreply(State#client_state{proto_state = ProtoState1}); + noreply(State#wsclient_state{proto_state = ProtoState1}); {error, Error} -> - lager:error("MQTT protocol error ~p", [Error]), - stop({shutdown, Error}, State); + ?WSLOG(error, "Protocol error - ~p", [Error], Req), + shutdown(Error, State); {error, Error, ProtoState1} -> - stop({shutdown, Error}, State#client_state{proto_state = ProtoState1}); + shutdown(Error, State#wsclient_state{proto_state = ProtoState1}); {stop, Reason, ProtoState1} -> - stop(Reason, State#client_state{proto_state = ProtoState1}) + stop(Reason, State#wsclient_state{proto_state = ProtoState1}) end; -handle_cast(_Msg, State) -> +handle_cast(Msg, State = #wsclient_state{request = Req}) -> + ?WSLOG(critical, "Unexpected msg: ~p", [Msg], Req), {noreply, State}. -handle_info({deliver, Message}, State = #client_state{proto_state = ProtoState}) -> - {ok, ProtoState1} = emqttd_protocol:send(Message, ProtoState), - noreply(State#client_state{proto_state = ProtoState1}); +handle_info({suback, PacketId, GrantedQos}, State) -> + with_proto_state(fun(ProtoState) -> + Packet = ?SUBACK_PACKET(PacketId, GrantedQos), + emqttd_protocol:send(Packet, ProtoState) + end, State); -handle_info({redeliver, {?PUBREL, PacketId}}, State = #client_state{proto_state = ProtoState}) -> - {ok, ProtoState1} = emqttd_protocol:redeliver({?PUBREL, PacketId}, ProtoState), - noreply(State#client_state{proto_state = ProtoState1}); +handle_info({deliver, Message}, State) -> + with_proto_state(fun(ProtoState) -> + emqttd_protocol:send(Message, ProtoState) + end, State); -handle_info({stop, duplicate_id, _NewPid}, State = #client_state{proto_state = ProtoState}) -> - lager:error("Shutdown for duplicate clientid: ~s", [emqttd_protocol:clientid(ProtoState)]), - stop({shutdown, duplicate_id}, State); +handle_info({redeliver, {?PUBREL, PacketId}}, State) -> + with_proto_state(fun(ProtoState) -> + emqttd_protocol:redeliver({?PUBREL, PacketId}, ProtoState) + end, State); -handle_info({keepalive, start, TimeoutSec}, State = #client_state{request = Req}) -> - lager:debug("Client(WebSocket) ~s: Start KeepAlive with ~p seconds", [Req:get(peer), TimeoutSec]), - Socket = Req:get(socket), +handle_info({shutdown, conflict, {ClientId, NewPid}}, State = #wsclient_state{request = Req}) -> + ?WSLOG(warning, "clientid '~s' conflict with ~p", [ClientId, NewPid], Req), + shutdown(conflict, State); + +handle_info({keepalive, start, Interval}, State = #wsclient_state{request = Req}) -> + ?WSLOG(debug, "Keepalive at the interval of ~p", [Interval], Req), + Conn = Req:get(connection), StatFun = fun() -> - case esockd_transport:getstat(Socket, [recv_oct]) of + case Conn:getstat([recv_oct]) of {ok, [{recv_oct, RecvOct}]} -> {ok, RecvOct}; {error, Error} -> {error, Error} end end, - KeepAlive = emqttd_keepalive:start(StatFun, TimeoutSec, {keepalive, check}), - noreply(State#client_state{keepalive = KeepAlive}); + KeepAlive = emqttd_keepalive:start(StatFun, Interval, {keepalive, check}), + noreply(State#wsclient_state{keepalive = KeepAlive}); -handle_info({keepalive, check}, State = #client_state{request = Req, keepalive = KeepAlive}) -> +handle_info({keepalive, check}, State = #wsclient_state{request = Req, + keepalive = KeepAlive}) -> case emqttd_keepalive:check(KeepAlive) of {ok, KeepAlive1} -> - noreply(State#client_state{keepalive = KeepAlive1}); + noreply(State#wsclient_state{keepalive = KeepAlive1}); {error, timeout} -> - lager:debug("Client(WebSocket) ~s: Keepalive Timeout!", [Req:get(peer)]), - stop({shutdown, keepalive_timeout}, State#client_state{keepalive = undefined}); + ?WSLOG(debug, "Keepalive Timeout!", [], Req), + shutdown(keepalive_timeout, State); {error, Error} -> - lager:debug("Client(WebSocket) ~s: Keepalive Error: ~p", [Req:get(peer), Error]), - stop({shutdown, keepalive_error}, State#client_state{keepalive = undefined}) + ?WSLOG(warning, "Keepalive error - ~p", [Error], Req), + shutdown(keepalive_error, State) end; -handle_info({'EXIT', WsPid, Reason}, State = #client_state{ws_pid = WsPid, - proto_state = ProtoState}) -> - ClientId = emqttd_protocol:clientid(ProtoState), - lager:warning("Websocket client ~s exit: reason=~p", [ClientId, Reason]), - stop({shutdown, websocket_closed}, State); +handle_info({'EXIT', WsPid, Reason}, State = #wsclient_state{ws_pid = WsPid}) -> + stop(Reason, State); -handle_info(Info, State = #client_state{request = Req}) -> - lager:error("Client(WebSocket) ~s: Unexpected Info - ~p", [Req:get(peer), Info]), +handle_info(Info, State = #wsclient_state{request = Req}) -> + ?WSLOG(error, "Unexpected Info: ~p", [Info], Req), noreply(State). -terminate(Reason, #client_state{proto_state = ProtoState, keepalive = KeepAlive}) -> - lager:info("WebSocket client terminated: ~p", [Reason]), +terminate(Reason, #wsclient_state{proto_state = ProtoState, keepalive = KeepAlive}) -> emqttd_keepalive:cancel(KeepAlive), case Reason of {shutdown, Error} -> @@ -227,12 +246,19 @@ code_change(_OldVsn, State, _Extra) -> %%% Internal functions %%%============================================================================= +with_proto_state(Fun, State = #wsclient_state{proto_state = ProtoState}) -> + {ok, ProtoState1} = Fun(ProtoState), + noreply(State#wsclient_state{proto_state = ProtoState1}). + +with_session(Fun, State = #wsclient_state{proto_state = ProtoState}) -> + Fun(emqttd_protocol:session(ProtoState)), noreply(State). + noreply(State) -> {noreply, State, hibernate}. +shutdown(Reason, State) -> + stop({shutdown, Reason}, State). + stop(Reason, State ) -> {stop, Reason, State}. -with_session(Fun, State = #client_state{proto_state = ProtoState}) -> - Fun(emqttd_protocol:session(ProtoState)), noreply(State). - diff --git a/test/emqttd_access_control_tests.erl b/test/emqttd_access_control_tests.erl index 7db0490ef..5da45f4a8 100644 --- a/test/emqttd_access_control_tests.erl +++ b/test/emqttd_access_control_tests.erl @@ -42,30 +42,33 @@ register_mod_test() -> with_acl( fun() -> emqttd_access_control:register_mod(acl, emqttd_acl_test_mod, []), - ?assertMatch([{emqttd_acl_test_mod, _}, {emqttd_acl_internal, _}], + ?assertMatch([{emqttd_acl_test_mod, _, 0}, {emqttd_acl_internal, _, 0}], emqttd_access_control:lookup_mods(acl)), emqttd_access_control:register_mod(auth, emqttd_auth_anonymous_test_mod,[]), - ?assertMatch([{emqttd_auth_anonymous_test_mod, _}, {emqttd_auth_anonymous, _}], - emqttd_access_control:lookup_mods(auth)) + emqttd_access_control:register_mod(auth, emqttd_auth_dashboard, [], 99), + ?assertMatch([{emqttd_auth_dashboard, _, 99}, + {emqttd_auth_anonymous_test_mod, _, 0}, + {emqttd_auth_anonymous, _, 0}], + emqttd_access_control:lookup_mods(auth)) end). unregister_mod_test() -> with_acl( fun() -> - emqttd_access_control:register_mod(acl,emqttd_acl_test_mod, []), - ?assertMatch([{emqttd_acl_test_mod, _}, {emqttd_acl_internal, _}], + emqttd_access_control:register_mod(acl, emqttd_acl_test_mod, []), + ?assertMatch([{emqttd_acl_test_mod, _, 0}, {emqttd_acl_internal, _, 0}], emqttd_access_control:lookup_mods(acl)), emqttd_access_control:unregister_mod(acl, emqttd_acl_test_mod), timer:sleep(5), - ?assertMatch([{emqttd_acl_internal, _}], emqttd_access_control:lookup_mods(acl)), + ?assertMatch([{emqttd_acl_internal, _, 0}], emqttd_access_control:lookup_mods(acl)), emqttd_access_control:register_mod(auth, emqttd_auth_anonymous_test_mod,[]), - ?assertMatch([{emqttd_auth_anonymous_test_mod, _}, {emqttd_auth_anonymous, _}], + ?assertMatch([{emqttd_auth_anonymous_test_mod, _, 0}, {emqttd_auth_anonymous, _, 0}], emqttd_access_control:lookup_mods(auth)), emqttd_access_control:unregister_mod(auth, emqttd_auth_anonymous_test_mod), timer:sleep(5), - ?assertMatch([{emqttd_auth_anonymous, _}], emqttd_access_control:lookup_mods(auth)) + ?assertMatch([{emqttd_auth_anonymous, _, 0}], emqttd_access_control:lookup_mods(auth)) end). check_acl_test() -> @@ -83,7 +86,7 @@ check_acl_test() -> with_acl(Fun) -> process_flag(trap_exit, true), - AclOpts = [ + AclOpts = [ {auth, [ %% Authentication with username, password %{username, []}, diff --git a/test/emqttd_access_rule_tests.erl b/test/emqttd_access_rule_tests.erl index 142beeaeb..f46f23ce4 100644 --- a/test/emqttd_access_rule_tests.erl +++ b/test/emqttd_access_rule_tests.erl @@ -35,6 +35,14 @@ -include_lib("eunit/include/eunit.hrl"). compile_test() -> + + ?assertMatch({allow, {'and', [{ipaddr, {"127.0.0.1", _I, _I}}, + {user, <<"user">>}]}, subscribe, [ [<<"$SYS">>, '#'], ['#'] ]}, + compile({allow, {'and', [{ipaddr, "127.0.0.1"}, {user, <<"user">>}]}, subscribe, ["$SYS/#", "#"]})), + ?assertMatch({allow, {'or', [{ipaddr, {"127.0.0.1", _I, _I}}, + {user, <<"user">>}]}, subscribe, [ [<<"$SYS">>, '#'], ['#'] ]}, + compile({allow, {'or', [{ipaddr, "127.0.0.1"}, {user, <<"user">>}]}, subscribe, ["$SYS/#", "#"]})), + ?assertMatch({allow, {ipaddr, {"127.0.0.1", _I, _I}}, subscribe, [ [<<"$SYS">>, '#'], ['#'] ]}, compile({allow, {ipaddr, "127.0.0.1"}, subscribe, ["$SYS/#", "#"]})), ?assertMatch({allow, {user, <<"testuser">>}, subscribe, [ [<<"a">>, <<"b">>, <<"c">>], [<<"d">>, <<"e">>, <<"f">>, '#'] ]}, @@ -69,10 +77,15 @@ match_test() -> ?assertMatch({matched, allow}, match(User, <<"clients/testClient">>, compile({allow, all, pubsub, ["clients/$c"]}))), ?assertMatch({matched, allow}, match(#mqtt_client{username = <<"user2">>}, <<"users/user2/abc/def">>, - compile({allow, all, subscribe, ["users/$u/#"]}))), - ?assertMatch({matched, deny}, - match(User, <<"d/e/f">>, - compile({deny, all, subscribe, ["$SYS/#", "#"]}))). + compile({allow, all, subscribe, ["users/$u/#"]}))), + ?assertMatch({matched, deny}, match(User, <<"d/e/f">>, + compile({deny, all, subscribe, ["$SYS/#", "#"]}))), + Rule = compile({allow, {'and', [{ipaddr, "127.0.0.1"}, {user, <<"WrongUser">>}]}, publish, <<"Topic">>}), + ?assertMatch(nomatch, match(User, <<"Topic">>, Rule)), + AndRule = compile({allow, {'and', [{ipaddr, "127.0.0.1"}, {user, <<"TestUser">>}]}, publish, <<"Topic">>}), + ?assertMatch({matched, allow}, match(User, <<"Topic">>, AndRule)), + OrRule = compile({allow, {'or', [{ipaddr, "127.0.0.1"}, {user, <<"WrongUser">>}]}, publish, ["Topic"]}), + ?assertMatch({matched, allow}, match(User, <<"Topic">>, OrRule)). -endif. diff --git a/test/emqttd_auth_dashboard.erl b/test/emqttd_auth_dashboard.erl new file mode 100644 index 000000000..ea9aca7e1 --- /dev/null +++ b/test/emqttd_auth_dashboard.erl @@ -0,0 +1,14 @@ + +-module(emqttd_auth_dashboard). + +%% Auth callbacks +-export([init/1, check/3, description/0]). + +init(Opts) -> + {ok, Opts}. + +check(_Client, _Password, _Opts) -> + allow. + +description() -> + "Test emqttd_auth_dashboard Mod". diff --git a/test/emqttd_retained_tests.erl b/test/emqttd_retained_tests.erl new file mode 100644 index 000000000..b541a0731 --- /dev/null +++ b/test/emqttd_retained_tests.erl @@ -0,0 +1,14 @@ +-module(emqttd_retained_tests). + +-include("emqttd.hrl"). + +-ifdef(TEST). + +-include_lib("eunit/include/eunit.hrl"). + +retain_test() -> + mnesia:start(), + emqttd_retained:mnesia(boot), + mnesia:stop(). + +-endif.