Validate packet id if strict mode.

This commit is contained in:
Feng Lee 2019-11-14 09:46:56 +08:00
parent d6ebbb7cce
commit d3107facf9
2 changed files with 61 additions and 27 deletions

View File

@ -89,12 +89,12 @@ parse(<<>>, {none, Options}) ->
parse(<<Type:4, Dup:1, QoS:2, Retain:1, Rest/binary>>,
{none, Options = #{strict_mode := StrictMode}}) ->
%% Validate header if strict mode.
StrictMode andalso validate_header(Type, Dup, QoS, Retain),
Header = #mqtt_packet_header{type = Type,
dup = bool(Dup),
qos = QoS,
retain = bool(Retain)
},
StrictMode andalso validate_header(Type, Dup, QoS, Retain),
Header1 = case fixqos(Type, QoS) of
QoS -> Header;
FixedQoS -> Header#mqtt_packet_header{qos = FixedQoS}
@ -164,7 +164,8 @@ packet(Header, Variable, Payload) ->
parse_packet(#mqtt_packet_header{type = ?CONNECT}, FrameBin, _Options) ->
{ProtoName, Rest} = parse_utf8_string(FrameBin),
<<BridgeTag:4, ProtoVer:4, Rest1/binary>> = Rest,
% Note: Crash when reserved flag doesn't equal to 0, there is no strict compliance with the MQTT5.0.
% Note: Crash when reserved flag doesn't equal to 0, there is no strict
% compliance with the MQTT5.0.
<<UsernameFlag : 1,
PasswordFlag : 1,
WillRetain : 1,
@ -201,13 +202,15 @@ parse_packet(#mqtt_packet_header{type = ?CONNACK},
properties = Properties
};
parse_packet(#mqtt_packet_header{type = ?PUBLISH, qos = QoS}, Bin, #{version := Ver}) ->
parse_packet(#mqtt_packet_header{type = ?PUBLISH, qos = QoS}, Bin,
#{strict_mode := StrictMode, version := Ver}) ->
{TopicName, Rest} = parse_utf8_string(Bin),
{PacketId, Rest1} = case QoS of
?QOS_0 -> {undefined, Rest};
_ -> parse_packet_id(Rest)
end,
(PacketId =/= undefined) andalso validate_packet_id(PacketId),
(PacketId =/= undefined) andalso
StrictMode andalso validate_packet_id(PacketId),
{Properties, Payload} = parse_properties(Rest1, Ver),
Publish = #mqtt_packet_publish{topic_name = TopicName,
packet_id = PacketId,
@ -215,15 +218,15 @@ parse_packet(#mqtt_packet_header{type = ?PUBLISH, qos = QoS}, Bin, #{version :=
},
{Publish, Payload};
parse_packet(#mqtt_packet_header{type = PubAck}, <<PacketId:16/big>>, _Options)
when ?PUBACK =< PubAck, PubAck =< ?PUBCOMP ->
ok = validate_packet_id(PacketId),
parse_packet(#mqtt_packet_header{type = PubAck}, <<PacketId:16/big>>, #{strict_mode := StrictMode})
when ?PUBACK =< PubAck, PubAck =< ?PUBCOMP ->
StrictMode andalso validate_packet_id(PacketId),
#mqtt_packet_puback{packet_id = PacketId, reason_code = 0};
parse_packet(#mqtt_packet_header{type = PubAck}, <<PacketId:16/big, ReasonCode, Rest/binary>>,
#{version := Ver = ?MQTT_PROTO_V5})
when ?PUBACK =< PubAck, PubAck =< ?PUBCOMP ->
ok = validate_packet_id(PacketId),
#{strict_mode := StrictMode, version := Ver = ?MQTT_PROTO_V5})
when ?PUBACK =< PubAck, PubAck =< ?PUBCOMP ->
StrictMode andalso validate_packet_id(PacketId),
{Properties, <<>>} = parse_properties(Rest, Ver),
#mqtt_packet_puback{packet_id = PacketId,
reason_code = ReasonCode,
@ -231,8 +234,8 @@ parse_packet(#mqtt_packet_header{type = PubAck}, <<PacketId:16/big, ReasonCode,
};
parse_packet(#mqtt_packet_header{type = ?SUBSCRIBE}, <<PacketId:16/big, Rest/binary>>,
#{version := Ver}) ->
ok = validate_packet_id(PacketId),
#{strict_mode := StrictMode, version := Ver}) ->
StrictMode andalso validate_packet_id(PacketId),
{Properties, Rest1} = parse_properties(Rest, Ver),
TopicFilters = parse_topic_filters(subscribe, Rest1),
ok = validate_subqos([QoS || {_, #{qos := QoS}} <- TopicFilters]),
@ -242,8 +245,8 @@ parse_packet(#mqtt_packet_header{type = ?SUBSCRIBE}, <<PacketId:16/big, Rest/bin
};
parse_packet(#mqtt_packet_header{type = ?SUBACK}, <<PacketId:16/big, Rest/binary>>,
#{version := Ver}) ->
ok = validate_packet_id(PacketId),
#{strict_mode := StrictMode, version := Ver}) ->
StrictMode andalso validate_packet_id(PacketId),
{Properties, Rest1} = parse_properties(Rest, Ver),
ReasonCodes = parse_reason_codes(Rest1),
#mqtt_packet_suback{packet_id = PacketId,
@ -252,8 +255,8 @@ parse_packet(#mqtt_packet_header{type = ?SUBACK}, <<PacketId:16/big, Rest/binary
};
parse_packet(#mqtt_packet_header{type = ?UNSUBSCRIBE}, <<PacketId:16/big, Rest/binary>>,
#{version := Ver}) ->
ok = validate_packet_id(PacketId),
#{strict_mode := StrictMode, version := Ver}) ->
StrictMode andalso validate_packet_id(PacketId),
{Properties, Rest1} = parse_properties(Rest, Ver),
TopicFilters = parse_topic_filters(unsubscribe, Rest1),
#mqtt_packet_unsubscribe{packet_id = PacketId,
@ -261,13 +264,14 @@ parse_packet(#mqtt_packet_header{type = ?UNSUBSCRIBE}, <<PacketId:16/big, Rest/b
topic_filters = TopicFilters
};
parse_packet(#mqtt_packet_header{type = ?UNSUBACK}, <<PacketId:16/big>>, _Options) ->
ok = validate_packet_id(PacketId),
parse_packet(#mqtt_packet_header{type = ?UNSUBACK}, <<PacketId:16/big>>,
#{strict_mode := StrictMode}) ->
StrictMode andalso validate_packet_id(PacketId),
#mqtt_packet_unsuback{packet_id = PacketId};
parse_packet(#mqtt_packet_header{type = ?UNSUBACK}, <<PacketId:16/big, Rest/binary>>,
#{version := Ver}) ->
ok = validate_packet_id(PacketId),
#{strict_mode := StrictMode, version := Ver}) ->
StrictMode andalso validate_packet_id(PacketId),
{Properties, Rest1} = parse_properties(Rest, Ver),
ReasonCodes = parse_reason_codes(Rest1),
#mqtt_packet_unsuback{packet_id = PacketId,
@ -296,8 +300,7 @@ parse_will_message(Packet = #mqtt_packet_connect{will_flag = true,
will_topic = Topic,
will_payload = Payload
}, Rest2};
parse_will_message(Packet, Bin) ->
{Packet, Bin}.
parse_will_message(Packet, Bin) -> {Packet, Bin}.
-compile({inline, [parse_packet_id/1]}).
parse_packet_id(<<PacketId:16/big, Rest/binary>>) ->
@ -720,6 +723,7 @@ validate_header(?DISCONNECT, 0, 0, 0) -> ok;
validate_header(?AUTH, 0, 0, 0) -> ok;
validate_header(_Type, _Dup, _QoS, _Rt) -> error(bad_frame_header).
-compile({inline, [validate_packet_id/1]}).
validate_packet_id(0) -> error(bad_packet_id);
validate_packet_id(_) -> ok.

View File

@ -40,7 +40,8 @@ all() ->
{group, unsuback},
{group, ping},
{group, disconnect},
{group, auth}].
{group, auth}
].
groups() ->
[{parse, [parallel],
@ -333,7 +334,10 @@ t_serialize_parse_qos1_publish(_) ->
payload = <<"haha">>},
?assertEqual(Bin, serialize_to_binary(Packet)),
?assertMatch(Packet, parse_to_packet(Bin, #{strict_mode => true})),
?catch_error(bad_packet_id, parse_serialize(?PUBLISH_PACKET(?QOS_1, <<"Topic">>, 0, <<>>))).
%% strict_mode = true
?catch_error(bad_packet_id, parse_serialize(?PUBLISH_PACKET(?QOS_1, <<"Topic">>, 0, <<>>))),
%% strict_mode = false
_ = parse_serialize(?PUBLISH_PACKET(?QOS_1, <<"Topic">>, 0, <<>>), #{strict_mode => false}).
t_serialize_parse_qos2_publish(_) ->
Packet = ?PUBLISH_PACKET(?QOS_2, <<"Topic">>, 1, <<>>),
@ -341,7 +345,10 @@ t_serialize_parse_qos2_publish(_) ->
?assertEqual(Packet, parse_serialize(Packet)),
?assertEqual(Bin, serialize_to_binary(Packet)),
?assertMatch(Packet, parse_to_packet(Bin, #{strict_mode => true})),
?catch_error(bad_packet_id, parse_serialize(?PUBLISH_PACKET(?QOS_2, <<"Topic">>, 0, <<>>))).
%% strict_mode = true
?catch_error(bad_packet_id, parse_serialize(?PUBLISH_PACKET(?QOS_2, <<"Topic">>, 0, <<>>))),
%% strict_mode = false
_ = parse_serialize(?PUBLISH_PACKET(?QOS_2, <<"Topic">>, 0, <<>>), #{strict_mode => false}).
t_serialize_parse_publish_v5(_) ->
Props = #{'Payload-Format-Indicator' => 1,
@ -358,7 +365,10 @@ t_serialize_parse_puback(_) ->
Packet = ?PUBACK_PACKET(1),
?assertEqual(<<64,2,0,1>>, serialize_to_binary(Packet)),
?assertEqual(Packet, parse_serialize(Packet)),
?catch_error(bad_packet_id, parse_serialize(?PUBACK_PACKET(0))).
%% strict_mode = true
?catch_error(bad_packet_id, parse_serialize(?PUBACK_PACKET(0))),
%% strict_mode = false
?PUBACK_PACKET(0) = parse_serialize(?PUBACK_PACKET(0), #{strict_mode => false}).
t_serialize_parse_puback_v3_4(_) ->
Bin = <<64,2,0,1>>,
@ -376,7 +386,10 @@ t_serialize_parse_pubrec(_) ->
Packet = ?PUBREC_PACKET(1),
?assertEqual(<<5:4,0:4,2,0,1>>, serialize_to_binary(Packet)),
?assertEqual(Packet, parse_serialize(Packet)),
?catch_error(bad_packet_id, parse_serialize(?PUBREC_PACKET(0))).
%% strict_mode = true
?catch_error(bad_packet_id, parse_serialize(?PUBREC_PACKET(0))),
%% strict_mode = false
?PUBREC_PACKET(0) = parse_serialize(?PUBREC_PACKET(0), #{strict_mode => false}).
t_serialize_parse_pubrec_v5(_) ->
Packet = ?PUBREC_PACKET(16, ?RC_SUCCESS, #{'Reason-String' => <<"success">>}),
@ -391,6 +404,9 @@ t_serialize_parse_pubrel(_) ->
Bin0 = <<6:4,0:4,2,0,1>>,
?assertMatch(Packet, parse_to_packet(Bin0, #{strict_mode => false})),
?catch_error(bad_frame_header, parse_to_packet(Bin0, #{strict_mode => true})),
%% strict_mode = false
?PUBREL_PACKET(0) = parse_serialize(?PUBREL_PACKET(0), #{strict_mode => false}),
%% strict_mode = true
?catch_error(bad_packet_id, parse_serialize(?PUBREL_PACKET(0))).
t_serialize_parse_pubrel_v5(_) ->
@ -402,6 +418,9 @@ t_serialize_parse_pubcomp(_) ->
Bin = serialize_to_binary(Packet),
?assertEqual(<<7:4,0:4,2,0,1>>, Bin),
?assertEqual(Packet, parse_serialize(Packet)),
%% strict_mode = false
?PUBCOMP_PACKET(0) = parse_serialize(?PUBCOMP_PACKET(0), #{strict_mode => false}),
%% strict_mode = true
?catch_error(bad_packet_id, parse_serialize(?PUBCOMP_PACKET(0))).
t_serialize_parse_pubcomp_v5(_) ->
@ -419,7 +438,12 @@ t_serialize_parse_subscribe(_) ->
%% SUBSCRIBE with bad qos 0
Bin0 = <<?SUBSCRIBE:4,0:4,11,0,2,0,6,84,111,112,105,99,65,2>>,
?assertMatch(Packet, parse_to_packet(Bin0, #{strict_mode => false})),
%% strict_mode = false
_ = parse_to_packet(Bin0, #{strict_mode => false}),
?catch_error(bad_frame_header, parse_to_packet(Bin0, #{strict_mode => true})),
%% strict_mode = false
_ = parse_serialize(?SUBSCRIBE_PACKET(0, TopicFilters), #{strict_mode => false}),
%% strict_mode = true
?catch_error(bad_packet_id, parse_serialize(?SUBSCRIBE_PACKET(0, TopicFilters))),
?catch_error(bad_subqos, parse_serialize(?SUBSCRIBE_PACKET(1, [{<<"t">>, #{qos => 3}}]))).
@ -432,6 +456,9 @@ t_serialize_parse_subscribe_v5(_) ->
t_serialize_parse_suback(_) ->
Packet = ?SUBACK_PACKET(10, [?QOS_0, ?QOS_1, 128]),
?assertEqual(Packet, parse_serialize(Packet)),
%% strict_mode = false
_ = parse_serialize(?SUBACK_PACKET(0, [?QOS_0]), #{strict_mode => false}),
%% strict_mode = true
?catch_error(bad_packet_id, parse_serialize(?SUBACK_PACKET(0, [?QOS_0]))).
t_serialize_parse_suback_v5(_) ->
@ -451,6 +478,9 @@ t_serialize_parse_unsubscribe(_) ->
Bin0 = <<?UNSUBSCRIBE:4,0:4,10,0,2,0,6,84,111,112,105,99,65>>,
?assertMatch(Packet, parse_to_packet(Bin0, #{strict_mode => false})),
?catch_error(bad_frame_header, parse_to_packet(Bin0, #{strict_mode => true})),
%% strict_mode = false
_ = parse_serialize(?UNSUBSCRIBE_PACKET(0, [<<"TopicA">>]), #{strict_mode => false}),
%% strict_mode = true
?catch_error(bad_packet_id, parse_serialize(?UNSUBSCRIBE_PACKET(0, [<<"TopicA">>]))).
t_serialize_parse_unsubscribe_v5(_) ->