fix(mqtt-sn): fix clean session false reconnect topic-id badmatch

This commit is contained in:
Turtle 2021-06-03 09:46:17 +08:00 committed by turtleDeng
parent 0ef1a1f53f
commit bfb02fe8c3
6 changed files with 169 additions and 164 deletions

View File

@ -43,7 +43,8 @@
start(_Type, _Args) -> start(_Type, _Args) ->
Addr = application:get_env(emqx_sn, port, 1884), Addr = application:get_env(emqx_sn, port, 1884),
GwId = application:get_env(emqx_sn, gateway_id, 1), GwId = application:get_env(emqx_sn, gateway_id, 1),
{ok, Sup} = emqx_sn_sup:start_link(Addr, GwId), PredefTopics = application:get_env(emqx_sn, predefined, []),
{ok, Sup} = emqx_sn_sup:start_link(Addr, GwId, PredefTopics),
start_listeners(), start_listeners(),
{ok, Sup}. {ok, Sup}.
@ -57,13 +58,7 @@ stop(_State) ->
-spec start_listeners() -> ok. -spec start_listeners() -> ok.
start_listeners() -> start_listeners() ->
PredefTopics = application:get_env(emqx_sn, predefined, []), lists:foreach(fun start_listener/1, listeners_confs()).
ListenCfs = [begin
TabName = tabname(Proto, ListenOn),
{ok, RegistryPid} = emqx_sn_sup:start_registry_proc(emqx_sn_sup, TabName, PredefTopics),
{Proto, ListenOn, [{registry, {TabName, RegistryPid}} | Options]}
end || {Proto, ListenOn, Options} <- listeners_confs()],
lists:foreach(fun start_listener/1, ListenCfs).
-spec start_listener(listener()) -> ok. -spec start_listener(listener()) -> ok.
start_listener({Proto, ListenOn, Options}) -> start_listener({Proto, ListenOn, Options}) ->
@ -151,7 +146,3 @@ format({Addr, Port}) when is_list(Addr) ->
io_lib:format("~s:~w", [Addr, Port]); io_lib:format("~s:~w", [Addr, Port]);
format({Addr, Port}) when is_tuple(Addr) -> format({Addr, Port}) when is_tuple(Addr) ->
io_lib:format("~s:~w", [inet:ntoa(Addr), Port]). io_lib:format("~s:~w", [inet:ntoa(Addr), Port]).
tabname(Proto, ListenOn) ->
list_to_atom(lists:flatten(["emqx_sn_registry__", atom_to_list(Proto), "_", format(ListenOn)])).

View File

