diff --git a/src/emqtt_client.erl b/src/emqtt_client.erl index 390872c4c..9325b63e6 100644 --- a/src/emqtt_client.erl +++ b/src/emqtt_client.erl @@ -12,6 +12,7 @@ terminate/2]). -include("emqtt.hrl"). +-include("emqtt_frame.hrl"). -define(CLIENT_ID_MAXLEN, 23). @@ -53,20 +54,38 @@ handle_call({go, Sock}, _From, _State) -> inet_error, fun () -> emqtt_net:tune_buffer_size(Sock) end), {ok, ConnStr} = emqtt_net:connection_string(Sock, inbound), error_logger:info_msg("accepting MQTT connection (~s)~n", [ConnStr]), - control_throttle( + {reply, ok, + control_throttle( #state{ socket = Sock, conn_name = ConnStr, await_recv = false, connection_state = running, conserve = false, parse_state = emqtt_frame:initial_state(), - proc_state = emqtt_processor:initial_state(Sock) }). + proc_state = emqtt_processor:initial_state(Sock) })}. handle_cast(Msg, State) -> {stop, {badmsg, Msg}, State}. -handle_info({route, Msg}, State) -> - emqtt_processor:send_client(Msg), +handle_info({route, Msg}, #state{proc_state=PState} = State) -> + #mqtt_msg{ retain = Retain, + qos = Qos, + topic = Topic, + dup = Dup, + message_id = MessageId, + payload = Payload } = Msg, + + Frame = #mqtt_frame{fixed = #mqtt_frame_fixed{ + type = ?PUBLISH, + qos = Qos, + retain = Retain, + dup = Dup }, + variable = #mqtt_frame_publish{ + topic_name = Topic, + message_id = 1}, + payload = Payload }, + + emqtt_processor:send_client(Frame, PState), {noreply, State}; handle_info({inet_reply, _Ref, ok}, State) -> diff --git a/src/emqtt_processor.erl b/src/emqtt_processor.erl index 798ac6ca3..cc277bc04 100644 --- a/src/emqtt_processor.erl +++ b/src/emqtt_processor.erl @@ -19,7 +19,7 @@ -module(emqtt_processor). -export([info/2, initial_state/1, - process_frame/2, send_will/1]). + process_frame/2, send_client/2, send_will/1]). -include("emqtt.hrl"). -include("emqtt_frame.hrl"). @@ -53,11 +53,6 @@ initial_state(Socket) -> info(client_id, #proc_state{ client_id = ClientId }) -> ClientId. -process_frame(#mqtt_frame{ fixed = #mqtt_frame_fixed{ type = Type }}, - PState = #proc_state{ connection = undefined } ) - when Type =/= ?CONNECT -> - {err, connect_expected, PState}; - process_frame(Frame = #mqtt_frame{ fixed = #mqtt_frame_fixed{ type = Type }}, PState ) -> process_request(Type, Frame, PState). @@ -69,6 +64,7 @@ process_request(?CONNECT, proto_ver = ProtoVersion, clean_sess = CleanSess, client_id = ClientId } = Var}, PState) -> + error_logger:info_msg("connect frame: ~p~n", [Var]), {ReturnCode, PState1} = case {ProtoVersion =:= ?MQTT_PROTO_MAJOR, emqtt_util:valid_client_id(ClientId)} of @@ -110,7 +106,7 @@ process_request(?PUBLISH, dup = Dup }, variable = #mqtt_frame_publish{ topic_name = Topic, message_id = MessageId }, - payload = Payload }, PState) -> + payload = Payload }, #proc_state{ message_id = MsgId } = PState) -> Msg = #mqtt_msg{ retain = Retain, qos = Qos, topic = Topic, @@ -118,10 +114,10 @@ process_request(?PUBLISH, message_id = MessageId, payload = Payload }, emqtt_router:route(Msg), - + send_client( #mqtt_frame{ fixed = #mqtt_frame_fixed{ type = ?PUBACK }, - variable = #mqtt_frame_publish{ message_id = MessageId }}, + variable = #mqtt_frame_publish{ message_id = MsgId}}, PState), {ok, PState}; @@ -141,7 +137,7 @@ process_request(?SUBSCRIBE, [emqtt_topic:insert(Name) || #mqtt_topic{name=Name} <- Topics], - [emqtt_router:insert(#subscriber{topic=Name, pid=self()}) + [emqtt_router:insert(#subscriber{topic=emqtt_util:binary(Name), pid=self()}) || #mqtt_topic{name=Name} <- Topics], send_client(#mqtt_frame{ fixed = #mqtt_frame_fixed{ type = ?SUBACK }, diff --git a/src/emqtt_router.erl b/src/emqtt_router.erl index 4d59458fe..76eb4c79f 100644 --- a/src/emqtt_router.erl +++ b/src/emqtt_router.erl @@ -1,10 +1,12 @@ -module(emqtt_router). -include("emqtt.hrl"). +-include("emqtt_frame.hrl"). -export([start_link/0]). --export([route/2, +-export([route/1, + route/2, insert/1, delete/1]). @@ -22,6 +24,16 @@ start_link() -> gen_server:start_link({local, ?MODULE}, ?MODULE, [], []). +binary(S) when is_list(S) -> + list_to_binary(S); + +binary(B) when is_binary(B) -> + B. + +route(#mqtt_msg{topic=Topic}=Msg) when is_record(Msg, mqtt_msg) -> + error_logger:info_msg("route msg: ~p~n", [Msg]), + [ Pid ! {route, Msg} || #subscriber{pid=Pid} <- ets:lookup(subscriber, binary(Topic)) ]. + route(Topic, Msg) -> [ Pid ! {route, Msg} || #subscriber{pid=Pid} <- ets:lookup(subscriber, Topic) ]. @@ -32,8 +44,8 @@ delete(Sub) when is_record(Sub, subscriber) -> gen_server:cast(?MODULE, {delete, Sub}). init([]) -> - ets:new(subscriber, [bag, protected, {keypos, 2}]), - error_logger:info_msg("emqtt_router is started."), + Res = ets:new(subscriber, [bag, protected, named_table, {keypos, 2}]), + error_logger:info_msg("emqtt_router is started: ~p~n", [Res]), {ok, #state{}}. handle_call({insert, Sub}, _From, State) -> diff --git a/src/emqtt_topic.erl b/src/emqtt_topic.erl index 2a6bb8d51..211bf9eba 100644 --- a/src/emqtt_topic.erl +++ b/src/emqtt_topic.erl @@ -22,7 +22,8 @@ start_link() -> gen_server:start_link({local, ?MODULE}, ?MODULE, [], []). -match(Topic) when is_binary(Topic) -> +match(Topic0) -> + Topic = emqtt_util:binary(Topic0), Words = topic_split(Topic), DirectMatches = mnesia:dirty_read(direct_topic, Words), WildcardMatches = lists:append([ @@ -32,11 +33,11 @@ match(Topic) when is_binary(Topic) -> ]), DirectMatches ++ WildcardMatches. -insert(Topic) when is_binary(Topic) -> - gen_server:call(?MODULE, {insert, Topic}). +insert(Topic) -> + gen_server:call(?MODULE, {insert, emqtt_util:binary(Topic)}). -delete(Topic) when is_binary(Topic) -> - gen_server:cast(?MODULE, {delete, Topic}). +delete(Topic) -> + gen_server:cast(?MODULE, {delete, emqtt_util:binary(Topic)}). init([]) -> {atomic, ok} = mnesia:create_table( diff --git a/src/emqtt_util.erl b/src/emqtt_util.erl index 378eb22c0..46c320e7b 100644 --- a/src/emqtt_util.erl +++ b/src/emqtt_util.erl @@ -6,6 +6,9 @@ -compile(export_all). +binary(L) when is_list(L) -> list_to_binary(L); +binary(B) when is_binary(B) -> B. + subcription_queue_name(ClientId) -> Base = "mqtt-subscription-" ++ ClientId ++ "qos", {list_to_binary(Base ++ "0"), list_to_binary(Base ++ "1")}.