Validate packet id if strict mode.

This commit is contained in:
Feng Lee 2019-11-14 09:46:56 +08:00
parent d6ebbb7cce
commit d3107facf9
2 changed files with 61 additions and 27 deletions

View File

@ -89,12 +89,12 @@ parse(<<>>, {none, Options}) ->
parse(<<Type:4, Dup:1, QoS:2, Retain:1, Rest/binary>>, parse(<<Type:4, Dup:1, QoS:2, Retain:1, Rest/binary>>,
{none, Options = #{strict_mode := StrictMode}}) -> {none, Options = #{strict_mode := StrictMode}}) ->
%% Validate header if strict mode. %% Validate header if strict mode.
StrictMode andalso validate_header(Type, Dup, QoS, Retain),
Header = #mqtt_packet_header{type = Type, Header = #mqtt_packet_header{type = Type,
dup = bool(Dup), dup = bool(Dup),
qos = QoS, qos = QoS,
retain = bool(Retain) retain = bool(Retain)
}, },
StrictMode andalso validate_header(Type, Dup, QoS, Retain),
Header1 = case fixqos(Type, QoS) of Header1 = case fixqos(Type, QoS) of
QoS -> Header; QoS -> Header;
FixedQoS -> Header#mqtt_packet_header{qos = FixedQoS} FixedQoS -> Header#mqtt_packet_header{qos = FixedQoS}
@ -164,7 +164,8 @@ packet(Header, Variable, Payload) ->
parse_packet(#mqtt_packet_header{type = ?CONNECT}, FrameBin, _Options) -> parse_packet(#mqtt_packet_header{type = ?CONNECT}, FrameBin, _Options) ->
{ProtoName, Rest} = parse_utf8_string(FrameBin), {ProtoName, Rest} = parse_utf8_string(FrameBin),
<<BridgeTag:4, ProtoVer:4, Rest1/binary>> = Rest, <<BridgeTag:4, ProtoVer:4, Rest1/binary>> = Rest,
% Note: Crash when reserved flag doesn't equal to 0, there is no strict compliance with the MQTT5.0. % Note: Crash when reserved flag doesn't equal to 0, there is no strict
% compliance with the MQTT5.0.
<<UsernameFlag : 1, <<UsernameFlag : 1,
PasswordFlag : 1, PasswordFlag : 1,
WillRetain : 1, WillRetain : 1,
@ -201,13 +202,15 @@ parse_packet(#mqtt_packet_header{type = ?CONNACK},
properties = Properties properties = Properties
}; };
parse_packet(#mqtt_packet_header{type = ?PUBLISH, qos = QoS}, Bin, #{version := Ver}) -> parse_packet(#mqtt_packet_header{type = ?PUBLISH, qos = QoS}, Bin,
#{strict_mode := StrictMode, version := Ver}) ->
{TopicName, Rest} = parse_utf8_string(Bin), {TopicName, Rest} = parse_utf8_string(Bin),
{PacketId, Rest1} = case QoS of {PacketId, Rest1} = case QoS of
?QOS_0 -> {undefined, Rest}; ?QOS_0 -> {undefined, Rest};
_ -> parse_packet_id(Rest) _ -> parse_packet_id(Rest)
end, end,
(PacketId =/= undefined) andalso validate_packet_id(PacketId), (PacketId =/= undefined) andalso
StrictMode andalso validate_packet_id(PacketId),
{Properties, Payload} = parse_properties(Rest1, Ver), {Properties, Payload} = parse_properties(Rest1, Ver),
Publish = #mqtt_packet_publish{topic_name = TopicName, Publish = #mqtt_packet_publish{topic_name = TopicName,
packet_id = PacketId, packet_id = PacketId,
@ -215,15 +218,15 @@ parse_packet(#mqtt_packet_header{type = ?PUBLISH, qos = QoS}, Bin, #{version :=
}, },
{Publish, Payload}; {Publish, Payload};
parse_packet(#mqtt_packet_header{type = PubAck}, <<PacketId:16/big>>, _Options) parse_packet(#mqtt_packet_header{type = PubAck}, <<PacketId:16/big>>, #{strict_mode := StrictMode})
when ?PUBACK =< PubAck, PubAck =< ?PUBCOMP -> when ?PUBACK =< PubAck, PubAck =< ?PUBCOMP ->
ok = validate_packet_id(PacketId), StrictMode andalso validate_packet_id(PacketId),
#mqtt_packet_puback{packet_id = PacketId, reason_code = 0}; #mqtt_packet_puback{packet_id = PacketId, reason_code = 0};
parse_packet(#mqtt_packet_header{type = PubAck}, <<PacketId:16/big, ReasonCode, Rest/binary>>, parse_packet(#mqtt_packet_header{type = PubAck}, <<PacketId:16/big, ReasonCode, Rest/binary>>,
#{version := Ver = ?MQTT_PROTO_V5}) #{strict_mode := StrictMode, version := Ver = ?MQTT_PROTO_V5})
when ?PUBACK =< PubAck, PubAck =< ?PUBCOMP -> when ?PUBACK =< PubAck, PubAck =< ?PUBCOMP ->
ok = validate_packet_id(PacketId), StrictMode andalso validate_packet_id(PacketId),
{Properties, <<>>} = parse_properties(Rest, Ver), {Properties, <<>>} = parse_properties(Rest, Ver),
#mqtt_packet_puback{packet_id = PacketId, #mqtt_packet_puback{packet_id = PacketId,
reason_code = ReasonCode, reason_code = ReasonCode,
@ -231,8 +234,8 @@ parse_packet(#mqtt_packet_header{type = PubAck}, <<PacketId:16/big, ReasonCode,
}; };
parse_packet(#mqtt_packet_header{type = ?SUBSCRIBE}, <<PacketId:16/big, Rest/binary>>, parse_packet(#mqtt_packet_header{type = ?SUBSCRIBE}, <<PacketId:16/big, Rest/binary>>,
#{version := Ver}) -> #{strict_mode := StrictMode, version := Ver}) ->
ok = validate_packet_id(PacketId), StrictMode andalso validate_packet_id(PacketId),
{Properties, Rest1} = parse_properties(Rest, Ver), {Properties, Rest1} = parse_properties(Rest, Ver),
TopicFilters = parse_topic_filters(subscribe, Rest1), TopicFilters = parse_topic_filters(subscribe, Rest1),
ok = validate_subqos([QoS || {_, #{qos := QoS}} <- TopicFilters]), ok = validate_subqos([QoS || {_, #{qos := QoS}} <- TopicFilters]),
@ -242,8 +245,8 @@ parse_packet(#mqtt_packet_header{type = ?SUBSCRIBE}, <<PacketId:16/big, Rest/bin
}; };
parse_packet(#mqtt_packet_header{type = ?SUBACK}, <<PacketId:16/big, Rest/binary>>, parse_packet(#mqtt_packet_header{type = ?SUBACK}, <<PacketId:16/big, Rest/binary>>,
#{version := Ver}) -> #{strict_mode := StrictMode, version := Ver}) ->
ok = validate_packet_id(PacketId), StrictMode andalso validate_packet_id(PacketId),
{Properties, Rest1} = parse_properties(Rest, Ver), {Properties, Rest1} = parse_properties(Rest, Ver),
ReasonCodes = parse_reason_codes(Rest1), ReasonCodes = parse_reason_codes(Rest1),
#mqtt_packet_suback{packet_id = PacketId, #mqtt_packet_suback{packet_id = PacketId,
@ -252,8 +255,8 @@ parse_packet(#mqtt_packet_header{type = ?SUBACK}, <<PacketId:16/big, Rest/binary
}; };
parse_packet(#mqtt_packet_header{type = ?UNSUBSCRIBE}, <<PacketId:16/big, Rest/binary>>, parse_packet(#mqtt_packet_header{type = ?UNSUBSCRIBE}, <<PacketId:16/big, Rest/binary>>,
#{version := Ver}) -> #{strict_mode := StrictMode, version := Ver}) ->
ok = validate_packet_id(PacketId), StrictMode andalso validate_packet_id(PacketId),
{Properties, Rest1} = parse_properties(Rest, Ver), {Properties, Rest1} = parse_properties(Rest, Ver),
TopicFilters = parse_topic_filters(unsubscribe, Rest1), TopicFilters = parse_topic_filters(unsubscribe, Rest1),
#mqtt_packet_unsubscribe{packet_id = PacketId, #mqtt_packet_unsubscribe{packet_id = PacketId,
@ -261,13 +264,14 @@ parse_packet(#mqtt_packet_header{type = ?UNSUBSCRIBE}, <<PacketId:16/big, Rest/b
topic_filters = TopicFilters topic_filters = TopicFilters
}; };
parse_packet(#mqtt_packet_header{type = ?UNSUBACK}, <<PacketId:16/big>>, _Options) -> parse_packet(#mqtt_packet_header{type = ?UNSUBACK}, <<PacketId:16/big>>,
ok = validate_packet_id(PacketId), #{strict_mode := StrictMode}) ->
StrictMode andalso validate_packet_id(PacketId),
#mqtt_packet_unsuback{packet_id = PacketId}; #mqtt_packet_unsuback{packet_id = PacketId};
parse_packet(#mqtt_packet_header{type = ?UNSUBACK}, <<PacketId:16/big, Rest/binary>>, parse_packet(#mqtt_packet_header{type = ?UNSUBACK}, <<PacketId:16/big, Rest/binary>>,
#{version := Ver}) -> #{strict_mode := StrictMode, version := Ver}) ->
ok = validate_packet_id(PacketId), StrictMode andalso validate_packet_id(PacketId),
{Properties, Rest1} = parse_properties(Rest, Ver), {Properties, Rest1} = parse_properties(Rest, Ver),
ReasonCodes = parse_reason_codes(Rest1), ReasonCodes = parse_reason_codes(Rest1),
#mqtt_packet_unsuback{packet_id = PacketId, #mqtt_packet_unsuback{packet_id = PacketId,
@ -296,8 +300,7 @@ parse_will_message(Packet = #mqtt_packet_connect{will_flag = true,
will_topic = Topic, will_topic = Topic,
will_payload = Payload will_payload = Payload
}, Rest2}; }, Rest2};
parse_will_message(Packet, Bin) -> parse_will_message(Packet, Bin) -> {Packet, Bin}.
{Packet, Bin}.
-compile({inline, [parse_packet_id/1]}). -compile({inline, [parse_packet_id/1]}).
parse_packet_id(<<PacketId:16/big, Rest/binary>>) -> parse_packet_id(<<PacketId:16/big, Rest/binary>>) ->
@ -720,6 +723,7 @@ validate_header(?DISCONNECT, 0, 0, 0) -> ok;
validate_header(?AUTH, 0, 0, 0) -> ok; validate_header(?AUTH, 0, 0, 0) -> ok;
validate_header(_Type, _Dup, _QoS, _Rt) -> error(bad_frame_header). validate_header(_Type, _Dup, _QoS, _Rt) -> error(bad_frame_header).
-compile({inline, [validate_packet_id/1]}).
validate_packet_id(0) -> error(bad_packet_id); validate_packet_id(0) -> error(bad_packet_id);
validate_packet_id(_) -> ok. validate_packet_id(_) -> ok.

View File

@ -40,7 +40,8 @@ all() ->
{group, unsuback}, {group, unsuback},
{group, ping}, {group, ping},
{group, disconnect}, {group, disconnect},
{group, auth}]. {group, auth}
].
groups() -> groups() ->
[{parse, [parallel], [{parse, [parallel],
@ -333,7 +334,10 @@ t_serialize_parse_qos1_publish(_) ->
payload = <<"haha">>}, payload = <<"haha">>},
?assertEqual(Bin, serialize_to_binary(Packet)), ?assertEqual(Bin, serialize_to_binary(Packet)),
?assertMatch(Packet, parse_to_packet(Bin, #{strict_mode => true})), ?assertMatch(Packet, parse_to_packet(Bin, #{strict_mode => true})),
?catch_error(bad_packet_id, parse_serialize(?PUBLISH_PACKET(?QOS_1, <<"Topic">>, 0, <<>>))). %% strict_mode = true
?catch_error(bad_packet_id, parse_serialize(?PUBLISH_PACKET(?QOS_1, <<"Topic">>, 0, <<>>))),
%% strict_mode = false
_ = parse_serialize(?PUBLISH_PACKET(?QOS_1, <<"Topic">>, 0, <<>>), #{strict_mode => false}).
t_serialize_parse_qos2_publish(_) -> t_serialize_parse_qos2_publish(_) ->
Packet = ?PUBLISH_PACKET(?QOS_2, <<"Topic">>, 1, <<>>), Packet = ?PUBLISH_PACKET(?QOS_2, <<"Topic">>, 1, <<>>),
@ -341,7 +345,10 @@ t_serialize_parse_qos2_publish(_) ->
?assertEqual(Packet, parse_serialize(Packet)), ?assertEqual(Packet, parse_serialize(Packet)),
?assertEqual(Bin, serialize_to_binary(Packet)), ?assertEqual(Bin, serialize_to_binary(Packet)),
?assertMatch(Packet, parse_to_packet(Bin, #{strict_mode => true})), ?assertMatch(Packet, parse_to_packet(Bin, #{strict_mode => true})),
?catch_error(bad_packet_id, parse_serialize(?PUBLISH_PACKET(?QOS_2, <<"Topic">>, 0, <<>>))). %% strict_mode = true
?catch_error(bad_packet_id, parse_serialize(?PUBLISH_PACKET(?QOS_2, <<"Topic">>, 0, <<>>))),
%% strict_mode = false
_ = parse_serialize(?PUBLISH_PACKET(?QOS_2, <<"Topic">>, 0, <<>>), #{strict_mode => false}).
t_serialize_parse_publish_v5(_) -> t_serialize_parse_publish_v5(_) ->
Props = #{'Payload-Format-Indicator' => 1, Props = #{'Payload-Format-Indicator' => 1,
@ -358,7 +365,10 @@ t_serialize_parse_puback(_) ->
Packet = ?PUBACK_PACKET(1), Packet = ?PUBACK_PACKET(1),
?assertEqual(<<64,2,0,1>>, serialize_to_binary(Packet)), ?assertEqual(<<64,2,0,1>>, serialize_to_binary(Packet)),
?assertEqual(Packet, parse_serialize(Packet)), ?assertEqual(Packet, parse_serialize(Packet)),
?catch_error(bad_packet_id, parse_serialize(?PUBACK_PACKET(0))). %% strict_mode = true
?catch_error(bad_packet_id, parse_serialize(?PUBACK_PACKET(0))),
%% strict_mode = false
?PUBACK_PACKET(0) = parse_serialize(?PUBACK_PACKET(0), #{strict_mode => false}).
t_serialize_parse_puback_v3_4(_) -> t_serialize_parse_puback_v3_4(_) ->
Bin = <<64,2,0,1>>, Bin = <<64,2,0,1>>,
@ -376,7 +386,10 @@ t_serialize_parse_pubrec(_) ->
Packet = ?PUBREC_PACKET(1), Packet = ?PUBREC_PACKET(1),
?assertEqual(<<5:4,0:4,2,0,1>>, serialize_to_binary(Packet)), ?assertEqual(<<5:4,0:4,2,0,1>>, serialize_to_binary(Packet)),
?assertEqual(Packet, parse_serialize(Packet)), ?assertEqual(Packet, parse_serialize(Packet)),
?catch_error(bad_packet_id, parse_serialize(?PUBREC_PACKET(0))). %% strict_mode = true
?catch_error(bad_packet_id, parse_serialize(?PUBREC_PACKET(0))),
%% strict_mode = false
?PUBREC_PACKET(0) = parse_serialize(?PUBREC_PACKET(0), #{strict_mode => false}).
t_serialize_parse_pubrec_v5(_) -> t_serialize_parse_pubrec_v5(_) ->
Packet = ?PUBREC_PACKET(16, ?RC_SUCCESS, #{'Reason-String' => <<"success">>}), Packet = ?PUBREC_PACKET(16, ?RC_SUCCESS, #{'Reason-String' => <<"success">>}),
@ -391,6 +404,9 @@ t_serialize_parse_pubrel(_) ->
Bin0 = <<6:4,0:4,2,0,1>>, Bin0 = <<6:4,0:4,2,0,1>>,
?assertMatch(Packet, parse_to_packet(Bin0, #{strict_mode => false})), ?assertMatch(Packet, parse_to_packet(Bin0, #{strict_mode => false})),
?catch_error(bad_frame_header, parse_to_packet(Bin0, #{strict_mode => true})), ?catch_error(bad_frame_header, parse_to_packet(Bin0, #{strict_mode => true})),
%% strict_mode = false
?PUBREL_PACKET(0) = parse_serialize(?PUBREL_PACKET(0), #{strict_mode => false}),
%% strict_mode = true
?catch_error(bad_packet_id, parse_serialize(?PUBREL_PACKET(0))). ?catch_error(bad_packet_id, parse_serialize(?PUBREL_PACKET(0))).
t_serialize_parse_pubrel_v5(_) -> t_serialize_parse_pubrel_v5(_) ->
@ -402,6 +418,9 @@ t_serialize_parse_pubcomp(_) ->
Bin = serialize_to_binary(Packet), Bin = serialize_to_binary(Packet),
?assertEqual(<<7:4,0:4,2,0,1>>, Bin), ?assertEqual(<<7:4,0:4,2,0,1>>, Bin),
?assertEqual(Packet, parse_serialize(Packet)), ?assertEqual(Packet, parse_serialize(Packet)),
%% strict_mode = false
?PUBCOMP_PACKET(0) = parse_serialize(?PUBCOMP_PACKET(0), #{strict_mode => false}),
%% strict_mode = true
?catch_error(bad_packet_id, parse_serialize(?PUBCOMP_PACKET(0))). ?catch_error(bad_packet_id, parse_serialize(?PUBCOMP_PACKET(0))).
t_serialize_parse_pubcomp_v5(_) -> t_serialize_parse_pubcomp_v5(_) ->
@ -419,7 +438,12 @@ t_serialize_parse_subscribe(_) ->
%% SUBSCRIBE with bad qos 0 %% SUBSCRIBE with bad qos 0
Bin0 = <<?SUBSCRIBE:4,0:4,11,0,2,0,6,84,111,112,105,99,65,2>>, Bin0 = <<?SUBSCRIBE:4,0:4,11,0,2,0,6,84,111,112,105,99,65,2>>,
?assertMatch(Packet, parse_to_packet(Bin0, #{strict_mode => false})), ?assertMatch(Packet, parse_to_packet(Bin0, #{strict_mode => false})),
%% strict_mode = false
_ = parse_to_packet(Bin0, #{strict_mode => false}),
?catch_error(bad_frame_header, parse_to_packet(Bin0, #{strict_mode => true})), ?catch_error(bad_frame_header, parse_to_packet(Bin0, #{strict_mode => true})),
%% strict_mode = false
_ = parse_serialize(?SUBSCRIBE_PACKET(0, TopicFilters), #{strict_mode => false}),
%% strict_mode = true
?catch_error(bad_packet_id, parse_serialize(?SUBSCRIBE_PACKET(0, TopicFilters))), ?catch_error(bad_packet_id, parse_serialize(?SUBSCRIBE_PACKET(0, TopicFilters))),
?catch_error(bad_subqos, parse_serialize(?SUBSCRIBE_PACKET(1, [{<<"t">>, #{qos => 3}}]))). ?catch_error(bad_subqos, parse_serialize(?SUBSCRIBE_PACKET(1, [{<<"t">>, #{qos => 3}}]))).
@ -432,6 +456,9 @@ t_serialize_parse_subscribe_v5(_) ->
t_serialize_parse_suback(_) -> t_serialize_parse_suback(_) ->
Packet = ?SUBACK_PACKET(10, [?QOS_0, ?QOS_1, 128]), Packet = ?SUBACK_PACKET(10, [?QOS_0, ?QOS_1, 128]),
?assertEqual(Packet, parse_serialize(Packet)), ?assertEqual(Packet, parse_serialize(Packet)),
%% strict_mode = false
_ = parse_serialize(?SUBACK_PACKET(0, [?QOS_0]), #{strict_mode => false}),
%% strict_mode = true
?catch_error(bad_packet_id, parse_serialize(?SUBACK_PACKET(0, [?QOS_0]))). ?catch_error(bad_packet_id, parse_serialize(?SUBACK_PACKET(0, [?QOS_0]))).
t_serialize_parse_suback_v5(_) -> t_serialize_parse_suback_v5(_) ->
@ -451,6 +478,9 @@ t_serialize_parse_unsubscribe(_) ->
Bin0 = <<?UNSUBSCRIBE:4,0:4,10,0,2,0,6,84,111,112,105,99,65>>, Bin0 = <<?UNSUBSCRIBE:4,0:4,10,0,2,0,6,84,111,112,105,99,65>>,
?assertMatch(Packet, parse_to_packet(Bin0, #{strict_mode => false})), ?assertMatch(Packet, parse_to_packet(Bin0, #{strict_mode => false})),
?catch_error(bad_frame_header, parse_to_packet(Bin0, #{strict_mode => true})), ?catch_error(bad_frame_header, parse_to_packet(Bin0, #{strict_mode => true})),
%% strict_mode = false
_ = parse_serialize(?UNSUBSCRIBE_PACKET(0, [<<"TopicA">>]), #{strict_mode => false}),
%% strict_mode = true
?catch_error(bad_packet_id, parse_serialize(?UNSUBSCRIBE_PACKET(0, [<<"TopicA">>]))). ?catch_error(bad_packet_id, parse_serialize(?UNSUBSCRIBE_PACKET(0, [<<"TopicA">>]))).
t_serialize_parse_unsubscribe_v5(_) -> t_serialize_parse_unsubscribe_v5(_) ->