diff --git a/apps/emqx/src/emqx_frame.erl b/apps/emqx/src/emqx_frame.erl index c55c4d275..e570ed7cf 100644 --- a/apps/emqx/src/emqx_frame.erl +++ b/apps/emqx/src/emqx_frame.erl @@ -301,7 +301,7 @@ parse_packet( Bin, #{strict_mode := StrictMode, version := Ver} ) -> - {TopicName, Rest} = parse_topic_name(Bin, StrictMode), + {TopicName, Rest} = parse_utf8_string(Bin, StrictMode), {PacketId, Rest1} = case QoS of ?QOS_0 -> {undefined, Rest}; @@ -310,6 +310,7 @@ parse_packet( (PacketId =/= undefined) andalso StrictMode andalso validate_packet_id(PacketId), {Properties, Payload} = parse_properties(Rest1, Ver, StrictMode), + ok = ensure_topic_name_valid(StrictMode, TopicName, Properties), Publish = #mqtt_packet_publish{ topic_name = TopicName, packet_id = PacketId, @@ -422,8 +423,9 @@ parse_will_message( StrictMode ) -> {Props, Rest} = parse_properties(Bin, Ver, StrictMode), - {Topic, Rest1} = parse_topic_name(Rest, StrictMode), + {Topic, Rest1} = parse_utf8_string(Rest, StrictMode), {Payload, Rest2} = parse_binary_data(Rest1), + ok = ensure_topic_name_valid(StrictMode, Topic, Props), { Packet#mqtt_packet_connect{ will_props = Props, @@ -621,13 +623,14 @@ parse_binary_data(Bin) when -> ?PARSE_ERR(malformed_binary_data_length). -parse_topic_name(Bin, false) -> - parse_utf8_string(Bin, false); -parse_topic_name(Bin, true) -> - case parse_utf8_string(Bin, true) of - {<<>>, _Rest} -> ?PARSE_ERR(empty_topic_name); - Result -> Result - end. +ensure_topic_name_valid(false, _TopicName, _Properties) -> + ok; +ensure_topic_name_valid(true, TopicName, _Properties) when TopicName =/= <<>> -> + ok; +ensure_topic_name_valid(true, <<>>, #{'Topic-Alias' := _}) -> + ok; +ensure_topic_name_valid(true, <<>>, _) -> + error(empty_topic_name). %%-------------------------------------------------------------------- %% Serialize MQTT Packet diff --git a/apps/emqx/test/emqx_frame_SUITE.erl b/apps/emqx/test/emqx_frame_SUITE.erl index 99f815abf..8dad58243 100644 --- a/apps/emqx/test/emqx_frame_SUITE.erl +++ b/apps/emqx/test/emqx_frame_SUITE.erl @@ -158,12 +158,15 @@ t_parse_malformed_utf8_string(_) -> ?ASSERT_FRAME_THROW(utf8_string_invalid, emqx_frame:parse(MalformedPacket, ParseState)). t_parse_empty_topic_name(_) -> - Packet = <<48, 4, 0, 0, 0, 1>>, - NormalState = emqx_frame:initial_parse_state(#{strict_mode => false}), - ?assertMatch({_, _}, emqx_frame:parse(Packet, NormalState)), + Packet = ?PUBLISH_PACKET(?QOS_1, <<>>, 1, #{}, <<>>), + ?assertEqual(Packet, parse_serialize(Packet, #{strict_mode => false})), + ?ASSERT_FRAME_THROW(empty_topic_name, parse_serialize(Packet, #{strict_mode => true})). - StrictState = emqx_frame:initial_parse_state(#{strict_mode => true}), - ?ASSERT_FRAME_THROW(empty_topic_name, emqx_frame:parse(Packet, StrictState)). +t_parse_empty_topic_name_with_alias(_) -> + Props = #{'Topic-Alias' => 16#AB}, + Packet = ?PUBLISH_PACKET(?QOS_1, <<>>, 1, Props, <<>>), + ?assertEqual(Packet, parse_serialize(Packet, #{strict_mode => false})), + ?assertEqual(Packet, parse_serialize(Packet, #{strict_mode => true})). t_serialize_parse_v3_connect(_) -> Bin =