From d3107facf9dd78adcd7ed44dba4002bcb92c9b6b Mon Sep 17 00:00:00 2001 From: Feng Lee Date: Thu, 14 Nov 2019 09:46:56 +0800 Subject: [PATCH] Validate packet id if strict mode. --- src/emqx_frame.erl | 48 +++++++++++++++++++++------------------ test/emqx_frame_SUITE.erl | 40 ++++++++++++++++++++++++++++---- 2 files changed, 61 insertions(+), 27 deletions(-) diff --git a/src/emqx_frame.erl b/src/emqx_frame.erl index b402ad989..91788227b 100644 --- a/src/emqx_frame.erl +++ b/src/emqx_frame.erl @@ -89,12 +89,12 @@ parse(<<>>, {none, Options}) -> parse(<>, {none, Options = #{strict_mode := StrictMode}}) -> %% Validate header if strict mode. + StrictMode andalso validate_header(Type, Dup, QoS, Retain), Header = #mqtt_packet_header{type = Type, dup = bool(Dup), qos = QoS, retain = bool(Retain) }, - StrictMode andalso validate_header(Type, Dup, QoS, Retain), Header1 = case fixqos(Type, QoS) of QoS -> Header; FixedQoS -> Header#mqtt_packet_header{qos = FixedQoS} @@ -164,7 +164,8 @@ packet(Header, Variable, Payload) -> parse_packet(#mqtt_packet_header{type = ?CONNECT}, FrameBin, _Options) -> {ProtoName, Rest} = parse_utf8_string(FrameBin), <> = Rest, - % Note: Crash when reserved flag doesn't equal to 0, there is no strict compliance with the MQTT5.0. + % Note: Crash when reserved flag doesn't equal to 0, there is no strict + % compliance with the MQTT5.0. < +parse_packet(#mqtt_packet_header{type = ?PUBLISH, qos = QoS}, Bin, + #{strict_mode := StrictMode, version := Ver}) -> {TopicName, Rest} = parse_utf8_string(Bin), {PacketId, Rest1} = case QoS of ?QOS_0 -> {undefined, Rest}; _ -> parse_packet_id(Rest) end, - (PacketId =/= undefined) andalso validate_packet_id(PacketId), + (PacketId =/= undefined) andalso + StrictMode andalso validate_packet_id(PacketId), {Properties, Payload} = parse_properties(Rest1, Ver), Publish = #mqtt_packet_publish{topic_name = TopicName, packet_id = PacketId, @@ -215,15 +218,15 @@ parse_packet(#mqtt_packet_header{type = ?PUBLISH, qos = QoS}, Bin, #{version := }, {Publish, Payload}; -parse_packet(#mqtt_packet_header{type = PubAck}, <>, _Options) - when ?PUBACK =< PubAck, PubAck =< ?PUBCOMP -> - ok = validate_packet_id(PacketId), +parse_packet(#mqtt_packet_header{type = PubAck}, <>, #{strict_mode := StrictMode}) + when ?PUBACK =< PubAck, PubAck =< ?PUBCOMP -> + StrictMode andalso validate_packet_id(PacketId), #mqtt_packet_puback{packet_id = PacketId, reason_code = 0}; parse_packet(#mqtt_packet_header{type = PubAck}, <>, - #{version := Ver = ?MQTT_PROTO_V5}) - when ?PUBACK =< PubAck, PubAck =< ?PUBCOMP -> - ok = validate_packet_id(PacketId), + #{strict_mode := StrictMode, version := Ver = ?MQTT_PROTO_V5}) + when ?PUBACK =< PubAck, PubAck =< ?PUBCOMP -> + StrictMode andalso validate_packet_id(PacketId), {Properties, <<>>} = parse_properties(Rest, Ver), #mqtt_packet_puback{packet_id = PacketId, reason_code = ReasonCode, @@ -231,8 +234,8 @@ parse_packet(#mqtt_packet_header{type = PubAck}, <>, - #{version := Ver}) -> - ok = validate_packet_id(PacketId), + #{strict_mode := StrictMode, version := Ver}) -> + StrictMode andalso validate_packet_id(PacketId), {Properties, Rest1} = parse_properties(Rest, Ver), TopicFilters = parse_topic_filters(subscribe, Rest1), ok = validate_subqos([QoS || {_, #{qos := QoS}} <- TopicFilters]), @@ -242,8 +245,8 @@ parse_packet(#mqtt_packet_header{type = ?SUBSCRIBE}, <>, - #{version := Ver}) -> - ok = validate_packet_id(PacketId), + #{strict_mode := StrictMode, version := Ver}) -> + StrictMode andalso validate_packet_id(PacketId), {Properties, Rest1} = parse_properties(Rest, Ver), ReasonCodes = parse_reason_codes(Rest1), #mqtt_packet_suback{packet_id = PacketId, @@ -252,8 +255,8 @@ parse_packet(#mqtt_packet_header{type = ?SUBACK}, <>, - #{version := Ver}) -> - ok = validate_packet_id(PacketId), + #{strict_mode := StrictMode, version := Ver}) -> + StrictMode andalso validate_packet_id(PacketId), {Properties, Rest1} = parse_properties(Rest, Ver), TopicFilters = parse_topic_filters(unsubscribe, Rest1), #mqtt_packet_unsubscribe{packet_id = PacketId, @@ -261,13 +264,14 @@ parse_packet(#mqtt_packet_header{type = ?UNSUBSCRIBE}, <>, _Options) -> - ok = validate_packet_id(PacketId), +parse_packet(#mqtt_packet_header{type = ?UNSUBACK}, <>, + #{strict_mode := StrictMode}) -> + StrictMode andalso validate_packet_id(PacketId), #mqtt_packet_unsuback{packet_id = PacketId}; parse_packet(#mqtt_packet_header{type = ?UNSUBACK}, <>, - #{version := Ver}) -> - ok = validate_packet_id(PacketId), + #{strict_mode := StrictMode, version := Ver}) -> + StrictMode andalso validate_packet_id(PacketId), {Properties, Rest1} = parse_properties(Rest, Ver), ReasonCodes = parse_reason_codes(Rest1), #mqtt_packet_unsuback{packet_id = PacketId, @@ -296,8 +300,7 @@ parse_will_message(Packet = #mqtt_packet_connect{will_flag = true, will_topic = Topic, will_payload = Payload }, Rest2}; -parse_will_message(Packet, Bin) -> - {Packet, Bin}. +parse_will_message(Packet, Bin) -> {Packet, Bin}. -compile({inline, [parse_packet_id/1]}). parse_packet_id(<>) -> @@ -720,6 +723,7 @@ validate_header(?DISCONNECT, 0, 0, 0) -> ok; validate_header(?AUTH, 0, 0, 0) -> ok; validate_header(_Type, _Dup, _QoS, _Rt) -> error(bad_frame_header). +-compile({inline, [validate_packet_id/1]}). validate_packet_id(0) -> error(bad_packet_id); validate_packet_id(_) -> ok. diff --git a/test/emqx_frame_SUITE.erl b/test/emqx_frame_SUITE.erl index f70a815fe..35e8d1975 100644 --- a/test/emqx_frame_SUITE.erl +++ b/test/emqx_frame_SUITE.erl @@ -40,7 +40,8 @@ all() -> {group, unsuback}, {group, ping}, {group, disconnect}, - {group, auth}]. + {group, auth} + ]. groups() -> [{parse, [parallel], @@ -333,7 +334,10 @@ t_serialize_parse_qos1_publish(_) -> payload = <<"haha">>}, ?assertEqual(Bin, serialize_to_binary(Packet)), ?assertMatch(Packet, parse_to_packet(Bin, #{strict_mode => true})), - ?catch_error(bad_packet_id, parse_serialize(?PUBLISH_PACKET(?QOS_1, <<"Topic">>, 0, <<>>))). + %% strict_mode = true + ?catch_error(bad_packet_id, parse_serialize(?PUBLISH_PACKET(?QOS_1, <<"Topic">>, 0, <<>>))), + %% strict_mode = false + _ = parse_serialize(?PUBLISH_PACKET(?QOS_1, <<"Topic">>, 0, <<>>), #{strict_mode => false}). t_serialize_parse_qos2_publish(_) -> Packet = ?PUBLISH_PACKET(?QOS_2, <<"Topic">>, 1, <<>>), @@ -341,7 +345,10 @@ t_serialize_parse_qos2_publish(_) -> ?assertEqual(Packet, parse_serialize(Packet)), ?assertEqual(Bin, serialize_to_binary(Packet)), ?assertMatch(Packet, parse_to_packet(Bin, #{strict_mode => true})), - ?catch_error(bad_packet_id, parse_serialize(?PUBLISH_PACKET(?QOS_2, <<"Topic">>, 0, <<>>))). + %% strict_mode = true + ?catch_error(bad_packet_id, parse_serialize(?PUBLISH_PACKET(?QOS_2, <<"Topic">>, 0, <<>>))), + %% strict_mode = false + _ = parse_serialize(?PUBLISH_PACKET(?QOS_2, <<"Topic">>, 0, <<>>), #{strict_mode => false}). t_serialize_parse_publish_v5(_) -> Props = #{'Payload-Format-Indicator' => 1, @@ -358,7 +365,10 @@ t_serialize_parse_puback(_) -> Packet = ?PUBACK_PACKET(1), ?assertEqual(<<64,2,0,1>>, serialize_to_binary(Packet)), ?assertEqual(Packet, parse_serialize(Packet)), - ?catch_error(bad_packet_id, parse_serialize(?PUBACK_PACKET(0))). + %% strict_mode = true + ?catch_error(bad_packet_id, parse_serialize(?PUBACK_PACKET(0))), + %% strict_mode = false + ?PUBACK_PACKET(0) = parse_serialize(?PUBACK_PACKET(0), #{strict_mode => false}). t_serialize_parse_puback_v3_4(_) -> Bin = <<64,2,0,1>>, @@ -376,7 +386,10 @@ t_serialize_parse_pubrec(_) -> Packet = ?PUBREC_PACKET(1), ?assertEqual(<<5:4,0:4,2,0,1>>, serialize_to_binary(Packet)), ?assertEqual(Packet, parse_serialize(Packet)), - ?catch_error(bad_packet_id, parse_serialize(?PUBREC_PACKET(0))). + %% strict_mode = true + ?catch_error(bad_packet_id, parse_serialize(?PUBREC_PACKET(0))), + %% strict_mode = false + ?PUBREC_PACKET(0) = parse_serialize(?PUBREC_PACKET(0), #{strict_mode => false}). t_serialize_parse_pubrec_v5(_) -> Packet = ?PUBREC_PACKET(16, ?RC_SUCCESS, #{'Reason-String' => <<"success">>}), @@ -391,6 +404,9 @@ t_serialize_parse_pubrel(_) -> Bin0 = <<6:4,0:4,2,0,1>>, ?assertMatch(Packet, parse_to_packet(Bin0, #{strict_mode => false})), ?catch_error(bad_frame_header, parse_to_packet(Bin0, #{strict_mode => true})), + %% strict_mode = false + ?PUBREL_PACKET(0) = parse_serialize(?PUBREL_PACKET(0), #{strict_mode => false}), + %% strict_mode = true ?catch_error(bad_packet_id, parse_serialize(?PUBREL_PACKET(0))). t_serialize_parse_pubrel_v5(_) -> @@ -402,6 +418,9 @@ t_serialize_parse_pubcomp(_) -> Bin = serialize_to_binary(Packet), ?assertEqual(<<7:4,0:4,2,0,1>>, Bin), ?assertEqual(Packet, parse_serialize(Packet)), + %% strict_mode = false + ?PUBCOMP_PACKET(0) = parse_serialize(?PUBCOMP_PACKET(0), #{strict_mode => false}), + %% strict_mode = true ?catch_error(bad_packet_id, parse_serialize(?PUBCOMP_PACKET(0))). t_serialize_parse_pubcomp_v5(_) -> @@ -419,7 +438,12 @@ t_serialize_parse_subscribe(_) -> %% SUBSCRIBE with bad qos 0 Bin0 = <>, ?assertMatch(Packet, parse_to_packet(Bin0, #{strict_mode => false})), + %% strict_mode = false + _ = parse_to_packet(Bin0, #{strict_mode => false}), ?catch_error(bad_frame_header, parse_to_packet(Bin0, #{strict_mode => true})), + %% strict_mode = false + _ = parse_serialize(?SUBSCRIBE_PACKET(0, TopicFilters), #{strict_mode => false}), + %% strict_mode = true ?catch_error(bad_packet_id, parse_serialize(?SUBSCRIBE_PACKET(0, TopicFilters))), ?catch_error(bad_subqos, parse_serialize(?SUBSCRIBE_PACKET(1, [{<<"t">>, #{qos => 3}}]))). @@ -432,6 +456,9 @@ t_serialize_parse_subscribe_v5(_) -> t_serialize_parse_suback(_) -> Packet = ?SUBACK_PACKET(10, [?QOS_0, ?QOS_1, 128]), ?assertEqual(Packet, parse_serialize(Packet)), + %% strict_mode = false + _ = parse_serialize(?SUBACK_PACKET(0, [?QOS_0]), #{strict_mode => false}), + %% strict_mode = true ?catch_error(bad_packet_id, parse_serialize(?SUBACK_PACKET(0, [?QOS_0]))). t_serialize_parse_suback_v5(_) -> @@ -451,6 +478,9 @@ t_serialize_parse_unsubscribe(_) -> Bin0 = <>, ?assertMatch(Packet, parse_to_packet(Bin0, #{strict_mode => false})), ?catch_error(bad_frame_header, parse_to_packet(Bin0, #{strict_mode => true})), + %% strict_mode = false + _ = parse_serialize(?UNSUBSCRIBE_PACKET(0, [<<"TopicA">>]), #{strict_mode => false}), + %% strict_mode = true ?catch_error(bad_packet_id, parse_serialize(?UNSUBSCRIBE_PACKET(0, [<<"TopicA">>]))). t_serialize_parse_unsubscribe_v5(_) ->