diff --git a/include/emqx.hrl b/include/emqx.hrl index d676be9cf..ee048f1fa 100644 --- a/include/emqx.hrl +++ b/include/emqx.hrl @@ -64,7 +64,7 @@ %% Message flags flags :: #{atom() => boolean()}, %% Message headers, or MQTT 5.0 Properties - headers = #{}, + headers :: map(), %% Topic that the message is published to topic :: binary(), %% Message Payload diff --git a/src/emqx_access_rule.erl b/src/emqx_access_rule.erl index dbde32984..e48caa5c0 100644 --- a/src/emqx_access_rule.erl +++ b/src/emqx_access_rule.erl @@ -108,9 +108,9 @@ match_who(#{client_id := ClientId}, {client, ClientId}) -> true; match_who(#{username := Username}, {user, Username}) -> true; -match_who(#{peername := undefined}, {ipaddr, _Tup}) -> +match_who(#{peerhost := undefined}, {ipaddr, _Tup}) -> false; -match_who(#{peername := {IP, _}}, {ipaddr, CIDR}) -> +match_who(#{peerhost := IP}, {ipaddr, CIDR}) -> esockd_cidr:match(IP, CIDR); match_who(Client, {'and', Conds}) when is_list(Conds) -> lists:foldl(fun(Who, Allow) -> diff --git a/src/emqx_banned.erl b/src/emqx_banned.erl index c24bb4294..7e3e959e3 100644 --- a/src/emqx_banned.erl +++ b/src/emqx_banned.erl @@ -30,7 +30,7 @@ -boot_mnesia({mnesia, [boot]}). -copy_mnesia({mnesia, [copy]}). --export([start_link/0]). +-export([start_link/0, stop/0]). -export([ check/1 , add/1 @@ -69,11 +69,14 @@ mnesia(copy) -> start_link() -> gen_server:start_link({local, ?MODULE}, ?MODULE, [], []). +%% for tests +-spec(stop() -> ok). +stop() -> gen_server:stop(?MODULE). + -spec(check(emqx_types:client()) -> boolean()). check(#{client_id := ClientId, username := Username, - peername := {IPAddr, _} - }) -> + peerhost := IPAddr}) -> ets:member(?BANNED_TAB, {client_id, ClientId}) orelse ets:member(?BANNED_TAB, {username, Username}) orelse ets:member(?BANNED_TAB, {ipaddr, IPAddr}). @@ -82,11 +85,10 @@ check(#{client_id := ClientId, add(Banned) when is_record(Banned, banned) -> mnesia:dirty_write(?BANNED_TAB, Banned). --spec(delete({client_id, emqx_types:client_id()} | - {username, emqx_types:username()} | - {peername, emqx_types:peername()}) -> ok). -delete(Key) -> - mnesia:dirty_delete(?BANNED_TAB, Key). +-spec(delete({client_id, emqx_types:client_id()} + | {username, emqx_types:username()} + | {peerhost, emqx_types:peerhost()}) -> ok). +delete(Key) -> mnesia:dirty_delete(?BANNED_TAB, Key). info(InfoKey) -> mnesia:table_info(?BANNED_TAB, InfoKey). @@ -107,8 +109,7 @@ handle_cast(Msg, State) -> {noreply, State}. handle_info({timeout, TRef, expire}, State = #{expiry_timer := TRef}) -> - mnesia:async_dirty(fun expire_banned_items/1, - [erlang:system_time(second)]), + mnesia:async_dirty(fun expire_banned_items/1, [erlang:system_time(second)]), {noreply, ensure_expiry_timer(State), hibernate}; handle_info(Info, State) -> @@ -127,7 +128,7 @@ code_change(_OldVsn, State, _Extra) -> -ifdef(TEST). ensure_expiry_timer(State) -> - State#{expiry_timer := emqx_misc:start_timer(timer:seconds(1), expire)}. + State#{expiry_timer := emqx_misc:start_timer(10, expire)}. -else. ensure_expiry_timer(State) -> State#{expiry_timer := emqx_misc:start_timer(timer:minutes(1), expire)}. diff --git a/src/emqx_channel.erl b/src/emqx_channel.erl index f317dbd8d..117cb3fe2 100644 --- a/src/emqx_channel.erl +++ b/src/emqx_channel.erl @@ -24,8 +24,6 @@ -logger_header("[Channel]"). --export([init/2]). - -export([ info/1 , info/2 , attrs/1 @@ -36,12 +34,13 @@ %% Exports for unit tests:( -export([set_field/3]). --export([ handle_in/2 +-export([ init/2 + , handle_in/2 , handle_out/2 , handle_call/2 , handle_cast/2 , handle_info/2 - , timeout/3 + , handle_timeout/3 , terminate/2 ]). @@ -58,19 +57,25 @@ -export_type([channel/0]). -record(channel, { - %% MQTT Client + %% MQTT ConnInfo + conninfo :: emqx_types:conninfo(), + %% MQTT ClientInfo client :: emqx_types:client(), %% MQTT Session session :: emqx_session:session(), - %% MQTT Protocol - protocol :: emqx_protocol:protocol(), %% Keepalive keepalive :: emqx_keepalive:keepalive(), + %% MQTT Will Msg + will_msg :: emqx_types:message(), + %% MQTT Topic Aliases + topic_aliases :: maybe(map()), + %% MQTT Topic Alias Maximum + alias_maximum :: maybe(map()), %% Timers timers :: #{atom() => disabled | maybe(reference())}, %% GC State gc_state :: maybe(emqx_gc:gc_state()), - %% OOM Policy + %% OOM Policy TODO: should be removed from channel. oom_policy :: maybe(emqx_oom:oom_policy()), %% Connected connected :: undefined | boolean(), @@ -97,53 +102,8 @@ will_timer => will_message }). --define(ATTR_KEYS, [client, session, protocol, connected, connected_at, disconnected_at]). - --define(INFO_KEYS, ?ATTR_KEYS ++ [keepalive, gc_state, disconnected_at]). - -%%-------------------------------------------------------------------- -%% Init the channel -%%-------------------------------------------------------------------- - --spec(init(emqx_types:conninfo(), proplists:proplist()) -> channel()). -init(ConnInfo, Options) -> - Zone = proplists:get_value(zone, Options), - Peercert = maps:get(peercert, ConnInfo, undefined), - Username = case peer_cert_as_username(Options) of - cn -> esockd_peercert:common_name(Peercert); - dn -> esockd_peercert:subject(Peercert); - crt -> Peercert; - _ -> undefined - end, - MountPoint = emqx_zone:get_env(Zone, mountpoint), - Client = maps:merge(#{zone => Zone, - username => Username, - client_id => <<>>, - mountpoint => MountPoint, - is_bridge => false, - is_superuser => false - }, ConnInfo), - EnableStats = emqx_zone:get_env(Zone, enable_stats, true), - StatsTimer = if - EnableStats -> undefined; - ?Otherwise -> disabled - end, - GcState = maybe_apply(fun emqx_gc:init/1, - emqx_zone:get_env(Zone, force_gc_policy)), - OomPolicy = maybe_apply(fun emqx_oom:init/1, - emqx_zone:get_env(Zone, force_shutdown_policy)), - #channel{client = Client, - gc_state = GcState, - oom_policy = OomPolicy, - timers = #{stats_timer => StatsTimer}, - connected = undefined, - takeover = false, - resuming = false, - pendings = [] - }. - -peer_cert_as_username(Options) -> - proplists:get_value(peer_cert_as_username, Options). +-define(ATTR_KEYS, [conninfo, client, session, connected, connected_at, disconnected_at]). +-define(INFO_KEYS, ?ATTR_KEYS ++ [keepalive, topic_aliases, alias_maximum, gc_state, disconnected_at]). %%-------------------------------------------------------------------- %% Info, Attrs and Caps @@ -157,14 +117,18 @@ info(Channel) -> -spec(info(list(atom())|atom(), channel()) -> term()). info(Keys, Channel) when is_list(Keys) -> [{Key, info(Key, Channel)} || Key <- Keys]; -info(client, #channel{client = Client}) -> - Client; +info(conninfo, #channel{conninfo = ConnInfo}) -> + ConnInfo; +info(client, #channel{client = ClientInfo}) -> + ClientInfo; info(session, #channel{session = Session}) -> maybe_apply(fun emqx_session:info/1, Session); -info(protocol, #channel{protocol = Protocol}) -> - maybe_apply(fun emqx_protocol:info/1, Protocol); info(keepalive, #channel{keepalive = Keepalive}) -> maybe_apply(fun emqx_keepalive:info/1, Keepalive); +info(topic_aliases, #channel{topic_aliases = Aliases}) -> + Aliases; +info(alias_maximum, #channel{alias_maximum = Limits}) -> + Limits; info(gc_state, #channel{gc_state = GcState}) -> maybe_apply(fun emqx_gc:info/1, GcState); info(oom_policy, #channel{oom_policy = OomPolicy}) -> @@ -181,8 +145,8 @@ info(disconnected_at, #channel{disconnected_at = DisconnectedAt}) -> attrs(Channel) -> maps:from_list([{Key, attr(Key, Channel)} || Key <- ?ATTR_KEYS]). -attr(protocol, #channel{protocol = Proto}) -> - maybe_apply(fun emqx_protocol:attrs/1, Proto); +attr(conninfo, #channel{conninfo = ConnInfo}) -> + ConnInfo; attr(session, #channel{session = Session}) -> maybe_apply(fun emqx_session:attrs/1, Session); attr(Key, Channel) -> info(Key, Channel). @@ -201,6 +165,54 @@ set_field(Name, Val, Channel) -> Pos = emqx_misc:index_of(Name, Fields), setelement(Pos+1, Channel, Val). +%%-------------------------------------------------------------------- +%% Init the channel +%%-------------------------------------------------------------------- + +-spec(init(emqx_types:conninfo(), proplists:proplist()) -> channel()). +init(ConnInfo = #{peername := {PeerHost, _Port}}, Options) -> + Zone = proplists:get_value(zone, Options), + Peercert = maps:get(peercert, ConnInfo, undefined), + Username = case peer_cert_as_username(Options) of + cn -> esockd_peercert:common_name(Peercert); + dn -> esockd_peercert:subject(Peercert); + crt -> Peercert; + _ -> undefined + end, + MountPoint = emqx_zone:get_env(Zone, mountpoint), + ClientInfo = #{zone => Zone, + peerhost => PeerHost, + peercert => Peercert, + client_id => undefined, + username => Username, + mountpoint => MountPoint, + is_bridge => false, + is_superuser => false + }, + StatsTimer = case emqx_zone:enable_stats(Zone) of + true -> undefined; + false -> disabled + end, + #channel{conninfo = ConnInfo, + client = ClientInfo, + gc_state = init_gc_state(Zone), + oom_policy = init_oom_policy(Zone), + timers = #{stats_timer => StatsTimer}, + connected = undefined, + takeover = false, + resuming = false, + pendings = [] + }. + +peer_cert_as_username(Options) -> + proplists:get_value(peer_cert_as_username, Options). + +init_gc_state(Zone) -> + maybe_apply(fun emqx_gc:init/1, emqx_zone:force_gc_policy(Zone)). + +init_oom_policy(Zone) -> + maybe_apply(fun emqx_oom:init/1, emqx_zone:force_shutdown_policy(Zone)). + %%-------------------------------------------------------------------- %% Handle incoming packet %%-------------------------------------------------------------------- @@ -215,8 +227,8 @@ handle_in(?CONNECT_PACKET(_), Channel = #channel{connected = true}) -> handle_out({disconnect, ?RC_PROTOCOL_ERROR}, Channel); handle_in(?CONNECT_PACKET(ConnPkt), Channel) -> - case pipeline([fun check_connpkt/2, - fun init_protocol/2, + case pipeline([fun enrich_conninfo/2, + fun check_connect/2, fun enrich_client/2, fun set_logger_meta/2, fun check_banned/2, @@ -225,31 +237,25 @@ handle_in(?CONNECT_PACKET(ConnPkt), Channel) -> {ok, NConnPkt, NChannel} -> process_connect(NConnPkt, NChannel); {error, ReasonCode, NChannel} -> - handle_out({connack, ReasonCode}, NChannel) + handle_out({connack, ReasonCode, ConnPkt}, NChannel) end; -handle_in(Packet = ?PUBLISH_PACKET(_QoS, Topic, _PacketId), - Channel = #channel{protocol = Protocol}) -> - case pipeline([fun emqx_packet:check/1, - fun process_alias/2, - fun check_publish/2], Packet, Channel) of - {ok, NPacket, NChannel} -> - process_publish(NPacket, NChannel); - {error, ReasonCode, NChannel} -> - ProtoVer = emqx_protocol:info(proto_ver, Protocol), - ?LOG(warning, "Cannot publish message to ~s due to ~s", - [Topic, emqx_reason_codes:text(ReasonCode, ProtoVer)]), - handle_out({disconnect, ReasonCode}, NChannel) +handle_in(Packet = ?PUBLISH_PACKET(_QoS), Channel) -> + case emqx_packet:check(Packet) of + ok -> + handle_publish(Packet, Channel); + {error, ReasonCode} -> + handle_out({disconnect, ReasonCode}, Channel) end; handle_in(?PUBACK_PACKET(PacketId, _ReasonCode), - Channel = #channel{client = Client, session = Session}) -> + Channel = #channel{client = ClientInfo, session = Session}) -> case emqx_session:puback(PacketId, Session) of {ok, Msg, Publishes, NSession} -> - ok = emqx_hooks:run('message.acked', [Client, Msg]), + ok = emqx_hooks:run('message.acked', [ClientInfo, Msg]), handle_out({publish, Publishes}, Channel#channel{session = NSession}); {ok, Msg, NSession} -> - ok = emqx_hooks:run('message.acked', [Client, Msg]), + ok = emqx_hooks:run('message.acked', [ClientInfo, Msg]), {ok, Channel#channel{session = NSession}}; {error, ?RC_PACKET_IDENTIFIER_IN_USE} -> ?LOG(warning, "The PUBACK PacketId ~w is inuse.", [PacketId]), @@ -262,10 +268,10 @@ handle_in(?PUBACK_PACKET(PacketId, _ReasonCode), end; handle_in(?PUBREC_PACKET(PacketId, _ReasonCode), - Channel = #channel{client = Client, session = Session}) -> + Channel = #channel{client = ClientInfo, session = Session}) -> case emqx_session:pubrec(PacketId, Session) of {ok, Msg, NSession} -> - ok = emqx_hooks:run('message.acked', [Client, Msg]), + ok = emqx_hooks:run('message.acked', [ClientInfo, Msg]), NChannel = Channel#channel{session = NSession}, handle_out({pubrel, PacketId, ?RC_SUCCESS}, NChannel); {error, RC = ?RC_PACKET_IDENTIFIER_IN_USE} -> @@ -301,10 +307,10 @@ handle_in(?PUBCOMP_PACKET(PacketId, _ReasonCode), Channel = #channel{session = S end; handle_in(Packet = ?SUBSCRIBE_PACKET(PacketId, Properties, TopicFilters), - Channel = #channel{client = Client}) -> + Channel = #channel{client = ClientInfo}) -> case emqx_packet:check(Packet) of ok -> TopicFilters1 = emqx_hooks:run_fold('client.subscribe', - [Client, Properties], + [ClientInfo, Properties], parse_topic_filters(TopicFilters)), TopicFilters2 = enrich_subid(Properties, TopicFilters1), {ReasonCodes, NChannel} = process_subscribe(TopicFilters2, Channel), @@ -314,10 +320,10 @@ handle_in(Packet = ?SUBSCRIBE_PACKET(PacketId, Properties, TopicFilters), end; handle_in(Packet = ?UNSUBSCRIBE_PACKET(PacketId, Properties, TopicFilters), - Channel = #channel{client = Client}) -> + Channel = #channel{client = ClientInfo}) -> case emqx_packet:check(Packet) of ok -> TopicFilters1 = emqx_hooks:run_fold('client.unsubscribe', - [Client, Properties], + [ClientInfo, Properties], parse_topic_filters(TopicFilters)), {ReasonCodes, NChannel} = process_unsubscribe(TopicFilters1, Channel), handle_out({unsuback, PacketId, ReasonCodes}, NChannel); @@ -328,9 +334,10 @@ handle_in(Packet = ?UNSUBSCRIBE_PACKET(PacketId, Properties, TopicFilters), handle_in(?PACKET(?PINGREQ), Channel) -> {ok, ?PACKET(?PINGRESP), Channel}; -handle_in(?DISCONNECT_PACKET(ReasonCode, Properties), Channel = #channel{session = Session, protocol = Protocol}) -> +handle_in(?DISCONNECT_PACKET(ReasonCode, Properties), Channel = #channel{session = Session, + conninfo = ConnInfo = #{expiry_interval := OldInterval}}) -> OldInterval = emqx_session:info(expiry_interval, Session), - Interval = get_property('Session-Expiry-Interval', Properties, OldInterval), + Interval = emqx_mqtt_props:get('Session-Expiry-Interval', Props, OldInterval), case OldInterval =:= 0 andalso Interval =/= OldInterval of true -> handle_out({disconnect, ?RC_PROTOCOL_ERROR}, Channel); @@ -361,28 +368,43 @@ handle_in(Packet, Channel) -> %% Process Connect %%-------------------------------------------------------------------- -process_connect(ConnPkt, Channel) -> - case open_session(ConnPkt, Channel) of +process_connect(ConnPkt = #mqtt_packet_connect{clean_start = CleanStart}, + Channel = #channel{conninfo = ConnInfo, client = ClientInfo}) -> + case emqx_cm:open_session(CleanStart, ClientInfo, ConnInfo) of {ok, #{session := Session, present := false}} -> NChannel = Channel#channel{session = Session}, - handle_out({connack, ?RC_SUCCESS, sp(false)}, NChannel); + handle_out({connack, ?RC_SUCCESS, sp(false), ConnPkt}, NChannel); {ok, #{session := Session, present := true, pendings := Pendings}} -> %%TODO: improve later. NPendings = lists:usort(lists:append(Pendings, emqx_misc:drain_deliver())), NChannel = Channel#channel{session = Session, resuming = true, pendings = NPendings}, - handle_out({connack, ?RC_SUCCESS, sp(true)}, NChannel); + handle_out({connack, ?RC_SUCCESS, sp(true), ConnPkt}, NChannel); {error, Reason} -> %% TODO: Unknown error? ?LOG(error, "Failed to open session: ~p", [Reason]), - handle_out({connack, ?RC_UNSPECIFIED_ERROR}, Channel) + handle_out({connack, ?RC_UNSPECIFIED_ERROR, ConnPkt}, Channel) end. %%-------------------------------------------------------------------- %% Process Publish %%-------------------------------------------------------------------- +handle_publish(Packet = ?PUBLISH_PACKET(_QoS, Topic, _PacketId), + Channel = #channel{conninfo = #{proto_ver := ProtoVer}}) -> + case pipeline([fun process_alias/2, + fun check_pub_acl/2, + fun check_pub_alias/2, + fun check_pub_caps/2], Packet, Channel) of + {ok, NPacket, NChannel} -> + process_publish(NPacket, NChannel); + {error, ReasonCode, NChannel} -> + ?LOG(warning, "Cannot publish message to ~s due to ~s", + [Topic, emqx_reason_codes:text(ReasonCode, ProtoVer)]), + handle_out({disconnect, ReasonCode}, NChannel) + end. + process_publish(Packet = ?PUBLISH_PACKET(_QoS, _Topic, PacketId), Channel) -> Msg = publish_to_msg(Packet, Channel), process_publish(PacketId, Msg, Channel). @@ -417,11 +439,10 @@ process_publish(PacketId, Msg = #message{qos = ?QOS_2}, handle_out({pubrec, PacketId, RC}, Channel) end. -publish_to_msg(Packet, #channel{client = Client = #{mountpoint := MountPoint}, - protocol = Protocol}) -> - Msg = emqx_packet:to_message(Client, Packet), +publish_to_msg(Packet, #channel{conninfo = #{proto_ver := ProtoVer}, + client = ClientInfo = #{mountpoint := MountPoint}}) -> + Msg = emqx_packet:to_message(ClientInfo, Packet), Msg1 = emqx_message:set_flag(dup, false, Msg), - ProtoVer = emqx_protocol:info(proto_ver, Protocol), Msg2 = emqx_message:set_header(proto_ver, ProtoVer, Msg1), emqx_mountpoint:mount(MountPoint, Msg2). @@ -440,13 +461,13 @@ process_subscribe([{TopicFilter, SubOpts}|More], Acc, Channel) -> process_subscribe(More, [RC|Acc], NChannel). do_subscribe(TopicFilter, SubOpts = #{qos := QoS}, Channel = - #channel{client = Client = #{mountpoint := MountPoint}, + #channel{client = ClientInfo = #{mountpoint := MountPoint}, session = Session}) -> case check_subscribe(TopicFilter, SubOpts, Channel) of ok -> TopicFilter1 = emqx_mountpoint:mount(MountPoint, TopicFilter), SubOpts1 = enrich_subopts(maps:merge(?DEFAULT_SUBOPTS, SubOpts), Channel), - case emqx_session:subscribe(Client, TopicFilter1, SubOpts1, Session) of + case emqx_session:subscribe(ClientInfo, TopicFilter1, SubOpts1, Session) of {ok, NSession} -> {QoS, Channel#channel{session = NSession}}; {error, RC} -> {RC, Channel} @@ -470,10 +491,10 @@ process_unsubscribe([{TopicFilter, SubOpts}|More], Acc, Channel) -> process_unsubscribe(More, [RC|Acc], NChannel). do_unsubscribe(TopicFilter, _SubOpts, Channel = - #channel{client = Client = #{mountpoint := MountPoint}, + #channel{client = ClientInfo = #{mountpoint := MountPoint}, session = Session}) -> TopicFilter1 = emqx_mountpoint:mount(MountPoint, TopicFilter), - case emqx_session:unsubscribe(Client, TopicFilter1, Session) of + case emqx_session:unsubscribe(ClientInfo, TopicFilter1, Session) of {ok, NSession} -> {?RC_SUCCESS, Channel#channel{session = NSession}}; {error, RC} -> {RC, Channel} @@ -484,35 +505,37 @@ do_unsubscribe(TopicFilter, _SubOpts, Channel = %%-------------------------------------------------------------------- %%TODO: RunFold or Pipeline -handle_out({connack, ?RC_SUCCESS, SP}, Channel = #channel{client = Client}) -> +handle_out({connack, ?RC_SUCCESS, SP, ConnPkt}, + Channel = #channel{conninfo = ConnInfo, client = ClientInfo}) -> AckProps = run_fold([fun enrich_caps/2, fun enrich_server_keepalive/2, fun enrich_assigned_clientid/2 ], #{}, Channel), - Channel1 = ensure_keepalive(AckProps, ensure_connected(Channel)), - ok = emqx_hooks:run('client.connected', [Client, ?RC_SUCCESS, attrs(Channel1)]), + Channel1 = Channel#channel{will_msg = emqx_packet:will_msg(ConnPkt), + alias_maximum = init_alias_maximum(ConnPkt, ClientInfo), + connected = true, + connected_at = os:timestamp() + }, + Channel2 = ensure_keepalive(AckProps, Channel1), + ok = emqx_hooks:run('client.connected', [ClientInfo, ?RC_SUCCESS, ConnInfo]), AckPacket = ?CONNACK_PACKET(?RC_SUCCESS, SP, AckProps), - case maybe_resume_session(Channel1) of - ignore -> {ok, AckPacket, Channel1}; + case maybe_resume_session(Channel2) of + ignore -> + {ok, AckPacket, Channel2}; {ok, Publishes, NSession} -> - Channel2 = Channel1#channel{session = NSession, + Channel3 = Channel2#channel{session = NSession, resuming = false, pendings = []}, - {ok, Packets, _} = handle_out({publish, Publishes}, Channel2), + {ok, Packets, _} = handle_out({publish, Publishes}, Channel3), {ok, [AckPacket|Packets], Channel2} end; -handle_out({connack, ReasonCode}, Channel = #channel{client = Client, - protocol = Protocol - }) -> - ok = emqx_hooks:run('client.connected', [Client, ReasonCode, attrs(Channel)]), - ProtoVer = case Protocol of - undefined -> ?MQTT_PROTO_V5; - _ -> emqx_protocol:info(proto_ver, Protocol) - end, - ReasonCode1 = if - ProtoVer == ?MQTT_PROTO_V5 -> ReasonCode; - true -> emqx_reason_codes:compat(connack, ReasonCode) +handle_out({connack, ReasonCode, _ConnPkt}, Channel = #channel{conninfo = ConnInfo, + client = ClientInfo}) -> + ok = emqx_hooks:run('client.connected', [ClientInfo, ReasonCode, ConnInfo]), + ReasonCode1 = case ProtoVer = maps:get(proto_ver, ConnInfo) of + ?MQTT_PROTO_V5 -> ReasonCode; + _Ver -> emqx_reason_codes:compat(connack, ReasonCode) end, Reason = emqx_reason_codes:name(ReasonCode1, ProtoVer), {stop, {shutdown, Reason}, ?CONNACK_PACKET(ReasonCode1), Channel}; @@ -553,9 +576,9 @@ handle_out({publish, _PacketId, #message{from = ClientId, {ok, Channel}; handle_out({publish, PacketId, Msg}, Channel = - #channel{client = Client = #{mountpoint := MountPoint}}) -> + #channel{client = ClientInfo = #{mountpoint := MountPoint}}) -> Msg1 = emqx_message:update_expiry(Msg), - Msg2 = emqx_hooks:run_fold('message.delivered', [Client], Msg1), + Msg2 = emqx_hooks:run_fold('message.delivered', [ClientInfo], Msg1), Msg3 = emqx_mountpoint:unmount(MountPoint, Msg2), {ok, emqx_message:to_packet(PacketId, Msg3), Channel}; @@ -571,24 +594,23 @@ handle_out({pubrec, PacketId, ReasonCode}, Channel) -> handle_out({pubcomp, PacketId, ReasonCode}, Channel) -> {ok, ?PUBCOMP_PACKET(PacketId, ReasonCode), Channel}; -handle_out({suback, PacketId, ReasonCodes}, Channel = #channel{protocol = Protocol}) -> - ReasonCodes1 = case emqx_protocol:info(proto_ver, Protocol) of - ?MQTT_PROTO_V5 -> ReasonCodes; - _Ver -> - [emqx_reason_codes:compat(suback, RC) || RC <- ReasonCodes] - end, +handle_out({suback, PacketId, ReasonCodes}, + Channel = #channel{conninfo = #{proto_ver := ?MQTT_PROTO_V5}}) -> + {ok, ?SUBACK_PACKET(PacketId, ReasonCodes), Channel}; + +handle_out({suback, PacketId, ReasonCodes}, Channel) -> + ReasonCodes1 = [emqx_reason_codes:compat(suback, RC) || RC <- ReasonCodes], {ok, ?SUBACK_PACKET(PacketId, ReasonCodes1), Channel}; -handle_out({unsuback, PacketId, ReasonCodes}, Channel = #channel{protocol = Protocol}) -> - Unsuback = case emqx_protocol:info(proto_ver, Protocol) of - ?MQTT_PROTO_V5 -> - ?UNSUBACK_PACKET(PacketId, ReasonCodes); - _Ver -> ?UNSUBACK_PACKET(PacketId) - end, - {ok, Unsuback, Channel}; +handle_out({unsuback, PacketId, ReasonCodes}, + Channel = #channel{conninfo = #{proto_ver := ?MQTT_PROTO_V5}}) -> + {ok, ?UNSUBACK_PACKET(PacketId, ReasonCodes), Channel}; -handle_out({disconnect, ReasonCode}, Channel = #channel{protocol = Protocol}) -> - case emqx_protocol:info(proto_ver, Protocol) of +handle_out({unsuback, PacketId, _ReasonCodes}, Channel) -> + {ok, ?UNSUBACK_PACKET(PacketId), Channel}; + +handle_out({disconnect, ReasonCode}, Channel = #channel{conninfo = ConnInfo}) -> + case maps:get(proto_ver, ConnInfo) of ?MQTT_PROTO_V5 -> Reason = emqx_reason_codes:name(ReasonCode), Packet = ?DISCONNECT_PACKET(ReasonCode), @@ -650,16 +672,16 @@ handle_cast(Msg, Channel) -> -spec(handle_info(Info :: term(), channel()) -> {ok, channel()} | {stop, Reason :: term(), channel()}). -handle_info({subscribe, TopicFilters}, Channel = #channel{client = Client}) -> +handle_info({subscribe, TopicFilters}, Channel = #channel{client = ClientInfo}) -> TopicFilters1 = emqx_hooks:run_fold('client.subscribe', - [Client, #{'Internal' => true}], + [ClientInfo, #{'Internal' => true}], parse_topic_filters(TopicFilters)), {_ReasonCodes, NChannel} = process_subscribe(TopicFilters1, Channel), {ok, NChannel}; -handle_info({unsubscribe, TopicFilters}, Channel = #channel{client = Client}) -> +handle_info({unsubscribe, TopicFilters}, Channel = #channel{client = ClientInfo}) -> TopicFilters1 = emqx_hooks:run_fold('client.unsubscribe', - [Client, #{'Internal' => true}], + [ClientInfo, #{'Internal' => true}], parse_topic_filters(TopicFilters)), {_ReasonCodes, NChannel} = process_unsubscribe(TopicFilters1, Channel), {ok, NChannel}; @@ -670,19 +692,20 @@ handle_info(disconnected, Channel = #channel{connected = undefined}) -> handle_info(disconnected, Channel = #channel{connected = false}) -> {ok, Channel}; -handle_info(disconnected, Channel = #channel{protocol = Protocol, +handle_info(disconnected, Channel = #channel{conninfo = #{expiry_interval := ExpiryInterval}, + client = ClientInfo = #{zone := Zone}, session = Session, - client = Client = #{zone := Zone}}) -> - emqx_zone:enable_flapping_detect(Zone) andalso emqx_flapping:detect(Client), + will_msg = WillMsg}) -> + emqx_zone:enable_flapping_detect(Zone) andalso emqx_flapping:detect(ClientInfo), Channel1 = ensure_disconnected(Channel), - Channel2 = case timer:seconds(emqx_protocol:info(will_delay_interval, Protocol)) of + Channel2 = case timer:seconds(will_delay_interval(WillMsg)) of 0 -> - publish_will_msg(emqx_protocol:info(will_msg, Protocol)), - Channel1#channel{protocol = emqx_protocol:clear_will_msg(Protocol)}; + publish_will_msg(WillMsg), + Channel1#channel{will_msg = undefined}; _ -> ensure_timer(will_timer, Channel1) end, - case emqx_session:info(expiry_interval, Session) of + case ExpiryInterval of ?UINT_MAX -> {ok, Channel2}; Int when Int > 0 -> @@ -699,20 +722,19 @@ handle_info(Info, Channel) -> %% Handle timeout %%-------------------------------------------------------------------- --spec(timeout(reference(), Msg :: term(), channel()) +-spec(handle_timeout(reference(), Msg :: term(), channel()) -> {ok, channel()} | {ok, Result :: term(), channel()} | {stop, Reason :: term(), channel()}). -timeout(TRef, {emit_stats, Stats}, - Channel = #channel{client = #{client_id := ClientId}, - timers = #{stats_timer := TRef} - }) -> +handle_timeout(TRef, {emit_stats, Stats}, + Channel = #channel{client = #{client_id := ClientId}, + timers = #{stats_timer := TRef}}) -> ok = emqx_cm:set_chan_stats(ClientId, Stats), {ok, clean_timer(stats_timer, Channel)}; -timeout(TRef, {keepalive, StatVal}, Channel = #channel{keepalive = Keepalive, - timers = #{alive_timer := TRef} - }) -> +handle_timeout(TRef, {keepalive, StatVal}, + Channel = #channel{keepalive = Keepalive, + timers = #{alive_timer := TRef}}) -> case emqx_keepalive:check(StatVal, Keepalive) of {ok, NKeepalive} -> NChannel = Channel#channel{keepalive = NKeepalive}, @@ -721,9 +743,9 @@ timeout(TRef, {keepalive, StatVal}, Channel = #channel{keepalive = Keepalive, {wait_session_expire, {shutdown, keepalive_timeout}, Channel} end; -timeout(TRef, retry_delivery, Channel = #channel{session = Session, - timers = #{retry_timer := TRef} - }) -> +handle_timeout(TRef, retry_delivery, + Channel = #channel{session = Session, + timers = #{retry_timer := TRef}}) -> case emqx_session:retry(Session) of {ok, NSession} -> {ok, clean_timer(retry_timer, Channel#channel{session = NSession})}; @@ -735,8 +757,9 @@ timeout(TRef, retry_delivery, Channel = #channel{session = Session, handle_out({publish, Publishes}, reset_timer(retry_timer, Timeout, NChannel)) end; -timeout(TRef, expire_awaiting_rel, Channel = #channel{session = Session, - timers = #{await_timer := TRef}}) -> +handle_timeout(TRef, expire_awaiting_rel, + Channel = #channel{session = Session, + timers = #{await_timer := TRef}}) -> case emqx_session:expire(awaiting_rel, Session) of {ok, Session} -> {ok, clean_timer(await_timer, Channel#channel{session = Session})}; @@ -744,15 +767,15 @@ timeout(TRef, expire_awaiting_rel, Channel = #channel{session = Session, {ok, reset_timer(await_timer, Timeout, Channel#channel{session = Session})} end; -timeout(TRef, expire_session, Channel = #channel{timers = #{expire_timer := TRef}}) -> +handle_timeout(TRef, expire_session, Channel = #channel{timers = #{expire_timer := TRef}}) -> shutdown(expired, Channel); -timeout(TRef, will_message, Channel = #channel{protocol = Protocol, - timers = #{will_timer := TRef}}) -> - publish_will_msg(emqx_protocol:info(will_msg, Protocol)), - {ok, clean_timer(will_timer, Channel#channel{protocol = emqx_protocol:clear_will_msg(Protocol)})}; +handle_timeout(TRef, will_message, Channel = #channel{will_msg = WillMsg, + timers = #{will_timer := TRef}}) -> + publish_will_msg(WillMsg), + {ok, clean_timer(will_timer, Channel#channel{will_msg = undefined})}; -timeout(_TRef, Msg, Channel) -> +handle_timeout(_TRef, Msg, Channel) -> ?LOG(error, "Unexpected timeout: ~p~n", [Msg]), {ok, Channel}. @@ -796,24 +819,28 @@ interval(retry_timer, #channel{session = Session}) -> emqx_session:info(retry_interval, Session); interval(await_timer, #channel{session = Session}) -> emqx_session:info(await_rel_timeout, Session); -interval(expire_timer, #channel{session = Session}) -> - timer:seconds(emqx_session:info(expiry_interval, Session)); -interval(will_timer, #channel{protocol = Protocol}) -> - timer:seconds(emqx_protocol:info(will_delay_interval, Protocol)). +interval(expire_timer, #channel{conninfo = ConnInfo}) -> + timer:seconds(maps:get(expiry_interval, ConnInfo)); +interval(will_timer, #channel{will_msg = WillMsg}) -> + %% TODO: Ensure the header exists. + timer:seconds(will_delay_interval(WillMsg)). + +will_delay_interval(undefined) -> 0; +will_delay_interval(WillMsg) -> + emqx_message:get_header('Will-Delay-Interval', WillMsg, 0). %%-------------------------------------------------------------------- %% Terminate %%-------------------------------------------------------------------- -terminate(normal, #channel{client = Client}) -> - ok = emqx_hooks:run('client.disconnected', [Client, normal]); -terminate({shutdown, Reason}, #channel{client = Client}) +terminate(normal, #channel{conninfo = ConnInfo, client = ClientInfo}) -> + ok = emqx_hooks:run('client.disconnected', [ClientInfo, normal, ConnInfo]); +terminate({shutdown, Reason}, #channel{conninfo = ConnInfo, client = ClientInfo,}) when Reason =:= kicked orelse Reason =:= discarded orelse Reason =:= takeovered -> - ok = emqx_hooks:run('client.disconnected', [Client, Reason]); -terminate(Reason, #channel{client = Client, - protocol = Protocol - }) -> - ok = emqx_hooks:run('client.disconnected', [Client, Reason]), + ok = emqx_hooks:run('client.disconnected', [ClientInfo, Reason, ConnInfo]); +terminate(Reason, #channel{conninfo = ConnInfo, client = ClientInfo, will_msg = WillMsg}) -> + publish_will_msg(WillMsg), + ok = emqx_hooks:run('client.disconnected', [ClientInfo, Reason, ConnInfo]). if Protocol == undefined -> ok; true -> publish_will_msg(emqx_protocol:info(will_msg, Protocol)) @@ -833,51 +860,78 @@ publish_will_msg(undefined) -> publish_will_msg(Msg) -> emqx_broker:publish(Msg). +%% @doc Enrich MQTT Connect Info. +enrich_conninfo(#mqtt_packet_connect{ + proto_name = ProtoName, + proto_ver = ProtoVer, + clean_start = CleanStart, + keepalive = Keepalive, + properties = ConnProps, + client_id = ClientId, + username = Username}, Channel) -> + #channel{conninfo = ConnInfo, client = #{zone := Zone}} = Channel, + MaxInflight = emqx_mqtt_props:get('Receive-Maximum', + ConnProps, emqx_zone:max_inflight(Zone)), + Interval = if ProtoVer == ?MQTT_PROTO_V5 -> + emqx_mqtt_props:get('Session-Expiry-Interval', ConnProps, 0); + true -> case CleanStart of + true -> 0; + false -> emqx_zone:session_expiry_interval(Zone) + end + end, + NConnInfo = ConnInfo#{proto_name => ProtoName, + proto_ver => ProtoVer, + clean_start => CleanStart, + keepalive => Keepalive, + client_id => ClientId, + username => Username, + conn_props => ConnProps, + receive_maximum => MaxInflight, + expiry_interval => Interval + }, + {ok, Channel#channel{conninfo = NConnInfo}}. + %% @doc Check connect packet. -check_connpkt(ConnPkt, #channel{client = #{zone := Zone}}) -> +check_connect(ConnPkt, #channel{client = #{zone := Zone}}) -> emqx_packet:check(ConnPkt, emqx_mqtt_caps:get_caps(Zone)). -%% @doc Init protocol record. -init_protocol(ConnPkt, Channel = #channel{client = #{zone := Zone}}) -> - {ok, Channel#channel{protocol = emqx_protocol:init(ConnPkt, Zone)}}. - %% @doc Enrich client -enrich_client(ConnPkt, Channel = #channel{client = Client}) -> - {ok, NConnPkt, NClient} = pipeline([fun set_username/2, - fun set_bridge_mode/2, - fun maybe_username_as_clientid/2, - fun maybe_assign_clientid/2, - fun fix_mountpoint/2 - ], ConnPkt, Client), - {ok, NConnPkt, Channel#channel{client = NClient}}. +enrich_client(ConnPkt, Channel = #channel{client = ClientInfo}) -> + {ok, NConnPkt, NClientInfo} = + pipeline([fun set_username/2, + fun set_bridge_mode/2, + fun maybe_username_as_clientid/2, + fun maybe_assign_clientid/2, + fun fix_mountpoint/2], ConnPkt, ClientInfo), + {ok, NConnPkt, Channel#channel{client = NClientInfo}}. -set_username(#mqtt_packet_connect{username = Username}, Client = #{username := undefined}) -> - {ok, Client#{username => Username}}; -set_username(_ConnPkt, Client) -> - {ok, Client}. +set_username(#mqtt_packet_connect{username = Username}, + ClientInfo = #{username := undefined}) -> + {ok, ClientInfo#{username => Username}}; +set_username(_ConnPkt, ClientInfo) -> + {ok, ClientInfo}. -set_bridge_mode(#mqtt_packet_connect{is_bridge = true}, Client) -> - {ok, Client#{is_bridge => true}}; -set_bridge_mode(_ConnPkt, _Client) -> ok. +set_bridge_mode(#mqtt_packet_connect{is_bridge = true}, ClientInfo) -> + {ok, ClientInfo#{is_bridge => true}}; +set_bridge_mode(_ConnPkt, _ClientInfo) -> ok. -maybe_username_as_clientid(_ConnPkt, Client = #{username := undefined}) -> - {ok, Client}; -maybe_username_as_clientid(_ConnPkt, Client = #{zone := Zone, username := Username}) -> +maybe_username_as_clientid(_ConnPkt, ClientInfo = #{username := undefined}) -> + {ok, ClientInfo}; +maybe_username_as_clientid(_ConnPkt, ClientInfo = #{zone := Zone, username := Username}) -> case emqx_zone:use_username_as_clientid(Zone) of - true -> {ok, Client#{client_id => Username}}; + true -> {ok, ClientInfo#{client_id => Username}}; false -> ok end. -maybe_assign_clientid(#mqtt_packet_connect{client_id = <<>>}, Client) -> +maybe_assign_clientid(#mqtt_packet_connect{client_id = <<>>}, ClientInfo) -> %% Generate a rand clientId - RandId = emqx_guid:to_base62(emqx_guid:gen()), - {ok, Client#{client_id => RandId}}; -maybe_assign_clientid(#mqtt_packet_connect{client_id = ClientId}, Client) -> - {ok, Client#{client_id => ClientId}}. + {ok, ClientInfo#{client_id => emqx_guid:to_base62(emqx_guid:gen())}}; +maybe_assign_clientid(#mqtt_packet_connect{client_id = ClientId}, ClientInfo) -> + {ok, ClientInfo#{client_id => ClientId}}. fix_mountpoint(_ConnPkt, #{mountpoint := undefined}) -> ok; -fix_mountpoint(_ConnPkt, Client = #{mountpoint := Mountpoint}) -> - {ok, Client#{mountpoint := emqx_mountpoint:replvar(Mountpoint, Client)}}. +fix_mountpoint(_ConnPkt, ClientInfo = #{mountpoint := Mountpoint}) -> + {ok, ClientInfo#{mountpoint := emqx_mountpoint:replvar(Mountpoint, ClientInfo)}}. %% @doc Set logger metadata. set_logger_meta(_ConnPkt, #channel{client = #{client_id := ClientId}}) -> @@ -887,15 +941,15 @@ set_logger_meta(_ConnPkt, #channel{client = #{client_id := ClientId}}) -> %% Check banned/flapping %%-------------------------------------------------------------------- -check_banned(_ConnPkt, #channel{client = Client = #{zone := Zone}}) -> - case emqx_zone:enable_banned(Zone) andalso emqx_banned:check(Client) of +check_banned(_ConnPkt, #channel{client = ClientInfo = #{zone := Zone}}) -> + case emqx_zone:enable_ban(Zone) andalso emqx_banned:check(ClientInfo) of true -> {error, ?RC_BANNED}; false -> ok end. -check_flapping(_ConnPkt, #channel{client = Client = #{zone := Zone}}) -> +check_flapping(_ConnPkt, #channel{client = ClientInfo = #{zone := Zone}}) -> case emqx_zone:enable_flapping_detect(Zone) - andalso emqx_flapping:check(Client) of + andalso emqx_flapping:check(ClientInfo) of true -> {error, ?RC_CONNECTION_RATE_EXCEEDED}; false -> ok end. @@ -907,38 +961,16 @@ check_flapping(_ConnPkt, #channel{client = Client = #{zone := Zone}}) -> auth_connect(#mqtt_packet_connect{client_id = ClientId, username = Username, password = Password}, - Channel = #channel{client = Client}) -> - case emqx_access_control:authenticate(Client#{password => Password}) of + Channel = #channel{client = ClientInfo}) -> + case emqx_access_control:authenticate(ClientInfo#{password => Password}) of {ok, AuthResult} -> - {ok, Channel#channel{client = maps:merge(Client, AuthResult)}}; + {ok, Channel#channel{client = maps:merge(ClientInfo, AuthResult)}}; {error, Reason} -> ?LOG(warning, "Client ~s (Username: '~s') login failed for ~0p", [ClientId, Username, Reason]), {error, emqx_reason_codes:connack_error(Reason)} end. -%%-------------------------------------------------------------------- -%% Open session -%%-------------------------------------------------------------------- - -open_session(#mqtt_packet_connect{clean_start = CleanStart, - properties = ConnProps}, - #channel{client = Client = #{zone := Zone}, protocol = Protocol}) -> - MaxInflight = get_property('Receive-Maximum', ConnProps, - emqx_zone:get_env(Zone, max_inflight, 65535)), - Interval = - case emqx_protocol:info(proto_ver, Protocol) of - ?MQTT_PROTO_V5 -> get_property('Session-Expiry-Interval', ConnProps, 0); - _ -> - case CleanStart of - true -> 0; - false -> emqx_zone:get_env(Zone, session_expiry_interval, 0) - end - end, - emqx_cm:open_session(CleanStart, Client, #{max_inflight => MaxInflight, - expiry_interval => Interval - }). - %%-------------------------------------------------------------------- %% Process publish message: Client -> Broker %%-------------------------------------------------------------------- @@ -948,8 +980,8 @@ process_alias(Packet = #mqtt_packet{ properties = #{'Topic-Alias' := AliasId} } = Publish }, - Channel = #channel{protocol = Protocol}) -> - case emqx_protocol:find_alias(AliasId, Protocol) of + Channel = #channel{topic_aliases = Aliases}) -> + case find_alias(AliasId, Aliases) of {ok, Topic} -> {ok, Packet#mqtt_packet{ variable = Publish#mqtt_packet_publish{ @@ -961,23 +993,23 @@ process_alias(#mqtt_packet{ variable = #mqtt_packet_publish{topic_name = Topic, properties = #{'Topic-Alias' := AliasId} } - }, Channel = #channel{protocol = Protocol}) -> - {ok, Channel#channel{protocol = emqx_protocol:save_alias(AliasId, Topic, Protocol)}}; + }, Channel = #channel{topic_aliases = Aliases}) -> + {ok, Channel#channel{topic_aliases = save_alias(AliasId, Topic, Aliases)}}; process_alias(_Packet, Channel) -> {ok, Channel}. -%% Check Publish -check_publish(Packet, Channel) -> - pipeline([fun check_pub_acl/2, - fun check_pub_alias/2, - fun check_pub_caps/2], Packet, Channel). +find_alias(_AliasId, undefined) -> false; +find_alias(AliasId, Aliases) -> maps:find(AliasId, Aliases). + +save_alias(AliasId, Topic, undefined) -> #{AliasId => Topic}; +save_alias(AliasId, Topic, Aliases) -> maps:put(AliasId, Topic, Aliases). %% Check Pub ACL check_pub_acl(#mqtt_packet{variable = #mqtt_packet_publish{topic_name = Topic}}, - #channel{client = Client}) -> - case is_acl_enabled(Client) andalso - emqx_access_control:check_acl(Client, publish, Topic) of + #channel{client = ClientInfo}) -> + case is_acl_enabled(ClientInfo) andalso + emqx_access_control:check_acl(ClientInfo, publish, Topic) of false -> ok; allow -> ok; deny -> {error, ?RC_NOT_AUTHORIZED} @@ -989,9 +1021,8 @@ check_pub_alias(#mqtt_packet{ properties = #{'Topic-Alias' := AliasId} } }, - #channel{protocol = Protocol}) -> + #channel{alias_maximum = Limits}) -> %% TODO: Move to Protocol - Limits = emqx_protocol:info(alias_maximum, Protocol), case (Limits == undefined) orelse (Max = maps:get(inbound, Limits, 0)) == 0 orelse (AliasId > Max) of @@ -1016,9 +1047,9 @@ check_subscribe(TopicFilter, SubOpts, Channel) -> end. %% Check Sub ACL -check_sub_acl(TopicFilter, #channel{client = Client}) -> - case is_acl_enabled(Client) andalso - emqx_access_control:check_acl(Client, subscribe, TopicFilter) of +check_sub_acl(TopicFilter, #channel{client = ClientInfo}) -> + case is_acl_enabled(ClientInfo) andalso + emqx_access_control:check_acl(ClientInfo, subscribe, TopicFilter) of false -> allow; Result -> Result end. @@ -1032,64 +1063,63 @@ enrich_subid(#{'Subscription-Identifier' := SubId}, TopicFilters) -> enrich_subid(_Properties, TopicFilters) -> TopicFilters. -enrich_subopts(SubOpts, #channel{client = Client, protocol = Proto}) -> - #{zone := Zone, is_bridge := IsBridge} = Client, - case emqx_protocol:info(proto_ver, Proto) of - ?MQTT_PROTO_V5 -> SubOpts; - _Ver -> Rap = flag(IsBridge), - Nl = flag(emqx_zone:get_env(Zone, ignore_loop_deliver, false)), - SubOpts#{rap => Rap, nl => Nl} - end. +enrich_subopts(SubOpts, #channel{conninfo = #{proto_ver := ?MQTT_PROTO_V5}}) -> + SubOpts; -enrich_caps(AckProps, #channel{client = #{zone := Zone}, protocol = Protocol}) -> - case emqx_protocol:info(proto_ver, Protocol) of - ?MQTT_PROTO_V5 -> - #{max_packet_size := MaxPktSize, - max_qos_allowed := MaxQoS, - retain_available := Retain, - max_topic_alias := MaxAlias, - shared_subscription := Shared, - wildcard_subscription := Wildcard - } = emqx_mqtt_caps:get_caps(Zone), - AckProps#{'Retain-Available' => flag(Retain), - 'Maximum-Packet-Size' => MaxPktSize, - 'Topic-Alias-Maximum' => MaxAlias, - 'Wildcard-Subscription-Available' => flag(Wildcard), - 'Subscription-Identifier-Available' => 1, - 'Shared-Subscription-Available' => flag(Shared), - 'Maximum-QoS' => MaxQoS - }; - _Ver -> AckProps - end. +enrich_subopts(SubOpts, #channel{client = #{zone := Zone, is_bridge := IsBridge}}) -> + NL = flag(emqx_zone:ignore_loop_deliver(Zone)), + SubOpts#{rap => flag(IsBridge), nl => NL}. + +enrich_caps(AckProps, #channel{conninfo = #{proto_ver := ?MQTT_PROTO_V5}, + client = #{zone := Zone}}) -> + #{max_packet_size := MaxPktSize, + max_qos_allowed := MaxQoS, + retain_available := Retain, + max_topic_alias := MaxAlias, + shared_subscription := Shared, + wildcard_subscription := Wildcard + } = emqx_mqtt_caps:get_caps(Zone), + AckProps#{'Retain-Available' => flag(Retain), + 'Maximum-Packet-Size' => MaxPktSize, + 'Topic-Alias-Maximum' => MaxAlias, + 'Wildcard-Subscription-Available' => flag(Wildcard), + 'Subscription-Identifier-Available' => 1, + 'Shared-Subscription-Available' => flag(Shared), + 'Maximum-QoS' => MaxQoS + }; +enrich_caps(AckProps, _Channel) -> + AckProps. enrich_server_keepalive(AckProps, #channel{client = #{zone := Zone}}) -> - case emqx_zone:get_env(Zone, server_keepalive) of + case emqx_zone:server_keepalive(Zone) of undefined -> AckProps; Keepalive -> AckProps#{'Server-Keep-Alive' => Keepalive} end. -enrich_assigned_clientid(AckProps, #channel{client = #{client_id := ClientId}, - protocol = Protocol}) -> - case emqx_protocol:info(client_id, Protocol) of - <<>> -> %% Original ClientId. +enrich_assigned_clientid(AckProps, #channel{conninfo = ConnInfo, + client = #{client_id := ClientId} + }) -> + case maps:get(client_id, ConnInfo) of + <<>> -> %% Original ClientId is null. AckProps#{'Assigned-Client-Identifier' => ClientId}; _Origin -> AckProps end. -ensure_connected(Channel) -> - Channel#channel{connected = true, connected_at = os:timestamp(), disconnected_at = undefined}. +init_alias_maximum(#mqtt_packet_connect{proto_ver = ?MQTT_PROTO_V5, + properties = Properties}, #{zone := Zone}) -> + #{outbound => emqx_mqtt_props:get('Topic-Alias-Maximum', Properties, 0), + inbound => emqx_mqtt_caps:get_caps(Zone, max_topic_alias, 0)}; +init_alias_maximum(_ConnPkt, _ClientInfo) -> undefined. ensure_disconnected(Channel) -> Channel#channel{connected = false, disconnected_at = os:timestamp()}. ensure_keepalive(#{'Server-Keep-Alive' := Interval}, Channel) -> ensure_keepalive_timer(Interval, Channel); -ensure_keepalive(_AckProp, Channel = #channel{protocol = Protocol}) -> - case emqx_protocol:info(keepalive, Protocol) of - 0 -> Channel; - Interval -> ensure_keepalive_timer(Interval, Channel) - end. +ensure_keepalive(_AckProps, Channel = #channel{conninfo = ConnInfo}) -> + ensure_keepalive_timer(maps:get(keepalive, ConnInfo), Channel). +ensure_keepalive_timer(0, Channel) -> Channel; ensure_keepalive_timer(Interval, Channel = #channel{client = #{zone := Zone}}) -> Backoff = emqx_zone:get_env(Zone, keepalive_backoff, 0.75), Keepalive = emqx_keepalive:init(round(timer:seconds(Interval) * Backoff)), @@ -1140,11 +1170,6 @@ check_oom(OomPolicy) -> %% Helper functions %%-------------------------------------------------------------------- -get_property(_Name, undefined, Default) -> - Default; -get_property(Name, Props, Default) -> - maps:get(Name, Props, Default). - sp(true) -> 1; sp(false) -> 0. diff --git a/src/emqx_cm.erl b/src/emqx_cm.erl index 5f003807f..061d9aa4b 100644 --- a/src/emqx_cm.erl +++ b/src/emqx_cm.erl @@ -161,15 +161,15 @@ set_chan_stats(ClientId, ChanPid, Stats) -> present := boolean(), pendings => list()}} | {error, Reason :: term()}). -open_session(true, Client = #{client_id := ClientId}, Options) -> +open_session(true, ClientInfo = #{client_id := ClientId}, ConnInfo) -> CleanStart = fun(_) -> ok = discard_session(ClientId), - Session = emqx_session:init(Client, Options), + Session = emqx_session:init(ClientInfo, ConnInfo), {ok, #{session => Session, present => false}} end, emqx_cm_locker:trans(ClientId, CleanStart); -open_session(false, Client = #{client_id := ClientId}, Options) -> +open_session(false, ClientInfo = #{client_id := ClientId}, ConnInfo) -> ResumeStart = fun(_) -> case takeover_session(ClientId) of {ok, ConnMod, ChanPid, Session} -> @@ -179,7 +179,7 @@ open_session(false, Client = #{client_id := ClientId}, Options) -> present => true, pendings => Pendings}}; {error, not_found} -> - Session = emqx_session:init(Client, Options), + Session = emqx_session:init(ClientInfo, ConnInfo), {ok, #{session => Session, present => false}} end end, diff --git a/src/emqx_connection.erl b/src/emqx_connection.erl index 338d3e85d..f4aeede00 100644 --- a/src/emqx_connection.erl +++ b/src/emqx_connection.erl @@ -102,9 +102,9 @@ start_link(Transport, Socket, Options) -> info(CPid) when is_pid(CPid) -> call(CPid, info); info(Conn = #connection{chan_state = ChanState}) -> - ConnInfo = info(?INFO_KEYS, Conn), ChanInfo = emqx_channel:info(ChanState), - maps:merge(ChanInfo, #{conninfo => maps:from_list(ConnInfo)}). + SockInfo = maps:from_list(info(?INFO_KEYS, Conn)), + maps:merge(ChanInfo, #{sockinfo => SockInfo}). info(Keys, Conn) when is_list(Keys) -> [{Key, info(Key, Conn)} || Key <- Keys]; @@ -133,9 +133,9 @@ limit_info(Limit) -> attrs(CPid) when is_pid(CPid) -> call(CPid, attrs); attrs(Conn = #connection{chan_state = ChanState}) -> - ConnAttrs = info(?ATTR_KEYS, Conn), ChanAttrs = emqx_channel:attrs(ChanState), - maps:merge(ChanAttrs, #{conninfo => maps:from_list(ConnAttrs)}). + SockAttrs = maps:from_list(info(?ATTR_KEYS, Conn)), + maps:merge(ChanAttrs, #{sockinfo => SockAttrs}). %% @doc Get stats of the channel. -spec(stats(pid()|connection()) -> emqx_types:stats()). @@ -219,7 +219,7 @@ idle(timeout, _Timeout, State) -> idle(cast, {incoming, Packet = ?CONNECT_PACKET(ConnPkt)}, State) -> #mqtt_packet_connect{proto_ver = ProtoVer, properties = Properties} = ConnPkt, - MaxPacketSize = emqx_mqtt_props:get_property('Maximum-Packet-Size', Properties, undefined), + MaxPacketSize = emqx_mqtt_props:get('Maximum-Packet-Size', Properties, undefined), NState = State#connection{serialize = serialize_fun(ProtoVer, MaxPacketSize)}, SuccFun = fun(NewSt) -> {next_state, connected, NewSt} end, handle_incoming(Packet, SuccFun, NState); @@ -507,7 +507,7 @@ serialize_fun(ProtoVer, MaxPacketSize) -> false -> ?LOG(warning, "DROP ~s due to oversize packet size", [emqx_packet:format(Packet)]), <<"">> - end + end end. %%-------------------------------------------------------------------- @@ -529,7 +529,7 @@ send(IoData, SuccFun, State = #connection{transport = Transport, %% Handle timeout handle_timeout(TRef, Msg, State = #connection{chan_state = ChanState}) -> - case emqx_channel:timeout(TRef, Msg, ChanState) of + case emqx_channel:handle_timeout(TRef, Msg, ChanState) of {ok, NChanState} -> keep_state(State#connection{chan_state = NChanState}); {ok, Packets, NChanState} -> diff --git a/src/emqx_ctl.erl b/src/emqx_ctl.erl index dbde2b60b..3b050a51f 100644 --- a/src/emqx_ctl.erl +++ b/src/emqx_ctl.erl @@ -18,11 +18,12 @@ -behaviour(gen_server). +-include("types.hrl"). -include("logger.hrl"). -logger_header("[Ctl]"). --export([start_link/0]). +-export([start_link/0, stop/0]). -export([ register_command/2 , register_command/3 @@ -32,6 +33,7 @@ -export([ run_command/1 , run_command/2 , lookup_command/1 + , get_commands/0 ]). -export([ print/1 @@ -40,7 +42,7 @@ , usage/2 ]). -%% format/1,2 and format_usage/1,2 are exported mainly for test cases +%% Exports mainly for test cases -export([ format/1 , format/2 , format_usage/1 @@ -59,38 +61,44 @@ -record(state, {seq = 0}). -type(cmd() :: atom()). +-type(cmd_params() :: string()). -type(cmd_descr() :: string()). --type(cmd_usage() :: {cmd(), cmd_descr()}). +-type(cmd_usage() :: {cmd_params(), cmd_descr()}). -define(SERVER, ?MODULE). --define(TAB, emqx_command). +-define(CMD_TAB, emqx_command). +-spec(start_link() -> startlink_ret()). start_link() -> gen_server:start_link({local, ?SERVER}, ?MODULE, [], []). +-spec(stop() -> ok). +stop() -> gen_server:stop(?SERVER). + -spec(register_command(cmd(), {module(), atom()}) -> ok). register_command(Cmd, MF) when is_atom(Cmd) -> register_command(Cmd, MF, []). -spec(register_command(cmd(), {module(), atom()}, list()) -> ok). register_command(Cmd, MF, Opts) when is_atom(Cmd) -> - cast({register_command, Cmd, MF, Opts}). + call({register_command, Cmd, MF, Opts}). -spec(unregister_command(cmd()) -> ok). unregister_command(Cmd) when is_atom(Cmd) -> cast({unregister_command, Cmd}). -cast(Msg) -> - gen_server:cast(?SERVER, Msg). +call(Req) -> gen_server:call(?SERVER, Req). +cast(Msg) -> gen_server:cast(?SERVER, Msg). + +-spec(run_command(list(string())) -> ok | {error, term()}). run_command([]) -> run_command(help, []); run_command([Cmd | Args]) -> run_command(list_to_atom(Cmd), Args). --spec(run_command(cmd(), [string()]) -> ok | {error, term()}). -run_command(help, []) -> - help(); +-spec(run_command(cmd(), list(string())) -> ok | {error, term()}). +run_command(help, []) -> help(); run_command(Cmd, Args) when is_atom(Cmd) -> case lookup_command(Cmd) of [{Mod, Fun}] -> @@ -107,15 +115,19 @@ run_command(Cmd, Args) when is_atom(Cmd) -> -spec(lookup_command(cmd()) -> [{module(), atom()}]). lookup_command(Cmd) when is_atom(Cmd) -> - case ets:match(?TAB, {{'_', Cmd}, '$1', '_'}) of + case ets:match(?CMD_TAB, {{'_', Cmd}, '$1', '_'}) of [El] -> El; [] -> [] end. +-spec(get_commands() -> list({cmd(), module(), atom()})). +get_commands() -> + [{Cmd, M, F} || {{_Seq, Cmd}, {M, F}, _Opts} <- ets:tab2list(?CMD_TAB)]. + help() -> print("Usage: ~s~n", [?MODULE]), [begin print("~80..-s~n", [""]), Mod:Cmd(usage) end - || {_, {Mod, Cmd}, _} <- ets:tab2list(?TAB)]. + || {_, {Mod, Cmd}, _} <- ets:tab2list(?CMD_TAB)]. -spec(print(io:format()) -> ok). print(Msg) -> @@ -129,13 +141,13 @@ print(Format, Args) -> usage(UsageList) -> io:format(format_usage(UsageList)). --spec(usage(cmd(), cmd_descr()) -> ok). -usage(Cmd, Desc) -> - io:format(format_usage(Cmd, Desc)). +-spec(usage(cmd_params(), cmd_descr()) -> ok). +usage(CmdParams, Desc) -> + io:format(format_usage(CmdParams, Desc)). -spec(format(io:format()) -> string()). format(Msg) -> - lists:flatten(io_lib:format("~p", [Msg])). + lists:flatten(io_lib:format("~s", [Msg])). -spec(format(io:format(), [term()]) -> string()). format(Format, Args) -> @@ -144,42 +156,41 @@ format(Format, Args) -> -spec(format_usage([cmd_usage()]) -> ok). format_usage(UsageList) -> lists:map( - fun({Cmd, Desc}) -> - format_usage(Cmd, Desc) + fun({CmdParams, Desc}) -> + format_usage(CmdParams, Desc) end, UsageList). --spec(format_usage(cmd(), cmd_descr()) -> string()). -format_usage(Cmd, Desc) -> - CmdLines = split_cmd(Cmd), +-spec(format_usage(cmd_params(), cmd_descr()) -> string()). +format_usage(CmdParams, Desc) -> + CmdLines = split_cmd(CmdParams), DescLines = split_cmd(Desc), - lists:foldl( - fun({CmdStr, DescStr}, Usage) -> - Usage ++ format("~-48s# ~s~n", [CmdStr, DescStr]) - end, "", zip_cmd(CmdLines, DescLines)). + lists:foldl(fun({CmdStr, DescStr}, Usage) -> + Usage ++ format("~-48s# ~s~n", [CmdStr, DescStr]) + end, "", zip_cmd(CmdLines, DescLines)). -%%------------------------------------------------------------------------------ +%%-------------------------------------------------------------------- %% gen_server callbacks -%%------------------------------------------------------------------------------ +%%-------------------------------------------------------------------- init([]) -> - ok = emqx_tables:new(?TAB, [protected, ordered_set]), + ok = emqx_tables:new(?CMD_TAB, [protected, ordered_set]), {ok, #state{seq = 0}}. +handle_call({register_command, Cmd, MF, Opts}, _From, State = #state{seq = Seq}) -> + case ets:match(?CMD_TAB, {{'$1', Cmd}, '_', '_'}) of + [] -> ets:insert(?CMD_TAB, {{Seq, Cmd}, MF, Opts}); + [[OriginSeq] | _] -> + ?LOG(warning, "CMD ~s is overidden by ~p", [Cmd, MF]), + true = ets:insert(?CMD_TAB, {{OriginSeq, Cmd}, MF, Opts}) + end, + {reply, ok, next_seq(State)}; + handle_call(Req, _From, State) -> ?LOG(error, "Unexpected call: ~p", [Req]), {reply, ignored, State}. -handle_cast({register_command, Cmd, MF, Opts}, State = #state{seq = Seq}) -> - case ets:match(?TAB, {{'$1', Cmd}, '_', '_'}) of - [] -> ets:insert(?TAB, {{Seq, Cmd}, MF, Opts}); - [[OriginSeq] | _] -> - ?LOG(warning, "CMD ~s is overidden by ~p", [Cmd, MF]), - ets:insert(?TAB, {{OriginSeq, Cmd}, MF, Opts}) - end, - noreply(next_seq(State)); - handle_cast({unregister_command, Cmd}, State) -> - ets:match_delete(?TAB, {{'_', Cmd}, '_', '_'}), + ets:match_delete(?CMD_TAB, {{'_', Cmd}, '_', '_'}), noreply(State); handle_cast(Msg, State) -> @@ -214,3 +225,4 @@ zip_cmd([X | Xs], [Y | Ys]) -> [{X, Y} | zip_cmd(Xs, Ys)]; zip_cmd([X | Xs], []) -> [{X, ""} | zip_cmd(Xs, [])]; zip_cmd([], [Y | Ys]) -> [{"", Y} | zip_cmd([], Ys)]; zip_cmd([], []) -> []. + diff --git a/src/emqx_flapping.erl b/src/emqx_flapping.erl index ee898acbd..ca0e411a0 100644 --- a/src/emqx_flapping.erl +++ b/src/emqx_flapping.erl @@ -52,7 +52,7 @@ -record(flapping, { client_id :: emqx_types:client_id(), - peername :: emqx_types:peername(), + peerhost :: emqx_types:peerhost(), started_at :: pos_integer(), detect_cnt :: pos_integer(), banned_at :: pos_integer() @@ -84,7 +84,7 @@ check(ClientId, #{banned_interval := Interval}) -> -spec(detect(emqx_types:client()) -> boolean()). detect(Client) -> detect(Client, get_policy()). -detect(#{client_id := ClientId, peername := Peername}, +detect(#{client_id := ClientId, peerhost := PeerHost}, Policy = #{threshold := Threshold}) -> try ets:update_counter(?FLAPPING_TAB, ClientId, {#flapping.detect_cnt, 1}) of Cnt when Cnt < Threshold -> false; @@ -98,7 +98,7 @@ detect(#{client_id := ClientId, peername := Peername}, error:badarg -> %% Create a flapping record. Flapping = #flapping{client_id = ClientId, - peername = Peername, + peerhost = PeerHost, started_at = emqx_time:now_ms(), detect_cnt = 1 }, @@ -132,7 +132,7 @@ handle_call(Req, _From, State) -> {reply, ignored, State}. handle_cast({detected, Flapping = #flapping{client_id = ClientId, - peername = Peername, + peerhost = PeerHost, started_at = StartedAt, detect_cnt = DetectCnt}, #{duration := Duration}}, State) -> @@ -140,7 +140,7 @@ handle_cast({detected, Flapping = #flapping{client_id = ClientId, true -> %% Flapping happened:( %% Log first ?LOG(error, "Flapping detected: ~s(~s) disconnected ~w times in ~wms", - [ClientId, esockd_net:format(Peername), DetectCnt, Duration]), + [ClientId, esockd_net:ntoa(PeerHost), DetectCnt, Duration]), %% Banned. BannedFlapping = Flapping#flapping{client_id = {banned, ClientId}, banned_at = emqx_time:now_ms() @@ -149,7 +149,7 @@ handle_cast({detected, Flapping = #flapping{client_id = ClientId, ets:insert(?FLAPPING_TAB, BannedFlapping); false -> ?LOG(warning, "~s(~s) disconnected ~w times in ~wms", - [ClientId, esockd_net:format(Peername), DetectCnt, Interval]), + [ClientId, esockd_net:ntoa(PeerHost), DetectCnt, Interval]), ets:delete_object(?FLAPPING_TAB, Flapping) end, {noreply, State}; diff --git a/src/emqx_gc.erl b/src/emqx_gc.erl index 7d47b7071..5f939bfe5 100644 --- a/src/emqx_gc.erl +++ b/src/emqx_gc.erl @@ -34,7 +34,7 @@ , reset/1 ]). --export_type([gc_state/0]). +-export_type([opts/0, gc_state/0]). -type(opts() :: #{count => integer(), bytes => integer()}). diff --git a/src/emqx_message.erl b/src/emqx_message.erl index 6fbd3bbf5..ca8a433d0 100644 --- a/src/emqx_message.erl +++ b/src/emqx_message.erl @@ -38,6 +38,7 @@ %% Flags -export([ get_flag/2 , get_flag/3 + , get_flags/1 , set_flag/2 , set_flag/3 , set_flags/2 @@ -85,6 +86,7 @@ make(From, QoS, Topic, Payload) when ?QOS_0 =< QoS, QoS =< ?QOS_2 -> qos = QoS, from = From, flags = #{dup => false}, + headers = #{}, topic = Topic, payload = Payload, timestamp = os:timestamp()}. @@ -119,6 +121,9 @@ get_flag(Flag, Msg) -> get_flag(Flag, #message{flags = Flags}, Default) -> maps:get(Flag, Flags, Default). +-spec(get_flags(emqx_types:message()) -> maybe(map())). +get_flags(#message{flags = Flags}) -> Flags. + -spec(set_flag(flag(), emqx_types:message()) -> emqx_types:message()). set_flag(Flag, Msg = #message{flags = undefined}) when is_atom(Flag) -> Msg#message{flags = #{Flag => true}}; @@ -144,8 +149,7 @@ unset_flag(Flag, Msg = #message{flags = Flags}) -> set_headers(Headers, Msg = #message{headers = undefined}) when is_map(Headers) -> Msg#message{headers = Headers}; set_headers(New, Msg = #message{headers = Old}) when is_map(New) -> - Msg#message{headers = maps:merge(Old, New)}; -set_headers(undefined, Msg) -> Msg. + Msg#message{headers = maps:merge(Old, New)}. -spec(get_headers(emqx_types:message()) -> map()). get_headers(Msg) -> diff --git a/src/emqx_mod_presence.erl b/src/emqx_mod_presence.erl index 3947c2c8c..84d69b48b 100644 --- a/src/emqx_mod_presence.erl +++ b/src/emqx_mod_presence.erl @@ -23,70 +23,78 @@ -logger_header("[Presence]"). -%% APIs --export([ on_client_connected/4 - , on_client_disconnected/3 - ]). - %% emqx_gen_mod callbacks -export([ load/1 , unload/1 ]). -%%-------------------------------------------------------------------- -%% APIs -%%-------------------------------------------------------------------- +-export([ on_client_connected/4 + , on_client_disconnected/4 + ]). -load(_Env) -> - ok. - %% emqx_hooks:add('client.connected', {?MODULE, on_client_connected, [Env]}), - %% emqx_hooks:add('client.disconnected', {?MODULE, on_client_disconnected, [Env]}). +-ifdef(TEST). +-export([ reason/1 ]). +-endif. -on_client_connected(#{client_id := ClientId, - username := Username, - peername := {IpAddr, _} - }, ConnAck, - #{session := Session, - proto_name := ProtoName, - proto_ver := ProtoVer, - keepalive := Keepalive - }, Env) -> - case emqx_json:safe_encode(maps:merge(#{clientid => ClientId, - username => Username, - ipaddress => iolist_to_binary(esockd_net:ntoa(IpAddr)), - proto_name => ProtoName, - proto_ver => ProtoVer, - keepalive => Keepalive, - connack => ConnAck, - ts => erlang:system_time(millisecond) - }, maps:with([clean_start, expiry_interval], Session))) of - {ok, Payload} -> - emqx:publish(message(qos(Env), topic(connected, ClientId), Payload)); - {error, Reason} -> - ?LOG(error, "Encoding connected event error: ~p", [Reason]) - end. - - - - -on_client_disconnected(#{client_id := ClientId, - username := Username}, Reason, Env) -> - case emqx_json:safe_encode(#{clientid => ClientId, - username => Username, - reason => reason(Reason), - ts => erlang:system_time(millisecond) - }) of - {ok, Payload} -> - emqx_broker:publish(message(qos(Env), topic(disconnected, ClientId), Payload)); - {error, Reason} -> - ?LOG(error, "Encoding disconnected event error: ~p", [Reason]) - end. +load(Env) -> + emqx_hooks:add('client.connected', {?MODULE, on_client_connected, [Env]}), + emqx_hooks:add('client.disconnected', {?MODULE, on_client_disconnected, [Env]}). unload(_Env) -> emqx_hooks:del('client.connected', {?MODULE, on_client_connected}), emqx_hooks:del('client.disconnected', {?MODULE, on_client_disconnected}). -message(QoS, Topic, Payload) -> +on_client_connected(ClientInfo, ConnAck, ConnInfo, Env) -> + #{peerhost := PeerHost} = ClientInfo, + #{clean_start := CleanStart, + proto_name := ProtoName, + proto_ver := ProtoVer, + keepalive := Keepalive, + expiry_interval := ExpiryInterval} = ConnInfo, + ClientId = clientid(ClientInfo, ConnInfo), + Username = username(ClientInfo, ConnInfo), + Presence = #{clientid => ClientId, + username => Username, + ipaddress => ntoa(PeerHost), + proto_name => ProtoName, + proto_ver => ProtoVer, + keepalive => Keepalive, + connack => ConnAck, + clean_start => CleanStart, + expiry_interval => ExpiryInterval, + ts => emqx_time:now_ms() + }, + case emqx_json:safe_encode(Presence) of + {ok, Payload} -> + emqx_broker:safe_publish( + make_msg(qos(Env), topic(connected, ClientId), Payload)); + {error, _Reason} -> + ?LOG(error, "Failed to encode 'connected' presence: ~p", [Presence]) + end. + +on_client_disconnected(ClientInfo, Reason, ConnInfo, Env) -> + ClientId = clientid(ClientInfo, ConnInfo), + Username = username(ClientInfo, ConnInfo), + Presence = #{clientid => ClientId, + username => Username, + reason => reason(Reason), + ts => emqx_time:now_ms() + }, + case emqx_json:safe_encode(Presence) of + {ok, Payload} -> + emqx_broker:safe_publish( + make_msg(qos(Env), topic(disconnected, ClientId), Payload)); + {error, _Reason} -> + ?LOG(error, "Failed to encode 'disconnected' presence: ~p", [Presence]) + end. + +clientid(#{client_id := undefined}, #{client_id := ClientId}) -> ClientId; +clientid(#{client_id := ClientId}, _ConnInfo) -> ClientId. + +username(#{username := undefined}, #{username := Username}) -> Username; +username(#{username := Username}, _ConnInfo) -> Username. + +make_msg(QoS, Topic, Payload) -> emqx_message:set_flag( sys, emqx_message:make( ?MODULE, QoS, Topic, iolist_to_binary(Payload))). @@ -99,6 +107,10 @@ topic(disconnected, ClientId) -> qos(Env) -> proplists:get_value(qos, Env, 0). reason(Reason) when is_atom(Reason) -> Reason; +reason({shutdown, Reason}) when is_atom(Reason) -> Reason; reason({Error, _}) when is_atom(Error) -> Error; reason(_) -> internal_error. +-compile({inline, [ntoa/1]}). +ntoa(IpAddr) -> iolist_to_binary(esockd_net:ntoa(IpAddr)). + diff --git a/src/emqx_mod_rewrite.erl b/src/emqx_mod_rewrite.erl index 17cff974a..0d32e9aee 100644 --- a/src/emqx_mod_rewrite.erl +++ b/src/emqx_mod_rewrite.erl @@ -22,8 +22,9 @@ -include_lib("emqx_mqtt.hrl"). -ifdef(TEST). --compile(export_all). --compile(nowarn_export_all). +-export([ compile/1 + , match_and_rewrite/2 + ]). -endif. %% APIs @@ -47,14 +48,14 @@ load(RawRules) -> emqx_hooks:add('client.unsubscribe', {?MODULE, rewrite_unsubscribe, [Rules]}), emqx_hooks:add('message.publish', {?MODULE, rewrite_publish, [Rules]}). -rewrite_subscribe(_Client, _Properties, TopicFilters, Rules) -> - {ok, [{match_rule(Topic, Rules), Opts} || {Topic, Opts} <- TopicFilters]}. +rewrite_subscribe(_ClientInfo, _Properties, TopicFilters, Rules) -> + {ok, [{match_and_rewrite(Topic, Rules), Opts} || {Topic, Opts} <- TopicFilters]}. -rewrite_unsubscribe(_Client, _Properties, TopicFilters, Rules) -> - {ok, [{match_rule(Topic, Rules), Opts} || {Topic, Opts} <- TopicFilters]}. +rewrite_unsubscribe(_ClientInfo, _Properties, TopicFilters, Rules) -> + {ok, [{match_and_rewrite(Topic, Rules), Opts} || {Topic, Opts} <- TopicFilters]}. rewrite_publish(Message = #message{topic = Topic}, Rules) -> - {ok, Message#message{topic = match_rule(Topic, Rules)}}. + {ok, Message#message{topic = match_and_rewrite(Topic, Rules)}}. unload(_) -> emqx_hooks:del('client.subscribe', {?MODULE, rewrite_subscribe}), @@ -65,16 +66,22 @@ unload(_) -> %% Internal functions %%-------------------------------------------------------------------- -match_rule(Topic, []) -> +compile(Rules) -> + lists:map(fun({rewrite, Topic, Re, Dest}) -> + {ok, MP} = re:compile(Re), + {rewrite, Topic, MP, Dest} + end, Rules). + +match_and_rewrite(Topic, []) -> Topic; -match_rule(Topic, [{rewrite, Filter, MP, Dest} | Rules]) -> +match_and_rewrite(Topic, [{rewrite, Filter, MP, Dest} | Rules]) -> case emqx_topic:match(Topic, Filter) of - true -> match_regx(Topic, MP, Dest); - false -> match_rule(Topic, Rules) + true -> rewrite(Topic, MP, Dest); + false -> match_and_rewrite(Topic, Rules) end. -match_regx(Topic, MP, Dest) -> +rewrite(Topic, MP, Dest) -> case re:run(Topic, MP, [{capture, all_but_first, list}]) of {match, Captured} -> Vars = lists:zip(["\\$" ++ integer_to_list(I) @@ -86,8 +93,3 @@ match_regx(Topic, MP, Dest) -> nomatch -> Topic end. -compile(Rules) -> - lists:map(fun({rewrite, Topic, Re, Dest}) -> - {ok, MP} = re:compile(Re), - {rewrite, Topic, MP, Dest} - end, Rules). diff --git a/src/emqx_mod_subscription.erl b/src/emqx_mod_subscription.erl index f46ef0f0b..a42234856 100644 --- a/src/emqx_mod_subscription.erl +++ b/src/emqx_mod_subscription.erl @@ -21,14 +21,14 @@ -include_lib("emqx.hrl"). -include_lib("emqx_mqtt.hrl"). -%% APIs --export([on_client_connected/4]). - %% emqx_gen_mod callbacks -export([ load/1 , unload/1 ]). +%% APIs +-export([on_client_connected/4]). + %%-------------------------------------------------------------------- %% Load/Unload Hook %%-------------------------------------------------------------------- @@ -37,7 +37,7 @@ load(Topics) -> emqx_hooks:add('client.connected', {?MODULE, on_client_connected, [Topics]}). on_client_connected(#{client_id := ClientId, - username := Username}, ?RC_SUCCESS, _ConnAttrs, Topics) -> + username := Username}, ?RC_SUCCESS, _ConnInfo, Topics) -> Replace = fun(Topic) -> rep(<<"%u">>, Username, rep(<<"%c">>, ClientId, Topic)) end, diff --git a/src/emqx_mqtt_caps.erl b/src/emqx_mqtt_caps.erl index ec7f55330..25d2ee5e4 100644 --- a/src/emqx_mqtt_caps.erl +++ b/src/emqx_mqtt_caps.erl @@ -26,6 +26,7 @@ -export([ get_caps/1 , get_caps/2 + , get_caps/3 ]). -export([default/0]). @@ -114,10 +115,13 @@ get_caps(Zone) -> -spec(get_caps(emqx_zone:zone(), publish|subscribe) -> caps()). get_caps(Zone, publish) -> with_env(Zone, '$mqtt_pub_caps', fun pub_caps/1); - get_caps(Zone, subscribe) -> with_env(Zone, '$mqtt_sub_caps', fun sub_caps/1). +-spec(get_caps(emqx_zone:zone(), atom(), term()) -> term()). +get_caps(Zone, Cap, Def) -> + emqx_zone:get_env(Zone, Cap, Def). + pub_caps(Zone) -> filter_caps(?PUBCAP_KEYS, get_caps(Zone)). diff --git a/src/emqx_mqtt_props.erl b/src/emqx_mqtt_props.erl index 241db9a2e..0377b6d0c 100644 --- a/src/emqx_mqtt_props.erl +++ b/src/emqx_mqtt_props.erl @@ -28,8 +28,8 @@ %% For tests -export([all/0]). --export([ set_property/3 - , get_property/3 +-export([ set/3 + , get/3 ]). -type(prop_name() :: atom()). @@ -183,13 +183,13 @@ validate_value(_Type, _Val) -> false. -spec(all() -> map()). all() -> ?PROPS_TABLE. -set_property(Name, Value, undefined) -> +set(Name, Value, undefined) -> #{Name => Value}; -set_property(Name, Value, Props) -> +set(Name, Value, Props) -> Props#{Name => Value}. -get_property(_Name, undefined, Default) -> +get(_Name, undefined, Default) -> Default; -get_property(Name, Props, Default) -> +get(Name, Props, Default) -> maps:get(Name, Props, Default). diff --git a/src/emqx_oom.erl b/src/emqx_oom.erl index 8d3344402..efc0a4c69 100644 --- a/src/emqx_oom.erl +++ b/src/emqx_oom.erl @@ -28,7 +28,7 @@ , info/1 ]). --export_type([oom_policy/0]). +-export_type([opts/0, oom_policy/0]). -type(opts() :: #{message_queue_len => non_neg_integer(), max_heap_size => non_neg_integer() diff --git a/src/emqx_packet.erl b/src/emqx_packet.erl index 1113c15d7..6b7663b75 100644 --- a/src/emqx_packet.erl +++ b/src/emqx_packet.erl @@ -89,7 +89,7 @@ proto_name(#mqtt_packet_connect{proto_name = Name}) -> %% @doc Protocol version of the CONNECT Packet. -spec(proto_ver(emqx_types:packet()|connect()) -> emqx_types:version()). -proto_ver(?CONNACK_PACKET(ConnPkt)) -> +proto_ver(?CONNECT_PACKET(ConnPkt)) -> proto_ver(ConnPkt); proto_ver(#mqtt_packet_connect{proto_ver = Ver}) -> Ver. @@ -241,7 +241,7 @@ validate_topic_filters(TopicFilters) -> %% @doc Publish Packet to Message. -spec(to_message(emqx_types:client(), emqx_ypes:packet()) -> emqx_types:message()). -to_message(#{client_id := ClientId, username := Username, peername := Peername}, +to_message(#{client_id := ClientId, username := Username, peerhost := PeerHost}, #mqtt_packet{header = #mqtt_packet_header{type = ?PUBLISH, retain = Retain, qos = QoS, @@ -252,7 +252,7 @@ to_message(#{client_id := ClientId, username := Username, peername := Peername}, Msg = emqx_message:make(ClientId, QoS, Topic, Payload), Msg#message{flags = #{dup => Dup, retain => Retain}, headers = merge_props(#{username => Username, - peername => Peername}, Props)}. + peerhost => PeerHost}, Props)}. -spec(will_msg(#mqtt_packet_connect{}) -> emqx_types:message()). will_msg(#mqtt_packet_connect{will_flag = false}) -> diff --git a/src/emqx_protocol.erl b/src/emqx_protocol.erl deleted file mode 100644 index 54ff6056c..000000000 --- a/src/emqx_protocol.erl +++ /dev/null @@ -1,133 +0,0 @@ -%%-------------------------------------------------------------------- -%% Copyright (c) 2019 EMQ Technologies Co., Ltd. All Rights Reserved. -%% -%% Licensed under the Apache License, Version 2.0 (the "License"); -%% you may not use this file except in compliance with the License. -%% You may obtain a copy of the License at -%% -%% http://www.apache.org/licenses/LICENSE-2.0 -%% -%% Unless required by applicable law or agreed to in writing, software -%% distributed under the License is distributed on an "AS IS" BASIS, -%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -%% See the License for the specific language governing permissions and -%% limitations under the License. -%%-------------------------------------------------------------------- - -%% MQTT Protocol --module(emqx_protocol). - --include("types.hrl"). --include("emqx_mqtt.hrl"). - --export([ init/2 - , info/1 - , info/2 - , attrs/1 - ]). - --export([ find_alias/2 - , save_alias/3 - , clear_will_msg/1 - ]). - --export_type([protocol/0]). - --record(protocol, { - %% MQTT Proto Name - proto_name :: binary(), - %% MQTT Proto Version - proto_ver :: emqx_types:ver(), - %% Clean Start Flag - clean_start :: boolean(), - %% MQTT Keepalive interval - keepalive :: non_neg_integer(), - %% ClientId in CONNECT Packet - client_id :: emqx_types:client_id(), - %% Username in CONNECT Packet - username :: emqx_types:username(), - %% MQTT Will Msg - will_msg :: emqx_types:message(), - %% MQTT Topic Aliases - topic_aliases :: maybe(map()), - %% MQTT Topic Alias Maximum - alias_maximum :: maybe(map()) - }). - --opaque(protocol() :: #protocol{}). - --define(INFO_KEYS, record_info(fields, protocol)). - --define(ATTR_KEYS, [proto_name, proto_ver, clean_start, keepalive]). - --spec(init(#mqtt_packet_connect{}, atom()) -> protocol()). -init(#mqtt_packet_connect{proto_name = ProtoName, - proto_ver = ProtoVer, - clean_start = CleanStart, - keepalive = Keepalive, - properties = Properties, - client_id = ClientId, - username = Username} = ConnPkt, Zone) -> - WillMsg = emqx_packet:will_msg(ConnPkt), - #protocol{proto_name = ProtoName, - proto_ver = ProtoVer, - clean_start = CleanStart, - keepalive = Keepalive, - client_id = ClientId, - username = Username, - will_msg = WillMsg, - alias_maximum = #{outbound => emqx_mqtt_props:get_property('Topic-Alias-Maximum', Properties, 0), - inbound => maps:get(max_topic_alias, emqx_mqtt_caps:get_caps(Zone), 0)} - }. - --spec(info(protocol()) -> emqx_types:infos()). -info(Proto) -> - maps:from_list(info(?INFO_KEYS, Proto)). - --spec(info(atom()|list(atom()), protocol()) -> term()). -info(Keys, Proto) when is_list(Keys) -> - [{Key, info(Key, Proto)} || Key <- Keys]; -info(proto_name, #protocol{proto_name = ProtoName}) -> - ProtoName; -info(proto_ver, #protocol{proto_ver = ProtoVer}) -> - ProtoVer; -info(clean_start, #protocol{clean_start = CleanStart}) -> - CleanStart; -info(keepalive, #protocol{keepalive = Keepalive}) -> - Keepalive; -info(client_id, #protocol{client_id = ClientId}) -> - ClientId; -info(username, #protocol{username = Username}) -> - Username; -info(will_msg, #protocol{will_msg = WillMsg}) -> - WillMsg; -info(will_delay_interval, #protocol{will_msg = undefined}) -> - 0; -info(will_delay_interval, #protocol{will_msg = WillMsg}) -> - emqx_message:get_header('Will-Delay-Interval', WillMsg, 0); -info(topic_aliases, #protocol{topic_aliases = Aliases}) -> - Aliases; -info(alias_maximum, #protocol{alias_maximum = AliasMaximum}) -> - AliasMaximum. - --spec(attrs(protocol()) -> emqx_types:attrs()). -attrs(Proto) -> - maps:from_list(info(?ATTR_KEYS, Proto)). - --spec(find_alias(emqx_types:alias_id(), protocol()) - -> {ok, emqx_types:topic()} | false). -find_alias(_AliasId, #protocol{topic_aliases = undefined}) -> - false; -find_alias(AliasId, #protocol{topic_aliases = Aliases}) -> - maps:find(AliasId, Aliases). - --spec(save_alias(emqx_types:alias_id(), emqx_types:topic(), protocol()) - -> protocol()). -save_alias(AliasId, Topic, Proto = #protocol{topic_aliases = undefined}) -> - Proto#protocol{topic_aliases = #{AliasId => Topic}}; -save_alias(AliasId, Topic, Proto = #protocol{topic_aliases = Aliases}) -> - Proto#protocol{topic_aliases = maps:put(AliasId, Topic, Aliases)}. - -clear_will_msg(Protocol) -> - Protocol#protocol{will_msg = undefined}. - diff --git a/src/emqx_session.erl b/src/emqx_session.erl index 59cc71ea9..434f5f5ed 100644 --- a/src/emqx_session.erl +++ b/src/emqx_session.erl @@ -61,8 +61,6 @@ %% Exports for unit tests -export([set_field/3]). --export([update_expiry_interval/2]). - -export([ subscribe/4 , unsubscribe/3 ]). @@ -116,13 +114,10 @@ max_awaiting_rel :: non_neg_integer(), %% Awaiting PUBREL Timeout await_rel_timeout :: timeout(), - %% Session Expiry Interval - expiry_interval :: timeout(), %% Enqueue Count enqueue_cnt :: non_neg_integer(), %% Created at created_at :: erlang:timestamp() - }). -opaque(session() :: #session{}). @@ -130,11 +125,12 @@ -type(publish() :: {publish, emqx_types:packet_id(), emqx_types:message()}). -define(DEFAULT_BATCH_N, 1000). --define(ATTR_KEYS, [expiry_interval, created_at]). +-define(ATTR_KEYS, [max_inflight, max_mqueue, retry_interval, + max_awaiting_rel, await_rel_timeout, created_at]). -define(INFO_KEYS, [subscriptions, max_subscriptions, upgrade_qos, inflight, max_inflight, retry_interval, mqueue_len, max_mqueue, mqueue_dropped, next_pkt_id, awaiting_rel, max_awaiting_rel, - await_rel_timeout, expiry_interval, created_at]). + await_rel_timeout, created_at]). -define(STATS_KEYS, [subscriptions_cnt, max_subscriptions, inflight, max_inflight, mqueue_len, max_mqueue, mqueue_dropped, awaiting_rel, max_awaiting_rel, enqueue_cnt]). @@ -145,8 +141,7 @@ %% @doc Init a session. -spec(init(emqx_types:client(), Options :: map()) -> session()). -init(#{zone := Zone}, #{max_inflight := MaxInflight, - expiry_interval := ExpiryInterval}) -> +init(#{zone := Zone}, #{receive_maximum := MaxInflight}) -> #session{max_subscriptions = get_env(Zone, max_subscriptions, 0), subscriptions = #{}, upgrade_qos = get_env(Zone, upgrade_qos, false), @@ -157,7 +152,6 @@ init(#{zone := Zone}, #{max_inflight := MaxInflight, awaiting_rel = #{}, max_awaiting_rel = get_env(Zone, max_awaiting_rel, 100), await_rel_timeout = get_env(Zone, await_rel_timeout, 3600*1000), - expiry_interval = ExpiryInterval, enqueue_cnt = 0, created_at = os:timestamp() }. @@ -213,8 +207,6 @@ info(max_awaiting_rel, #session{max_awaiting_rel = MaxAwaitingRel}) -> MaxAwaitingRel; info(await_rel_timeout, #session{await_rel_timeout = Timeout}) -> Timeout; -info(expiry_interval, #session{expiry_interval = Interval}) -> - Interval; info(enqueue_cnt, #session{enqueue_cnt = Cnt}) -> Cnt; info(created_at, #session{created_at = CreatedAt}) -> @@ -226,9 +218,6 @@ set_field(Name, Val, Channel) -> Pos = emqx_misc:index_of(Name, Fields), setelement(Pos+1, Channel, Val). -update_expiry_interval(ExpiryInterval, Session) -> - Session#session{expiry_interval = ExpiryInterval}. - -spec(takeover(session()) -> ok). takeover(#session{subscriptions = Subs}) -> lists:foreach(fun({TopicFilter, _SubOpts}) -> diff --git a/src/emqx_types.erl b/src/emqx_types.erl index 86c682dbe..58b438010 100644 --- a/src/emqx_types.erl +++ b/src/emqx_types.erl @@ -36,6 +36,7 @@ , client_id/0 , username/0 , password/0 + , peerhost/0 , peername/0 , protocol/0 ]). @@ -102,9 +103,7 @@ atom() => term() }). -type(client() :: #{zone := zone(), - conn_mod := maybe(module()), - peername := peername(), - sockname := peername(), + peerhost := peerhost(), client_id := client_id(), username := username(), peercert := esockd_peercert:peercert(), @@ -117,9 +116,10 @@ anonymous => boolean(), atom() => term() }). --type(client_id() :: binary() | atom()). +-type(client_id() :: binary()|atom()). -type(username() :: maybe(binary())). -type(password() :: maybe(binary())). +-type(peerhost() :: inet:ip_address()). -type(peername() :: {inet:ip_address(), inet:port_number()}). -type(protocol() :: mqtt | 'mqtt-sn' | coap | stomp | none | atom()). -type(auth_result() :: success diff --git a/src/emqx_ws_connection.erl b/src/emqx_ws_connection.erl index c99829356..03e15cce6 100644 --- a/src/emqx_ws_connection.erl +++ b/src/emqx_ws_connection.erl @@ -74,9 +74,9 @@ info(WsPid) when is_pid(WsPid) -> call(WsPid, info); info(WsConn = #ws_connection{chan_state = ChanState}) -> - ConnInfo = info(?INFO_KEYS, WsConn), ChanInfo = emqx_channel:info(ChanState), - maps:merge(ChanInfo, #{conninfo => maps:from_list(ConnInfo)}). + SockInfo = maps:from_list(info(?INFO_KEYS, WsConn)), + maps:merge(ChanInfo, #{sockinfo => SockInfo}). info(Keys, WsConn) when is_list(Keys) -> [{Key, info(Key, WsConn)} || Key <- Keys]; @@ -95,9 +95,9 @@ info(chan_state, #ws_connection{chan_state = ChanState}) -> attrs(WsPid) when is_pid(WsPid) -> call(WsPid, attrs); attrs(WsConn = #ws_connection{chan_state = ChanState}) -> - ConnAttrs = info(?ATTR_KEYS, WsConn), ChanAttrs = emqx_channel:attrs(ChanState), - maps:merge(ChanAttrs, #{conninfo => maps:from_list(ConnAttrs)}). + SockAttrs = maps:from_list(info(?ATTR_KEYS, WsConn)), + maps:merge(ChanAttrs, #{sockinfo => SockAttrs}). -spec(stats(pid()|ws_connection()) -> emqx_types:stats()). stats(WsPid) when is_pid(WsPid) -> @@ -273,7 +273,7 @@ websocket_info({incoming, {error, _Reason}}, State = #ws_connection{fsm_state = websocket_info({incoming, Packet = ?CONNECT_PACKET(ConnPkt)}, State = #ws_connection{fsm_state = idle}) -> #mqtt_packet_connect{proto_ver = ProtoVer, properties = Properties} = ConnPkt, - MaxPacketSize = emqx_mqtt_props:get_property('Maximum-Packet-Size', Properties, undefined), + MaxPacketSize = emqx_mqtt_props:get('Maximum-Packet-Size', Properties, undefined), NState = State#ws_connection{serialize = serialize_fun(ProtoVer, MaxPacketSize)}, handle_incoming(Packet, fun connected/1, NState); @@ -341,7 +341,7 @@ disconnected(State) -> %% Handle timeout handle_timeout(TRef, Msg, State = #ws_connection{chan_state = ChanState}) -> - case emqx_channel:timeout(TRef, Msg, ChanState) of + case emqx_channel:handle_timeout(TRef, Msg, ChanState) of {ok, NChanState} -> {ok, State#ws_connection{chan_state = NChanState}}; {ok, Packets, NChanState} -> diff --git a/src/emqx_zone.erl b/src/emqx_zone.erl index 2b696c937..9d45c64e1 100644 --- a/src/emqx_zone.erl +++ b/src/emqx_zone.erl @@ -28,18 +28,26 @@ -export([start_link/0, stop/0]). -export([ use_username_as_clientid/1 + , enable_stats/1 , enable_acl/1 - , enable_banned/1 + , enable_ban/1 , enable_flapping_detect/1 + , ignore_loop_deliver/1 + , server_keepalive/1 + , max_inflight/1 + , session_expiry_interval/1 + , force_gc_policy/1 + , force_shutdown_policy/1 ]). -export([ get_env/2 , get_env/3 , set_env/3 , unset_env/2 - , force_reload/0 ]). +-export([force_reload/0]). + %% gen_server callbacks -export([ init/1 , handle_call/3 @@ -72,18 +80,46 @@ start_link() -> use_username_as_clientid(Zone) -> get_env(Zone, use_username_as_clientid, false). +-spec(enable_stats(zone()) -> boolean()). +enable_stats(Zone) -> + get_env(Zone, enable_stats, true). + -spec(enable_acl(zone()) -> boolean()). enable_acl(Zone) -> get_env(Zone, enable_acl, true). --spec(enable_banned(zone()) -> boolean()). -enable_banned(Zone) -> - get_env(Zone, enable_banned, false). +-spec(enable_ban(zone()) -> boolean()). +enable_ban(Zone) -> + get_env(Zone, enable_ban, false). -spec(enable_flapping_detect(zone()) -> boolean()). enable_flapping_detect(Zone) -> get_env(Zone, enable_flapping_detect, false). +-spec(ignore_loop_deliver(zone()) -> boolean()). +ignore_loop_deliver(Zone) -> + get_env(Zone, ignore_loop_deliver, false). + +-spec(server_keepalive(zone()) -> pos_integer()). +server_keepalive(Zone) -> + get_env(Zone, server_keepalive). + +-spec(max_inflight(zone()) -> 0..65535). +max_inflight(Zone) -> + get_env(Zone, max_inflight, 65535). + +-spec(session_expiry_interval(zone()) -> non_neg_integer()). +session_expiry_interval(Zone) -> + get_env(Zone, session_expiry_interval, 0). + +-spec(force_gc_policy(zone()) -> maybe(emqx_gc:opts())). +force_gc_policy(Zone) -> + get_env(Zone, force_gc_policy). + +-spec(force_shutdown_policy(zone()) -> maybe(emqx_oom:opts())). +force_shutdown_policy(Zone) -> + get_env(Zone, force_shutdown_policy). + -spec(get_env(maybe(zone()), atom()) -> maybe(term())). get_env(undefined, Key) -> emqx:get_env(Key); get_env(Zone, Key) -> diff --git a/test/emqx_access_SUITE.erl b/test/emqx_access_SUITE.erl index f276e4929..5b1694e74 100644 --- a/test/emqx_access_SUITE.erl +++ b/test/emqx_access_SUITE.erl @@ -344,35 +344,35 @@ t_compile_rule(_) -> {deny, all} = compile({deny, all}). t_match_rule(_) -> - Client1 = #{zone => external, - client_id => <<"testClient">>, - username => <<"TestUser">>, - peername => {{127,0,0,1}, 2948} - }, - Client2 = #{zone => external, - client_id => <<"testClient">>, - username => <<"TestUser">>, - peername => {{192,168,0,10}, 3028} - }, - {matched, allow} = match(Client1, <<"Test/Topic">>, {allow, all}), - {matched, deny} = match(Client1, <<"Test/Topic">>, {deny, all}), - {matched, allow} = match(Client1, <<"Test/Topic">>, + ClientInfo1 = #{zone => external, + client_id => <<"testClient">>, + username => <<"TestUser">>, + peerhost => {127,0,0,1} + }, + ClientInfo2 = #{zone => external, + client_id => <<"testClient">>, + username => <<"TestUser">>, + peerhost => {192,168,0,10} + }, + {matched, allow} = match(ClientInfo1, <<"Test/Topic">>, {allow, all}), + {matched, deny} = match(ClientInfo1, <<"Test/Topic">>, {deny, all}), + {matched, allow} = match(ClientInfo1, <<"Test/Topic">>, compile({allow, {ipaddr, "127.0.0.1"}, subscribe, ["$SYS/#", "#"]})), - {matched, allow} = match(Client2, <<"Test/Topic">>, + {matched, allow} = match(ClientInfo2, <<"Test/Topic">>, compile({allow, {ipaddr, "192.168.0.1/24"}, subscribe, ["$SYS/#", "#"]})), - {matched, allow} = match(Client1, <<"d/e/f/x">>, + {matched, allow} = match(ClientInfo1, <<"d/e/f/x">>, compile({allow, {user, "TestUser"}, subscribe, ["a/b/c", "d/e/f/#"]})), - nomatch = match(Client1, <<"d/e/f/x">>, compile({allow, {user, "admin"}, pubsub, ["d/e/f/#"]})), - {matched, allow} = match(Client1, <<"testTopics/testClient">>, + nomatch = match(ClientInfo1, <<"d/e/f/x">>, compile({allow, {user, "admin"}, pubsub, ["d/e/f/#"]})), + {matched, allow} = match(ClientInfo1, <<"testTopics/testClient">>, compile({allow, {client, "testClient"}, publish, ["testTopics/testClient"]})), - {matched, allow} = match(Client1, <<"clients/testClient">>, compile({allow, all, pubsub, ["clients/%c"]})), + {matched, allow} = match(ClientInfo1, <<"clients/testClient">>, compile({allow, all, pubsub, ["clients/%c"]})), {matched, allow} = match(#{username => <<"user2">>}, <<"users/user2/abc/def">>, compile({allow, all, subscribe, ["users/%u/#"]})), - {matched, deny} = match(Client1, <<"d/e/f">>, compile({deny, all, subscribe, ["$SYS/#", "#"]})), + {matched, deny} = match(ClientInfo1, <<"d/e/f">>, compile({deny, all, subscribe, ["$SYS/#", "#"]})), Rule = compile({allow, {'and', [{ipaddr, "127.0.0.1"}, {user, <<"WrongUser">>}]}, publish, <<"Topic">>}), - nomatch = match(Client1, <<"Topic">>, Rule), + nomatch = match(ClientInfo1, <<"Topic">>, Rule), AndRule = compile({allow, {'and', [{ipaddr, "127.0.0.1"}, {user, <<"TestUser">>}]}, publish, <<"Topic">>}), - {matched, allow} = match(Client1, <<"Topic">>, AndRule), + {matched, allow} = match(ClientInfo1, <<"Topic">>, AndRule), OrRule = compile({allow, {'or', [{ipaddr, "127.0.0.1"}, {user, <<"WrongUser">>}]}, publish, ["Topic"]}), - {matched, allow} = match(Client1, <<"Topic">>, OrRule). + {matched, allow} = match(ClientInfo1, <<"Topic">>, OrRule). diff --git a/test/emqx_banned_SUITE.erl b/test/emqx_banned_SUITE.erl index 1dc8e0dcb..99c5df3d3 100644 --- a/test/emqx_banned_SUITE.erl +++ b/test/emqx_banned_SUITE.erl @@ -27,6 +27,8 @@ all() -> emqx_ct:all(?MODULE). init_per_suite(Config) -> application:load(emqx), ok = ekka:start(), + %% for coverage + ok = emqx_banned:mnesia(copy), Config. end_per_suite(_Config) -> @@ -51,32 +53,43 @@ t_check(_) -> ok = emqx_banned:add(#banned{who = {username, <<"BannedUser">>}}), ok = emqx_banned:add(#banned{who = {ipaddr, {192,168,0,1}}}), ?assertEqual(3, emqx_banned:info(size)), - Client1 = #{client_id => <<"BannedClient">>, - username => <<"user">>, - peername => {{127,0,0,1}, 5000} - }, - Client2 = #{client_id => <<"client">>, - username => <<"BannedUser">>, - peername => {{127,0,0,1}, 5000} - }, - Client3 = #{client_id => <<"client">>, - username => <<"user">>, - peername => {{192,168,0,1}, 5000} - }, - Client4 = #{client_id => <<"client">>, - username => <<"user">>, - peername => {{127,0,0,1}, 5000} - }, - ?assert(emqx_banned:check(Client1)), - ?assert(emqx_banned:check(Client2)), - ?assert(emqx_banned:check(Client3)), - ?assertNot(emqx_banned:check(Client4)), + ClientInfo1 = #{client_id => <<"BannedClient">>, + username => <<"user">>, + peerhost => {127,0,0,1} + }, + ClientInfo2 = #{client_id => <<"client">>, + username => <<"BannedUser">>, + peerhost => {127,0,0,1} + }, + ClientInfo3 = #{client_id => <<"client">>, + username => <<"user">>, + peerhost => {192,168,0,1} + }, + ClientInfo4 = #{client_id => <<"client">>, + username => <<"user">>, + peerhost => {127,0,0,1} + }, + ?assert(emqx_banned:check(ClientInfo1)), + ?assert(emqx_banned:check(ClientInfo2)), + ?assert(emqx_banned:check(ClientInfo3)), + ?assertNot(emqx_banned:check(ClientInfo4)), ok = emqx_banned:delete({client_id, <<"BannedClient">>}), ok = emqx_banned:delete({username, <<"BannedUser">>}), ok = emqx_banned:delete({ipaddr, {192,168,0,1}}), - ?assertNot(emqx_banned:check(Client1)), - ?assertNot(emqx_banned:check(Client2)), - ?assertNot(emqx_banned:check(Client3)), - ?assertNot(emqx_banned:check(Client4)), + ?assertNot(emqx_banned:check(ClientInfo1)), + ?assertNot(emqx_banned:check(ClientInfo2)), + ?assertNot(emqx_banned:check(ClientInfo3)), + ?assertNot(emqx_banned:check(ClientInfo4)), ?assertEqual(0, emqx_banned:info(size)). +t_unused(_) -> + {ok, Banned} = emqx_banned:start_link(), + ok = emqx_banned:add(#banned{who = {client_id, <<"BannedClient">>}, + until = erlang:system_time(second) + }), + ?assertEqual(ignored, gen_server:call(Banned, unexpected_req)), + ?assertEqual(ok, gen_server:cast(Banned, unexpected_msg)), + ?assertEqual(ok, Banned ! ok), + timer:sleep(500), %% expiry timer + ok = emqx_banned:stop(). + diff --git a/test/emqx_channel_SUITE.erl b/test/emqx_channel_SUITE.erl index 977cd397e..e13a76686 100644 --- a/test/emqx_channel_SUITE.erl +++ b/test/emqx_channel_SUITE.erl @@ -176,13 +176,20 @@ t_handle_deliver(_) -> %%-------------------------------------------------------------------- t_handle_connack(_) -> + ConnPkt = #mqtt_packet_connect{ + proto_name = <<"MQTT">>, + proto_ver = ?MQTT_PROTO_V4, + clean_start = true, + properties = #{}, + client_id = <<"clientid">> + }, with_channel( fun(Channel) -> {ok, ?CONNACK_PACKET(?RC_SUCCESS, SP, _), _} - = handle_out({connack, ?RC_SUCCESS, 0}, Channel), + = handle_out({connack, ?RC_SUCCESS, 0, ConnPkt}, Channel), {stop, {shutdown, not_authorized}, ?CONNACK_PACKET(?RC_NOT_AUTHORIZED), _} - = handle_out({connack, ?RC_NOT_AUTHORIZED}, Channel) + = handle_out({connack, ?RC_NOT_AUTHORIZED, ConnPkt}, Channel) end). t_handle_out_publish(_) -> @@ -271,30 +278,31 @@ t_terminate(_) -> %% Helper functions %%-------------------------------------------------------------------- -with_channel(Fun) -> - ConnInfo = #{peername => {{127,0,0,1}, 3456}, - sockname => {{127,0,0,1}, 1883}, +with_channel(TestFun) -> + ConnInfo = #{peername => {{127,0,0,1}, 3456}, + sockname => {{127,0,0,1}, 1883}, + conn_mod => emqx_connection, + proto_name => <<"MQTT">>, + proto_ver => ?MQTT_PROTO_V5, + clean_start => true, + keepalive => 30, client_id => <<"clientid">>, - username => <<"username">> + username => <<"username">>, + conn_props => #{}, + receive_maximum => 100, + expiry_interval => 60 }, - Options = [{zone, testing}], - Channel = emqx_channel:init(ConnInfo, Options), - ConnPkt = #mqtt_packet_connect{ - proto_name = <<"MQTT">>, - proto_ver = ?MQTT_PROTO_V5, - clean_start = true, - keepalive = 30, - properties = #{}, - client_id = <<"clientid">>, - username = <<"username">>, - password = <<"passwd">> - }, - Protocol = emqx_protocol:init(ConnPkt, testing), - Session = emqx_session:init(#{zone => testing}, - #{max_inflight => 100, - expiry_interval => 0 - }), - Fun(emqx_channel:set_field(protocol, Protocol, - emqx_channel:set_field( - session, Session, Channel))). + ClientInfo = #{zone => <<"external">>, + peerhost => {127,0,0,1}, + client_id => <<"clientid">>, + username => <<"username">>, + peercert => undefined, + is_bridge => false, + is_superuser => false, + mountpoint => undefined + }, + Channel = emqx_channel:init(ConnInfo, [{zone, testing}]), + Session = emqx_session:init(ClientInfo, ConnInfo), + Channel1 = emqx_channel:set_field(client, ClientInfo, Channel), + TestFun(emqx_channel:set_field(session, Session, Channel1)). diff --git a/test/emqx_mod_subscription_SUITE.erl b/test/emqx_cm_SUITE.erl similarity index 53% rename from test/emqx_mod_subscription_SUITE.erl rename to test/emqx_cm_SUITE.erl index 0c9c9c678..cd91ed1f3 100644 --- a/test/emqx_mod_subscription_SUITE.erl +++ b/test/emqx_cm_SUITE.erl @@ -14,39 +14,47 @@ %% limitations under the License. %%-------------------------------------------------------------------- --module(emqx_mod_subscription_SUITE). +-module(emqx_cm_SUITE). -compile(export_all). -compile(nowarn_export_all). --include("emqx_mqtt.hrl"). -include("emqx.hrl"). - -include_lib("eunit/include/eunit.hrl"). --include_lib("common_test/include/ct.hrl"). all() -> emqx_ct:all(?MODULE). init_per_suite(Config) -> - emqx_ct_helpers:boot_modules(all), - emqx_ct_helpers:start_apps([emqx]), + emqx_ct_helpers:start_apps([]), Config. end_per_suite(_Config) -> - emqx_ct_helpers:stop_apps([emqx]). + emqx_ct_helpers:stop_apps([]). + +t_reg_unreg_channel(_) -> + error(not_implemented). + +t_get_set_chan_attrs(_) -> + error(not_implemented). + +t_get_set_chan_stats(_) -> + error(not_implemented). + +t_open_session(_) -> + error(not_implemented). + +t_discard_session(_) -> + error(not_implemented). + +t_takeover_session(_) -> + error(not_implemented). + +t_lookup_channels(_) -> + error(not_implemented). + +t_lock_clientid(_) -> + error(not_implemented). + +t_unlock_clientid(_) -> + error(not_implemented). -t_mod_subscription(_) -> - emqx_mod_subscription:load([{<<"connected/%c/%u">>, ?QOS_0}]), - {ok, C} = emqtt:start_link([{host, "localhost"}, {client_id, "myclient"}, {username, "admin"}]), - {ok, _} = emqtt:connect(C), - % ct:sleep(100), - emqtt:publish(C, <<"connected/myclient/admin">>, <<"Hello world">>, ?QOS_0), - receive - {publish, #{topic := Topic, payload := Payload}} -> - ?assertEqual(<<"connected/myclient/admin">>, Topic), - ?assertEqual(<<"Hello world">>, Payload) - after 100 -> - ct:fail("no_message") - end, - ok = emqtt:disconnect(C), - emqx_mod_subscription:unload([]). diff --git a/test/emqx_ctl_SUITE.erl b/test/emqx_ctl_SUITE.erl index 5e18f19c2..cc06fa092 100644 --- a/test/emqx_ctl_SUITE.erl +++ b/test/emqx_ctl_SUITE.erl @@ -25,30 +25,101 @@ all() -> emqx_ct:all(?MODULE). init_per_suite(Config) -> - emqx_ct_helpers:boot_modules([]), - emqx_ct_helpers:start_apps([]), Config. end_per_suite(_Config) -> - emqx_ct_helpers:stop_apps([]). + ok. -t_command(_) -> - emqx_ctl:start_link(), - emqx_ctl:register_command(test, {?MODULE, test}), - ct:sleep(50), - ?assertEqual([{emqx_ctl_SUITE,test}], emqx_ctl:lookup_command(test)), - ?assertEqual(ok, emqx_ctl:run_command(["test", "ok"])), - ?assertEqual({error, test_failed}, emqx_ctl:run_command(["test", "error"])), - ?assertEqual({error, cmd_not_found}, emqx_ctl:run_command(["test2", "ok"])), - emqx_ctl:unregister_command(test), - ct:sleep(50), - ?assertEqual([], emqx_ctl:lookup_command(test)). +%%-------------------------------------------------------------------- +%% Test cases +%%-------------------------------------------------------------------- -test(["ok"]) -> - ok; -test(["error"]) -> - error(test_failed); -test(_) -> - io:format("Hello world"). +t_reg_unreg_command(_) -> + with_ctl_server( + fun(_CtlSrv) -> + emqx_ctl:register_command(cmd1, {?MODULE, cmd1_fun}), + emqx_ctl:register_command(cmd2, {?MODULE, cmd2_fun}), + ?assertEqual([{?MODULE, cmd1_fun}], emqx_ctl:lookup_command(cmd1)), + ?assertEqual([{?MODULE, cmd2_fun}], emqx_ctl:lookup_command(cmd2)), + ?assertEqual([{cmd1, ?MODULE, cmd1_fun}, {cmd2, ?MODULE, cmd2_fun}], + emqx_ctl:get_commands()), + emqx_ctl:unregister_command(cmd1), + emqx_ctl:unregister_command(cmd2), + ct:sleep(100), + ?assertEqual([], emqx_ctl:lookup_command(cmd1)), + ?assertEqual([], emqx_ctl:lookup_command(cmd2)), + ?assertEqual([], emqx_ctl:get_commands()) + end). +t_run_commands(_) -> + with_ctl_server( + fun(_CtlSrv) -> + ?assertEqual({error, cmd_not_found}, emqx_ctl:run_command(["cmd", "arg"])), + emqx_ctl:register_command(cmd1, {?MODULE, cmd1_fun}), + emqx_ctl:register_command(cmd2, {?MODULE, cmd2_fun}), + ok = emqx_ctl:run_command(["cmd1", "arg"]), + {error, badarg} = emqx_ctl:run_command(["cmd1", "badarg"]), + ok = emqx_ctl:run_command(["cmd2", "arg1", "arg2"]), + {error, badarg} = emqx_ctl:run_command(["cmd2", "arg1", "badarg"]) + end). +t_print(_) -> + ok = emqx_ctl:print("help"), + ok = emqx_ctl:print("~s", [help]), + % - check the output of the usage + print_mock(), + ?assertEqual("help", emqx_ctl:print("help")), + ?assertEqual("help", emqx_ctl:print("~s", [help])). + +t_usage(_) -> + CmdParams1 = "emqx_cmd_1 param1 param2", + CmdDescr1 = "emqx_cmd_1 is a test command means nothing", + Output1 = "emqx_cmd_1 param1 param2 # emqx_cmd_1 is a test command means nothing\n", + % - usage/1,2 should return ok + ok = emqx_ctl:usage([{CmdParams1, CmdDescr1}, {CmdParams1, CmdDescr1}]), + ok = emqx_ctl:usage(CmdParams1, CmdDescr1), + + % - check the output of the usage + print_mock(), + ?assertEqual(Output1, emqx_ctl:usage(CmdParams1, CmdDescr1)), + ?assertEqual([Output1, Output1], emqx_ctl:usage([{CmdParams1, CmdDescr1}, {CmdParams1, CmdDescr1}])), + + % - for the commands or descriptions have multi-lines + CmdParams2 = "emqx_cmd_2 param1 param2", + CmdDescr2 = "emqx_cmd_2 is a test command\nmeans nothing", + Output2 = "emqx_cmd_2 param1 param2 # emqx_cmd_2 is a test command\n" + " ""# means nothing\n", + ?assertEqual(Output2, emqx_ctl:usage(CmdParams2, CmdDescr2)), + ?assertEqual([Output2, Output2], emqx_ctl:usage([{CmdParams2, CmdDescr2}, {CmdParams2, CmdDescr2}])). + +t_unexpected(_) -> + with_ctl_server( + fun(CtlSrv) -> + ignored = gen_server:call(CtlSrv, unexpected_call), + ok = gen_server:cast(CtlSrv, unexpected_cast), + CtlSrv ! unexpected_info, + ?assert(is_process_alive(CtlSrv)) + end). + +%%-------------------------------------------------------------------- +%% Cmds for test +%%-------------------------------------------------------------------- + +cmd1_fun(["arg"]) -> ok; +cmd1_fun(["badarg"]) -> error(badarg). + +cmd2_fun(["arg1", "arg2"]) -> ok; +cmd2_fun(["arg1", "badarg"]) -> error(badarg). + +with_ctl_server(Fun) -> + {ok, Pid} = emqx_ctl:start_link(), + _ = Fun(Pid), + ok = emqx_ctl:stop(). + +print_mock() -> + %% proxy usage/1,2 and print/1,2 to format_xx/1,2 funcs + meck:new(emqx_ctl, [non_strict, passthrough]), + meck:expect(emqx_ctl, print, fun(Arg) -> emqx_ctl:format(Arg) end), + meck:expect(emqx_ctl, print, fun(Msg, Arg) -> emqx_ctl:format(Msg, Arg) end), + meck:expect(emqx_ctl, usage, fun(Usages) -> emqx_ctl:format_usage(Usages) end), + meck:expect(emqx_ctl, usage, fun(CmdParams, CmdDescr) -> emqx_ctl:format_usage(CmdParams, CmdDescr) end). diff --git a/test/emqx_ctl_SUTIES.erl b/test/emqx_ctl_SUTIES.erl deleted file mode 100644 index a3ce8e8b0..000000000 --- a/test/emqx_ctl_SUTIES.erl +++ /dev/null @@ -1,17 +0,0 @@ -%%-------------------------------------------------------------------- -%% Copyright (c) 2019 EMQ Technologies Co., Ltd. All Rights Reserved. -%% -%% Licensed under the Apache License, Version 2.0 (the "License"); -%% you may not use this file except in compliance with the License. -%% You may obtain a copy of the License at -%% -%% http://www.apache.org/licenses/LICENSE-2.0 -%% -%% Unless required by applicable law or agreed to in writing, software -%% distributed under the License is distributed on an "AS IS" BASIS, -%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -%% See the License for the specific language governing permissions and -%% limitations under the License. -%%-------------------------------------------------------------------- - --module(emqx_ctl_SUTIES). diff --git a/test/emqx_flapping_SUITE.erl b/test/emqx_flapping_SUITE.erl index cb5d0321c..234fc1950 100644 --- a/test/emqx_flapping_SUITE.erl +++ b/test/emqx_flapping_SUITE.erl @@ -40,18 +40,18 @@ end_per_suite(_Config) -> ok. t_detect_check(_) -> - Client = #{zone => external, - client_id => <<"clientid">>, - peername => {{127,0,0,1}, 5000} - }, - false = emqx_flapping:detect(Client), - false = emqx_flapping:check(Client), - false = emqx_flapping:detect(Client), - false = emqx_flapping:check(Client), - true = emqx_flapping:detect(Client), + ClientInfo = #{zone => external, + client_id => <<"clientid">>, + peerhost => {127,0,0,1} + }, + false = emqx_flapping:detect(ClientInfo), + false = emqx_flapping:check(ClientInfo), + false = emqx_flapping:detect(ClientInfo), + false = emqx_flapping:check(ClientInfo), + true = emqx_flapping:detect(ClientInfo), timer:sleep(50), - true = emqx_flapping:check(Client), + true = emqx_flapping:check(ClientInfo), timer:sleep(300), - false = emqx_flapping:check(Client), + false = emqx_flapping:check(ClientInfo), ok = emqx_flapping:stop(). diff --git a/test/emqx_inflight_SUITE.erl b/test/emqx_inflight_SUITE.erl index 7a015a443..4d7065486 100644 --- a/test/emqx_inflight_SUITE.erl +++ b/test/emqx_inflight_SUITE.erl @@ -84,8 +84,9 @@ t_is_empty(_) -> ?assert(emqx_inflight:is_empty(Inflight1)). t_window(_) -> + ?assertEqual([], emqx_inflight:window(emqx_inflight:new(0))), Inflight = emqx_inflight:insert( b, 2, emqx_inflight:insert( a, 1, emqx_inflight:new(2))), - [a, b] = emqx_inflight:window(Inflight). + ?assertEqual([a, b], emqx_inflight:window(Inflight)). diff --git a/test/emqx_message_SUITE.erl b/test/emqx_message_SUITE.erl index e9f51db08..0c1cad6c7 100644 --- a/test/emqx_message_SUITE.erl +++ b/test/emqx_message_SUITE.erl @@ -19,22 +19,10 @@ -compile(export_all). -compile(nowarn_export_all). +-include("emqx.hrl"). -include("emqx_mqtt.hrl"). -include_lib("eunit/include/eunit.hrl"). --export([ t_make/1 - , t_flag/1 - , t_header/1 - , t_format/1 - , t_expired/1 - , t_to_packet/1 - , t_to_map/1 - ]). - --export([ all/0 - , suite/0 - ]). - all() -> emqx_ct:all(?MODULE). suite() -> @@ -55,7 +43,12 @@ t_make(_) -> ?assertEqual(<<"topic">>, emqx_message:topic(Msg2)), ?assertEqual(<<"payload">>, emqx_message:payload(Msg2)). -t_flag(_) -> +t_get_set_flags(_) -> + Msg = #message{id = <<"id">>, qos = ?QOS_1, flags = undefined}, + Msg1 = emqx_message:set_flags(#{retain => true}, Msg), + ?assertEqual(#{retain => true}, emqx_message:get_flags(Msg1)). + +t_get_set_flag(_) -> Msg = emqx_message:make(<<"clientid">>, <<"topic">>, <<"payload">>), Msg2 = emqx_message:set_flag(retain, false, Msg), Msg3 = emqx_message:set_flag(dup, Msg2), @@ -63,32 +56,62 @@ t_flag(_) -> ?assertNot(emqx_message:get_flag(retain, Msg3)), Msg4 = emqx_message:unset_flag(dup, Msg3), Msg5 = emqx_message:unset_flag(retain, Msg4), + Msg5 = emqx_message:unset_flag(badflag, Msg5), ?assertEqual(undefined, emqx_message:get_flag(dup, Msg5, undefined)), ?assertEqual(undefined, emqx_message:get_flag(retain, Msg5, undefined)), Msg6 = emqx_message:set_flags(#{dup => true, retain => true}, Msg5), ?assert(emqx_message:get_flag(dup, Msg6)), - ?assert(emqx_message:get_flag(retain, Msg6)). + ?assert(emqx_message:get_flag(retain, Msg6)), + Msg7 = #message{id = <<"id">>, qos = ?QOS_1, flags = undefined}, + Msg8 = emqx_message:set_flag(retain, Msg7), + Msg9 = emqx_message:set_flag(retain, true, Msg7), + ?assertEqual(#{retain => true}, emqx_message:get_flags(Msg8)), + ?assertEqual(#{retain => true}, emqx_message:get_flags(Msg9)). -t_header(_) -> +t_get_set_headers(_) -> Msg = emqx_message:make(<<"clientid">>, <<"topic">>, <<"payload">>), Msg1 = emqx_message:set_headers(#{a => 1, b => 2}, Msg), - Msg2 = emqx_message:set_header(c, 3, Msg1), - ?assertEqual(1, emqx_message:get_header(a, Msg2)), + Msg2 = emqx_message:set_headers(#{c => 3}, Msg1), + ?assertEqual(#{a => 1, b => 2, c => 3}, emqx_message:get_headers(Msg2)). + +t_get_set_header(_) -> + Msg = emqx_message:make(<<"clientid">>, <<"topic">>, <<"payload">>), + Msg1 = emqx_message:set_header(a, 1, Msg), + Msg2 = emqx_message:set_header(b, 2, Msg1), + Msg3 = emqx_message:set_header(c, 3, Msg2), + ?assertEqual(1, emqx_message:get_header(a, Msg3)), ?assertEqual(4, emqx_message:get_header(d, Msg2, 4)), - Msg3 = emqx_message:remove_header(a, Msg2), - ?assertEqual(#{b => 2, c => 3}, emqx_message:get_headers(Msg3)). + Msg4 = emqx_message:remove_header(a, Msg3), + Msg4 = emqx_message:remove_header(a, Msg4), + ?assertEqual(#{b => 2, c => 3}, emqx_message:get_headers(Msg4)). + +t_undefined_headers(_) -> + Msg = #message{id = <<"id">>, qos = ?QOS_0, headers = undefined}, + Msg1 = emqx_message:set_headers(#{a => 1, b => 2}, Msg), + ?assertEqual(1, emqx_message:get_header(a, Msg1)), + Msg2 = emqx_message:set_header(c, 3, Msg), + ?assertEqual(3, emqx_message:get_header(c, Msg2)). t_format(_) -> - io:format("~s", [emqx_message:format(emqx_message:make(<<"clientid">>, <<"topic">>, <<"payload">>))]). + Msg = emqx_message:make(<<"clientid">>, <<"topic">>, <<"payload">>), + io:format("~s~n", [emqx_message:format(Msg)]), + Msg1 = #message{id = <<"id">>, + qos = ?QOS_0, + flags = undefined, + headers = undefined + }, + io:format("~s~n", [emqx_message:format(Msg1)]). t_expired(_) -> Msg = emqx_message:make(<<"clientid">>, <<"topic">>, <<"payload">>), + ?assertNot(emqx_message:is_expired(Msg)), Msg1 = emqx_message:set_headers(#{'Message-Expiry-Interval' => 1}, Msg), timer:sleep(500), ?assertNot(emqx_message:is_expired(Msg1)), timer:sleep(600), ?assert(emqx_message:is_expired(Msg1)), timer:sleep(1000), + Msg = emqx_message:update_expiry(Msg), Msg2 = emqx_message:update_expiry(Msg1), ?assertEqual(1, emqx_message:get_header('Message-Expiry-Interval', Msg2)). diff --git a/test/emqx_mod_rewrite_SUITE.erl b/test/emqx_mod_rewrite_SUITE.erl deleted file mode 100644 index 158d93235..000000000 --- a/test/emqx_mod_rewrite_SUITE.erl +++ /dev/null @@ -1,50 +0,0 @@ -%%-------------------------------------------------------------------- -%% Copyright (c) 2019 EMQ Technologies Co., Ltd. All Rights Reserved. -%% -%% Licensed under the Apache License, Version 2.0 (the "License"); -%% you may not use this file except in compliance with the License. -%% You may obtain a copy of the License at -%% -%% http://www.apache.org/licenses/LICENSE-2.0 -%% -%% Unless required by applicable law or agreed to in writing, software -%% distributed under the License is distributed on an "AS IS" BASIS, -%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -%% See the License for the specific language governing permissions and -%% limitations under the License. -%%-------------------------------------------------------------------- - --module(emqx_mod_rewrite_SUITE). - --compile(export_all). --compile(nowarn_export_all). - --include("emqx_mqtt.hrl"). --include_lib("eunit/include/eunit.hrl"). - --define(rules, [{rewrite,<<"x/#">>,<<"^x/y/(.+)$">>,<<"z/y/$1">>}, - {rewrite,<<"y/+/z/#">>,<<"^y/(.+)/z/(.+)$">>,<<"y/z/$2">>}]). - -all() -> emqx_ct:all(?MODULE). - -t_rewrite_rule(_Config) -> - {ok, _} = emqx_hooks:start_link(), - ok = emqx_mod_rewrite:load(?rules), - RawTopicFilters = [{<<"x/y/2">>, opts}, - {<<"x/1/2">>, opts}, - {<<"y/a/z/b">>, opts}, - {<<"y/def">>, opts}], - SubTopicFilters = emqx_hooks:run_fold('client.subscribe', [client, properties], RawTopicFilters), - UnSubTopicFilters = emqx_hooks:run_fold('client.unsubscribe', [client, properties], RawTopicFilters), - Messages = [emqx_hooks:run_fold('message.publish', [], emqx_message:make(Topic, <<"payload">>)) - || {Topic, _Opts} <- RawTopicFilters], - ExpectedTopicFilters = [{<<"z/y/2">>, opts}, - {<<"x/1/2">>, opts}, - {<<"y/z/b">>, opts}, - {<<"y/def">>, opts}], - ?assertEqual(ExpectedTopicFilters, SubTopicFilters), - ?assertEqual(ExpectedTopicFilters, UnSubTopicFilters), - [?assertEqual(ExpectedTopic, emqx_message:topic(Message)) - || {{ExpectedTopic, _opts}, Message} <- lists:zip(ExpectedTopicFilters, Messages)], - ok = emqx_mod_rewrite:unload(?rules), - ok = emqx_hooks:stop(). diff --git a/test/emqx_modules_SUITE.erl b/test/emqx_modules_SUITE.erl new file mode 100644 index 000000000..12f304f54 --- /dev/null +++ b/test/emqx_modules_SUITE.erl @@ -0,0 +1,154 @@ +%%-------------------------------------------------------------------- +%% Copyright (c) 2019 EMQ Technologies Co., Ltd. All Rights Reserved. +%% +%% Licensed under the Apache License, Version 2.0 (the "License"); +%% you may not use this file except in compliance with the License. +%% You may obtain a copy of the License at +%% +%% http://www.apache.org/licenses/LICENSE-2.0 +%% +%% Unless required by applicable law or agreed to in writing, software +%% distributed under the License is distributed on an "AS IS" BASIS, +%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +%% See the License for the specific language governing permissions and +%% limitations under the License. +%%-------------------------------------------------------------------- + +-module(emqx_modules_SUITE). + +%% API +-compile(export_all). +-compile(nowarn_export_all). + +-include("emqx.hrl"). +-include("emqx_mqtt.hrl"). + +%%-include_lib("proper/include/proper.hrl"). +-include_lib("common_test/include/ct.hrl"). +-include_lib("eunit/include/eunit.hrl"). + +%%-define(PROPTEST(M,F), true = proper:quickcheck(M:F())). + +-define(RULES, [{rewrite,<<"x/#">>,<<"^x/y/(.+)$">>,<<"z/y/$1">>}, + {rewrite,<<"y/+/z/#">>,<<"^y/(.+)/z/(.+)$">>,<<"y/z/$2">>} + ]). + +all() -> emqx_ct:all(?MODULE). + +suite() -> + [{ct_hooks,[cth_surefire]}, {timetrap, {seconds, 30}}]. + +init_per_suite(Config) -> + emqx_ct_helpers:boot_modules(all), + emqx_ct_helpers:start_apps([emqx]), + %% Ensure all the modules unloaded. + ok = emqx_modules:unload(), + Config. + +end_per_suite(_Config) -> + emqx_ct_helpers:stop_apps([emqx]). + +%%-------------------------------------------------------------------- +%% Test cases +%%-------------------------------------------------------------------- + +%% Test case for emqx_mod_presence +t_mod_presence(_) -> + ok = emqx_mod_presence:load([{qos, ?QOS_1}]), + {ok, C1} = emqtt:start_link([{client_id, <<"monsys">>}]), + {ok, _} = emqtt:connect(C1), + {ok, _Props, [?QOS_1]} = emqtt:subscribe(C1, <<"$SYS/brokers/+/clients/#">>, qos1), + %% Connected Presence + {ok, C2} = emqtt:start_link([{client_id, <<"clientid">>}, + {username, <<"username">>}]), + {ok, _} = emqtt:connect(C2), + ok = recv_and_check_presence(<<"clientid">>, <<"connected">>), + %% Disconnected Presence + ok = emqtt:disconnect(C2), + ok = recv_and_check_presence(<<"clientid">>, <<"disconnected">>), + ok = emqtt:disconnect(C1), + ok = emqx_mod_presence:unload([{qos, ?QOS_1}]). + +t_mod_presence_reason(_) -> + ?assertEqual(normal, emqx_mod_presence:reason(normal)), + ?assertEqual(discarded, emqx_mod_presence:reason({shutdown, discarded})), + ?assertEqual(tcp_error, emqx_mod_presence:reason({tcp_error, einval})), + ?assertEqual(internal_error, emqx_mod_presence:reason(<<"unknown error">>)). + +recv_and_check_presence(ClientId, Presence) -> + {ok, #{qos := ?QOS_1, topic := Topic, payload := Payload}} = receive_publish(100), + ?assertMatch([<<"$SYS">>, <<"brokers">>, _Node, <<"clients">>, ClientId, Presence], + binary:split(Topic, <<"/">>, [global])), + case Presence of + <<"connected">> -> + ?assertMatch(#{clientid := <<"clientid">>, + username := <<"username">>, + ipaddress := <<"127.0.0.1">>, + proto_name := <<"MQTT">>, + proto_ver := ?MQTT_PROTO_V4, + connack := ?RC_SUCCESS, + clean_start := true}, emqx_json:decode(Payload, [{labels, atom}, return_maps])); + <<"disconnected">> -> + ?assertMatch(#{clientid := <<"clientid">>, + username := <<"username">>, + reason := <<"normal">>}, emqx_json:decode(Payload, [{labels, atom}, return_maps])) + end. + +%% Test case for emqx_mod_subscription +t_mod_subscription(_) -> + emqx_mod_subscription:load([{<<"connected/%c/%u">>, ?QOS_0}]), + {ok, C} = emqtt:start_link([{host, "localhost"}, + {client_id, "myclient"}, + {username, "admin"}]), + {ok, _} = emqtt:connect(C), + emqtt:publish(C, <<"connected/myclient/admin">>, <<"Hello world">>, ?QOS_0), + {ok, #{topic := Topic, payload := Payload}} = receive_publish(100), + ?assertEqual(<<"connected/myclient/admin">>, Topic), + ?assertEqual(<<"Hello world">>, Payload), + ok = emqtt:disconnect(C), + emqx_mod_subscription:unload([]). + +%% Test case for emqx_mod_write +t_mod_rewrite(_Config) -> + ok = emqx_mod_rewrite:load(?RULES), + {ok, C} = emqtt:start_link([{client_id, <<"rewrite_client">>}]), + {ok, _} = emqtt:connect(C), + OrigTopics = [<<"x/y/2">>, <<"x/1/2">>, <<"y/a/z/b">>, <<"y/def">>], + DestTopics = [<<"z/y/2">>, <<"x/1/2">>, <<"y/z/b">>, <<"y/def">>], + %% Subscribe + {ok, _Props, _} = emqtt:subscribe(C, [{Topic, ?QOS_1} || Topic <- OrigTopics]), + timer:sleep(100), + Subscriptions = emqx_broker:subscriptions(<<"rewrite_client">>), + ?assertEqual(DestTopics, [Topic || {Topic, _SubOpts} <- Subscriptions]), + %% Publish + RecvTopics = [begin + ok = emqtt:publish(C, Topic, <<"payload">>), + {ok, #{topic := RecvTopic}} = receive_publish(100), + RecvTopic + end || Topic <- OrigTopics], + ?assertEqual(DestTopics, RecvTopics), + %% Unsubscribe + {ok, _, _} = emqtt:unsubscribe(C, OrigTopics), + timer:sleep(100), + ?assertEqual([], emqx_broker:subscriptions(<<"rewrite_client">>)), + ok = emqtt:disconnect(C), + ok = emqx_mod_rewrite:unload(?RULES). + +t_rewrite_rule(_Config) -> + Rules = emqx_mod_rewrite:compile(?RULES), + ?assertEqual(<<"z/y/2">>, emqx_mod_rewrite:match_and_rewrite(<<"x/y/2">>, Rules)), + ?assertEqual(<<"x/1/2">>, emqx_mod_rewrite:match_and_rewrite(<<"x/1/2">>, Rules)), + ?assertEqual(<<"y/z/b">>, emqx_mod_rewrite:match_and_rewrite(<<"y/a/z/b">>, Rules)), + ?assertEqual(<<"y/def">>, emqx_mod_rewrite:match_and_rewrite(<<"y/def">>, Rules)). + +%%-------------------------------------------------------------------- +%% Internal functions +%%-------------------------------------------------------------------- + +receive_publish(Timeout) -> + receive + {publish, Publish} -> {ok, Publish} + after + Timeout -> {error, timeout} + end. + diff --git a/test/emqx_packet_SUITE.erl b/test/emqx_packet_SUITE.erl index f802e7718..bc3caf5ab 100644 --- a/test/emqx_packet_SUITE.erl +++ b/test/emqx_packet_SUITE.erl @@ -24,45 +24,41 @@ -include_lib("eunit/include/eunit.hrl"). +-define(PACKETS, + [{?CONNECT, 'CONNECT', ?CONNECT_PACKET(#mqtt_packet_connect{})}, + {?CONNACK, 'CONNACK', ?CONNACK_PACKET(?RC_SUCCESS)}, + {?PUBLISH, 'PUBLISH', ?PUBLISH_PACKET(?QOS_1)}, + {?PUBACK, 'PUBACK', ?PUBACK_PACKET(1)}, + {?PUBREC, 'PUBREC', ?PUBREC_PACKET(1)}, + {?PUBREL, 'PUBREL', ?PUBREL_PACKET(1)}, + {?PUBCOMP, 'PUBCOMP', ?PUBCOMP_PACKET(1)}, + {?SUBSCRIBE, 'SUBSCRIBE', ?SUBSCRIBE_PACKET(1, [])}, + {?SUBACK, 'SUBACK', ?SUBACK_PACKET(1, [0])}, + {?UNSUBSCRIBE, 'UNSUBSCRIBE', ?UNSUBSCRIBE_PACKET(1, [])}, + {?UNSUBACK, 'UNSUBACK', ?UNSUBACK_PACKET(1)}, + {?DISCONNECT, 'DISCONNECT', ?DISCONNECT_PACKET(?RC_SUCCESS)}, + {?AUTH, 'AUTH', ?AUTH_PACKET()} + ]). + all() -> emqx_ct:all(?MODULE). t_type(_) -> - ?assertEqual(?CONNECT, emqx_packet:type(?CONNECT_PACKET(#mqtt_packet_connect{}))), - ?assertEqual(?CONNACK, emqx_packet:type(?CONNACK_PACKET(?RC_SUCCESS))), - ?assertEqual(?PUBLISH, emqx_packet:type(?PUBLISH_PACKET(?QOS_1))), - ?assertEqual(?PUBACK, emqx_packet:type(?PUBACK_PACKET(1))), - ?assertEqual(?PUBREC, emqx_packet:type(?PUBREC_PACKET(1))), - ?assertEqual(?PUBREL, emqx_packet:type(?PUBREL_PACKET(1))), - ?assertEqual(?PUBCOMP, emqx_packet:type(?PUBCOMP_PACKET(1))), - ?assertEqual(?SUBSCRIBE, emqx_packet:type(?SUBSCRIBE_PACKET(1, []))), - ?assertEqual(?SUBACK, emqx_packet:type(?SUBACK_PACKET(1, [0]))), - ?assertEqual(?UNSUBSCRIBE, emqx_packet:type(?UNSUBSCRIBE_PACKET(1, []))), - ?assertEqual(?UNSUBACK, emqx_packet:type(?UNSUBACK_PACKET(1))), - ?assertEqual(?DISCONNECT, emqx_packet:type(?DISCONNECT_PACKET(?RC_SUCCESS))), - ?assertEqual(?AUTH, emqx_packet:type(?AUTH_PACKET())). + lists:foreach(fun({Type, _Name, Packet}) -> + ?assertEqual(Type, emqx_packet:type(Packet)) + end, ?PACKETS). t_type_name(_) -> - ?assertEqual('CONNECT', emqx_packet:type_name(?CONNECT_PACKET(#mqtt_packet_connect{}))), - ?assertEqual('CONNACK', emqx_packet:type_name(?CONNACK_PACKET(?RC_SUCCESS))), - ?assertEqual('PUBLISH', emqx_packet:type_name(?PUBLISH_PACKET(?QOS_1))), - ?assertEqual('PUBACK', emqx_packet:type_name(?PUBACK_PACKET(1))), - ?assertEqual('PUBREC', emqx_packet:type_name(?PUBREC_PACKET(1))), - ?assertEqual('PUBREL', emqx_packet:type_name(?PUBREL_PACKET(1))), - ?assertEqual('PUBCOMP', emqx_packet:type_name(?PUBCOMP_PACKET(1))), - ?assertEqual('SUBSCRIBE', emqx_packet:type_name(?SUBSCRIBE_PACKET(1, []))), - ?assertEqual('SUBACK', emqx_packet:type_name(?SUBACK_PACKET(1, [0]))), - ?assertEqual('UNSUBSCRIBE', emqx_packet:type_name(?UNSUBSCRIBE_PACKET(1, []))), - ?assertEqual('UNSUBACK', emqx_packet:type_name(?UNSUBACK_PACKET(1))), - ?assertEqual('DISCONNECT', emqx_packet:type_name(?DISCONNECT_PACKET(?RC_SUCCESS))), - ?assertEqual('AUTH', emqx_packet:type_name(?AUTH_PACKET())). + lists:foreach(fun({_Type, Name, Packet}) -> + ?assertEqual(Name, emqx_packet:type_name(Packet)) + end, ?PACKETS). t_dup(_) -> ?assertEqual(false, emqx_packet:dup(?PUBLISH_PACKET(?QOS_1))). t_qos(_) -> - ?assertEqual(?QOS_0, emqx_packet:qos(?PUBLISH_PACKET(?QOS_0))), - ?assertEqual(?QOS_1, emqx_packet:qos(?PUBLISH_PACKET(?QOS_1))), - ?assertEqual(?QOS_2, emqx_packet:qos(?PUBLISH_PACKET(?QOS_2))). + lists:foreach(fun(QoS) -> + ?assertEqual(QoS, emqx_packet:qos(?PUBLISH_PACKET(QoS))) + end, [?QOS_0, ?QOS_1, ?QOS_2]). t_retain(_) -> ?assertEqual(false, emqx_packet:retain(?PUBLISH_PACKET(?QOS_1))). @@ -78,15 +74,16 @@ t_proto_name(_) -> t_proto_ver(_) -> lists:foreach( fun(Ver) -> - ?assertEqual(Ver, emqx_packet:proto_ver(#mqtt_packet_connect{proto_ver = Ver})) + ConnPkt = ?CONNECT_PACKET(#mqtt_packet_connect{proto_ver = Ver}), + ?assertEqual(Ver, emqx_packet:proto_ver(ConnPkt)) end, [?MQTT_PROTO_V3, ?MQTT_PROTO_V4, ?MQTT_PROTO_V5]). t_check_publish(_) -> Props = #{'Response-Topic' => <<"responsetopic">>, 'Topic-Alias' => 1}, ok = emqx_packet:check(?PUBLISH_PACKET(?QOS_1, <<"topic">>, 1, Props, <<"payload">>)), ok = emqx_packet:check(#mqtt_packet_publish{packet_id = 1, topic_name = <<"t">>}), - {error, ?RC_PROTOCOL_ERROR} = emqx_packet:check(?PUBLISH_PACKET(1,<<>>,1,#{},<<"payload">>)), - {error, ?RC_TOPIC_NAME_INVALID} = emqx_packet:check(?PUBLISH_PACKET(1, <<"+/+">>, 1, #{}, <<"payload">>)), + {error, ?RC_TOPIC_NAME_INVALID} = emqx_packet:check(?PUBLISH_PACKET(?QOS_1, <<>>, 1, #{}, <<"payload">>)), + {error, ?RC_TOPIC_NAME_INVALID} = emqx_packet:check(?PUBLISH_PACKET(?QOS_1, <<"+/+">>, 1, #{}, <<"payload">>)), {error, ?RC_TOPIC_ALIAS_INVALID} = emqx_packet:check(?PUBLISH_PACKET(1, <<"topic">>, 1, #{'Topic-Alias' => 0}, <<"payload">>)), %% TODO:: %% {error, ?RC_PROTOCOL_ERROR} = emqx_packet:check(?PUBLISH_PACKET(1, <<"topic">>, 1, #{'Subscription-Identifier' => 10}, <<"payload">>)), @@ -143,10 +140,10 @@ t_check_connect(_) -> properties = #{'Receive-Maximum' => 0}}), Opts). t_from_to_message(_) -> - ExpectedMsg = emqx_message:set_headers( - #{peername => {{127,0,0,1}, 9527}, username => <<"test">>}, - emqx_message:make(<<"clientid">>, ?QOS_0, <<"topic">>, <<"payload">>)), + ExpectedMsg = emqx_message:make(<<"clientid">>, ?QOS_0, <<"topic">>, <<"payload">>), ExpectedMsg1 = emqx_message:set_flag(retain, false, ExpectedMsg), + ExpectedMsg2 = emqx_message:set_headers(#{peerhost => {127,0,0,1}, + username => <<"test">>}, ExpectedMsg1), Pkt = #mqtt_packet{header = #mqtt_packet_header{type = ?PUBLISH, qos = ?QOS_0, retain = false, @@ -157,8 +154,8 @@ t_from_to_message(_) -> payload = <<"payload">>}, MsgFromPkt = emqx_packet:to_message(#{client_id => <<"clientid">>, username => <<"test">>, - peername => {{127,0,0,1}, 9527}}, Pkt), - ?assertEqual(ExpectedMsg1, MsgFromPkt#message{id = emqx_message:id(ExpectedMsg), + peerhost => {127,0,0,1}}, Pkt), + ?assertEqual(ExpectedMsg2, MsgFromPkt#message{id = emqx_message:id(ExpectedMsg), timestamp = emqx_message:timestamp(ExpectedMsg) }). diff --git a/test/emqx_protocol_SUITE.erl b/test/emqx_protocol_SUITE.erl deleted file mode 100644 index 2c7e9479f..000000000 --- a/test/emqx_protocol_SUITE.erl +++ /dev/null @@ -1,81 +0,0 @@ -%%-------------------------------------------------------------------- -%% Copyright (c) 2019 EMQ Technologies Co., Ltd. All Rights Reserved. -%% -%% Licensed under the Apache License, Version 2.0 (the "License"); -%% you may not use this file except in compliance with the License. -%% You may obtain a copy of the License at -%% -%% http://www.apache.org/licenses/LICENSE-2.0 -%% -%% Unless required by applicable law or agreed to in writing, software -%% distributed under the License is distributed on an "AS IS" BASIS, -%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -%% See the License for the specific language governing permissions and -%% limitations under the License. -%%-------------------------------------------------------------------- - --module(emqx_protocol_SUITE). - --compile(export_all). --compile(nowarn_export_all). - --include("emqx_mqtt.hrl"). --include_lib("eunit/include/eunit.hrl"). - -all() -> emqx_ct:all(?MODULE). - -init_per_suite(Config) -> - [{proto, init_protocol()}|Config]. - -init_protocol() -> - emqx_protocol:init(#mqtt_packet_connect{ - proto_name = <<"MQTT">>, - proto_ver = ?MQTT_PROTO_V5, - is_bridge = false, - clean_start = true, - keepalive = 30, - properties = #{}, - client_id = <<"clientid">>, - username = <<"username">>, - password = <<"passwd">> - }, testing). - -end_per_suite(_Config) -> ok. - -t_init_info_1(Config) -> - Proto = proplists:get_value(proto, Config), - ?assertEqual(#{proto_name => <<"MQTT">>, - proto_ver => ?MQTT_PROTO_V5, - clean_start => true, - keepalive => 30, - will_msg => undefined, - client_id => <<"clientid">>, - username => <<"username">>, - topic_aliases => undefined, - alias_maximum => #{outbound => 0, inbound => 0} - }, emqx_protocol:info(Proto)). - -t_init_info_2(Config) -> - Proto = proplists:get_value(proto, Config), - ?assertEqual(<<"MQTT">>, emqx_protocol:info(proto_name, Proto)), - ?assertEqual(?MQTT_PROTO_V5, emqx_protocol:info(proto_ver, Proto)), - ?assertEqual(true, emqx_protocol:info(clean_start, Proto)), - ?assertEqual(30, emqx_protocol:info(keepalive, Proto)), - ?assertEqual(<<"clientid">>, emqx_protocol:info(client_id, Proto)), - ?assertEqual(<<"username">>, emqx_protocol:info(username, Proto)), - ?assertEqual(undefined, emqx_protocol:info(will_msg, Proto)), - ?assertEqual(0, emqx_protocol:info(will_delay_interval, Proto)), - ?assertEqual(undefined, emqx_protocol:info(topic_aliases, Proto)), - ?assertEqual(#{outbound => 0, inbound => 0}, emqx_protocol:info(alias_maximum, Proto)). - -t_find_save_alias(Config) -> - Proto = proplists:get_value(proto, Config), - ?assertEqual(undefined, emqx_protocol:info(topic_aliases, Proto)), - ?assertEqual(false, emqx_protocol:find_alias(1, Proto)), - Proto1 = emqx_protocol:save_alias(1, <<"t1">>, Proto), - Proto2 = emqx_protocol:save_alias(2, <<"t2">>, Proto1), - ?assertEqual(#{1 => <<"t1">>, 2 => <<"t2">>}, - emqx_protocol:info(topic_aliases, Proto2)), - ?assertEqual({ok, <<"t1">>}, emqx_protocol:find_alias(1, Proto2)), - ?assertEqual({ok, <<"t2">>}, emqx_protocol:find_alias(2, Proto2)). - diff --git a/test/emqx_session_SUITE.erl b/test/emqx_session_SUITE.erl index 3e30ff040..59bd53fe9 100644 --- a/test/emqx_session_SUITE.erl +++ b/test/emqx_session_SUITE.erl @@ -273,7 +273,7 @@ max_inflight() -> choose(0, 10). expiry_interval() -> ?LET(EI, choose(1, 10), EI * 3600). option() -> - ?LET(Option, [{max_inflight, max_inflight()}, + ?LET(Option, [{receive_maximum , max_inflight()}, {expiry_interval, expiry_interval()}], maps:from_list(Option)). diff --git a/test/emqx_zone_SUITE.erl b/test/emqx_zone_SUITE.erl index c9502d0a7..c56a68c51 100644 --- a/test/emqx_zone_SUITE.erl +++ b/test/emqx_zone_SUITE.erl @@ -21,29 +21,81 @@ -include_lib("eunit/include/eunit.hrl"). --define(OPTS, [{enable_acl, true}, - {enable_banned, false} +-define(ENVS, [{use_username_as_clientid, false}, + {server_keepalive, 60}, + {upgrade_qos, false}, + {session_expiry_interval, 7200}, + {retry_interval, 20000}, + {mqueue_store_qos0, true}, + {mqueue_priorities, none}, + {mqueue_default_priority, highest}, + {max_subscriptions, 0}, + {max_mqueue_len, 1000}, + {max_inflight, 32}, + {max_awaiting_rel, 100}, + {keepalive_backoff, 0.75}, + {ignore_loop_deliver, false}, + {idle_timeout, 15000}, + {force_shutdown_policy, #{max_heap_size => 838860800, + message_queue_len => 8000}}, + {force_gc_policy, #{bytes => 1048576, count => 1000}}, + {enable_stats, true}, + {enable_flapping_detect, false}, + {enable_ban, true}, + {enable_acl, true}, + {await_rel_timeout, 300000}, + {acl_deny_action, ignore} ]). all() -> emqx_ct:all(?MODULE). -t_set_get_env(_) -> +init_per_suite(Config) -> _ = application:load(emqx), - application:set_env(emqx, zones, [{external, ?OPTS}]), - {ok, _} = emqx_zone:start_link(), - ?assert(emqx_zone:get_env(external, enable_acl)), - ?assertNot(emqx_zone:get_env(external, enable_banned)), + application:set_env(emqx, zone_env, val), + application:set_env(emqx, zones, [{zone, ?ENVS}]), + Config. + +end_per_suite(_Config) -> + application:unset_env(emqx, zone_env), + application:unset_env(emqx, zones). + +t_zone_env_func(_) -> + lists:foreach(fun({Env, Val}) -> + case erlang:function_exported(emqx_zone, Env, 1) of + true -> + ?assertEqual(Val, erlang:apply(emqx_zone, Env, [zone])); + false -> ok + end + end, ?ENVS). + +t_get_env(_) -> + ?assertEqual(val, emqx_zone:get_env(undefined, zone_env)), + ?assertEqual(val, emqx_zone:get_env(undefined, zone_env, def)), + ?assert(emqx_zone:get_env(zone, enable_acl)), + ?assert(emqx_zone:get_env(zone, enable_ban)), ?assertEqual(defval, emqx_zone:get_env(extenal, key, defval)), ?assertEqual(undefined, emqx_zone:get_env(external, key)), ?assertEqual(undefined, emqx_zone:get_env(internal, key)), - ?assertEqual(def, emqx_zone:get_env(internal, key, def)), - emqx_zone:stop(). + ?assertEqual(def, emqx_zone:get_env(internal, key, def)). + +t_get_set_env(_) -> + ok = emqx_zone:set_env(zone, key, val), + ?assertEqual(val, emqx_zone:get_env(zone, key)), + true = emqx_zone:unset_env(zone, key), + ?assertEqual(undefined, emqx_zone:get_env(zone, key)). t_force_reload(_) -> {ok, _} = emqx_zone:start_link(), - application:set_env(emqx, zones, [{zone, [{key, val}]}]), - ?assertEqual(undefined, emqx_zone:get_env(zone, key)), + ?assertEqual(undefined, emqx_zone:get_env(xzone, key)), + application:set_env(emqx, zones, [{xzone, [{key, val}]}]), ok = emqx_zone:force_reload(), - ?assertEqual(val, emqx_zone:get_env(zone, key)), + ?assertEqual(val, emqx_zone:get_env(xzone, key)), + emqx_zone:stop(). + +t_uncovered_func(_) -> + {ok, Pid} = emqx_zone:start_link(), + ignored = gen_server:call(Pid, unexpected_call), + ok = gen_server:cast(Pid, unexpected_cast), + ok = Pid ! ok, emqx_zone:stop().