@ -82,7 +82,6 @@
sockname :: {inet:ip_address(), inet:port()}, sockname :: {inet:ip_address(), inet:port()},
peername :: {inet:ip_address(), inet:port()}, peername :: {inet:ip_address(), inet:port()},
channel :: maybe(emqx_channel:channel()), channel :: maybe(emqx_channel:channel()),
registry :: emqx_sn_registry:registry(),
clientid :: maybe(binary()), clientid :: maybe(binary()),
username :: maybe(binary()), username :: maybe(binary()),
password :: maybe(binary()), password :: maybe(binary()),
@ -147,7 +146,6 @@ kick(GwPid) ->
init([{_, SockPid, Sock}, Peername, Options]) -> init([{_, SockPid, Sock}, Peername, Options]) ->
GwId = proplists:get_value(gateway_id, Options), GwId = proplists:get_value(gateway_id, Options),
Registry = proplists:get_value(registry, Options),
Username = proplists:get_value(username, Options, undefined), Username = proplists:get_value(username, Options, undefined),
Password = proplists:get_value(password, Options, undefined), Password = proplists:get_value(password, Options, undefined),
EnableQos3 = proplists:get_value(enable_qos3, Options, false), EnableQos3 = proplists:get_value(enable_qos3, Options, false),
@ -165,7 +163,6 @@ init([{_, SockPid, Sock}, Peername, Options]) ->
sockname = Sockname, sockname = Sockname,
peername = Peername, peername = Peername,
channel = Channel, channel = Channel,
registry = Registry,
asleep_timer = emqx_sn_asleep_timer:init(), asleep_timer = emqx_sn_asleep_timer:init(),
enable_stats = EnableStats, enable_stats = EnableStats,
enable_qos3 = EnableQos3, enable_qos3 = EnableQos3,
@ -205,9 +202,9 @@ idle(cast, {incoming, ?SN_PUBLISH_MSG(_Flag, _TopicId, _MsgId, _Data)}, State =
idle(cast, {incoming, ?SN_PUBLISH_MSG(#mqtt_sn_flags{qos = ?QOS_NEG1, idle(cast, {incoming, ?SN_PUBLISH_MSG(#mqtt_sn_flags{qos = ?QOS_NEG1,
topic_id_type = TopicIdType topic_id_type = TopicIdType
}, TopicId, _MsgId, Data)}, }, TopicId, _MsgId, Data)},
State = #state{clientid = ClientId, registry = Registry}) -> State = #state{clientid = ClientId}) ->
TopicName = case (TopicIdType =:= ?SN_SHORT_TOPIC) of TopicName = case (TopicIdType =:= ?SN_SHORT_TOPIC) of
false -> emqx_sn_registry:lookup_topic(Registry, ClientId, TopicId); false -> emqx_sn_registry:lookup_topic(ClientId, TopicId);
true -> <<TopicId:16>> true -> <<TopicId:16>>
end, end,
_ = case TopicName =/= undefined of _ = case TopicName =/= undefined of
@ -292,9 +289,9 @@ wait_for_will_msg(EventType, EventContent, State) ->
handle_event(EventType, EventContent, wait_for_will_msg, State). handle_event(EventType, EventContent, wait_for_will_msg, State).
connected(cast, {incoming, ?SN_REGISTER_MSG(_TopicId, MsgId, TopicName)}, connected(cast, {incoming, ?SN_REGISTER_MSG(_TopicId, MsgId, TopicName)},
State = #state{clientid = ClientId, registry = Registry}) -> State = #state{clientid = ClientId}) ->
State0 = State0 =
case emqx_sn_registry:register_topic(Registry, ClientId, TopicName) of case emqx_sn_registry:register_topic(ClientId, TopicName) of
TopicId when is_integer(TopicId) -> TopicId when is_integer(TopicId) ->
?LOG(debug, "register ClientId=~p, TopicName=~p, TopicId=~p", [ClientId, TopicName, TopicId]), ?LOG(debug, "register ClientId=~p, TopicName=~p, TopicId=~p", [ClientId, TopicName, TopicId]),
send_message(?SN_REGACK_MSG(TopicId, MsgId, ?SN_RC_ACCEPTED), State); send_message(?SN_REGACK_MSG(TopicId, MsgId, ?SN_RC_ACCEPTED), State);
@ -580,14 +577,13 @@ handle_event(EventType, EventContent, StateName, State) ->
[StateName, {EventType, EventContent}]), [StateName, {EventType, EventContent}]),
{keep_state, State}. {keep_state, State}.
terminate(Reason, _StateName, #state{channel = Channel, terminate(Reason, _StateName, #state{channel = Channel}) ->
registry = Registry}) ->
ClientId = emqx_channel:info(clientid, Channel), ClientId = emqx_channel:info(clientid, Channel),
case Reason of case Reason of
{shutdown, takeovered} -> {shutdown, takeovered} ->
ok; ok;
_ -> _ ->
emqx_sn_registry:unregister_topic(Registry, ClientId) emqx_sn_registry:unregister_topic(ClientId)
end, end,
emqx_channel:terminate(Reason, Channel), emqx_channel:terminate(Reason, Channel),
ok. ok.
@ -723,13 +719,12 @@ mqtt2sn(?PUBCOMP_PACKET(MsgId), _State) ->
mqtt2sn(?UNSUBACK_PACKET(MsgId), _State)-> mqtt2sn(?UNSUBACK_PACKET(MsgId), _State)->
?SN_UNSUBACK_MSG(MsgId); ?SN_UNSUBACK_MSG(MsgId);
mqtt2sn(?PUBLISH_PACKET(QoS, Topic, PacketId, Payload), #state{registry = Registry, mqtt2sn(?PUBLISH_PACKET(QoS, Topic, PacketId, Payload), #state{channel = Channel}) ->
channel = Channel}) ->
NewPacketId = if QoS =:= ?QOS_0 -> 0; NewPacketId = if QoS =:= ?QOS_0 -> 0;
true -> PacketId true -> PacketId
end, end,
ClientId = emqx_channel:info(clientid, Channel), ClientId = emqx_channel:info(clientid, Channel),
{TopicIdType, TopicContent} = case emqx_sn_registry:lookup_topic_id(Registry, ClientId, Topic) of {TopicIdType, TopicContent} = case emqx_sn_registry:lookup_topic_id(ClientId, Topic) of
{predef, PredefTopicId} -> {predef, PredefTopicId} ->
{?SN_PREDEFINED_TOPIC, PredefTopicId}; {?SN_PREDEFINED_TOPIC, PredefTopicId};
TopicId when is_integer(TopicId) -> TopicId when is_integer(TopicId) ->
@ -850,14 +845,13 @@ do_connect(ClientId, CleanStart, WillFlag, Duration, State) ->
do_2nd_connect(Flags, Duration, ClientId, State = #state{sockname = Sockname, do_2nd_connect(Flags, Duration, ClientId, State = #state{sockname = Sockname,
peername = Peername, peername = Peername,
registry = Registry,
channel = Channel}) -> channel = Channel}) ->
emqx_logger:set_metadata_clientid(ClientId), emqx_logger:set_metadata_clientid(ClientId),
#mqtt_sn_flags{will = Will, clean_start = CleanStart} = Flags, #mqtt_sn_flags{will = Will, clean_start = CleanStart} = Flags,
NChannel = case CleanStart of NChannel = case CleanStart of
true -> true ->
emqx_channel:terminate(normal, Channel), emqx_channel:terminate(normal, Channel),
emqx_sn_registry:unregister_topic(Registry, ClientId), emqx_sn_registry:unregister_topic(ClientId),
emqx_channel:init(#{socktype => udp, emqx_channel:init(#{socktype => udp,
sockname => Sockname, sockname => Sockname,
peername => Peername, peername => Peername,
@ -870,9 +864,9 @@ do_2nd_connect(Flags, Duration, ClientId, State = #state{sockname = Sockname,
do_connect(ClientId, CleanStart, Will, Duration, NState). do_connect(ClientId, CleanStart, Will, Duration, NState).
handle_subscribe(?SN_NORMAL_TOPIC, TopicName, QoS, MsgId, handle_subscribe(?SN_NORMAL_TOPIC, TopicName, QoS, MsgId,
State=#state{registry = Registry, channel = Channel}) -> State=#state{channel = Channel}) ->
ClientId = emqx_channel:info(clientid, Channel), ClientId = emqx_channel:info(clientid, Channel),
case emqx_sn_registry:register_topic(Registry, ClientId, TopicName) of case emqx_sn_registry:register_topic(ClientId, TopicName) of
{error, too_large} -> {error, too_large} ->
State0 = send_message(?SN_SUBACK_MSG(#mqtt_sn_flags{qos = QoS}, State0 = send_message(?SN_SUBACK_MSG(#mqtt_sn_flags{qos = QoS},
?SN_INVALID_TOPIC_ID, ?SN_INVALID_TOPIC_ID,
@ -886,9 +880,9 @@ handle_subscribe(?SN_NORMAL_TOPIC, TopicName, QoS, MsgId,
end; end;
handle_subscribe(?SN_PREDEFINED_TOPIC, TopicId, QoS, MsgId, handle_subscribe(?SN_PREDEFINED_TOPIC, TopicId, QoS, MsgId,
State = #state{registry = Registry, channel = Channel}) -> State = #state{channel = Channel}) ->
ClientId = emqx_channel:info(clientid, Channel), ClientId = emqx_channel:info(clientid, Channel),
case emqx_sn_registry:lookup_topic(Registry, ClientId, TopicId) of case emqx_sn_registry:lookup_topic(ClientId, TopicId) of
undefined -> undefined ->
State0 = send_message(?SN_SUBACK_MSG(#mqtt_sn_flags{qos = QoS}, State0 = send_message(?SN_SUBACK_MSG(#mqtt_sn_flags{qos = QoS},
TopicId, TopicId,
@ -917,9 +911,9 @@ handle_unsubscribe(?SN_NORMAL_TOPIC, TopicId, MsgId, State) ->
proto_unsubscribe(TopicId, MsgId, State); proto_unsubscribe(TopicId, MsgId, State);
handle_unsubscribe(?SN_PREDEFINED_TOPIC, TopicId, MsgId, handle_unsubscribe(?SN_PREDEFINED_TOPIC, TopicId, MsgId,
State = #state{registry = Registry, channel = Channel}) -> State = #state{channel = Channel}) ->
ClientId = emqx_channel:info(clientid, Channel), ClientId = emqx_channel:info(clientid, Channel),
case emqx_sn_registry:lookup_topic(Registry, ClientId, TopicId) of case emqx_sn_registry:lookup_topic(ClientId, TopicId) of
undefined -> undefined ->
{keep_state, send_message(?SN_UNSUBACK_MSG(MsgId), State)}; {keep_state, send_message(?SN_UNSUBACK_MSG(MsgId), State)};
PredefinedTopic -> PredefinedTopic ->
@ -941,11 +935,11 @@ do_publish(?SN_NORMAL_TOPIC, TopicName, Data, Flags, MsgId, State) ->
<<TopicId:16>> = TopicName, <<TopicId:16>> = TopicName,
do_publish(?SN_PREDEFINED_TOPIC, TopicId, Data, Flags, MsgId, State); do_publish(?SN_PREDEFINED_TOPIC, TopicId, Data, Flags, MsgId, State);
do_publish(?SN_PREDEFINED_TOPIC, TopicId, Data, Flags, MsgId, do_publish(?SN_PREDEFINED_TOPIC, TopicId, Data, Flags, MsgId,
State=#state{registry = Registry, channel = Channel}) -> State=#state{channel = Channel}) ->
#mqtt_sn_flags{qos = QoS, dup = Dup, retain = Retain} = Flags, #mqtt_sn_flags{qos = QoS, dup = Dup, retain = Retain} = Flags,
NewQoS = get_corrected_qos(QoS), NewQoS = get_corrected_qos(QoS),
ClientId = emqx_channel:info(clientid, Channel), ClientId = emqx_channel:info(clientid, Channel),
case emqx_sn_registry:lookup_topic(Registry, ClientId, TopicId) of case emqx_sn_registry:lookup_topic(ClientId, TopicId) of
undefined -> undefined ->
{keep_state, maybe_send_puback(NewQoS, TopicId, MsgId, ?SN_RC_INVALID_TOPIC_ID, {keep_state, maybe_send_puback(NewQoS, TopicId, MsgId, ?SN_RC_INVALID_TOPIC_ID,
State)}; State)};
@ -984,13 +978,13 @@ do_publish_will(#state{will_msg = WillMsg, clientid = ClientId}) ->
ok. ok.
do_puback(TopicId, MsgId, ReturnCode, StateName, do_puback(TopicId, MsgId, ReturnCode, StateName,
State=#state{registry = Registry, channel = Channel}) -> State=#state{channel = Channel}) ->
case ReturnCode of case ReturnCode of
?SN_RC_ACCEPTED -> ?SN_RC_ACCEPTED ->
handle_incoming(?PUBACK_PACKET(MsgId), StateName, State); handle_incoming(?PUBACK_PACKET(MsgId), StateName, State);
?SN_RC_INVALID_TOPIC_ID -> ?SN_RC_INVALID_TOPIC_ID ->
ClientId = emqx_channel:info(clientid, Channel), ClientId = emqx_channel:info(clientid, Channel),
case emqx_sn_registry:lookup_topic(Registry, ClientId, TopicId) of case emqx_sn_registry:lookup_topic(ClientId, TopicId) of
undefined -> {keep_state, State}; undefined -> {keep_state, State};
TopicName -> TopicName ->
%%notice that this TopicName maybe normal or predefined, %%notice that this TopicName maybe normal or predefined,
@ -1079,10 +1073,10 @@ handle_outgoing(Packets, State) when is_list(Packets) ->
end, State, Packets); end, State, Packets);
handle_outgoing(PubPkt = ?PUBLISH_PACKET(_, TopicName, _, _), handle_outgoing(PubPkt = ?PUBLISH_PACKET(_, TopicName, _, _),
State = #state{registry = Registry, channel = Channel}) -> State = #state{channel = Channel}) ->
?LOG(debug, "Handle outgoing publish: ~0p", [PubPkt]), ?LOG(debug, "Handle outgoing publish: ~0p", [PubPkt]),
ClientId = emqx_channel:info(clientid, Channel), ClientId = emqx_channel:info(clientid, Channel),
TopicId = emqx_sn_registry:lookup_topic_id(Registry, ClientId, TopicName), TopicId = emqx_sn_registry:lookup_topic_id(ClientId, TopicName),
case (TopicId == undefined) andalso (byte_size(TopicName) =/= 2) of case (TopicId == undefined) andalso (byte_size(TopicName) =/= 2) of
true -> register_and_notify_client(PubPkt, State); true -> register_and_notify_client(PubPkt, State);
false -> send_message(mqtt2sn(PubPkt, State), State) false -> send_message(mqtt2sn(PubPkt, State), State)
@ -1106,11 +1100,11 @@ replay_no_reg_pending_publishes(TopicId, #state{pending_topic_ids = Pendings} =
State#state{pending_topic_ids = maps:remove(TopicId, Pendings)}. State#state{pending_topic_ids = maps:remove(TopicId, Pendings)}.
register_and_notify_client(?PUBLISH_PACKET(QoS, TopicName, PacketId, Payload) = PubPkt, register_and_notify_client(?PUBLISH_PACKET(QoS, TopicName, PacketId, Payload) = PubPkt,
State = #state{registry = Registry, pending_topic_ids = Pendings, channel = Channel}) -> State = #state{pending_topic_ids = Pendings, channel = Channel}) ->
MsgId = message_id(PacketId), MsgId = message_id(PacketId),
#mqtt_packet{header = #mqtt_packet_header{dup = Dup, retain = Retain}} = PubPkt, #mqtt_packet{header = #mqtt_packet_header{dup = Dup, retain = Retain}} = PubPkt,
ClientId = emqx_channel:info(clientid, Channel), ClientId = emqx_channel:info(clientid, Channel),
TopicId = emqx_sn_registry:register_topic(Registry, ClientId, TopicName), TopicId = emqx_sn_registry:register_topic(ClientId, TopicName),
?LOG(debug, "Register TopicId=~p, TopicName=~p, Payload=~p, Dup=~p, QoS=~p, " ?LOG(debug, "Register TopicId=~p, TopicName=~p, Payload=~p, Dup=~p, QoS=~p, "
"Retain=~p, MsgId=~p", [TopicId, TopicName, Payload, Dup, QoS, Retain, MsgId]), "Retain=~p, MsgId=~p", [TopicId, TopicName, Payload, Dup, QoS, Retain, MsgId]),
NewPendings = cache_no_reg_publish_message(Pendings, TopicId, PubPkt, State), NewPendings = cache_no_reg_publish_message(Pendings, TopicId, PubPkt, State),

View File

@ -23,16 +23,16 @@
-define(LOG(Level, Format, Args), -define(LOG(Level, Format, Args),
emqx_logger:Level("MQTT-SN(registry): " ++ Format, Args)). emqx_logger:Level("MQTT-SN(registry): " ++ Format, Args)).
-export([ start_link/2 -export([ start_link/1
, stop/1 , stop/0
]). ]).
-export([ register_topic/3 -export([ register_topic/2
, unregister_topic/2 , unregister_topic/1
]). ]).
-export([ lookup_topic/3 -export([ lookup_topic/2
, lookup_topic_id/3 , lookup_topic_id/2
]). ]).
%% gen_server callbacks %% gen_server callbacks
@ -46,25 +46,45 @@
-define(TAB, ?MODULE). -define(TAB, ?MODULE).
-record(state, {tab, max_predef_topic_id = 0}). -record(state, {max_predef_topic_id = 0}).
-type(registry() :: {ets:tab(), pid()}). -record(emqx_sn_registry, {key, value}).
%% Mnesia bootstrap
-export([mnesia/1]).
-boot_mnesia({mnesia, [boot]}).
-copy_mnesia({mnesia, [copy]}).
%% @doc Create or replicate tables.
-spec(mnesia(boot | copy) -> ok).
mnesia(boot) ->
%% Optimize storage
StoreProps = [{ets, [{read_concurrency, true}]}],
ok = ekka_mnesia:create_table(?MODULE, [
{attributes, record_info(fields, emqx_sn_registry)},
{ram_copies, [node()]},
{storage_properties, StoreProps}]);
mnesia(copy) ->
ok = ekka_mnesia:copy_table(?MODULE, ram_copies).
%%----------------------------------------------------------------------------- %%-----------------------------------------------------------------------------
-spec(start_link(atom(), list()) -> {ok, pid()} | ignore | {error, Reason :: term()}). -spec(start_link(list()) -> {ok, pid()} | ignore | {error, Reason :: term()}).
start_link(Tab, PredefTopics) -> start_link(PredefTopics) ->
gen_server:start_link(?MODULE, [Tab, PredefTopics], []). gen_server:start_link({local, ?MODULE}, ?MODULE, [PredefTopics], []).
-spec(stop(registry()) -> ok). -spec(stop() -> ok).
stop({_Tab, Pid}) -> stop() ->
gen_server:stop(Pid, normal, infinity). gen_server:stop(?MODULE, normal, infinity).
-spec(register_topic(registry(), binary(), binary()) -> integer() | {error, term()}). -spec(register_topic(binary(), binary()) -> integer() | {error, term()}).
register_topic({_, Pid}, ClientId, TopicName) when is_binary(TopicName) -> register_topic(ClientId, TopicName) when is_binary(TopicName) ->
case emqx_topic:wildcard(TopicName) of case emqx_topic:wildcard(TopicName) of
false -> false ->
gen_server:call(Pid, {register, ClientId, TopicName}); gen_server:call(?MODULE, {register, ClientId, TopicName});
%% TopicId: in case of accepted the value that will be used as topic %% TopicId: in case of accepted the value that will be used as topic
%% id by the gateway when sending PUBLISH messages to the client (not %% id by the gateway when sending PUBLISH messages to the client (not
%% relevant in case of subscriptions to a short topic name or to a topic %% relevant in case of subscriptions to a short topic name or to a topic
@ -72,22 +92,22 @@ register_topic({_, Pid}, ClientId, TopicName) when is_binary(TopicName) ->
true -> {error, wildcard_topic} true -> {error, wildcard_topic}
end. end.
-spec(lookup_topic(registry(), binary(), pos_integer()) -> undefined | binary()). -spec(lookup_topic(binary(), pos_integer()) -> undefined | binary()).
lookup_topic({Tab, _Pid}, ClientId, TopicId) when is_integer(TopicId) -> lookup_topic(ClientId, TopicId) when is_integer(TopicId) ->
case lookup_element(Tab, {predef, TopicId}, 2) of case lookup_element(?TAB, {predef, TopicId}, 3) of
undefined -> undefined ->
lookup_element(Tab, {ClientId, TopicId}, 2); lookup_element(?TAB, {ClientId, TopicId}, 3);
Topic -> Topic Topic -> Topic
end. end.
-spec(lookup_topic_id(registry(), binary(), binary()) -spec(lookup_topic_id(binary(), binary())
-> undefined -> undefined
| pos_integer() | pos_integer()
| {predef, integer()}). | {predef, integer()}).
lookup_topic_id({Tab, _Pid}, ClientId, TopicName) when is_binary(TopicName) -> lookup_topic_id(ClientId, TopicName) when is_binary(TopicName) ->
case lookup_element(Tab, {predef, TopicName}, 2) of case lookup_element(?TAB, {predef, TopicName}, 3) of
undefined -> undefined ->
lookup_element(Tab, {ClientId, TopicName}, 2); lookup_element(?TAB, {ClientId, TopicName}, 3);
TopicId -> TopicId ->
{predef, TopicId} {predef, TopicId}
end. end.
@ -96,47 +116,59 @@ lookup_topic_id({Tab, _Pid}, ClientId, TopicName) when is_binary(TopicName) ->
lookup_element(Tab, Key, Pos) -> lookup_element(Tab, Key, Pos) ->
try ets:lookup_element(Tab, Key, Pos) catch error:badarg -> undefined end. try ets:lookup_element(Tab, Key, Pos) catch error:badarg -> undefined end.
-spec(unregister_topic(registry(), binary()) -> ok). -spec(unregister_topic(binary()) -> ok).
unregister_topic({_Tab, Pid}, ClientId) -> unregister_topic(ClientId) ->
gen_server:call(Pid, {unregister, ClientId}). gen_server:call(?MODULE, {unregister, ClientId}).
%%----------------------------------------------------------------------------- %%-----------------------------------------------------------------------------
init([Tab, PredefTopics]) -> init([PredefTopics]) ->
%% {predef, TopicId} -> TopicName %% {predef, TopicId} -> TopicName
%% {predef, TopicName} -> TopicId %% {predef, TopicName} -> TopicId
%% {ClientId, TopicId} -> TopicName %% {ClientId, TopicId} -> TopicName
%% {ClientId, TopicName} -> TopicId %% {ClientId, TopicName} -> TopicId
_ = ets:new(Tab, [set, public, named_table, {read_concurrency, true}]),
MaxPredefId = lists:foldl( MaxPredefId = lists:foldl(
fun({TopicId, TopicName}, AccId) -> fun({TopicId, TopicName}, AccId) ->
_ = ets:insert(Tab, {{predef, TopicId}, TopicName}), mnesia:dirty_write(#emqx_sn_registry{key = {predef, TopicId},
_ = ets:insert(Tab, {{predef, TopicName}, TopicId}), value = TopicName}),
mnesia:dirty_write(#emqx_sn_registry{key = {predef, TopicName},
value = TopicId}),
if TopicId > AccId -> TopicId; true -> AccId end if TopicId > AccId -> TopicId; true -> AccId end
end, 0, PredefTopics), end, 0, PredefTopics),
{ok, #state{tab = Tab, max_predef_topic_id = MaxPredefId}}. {ok, #state{max_predef_topic_id = MaxPredefId}}.
handle_call({register, ClientId, TopicName}, _From, handle_call({register, ClientId, TopicName}, _From,
State = #state{tab = Tab, max_predef_topic_id = PredefId}) -> State = #state{max_predef_topic_id = PredefId}) ->
case lookup_topic_id({Tab, self()}, ClientId, TopicName) of case lookup_topic_id(ClientId, TopicName) of
{predef, PredefTopicId} when is_integer(PredefTopicId) -> {predef, PredefTopicId} when is_integer(PredefTopicId) ->
{reply, PredefTopicId, State}; {reply, PredefTopicId, State};
TopicId when is_integer(TopicId) -> TopicId when is_integer(TopicId) ->
{reply, TopicId, State}; {reply, TopicId, State};
undefined -> undefined ->
case next_topic_id(Tab, PredefId, ClientId) of case next_topic_id(?TAB, PredefId, ClientId) of
TopicId when TopicId >= 16#FFFF -> TopicId when TopicId >= 16#FFFF ->
{reply, {error, too_large}, State}; {reply, {error, too_large}, State};
TopicId -> TopicId ->
_ = ets:insert(Tab, {{ClientId, next_topic_id}, TopicId + 1}), Fun = fun() ->
_ = ets:insert(Tab, {{ClientId, TopicName}, TopicId}), mnesia:write(#emqx_sn_registry{key = {ClientId, next_topic_id},
_ = ets:insert(Tab, {{ClientId, TopicId}, TopicName}), value = TopicId + 1}),
{reply, TopicId, State} mnesia:write(#emqx_sn_registry{key = {ClientId, TopicName},
value = TopicId}),
mnesia:write(#emqx_sn_registry{key = {ClientId, TopicId},
value = TopicName})
end,
case mnesia:transaction(Fun) of
{atomic, ok} ->
{reply, TopicId, State};
{aborted, Error} ->
{reply, {error, Error}, State}
end
end end
end; end;
handle_call({unregister, ClientId}, _From, State = #state{tab = Tab}) -> handle_call({unregister, ClientId}, _From, State) ->
ets:match_delete(Tab, {{ClientId, '_'}, '_'}), Registry = mnesia:dirty_match_object({?TAB, {ClientId, '_'}, '_'}),
lists:foreach(fun(R) -> mnesia:dirty_delete_object(R) end, Registry),
{reply, ok, State}; {reply, ok, State};
handle_call(Req, _From, State) -> handle_call(Req, _From, State) ->
@ -160,7 +192,7 @@ code_change(_OldVsn, State, _Extra) ->
%%----------------------------------------------------------------------------- %%-----------------------------------------------------------------------------
next_topic_id(Tab, PredefId, ClientId) -> next_topic_id(Tab, PredefId, ClientId) ->
case ets:lookup(Tab, {ClientId, next_topic_id}) of case mnesia:dirty_read(Tab, {ClientId, next_topic_id}) of
[{_, Id}] -> Id; [#emqx_sn_registry{value = Id}] -> Id;
[] -> PredefId + 1 [] -> PredefId + 1
end. end.

View File

@ -18,32 +18,26 @@
-behaviour(supervisor). -behaviour(supervisor).
-export([ start_link/2 -export([ start_link/3
, start_registry_proc/3
, init/1 , init/1
]). ]).
start_registry_proc(Sup, TabName, PredefTopics) -> start_link(Addr, GwId, PredefTopics) ->
Registry = #{id => TabName, supervisor:start_link({local, ?MODULE}, ?MODULE, [Addr, GwId, PredefTopics]).
start => {emqx_sn_registry, start_link, [TabName, PredefTopics]},
restart => permanent,
shutdown => 5000,
type => worker,
modules => [emqx_sn_registry]},
handle_ret(supervisor:start_child(Sup, Registry)).
start_link(Addr, GwId) -> init([{_Ip, Port}, GwId, PredefTopics]) ->
supervisor:start_link({local, ?MODULE}, ?MODULE, [Addr, GwId]).
init([{_Ip, Port}, GwId]) ->
Broadcast = #{id => emqx_sn_broadcast, Broadcast = #{id => emqx_sn_broadcast,
start => {emqx_sn_broadcast, start_link, [GwId, Port]}, start => {emqx_sn_broadcast, start_link, [GwId, Port]},
restart => permanent, restart => permanent,
shutdown => brutal_kill, shutdown => brutal_kill,
type => worker, type => worker,
modules => [emqx_sn_broadcast]}, modules => [emqx_sn_broadcast]},
{ok, {{one_for_one, 10, 3600}, [Broadcast]}}. Registry = #{id => emqx_sn_registry,
start => {emqx_sn_registry, start_link, [PredefTopics]},
restart => permanent,
shutdown => brutal_kill,
type => worker,
modules => [emqx_sn_registry]},
{ok, {{one_for_one, 10, 3600}, [Broadcast, Registry]}}.
handle_ret({ok, Pid, _Info}) -> {ok, Pid};
handle_ret(Ret) -> Ret.

View File

@ -1084,7 +1084,7 @@ t_asleep_test03_to_awake_qos1_dl_msg(_) ->
{ok, C} = emqtt:start_link(), {ok, C} = emqtt:start_link(),
{ok, _} = emqtt:connect(C), {ok, _} = emqtt:connect(C),
{ok, _} = emqtt:publish(C, TopicName1, Payload1, QoS), {ok, _} = emqtt:publish(C, TopicName1, Payload1, QoS),
timer:sleep(500), timer:sleep(100),
ok = emqtt:disconnect(C), ok = emqtt:disconnect(C),
timer:sleep(50), timer:sleep(50),

View File

@ -16,12 +16,9 @@
-module(emqx_sn_registry_SUITE). -module(emqx_sn_registry_SUITE).
-import(proplists, [get_value/2]).
-compile(export_all). -compile(export_all).
-compile(nowarn_export_all). -compile(nowarn_export_all).
-include_lib("emqx_sn/include/emqx_sn.hrl").
-include_lib("eunit/include/eunit.hrl"). -include_lib("eunit/include/eunit.hrl").
-define(REGISTRY, emqx_sn_registry). -define(REGISTRY, emqx_sn_registry).
@ -44,84 +41,81 @@ end_per_suite(_Config) ->
ok. ok.
init_per_testcase(_TestCase, Config) -> init_per_testcase(_TestCase, Config) ->
ekka_mnesia:start(),
emqx_sn_registry:mnesia(boot),
mnesia:clear_table(emqx_sn_registry),
PredefTopics = application:get_env(emqx_sn, predefined, []), PredefTopics = application:get_env(emqx_sn, predefined, []),
TabName = emqx_sn_registry, {ok, _Pid} = ?REGISTRY:start_link(PredefTopics),
{ok, Pid} = ?REGISTRY:start_link(TabName, PredefTopics), Config.
[{registray, {TabName, Pid}} | Config].
end_per_testcase(_TestCase, Config) -> end_per_testcase(_TestCase, Config) ->
?REGISTRY:stop(get_value(registray, Config)), ?REGISTRY:stop(),
Config. Config.
%%-------------------------------------------------------------------- %%--------------------------------------------------------------------
%% Test cases %% Test cases
%%-------------------------------------------------------------------- %%--------------------------------------------------------------------
t_register(Config) -> t_register(_Config) ->
Registry = get_value(registray, Config), ?assertEqual(?MAX_PREDEF_ID+1, ?REGISTRY:register_topic(<<"ClientId">>, <<"Topic1">>)),
?assertEqual(?MAX_PREDEF_ID+1, ?REGISTRY:register_topic(Registry, <<"ClientId">>, <<"Topic1">>)), ?assertEqual(?MAX_PREDEF_ID+2, ?REGISTRY:register_topic(<<"ClientId">>, <<"Topic2">>)),
?assertEqual(?MAX_PREDEF_ID+2, ?REGISTRY:register_topic(Registry, <<"ClientId">>, <<"Topic2">>)), ?assertEqual(<<"Topic1">>, ?REGISTRY:lookup_topic(<<"ClientId">>, ?MAX_PREDEF_ID+1)),
?assertEqual(<<"Topic1">>, ?REGISTRY:lookup_topic(Registry, <<"ClientId">>, ?MAX_PREDEF_ID+1)), ?assertEqual(<<"Topic2">>, ?REGISTRY:lookup_topic(<<"ClientId">>, ?MAX_PREDEF_ID+2)),
?assertEqual(<<"Topic2">>, ?REGISTRY:lookup_topic(Registry, <<"ClientId">>, ?MAX_PREDEF_ID+2)), ?assertEqual(?MAX_PREDEF_ID+1, ?REGISTRY:lookup_topic_id(<<"ClientId">>, <<"Topic1">>)),
?assertEqual(?MAX_PREDEF_ID+1, ?REGISTRY:lookup_topic_id(Registry, <<"ClientId">>, <<"Topic1">>)), ?assertEqual(?MAX_PREDEF_ID+2, ?REGISTRY:lookup_topic_id(<<"ClientId">>, <<"Topic2">>)),
?assertEqual(?MAX_PREDEF_ID+2, ?REGISTRY:lookup_topic_id(Registry, <<"ClientId">>, <<"Topic2">>)), emqx_sn_registry:unregister_topic(<<"ClientId">>),
emqx_sn_registry:unregister_topic(Registry, <<"ClientId">>), ?assertEqual(undefined, ?REGISTRY:lookup_topic(<<"ClientId">>, ?MAX_PREDEF_ID+1)),
?assertEqual(undefined, ?REGISTRY:lookup_topic(Registry, <<"ClientId">>, ?MAX_PREDEF_ID+1)), ?assertEqual(undefined, ?REGISTRY:lookup_topic(<<"ClientId">>, ?MAX_PREDEF_ID+2)),
?assertEqual(undefined, ?REGISTRY:lookup_topic(Registry, <<"ClientId">>, ?MAX_PREDEF_ID+2)), ?assertEqual(undefined, ?REGISTRY:lookup_topic_id(<<"ClientId">>, <<"Topic1">>)),
?assertEqual(undefined, ?REGISTRY:lookup_topic_id(Registry, <<"ClientId">>, <<"Topic1">>)), ?assertEqual(undefined, ?REGISTRY:lookup_topic_id(<<"ClientId">>, <<"Topic2">>)).
?assertEqual(undefined, ?REGISTRY:lookup_topic_id(Registry, <<"ClientId">>, <<"Topic2">>)).
t_register_case2(Config) -> t_register_case2(_Config) ->
Registry = get_value(registray, Config), ?assertEqual(?MAX_PREDEF_ID+1, ?REGISTRY:register_topic(<<"ClientId">>, <<"Topic1">>)),
?assertEqual(?MAX_PREDEF_ID+1, ?REGISTRY:register_topic(Registry, <<"ClientId">>, <<"Topic1">>)), ?assertEqual(?MAX_PREDEF_ID+2, ?REGISTRY:register_topic(<<"ClientId">>, <<"Topic2">>)),
?assertEqual(?MAX_PREDEF_ID+2, ?REGISTRY:register_topic(Registry, <<"ClientId">>, <<"Topic2">>)), ?assertEqual(?MAX_PREDEF_ID+1, ?REGISTRY:register_topic(<<"ClientId">>, <<"Topic1">>)),
?assertEqual(?MAX_PREDEF_ID+1, ?REGISTRY:register_topic(Registry, <<"ClientId">>, <<"Topic1">>)), ?assertEqual(<<"Topic1">>, ?REGISTRY:lookup_topic(<<"ClientId">>, ?MAX_PREDEF_ID+1)),
?assertEqual(<<"Topic1">>, ?REGISTRY:lookup_topic(Registry, <<"ClientId">>, ?MAX_PREDEF_ID+1)), ?assertEqual(<<"Topic2">>, ?REGISTRY:lookup_topic(<<"ClientId">>, ?MAX_PREDEF_ID+2)),
?assertEqual(<<"Topic2">>, ?REGISTRY:lookup_topic(Registry, <<"ClientId">>, ?MAX_PREDEF_ID+2)), ?assertEqual(?MAX_PREDEF_ID+1, ?REGISTRY:lookup_topic_id(<<"ClientId">>, <<"Topic1">>)),
?assertEqual(?MAX_PREDEF_ID+1, ?REGISTRY:lookup_topic_id(Registry, <<"ClientId">>, <<"Topic1">>)), ?assertEqual(?MAX_PREDEF_ID+2, ?REGISTRY:lookup_topic_id(<<"ClientId">>, <<"Topic2">>)),
?assertEqual(?MAX_PREDEF_ID+2, ?REGISTRY:lookup_topic_id(Registry, <<"ClientId">>, <<"Topic2">>)), ?assertEqual(undefined, ?REGISTRY:lookup_topic_id(<<"ClientId">>, <<"Topic3">>)),
?assertEqual(undefined, ?REGISTRY:lookup_topic_id(Registry, <<"ClientId">>, <<"Topic3">>)), ?REGISTRY:unregister_topic(<<"ClientId">>),
?REGISTRY:unregister_topic(Registry, <<"ClientId">>), ?assertEqual(undefined, ?REGISTRY:lookup_topic(<<"ClientId">>, ?MAX_PREDEF_ID+1)),
?assertEqual(undefined, ?REGISTRY:lookup_topic(Registry, <<"ClientId">>, ?MAX_PREDEF_ID+1)), ?assertEqual(undefined, ?REGISTRY:lookup_topic(<<"ClientId">>, ?MAX_PREDEF_ID+2)),
?assertEqual(undefined, ?REGISTRY:lookup_topic(Registry, <<"ClientId">>, ?MAX_PREDEF_ID+2)), ?assertEqual(undefined, ?REGISTRY:lookup_topic_id(<<"ClientId">>, <<"Topic1">>)),
?assertEqual(undefined, ?REGISTRY:lookup_topic_id(Registry, <<"ClientId">>, <<"Topic1">>)), ?assertEqual(undefined, ?REGISTRY:lookup_topic_id(<<"ClientId">>, <<"Topic2">>)).
?assertEqual(undefined, ?REGISTRY:lookup_topic_id(Registry, <<"ClientId">>, <<"Topic2">>)).
t_reach_maximum(Config) -> t_reach_maximum(_Config) ->
Registry = get_value(registray, Config), register_a_lot(?MAX_PREDEF_ID+1, 16#ffff),
register_a_lot(Registry, ?MAX_PREDEF_ID+1, 16#ffff), ?assertEqual({error, too_large}, ?REGISTRY:register_topic(<<"ClientId">>, <<"TopicABC">>)),
?assertEqual({error, too_large}, ?REGISTRY:register_topic(Registry, <<"ClientId">>, <<"TopicABC">>)),
Topic1 = iolist_to_binary(io_lib:format("Topic~p", [?MAX_PREDEF_ID+1])), Topic1 = iolist_to_binary(io_lib:format("Topic~p", [?MAX_PREDEF_ID+1])),
Topic2 = iolist_to_binary(io_lib:format("Topic~p", [?MAX_PREDEF_ID+2])), Topic2 = iolist_to_binary(io_lib:format("Topic~p", [?MAX_PREDEF_ID+2])),
?assertEqual(?MAX_PREDEF_ID+1, ?REGISTRY:lookup_topic_id(Registry, <<"ClientId">>, Topic1)), ?assertEqual(?MAX_PREDEF_ID+1, ?REGISTRY:lookup_topic_id(<<"ClientId">>, Topic1)),
?assertEqual(?MAX_PREDEF_ID+2, ?REGISTRY:lookup_topic_id(Registry, <<"ClientId">>, Topic2)), ?assertEqual(?MAX_PREDEF_ID+2, ?REGISTRY:lookup_topic_id(<<"ClientId">>, Topic2)),
?REGISTRY:unregister_topic(Registry, <<"ClientId">>), ?REGISTRY:unregister_topic(<<"ClientId">>),
?assertEqual(undefined, ?REGISTRY:lookup_topic(Registry, <<"ClientId">>, ?MAX_PREDEF_ID+1)), ?assertEqual(undefined, ?REGISTRY:lookup_topic(<<"ClientId">>, ?MAX_PREDEF_ID+1)),
?assertEqual(undefined, ?REGISTRY:lookup_topic(Registry, <<"ClientId">>, ?MAX_PREDEF_ID+2)), ?assertEqual(undefined, ?REGISTRY:lookup_topic(<<"ClientId">>, ?MAX_PREDEF_ID+2)),
?assertEqual(undefined, ?REGISTRY:lookup_topic_id(Registry, <<"ClientId">>, Topic1)), ?assertEqual(undefined, ?REGISTRY:lookup_topic_id(<<"ClientId">>, Topic1)),
?assertEqual(undefined, ?REGISTRY:lookup_topic_id(Registry, <<"ClientId">>, Topic2)). ?assertEqual(undefined, ?REGISTRY:lookup_topic_id(<<"ClientId">>, Topic2)).
t_register_case4(Config) -> t_register_case4(_Config) ->
Registry = get_value(registray, Config), ?assertEqual(?MAX_PREDEF_ID+1, ?REGISTRY:register_topic(<<"ClientId">>, <<"TopicA">>)),
?assertEqual(?MAX_PREDEF_ID+1, ?REGISTRY:register_topic(Registry, <<"ClientId">>, <<"TopicA">>)), ?assertEqual(?MAX_PREDEF_ID+2, ?REGISTRY:register_topic(<<"ClientId">>, <<"TopicB">>)),
?assertEqual(?MAX_PREDEF_ID+2, ?REGISTRY:register_topic(Registry, <<"ClientId">>, <<"TopicB">>)), ?assertEqual(?MAX_PREDEF_ID+3, ?REGISTRY:register_topic(<<"ClientId">>, <<"TopicC">>)),
?assertEqual(?MAX_PREDEF_ID+3, ?REGISTRY:register_topic(Registry, <<"ClientId">>, <<"TopicC">>)), ?REGISTRY:unregister_topic(<<"ClientId">>),
?REGISTRY:unregister_topic(Registry, <<"ClientId">>), ?assertEqual(?MAX_PREDEF_ID+1, ?REGISTRY:register_topic(<<"ClientId">>, <<"TopicD">>)).
?assertEqual(?MAX_PREDEF_ID+1, ?REGISTRY:register_topic(Registry, <<"ClientId">>, <<"TopicD">>)).
t_deny_wildcard_topic(Config) -> t_deny_wildcard_topic(_Config) ->
Registry = get_value(registray, Config), ?assertEqual({error, wildcard_topic}, ?REGISTRY:register_topic(<<"ClientId">>, <<"/TopicA/#">>)),
?assertEqual({error, wildcard_topic}, ?REGISTRY:register_topic(Registry, <<"ClientId">>, <<"/TopicA/#">>)), ?assertEqual({error, wildcard_topic}, ?REGISTRY:register_topic(<<"ClientId">>, <<"/+/TopicB">>)).
?assertEqual({error, wildcard_topic}, ?REGISTRY:register_topic(Registry, <<"ClientId">>, <<"/+/TopicB">>)).
%%-------------------------------------------------------------------- %%--------------------------------------------------------------------
%% Helper funcs %% Helper funcs
%%-------------------------------------------------------------------- %%--------------------------------------------------------------------
register_a_lot(_, Max, Max) -> register_a_lot(Max, Max) ->
ok; ok;
register_a_lot(Registry, N, Max) when N < Max -> register_a_lot(N, Max) when N < Max ->
Topic = iolist_to_binary(["Topic", integer_to_list(N)]), Topic = iolist_to_binary(["Topic", integer_to_list(N)]),
?assertEqual(N, ?REGISTRY:register_topic(Registry, <<"ClientId">>, Topic)), ?assertEqual(N, ?REGISTRY:register_topic(<<"ClientId">>, Topic)),
register_a_lot(Registry, N+1, Max). register_a_lot(N+1, Max).