diff --git a/src/emqttd_protocol.erl b/src/emqttd_protocol.erl index b87274455..433825253 100644 --- a/src/emqttd_protocol.erl +++ b/src/emqttd_protocol.erl @@ -27,7 +27,7 @@ -import(proplists, [get_value/2, get_value/3]). %% API --export([init/3, info/1, stats/1, clientid/1, client/1, session/1]). +-export([init/3, init/4, info/1, stats/1, clientid/1, client/1, session/1]). -export([subscribe/2, unsubscribe/2, pubrel/2, shutdown/2]). @@ -43,12 +43,12 @@ -record(proto_state, {peername, sendfun, connected = false, client_id, client_pid, clean_sess, proto_ver, proto_name, username, is_superuser, will_msg, keepalive, max_clientid_len, session, stats_data, - ws_initial_headers, connected_at}). + mountpoint, ws_initial_headers, connected_at}). -type(proto_state() :: #proto_state{}). -define(INFO_KEYS, [client_id, username, clean_sess, proto_ver, proto_name, - keepalive, will_msg, ws_initial_headers, connected_at]). + keepalive, will_msg, ws_initial_headers, mountpoint, connected_at]). -define(STATS_KEYS, [recv_pkt, recv_msg, send_pkt, send_msg]). @@ -63,12 +63,22 @@ init(Peername, SendFun, Opts) -> WsInitialHeaders = get_value(ws_initial_headers, Opts), #proto_state{peername = Peername, sendfun = SendFun, - client_pid = self(), max_clientid_len = MaxLen, is_superuser = false, + client_pid = self(), ws_initial_headers = WsInitialHeaders, stats_data = #proto_stats{enable_stats = EnableStats}}. +init(Conn, Peername, SendFun, Opts) -> + enrich_opt(Conn:opts(), Conn, init(Peername, SendFun, Opts)). + +enrich_opt([], _Conn, State) -> + State; +enrich_opt([{mountpoint, MountPoint} | ConnOpts], Conn, State) -> + enrich_opt(ConnOpts, Conn, State#proto_state{mountpoint = MountPoint}); +enrich_opt([_ | ConnOpts], Conn, State) -> + enrich_opt(ConnOpts, Conn, State). + info(ProtoState) -> ?record_to_proplist(proto_state, ProtoState, ?INFO_KEYS). @@ -87,6 +97,7 @@ client(#proto_state{client_id = ClientId, keepalive = Keepalive, will_msg = WillMsg, ws_initial_headers = WsInitialHeaders, + mountpoint = MountPoint, connected_at = Time}) -> WillTopic = if WillMsg =:= undefined -> undefined; @@ -101,6 +112,7 @@ client(#proto_state{client_id = ClientId, keepalive = Keepalive, will_topic = WillTopic, ws_initial_headers = WsInitialHeaders, + mountpoint = MountPoint, connected_at = Time}. session(#proto_state{session = Session}) -> @@ -167,13 +179,13 @@ process(?CONNECT_PACKET(Var), State0) -> keep_alive = KeepAlive, client_id = ClientId} = Var, - State1 = State0#proto_state{proto_ver = ProtoVer, - proto_name = ProtoName, - username = Username, - client_id = ClientId, - clean_sess = CleanSess, - keepalive = KeepAlive, - will_msg = willmsg(Var), + State1 = State0#proto_state{proto_ver = ProtoVer, + proto_name = ProtoName, + username = Username, + client_id = ClientId, + clean_sess = CleanSess, + keepalive = KeepAlive, + will_msg = willmsg(Var, State0), connected_at = os:timestamp()}, {ReturnCode1, SessPresent, State3} = @@ -240,10 +252,11 @@ process(?SUBSCRIBE_PACKET(PacketId, []), State) -> %% TODO: refactor later... process(?SUBSCRIBE_PACKET(PacketId, RawTopicTable), - State = #proto_state{session = Session, - client_id = ClientId, + State = #proto_state{client_id = ClientId, username = Username, - is_superuser = IsSuperuser}) -> + is_superuser = IsSuperuser, + mountpoint = MountPoint, + session = Session}) -> Client = client(State), TopicTable = parse_topic_table(RawTopicTable), AllowDenies = if IsSuperuser -> []; @@ -256,7 +269,8 @@ process(?SUBSCRIBE_PACKET(PacketId, RawTopicTable), false -> case emqttd_hooks:run('client.subscribe', [ClientId, Username], TopicTable) of {ok, TopicTable1} -> - emqttd_session:subscribe(Session, PacketId, TopicTable1), {ok, State}; + emqttd_session:subscribe(Session, PacketId, mount(MountPoint, TopicTable1)), + {ok, State}; {stop, _} -> {ok, State} end @@ -267,12 +281,13 @@ process(?UNSUBSCRIBE_PACKET(PacketId, []), State) -> send(?UNSUBACK_PACKET(PacketId), State); process(?UNSUBSCRIBE_PACKET(PacketId, RawTopics), - State = #proto_state{client_id = ClientId, - username = Username, - session = Session}) -> + State = #proto_state{client_id = ClientId, + username = Username, + mountpoint = MountPoint, + session = Session}) -> case emqttd_hooks:run('client.unsubscribe', [ClientId, Username], parse_topics(RawTopics)) of {ok, TopicTable} -> - emqttd_session:unsubscribe(Session, TopicTable); + emqttd_session:unsubscribe(Session, mount(MountPoint, TopicTable)); {stop, _} -> ok end, @@ -286,11 +301,12 @@ process(?PACKET(?DISCONNECT), State) -> {stop, normal, State#proto_state{will_msg = undefined}}. publish(Packet = ?PUBLISH_PACKET(?QOS_0, _PacketId), - #proto_state{client_id = ClientId, - username = Username, - session = Session}) -> + #proto_state{client_id = ClientId, + username = Username, + mountpoint = MountPoint, + session = Session}) -> Msg = emqttd_message:from_packet(Username, ClientId, Packet), - emqttd_session:publish(Session, Msg); + emqttd_session:publish(Session, mount(MountPoint, Msg)); publish(Packet = ?PUBLISH_PACKET(?QOS_1, _PacketId), State) -> with_puback(?PUBACK, Packet, State); @@ -299,11 +315,12 @@ publish(Packet = ?PUBLISH_PACKET(?QOS_2, _PacketId), State) -> with_puback(?PUBREC, Packet, State). with_puback(Type, Packet = ?PUBLISH_PACKET(_Qos, PacketId), - State = #proto_state{client_id = ClientId, - username = Username, - session = Session}) -> + State = #proto_state{client_id = ClientId, + username = Username, + mountpoint = MountPoint, + session = Session}) -> Msg = emqttd_message:from_packet(Username, ClientId, Packet), - case emqttd_session:publish(Session, Msg) of + case emqttd_session:publish(Session, mount(MountPoint, Msg)) of ok -> send(?PUBACK_PACKET(Type, PacketId), State); {error, Error} -> @@ -311,10 +328,12 @@ with_puback(Type, Packet = ?PUBLISH_PACKET(_Qos, PacketId), end. -spec(send(mqtt_message() | mqtt_packet(), proto_state()) -> {ok, proto_state()}). -send(Msg, State = #proto_state{client_id = ClientId, username = Username}) +send(Msg, State = #proto_state{client_id = ClientId, + username = Username, + mountpoint = MountPoint}) when is_record(Msg, mqtt_message) -> emqttd_hooks:run('message.delivered', [ClientId, Username], Msg), - send(emqttd_message:to_packet(Msg), State); + send(emqttd_message:to_packet(unmount(MountPoint, Msg)), State); send(Packet = ?PACKET(Type), State = #proto_state{sendfun = SendFun, stats_data = Stats}) -> @@ -371,8 +390,11 @@ shutdown(Error, State = #proto_state{will_msg = WillMsg}) -> %% emqttd_cm:unreg(ClientId). ok. -willmsg(Packet) when is_record(Packet, mqtt_packet_connect) -> - emqttd_message:from_packet(Packet). +willmsg(Packet, #proto_state{mountpoint = MountPoint}) when is_record(Packet, mqtt_packet_connect) -> + case emqttd_message:from_packet(Packet) of + undefined -> undefined; + Msg -> mount(MountPoint, Msg) + end. %% Generate a client if if nulll maybe_set_clientid(State = #proto_state{client_id = NullId}) @@ -513,3 +535,23 @@ check_acl(subscribe, Topic, Client) -> sp(true) -> 1; sp(false) -> 0. + +%%-------------------------------------------------------------------- +%% Mount Point +%%-------------------------------------------------------------------- + +mount(undefined, Any) -> + Any; +mount(MountPoint, Msg = #mqtt_message{topic = Topic}) -> + Msg#mqtt_message{topic = <>}; +mount(MountPoint, TopicTable) when is_list(TopicTable) -> + [{<>, Opts} || {Topic, Opts} <- TopicTable]. + +unmount(undefined, Any) -> + Any; +unmount(MountPoint, Msg = #mqtt_message{topic = Topic}) -> + case catch split_binary(Topic, byte_size(MountPoint)) of + {MountPoint, Topic0} -> Msg#mqtt_message{topic = Topic0}; + _ -> Msg + end. +