diff --git a/.gitignore b/.gitignore index 60ddd53e8..4e53f8f05 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ +lib ebin/* *.log diff --git a/Makefile b/Makefile index e577ac1ca..1dc31e2b7 100644 --- a/Makefile +++ b/Makefile @@ -3,5 +3,8 @@ all: compile run: compile erl -pa ebin -config etc/emqtt.config -s emqtt_app start -compile: +compile: deps rebar compile + +deps: + rebar get-deps diff --git a/etc/emqtt.config b/etc/emqtt.config index 0d974db0e..d36330143 100644 --- a/etc/emqtt.config +++ b/etc/emqtt.config @@ -9,6 +9,15 @@ ]}, {mnesia, [ {dir, "var/mnesia"} + ]}, + {emqtt, [ + {tcp_listeners, [1883]}, + {tcp_listen_options, [ + binary, + {packet, raw}, + {reuseaddr, true}, + {backlog, 128}, + {nodelay, true}]} ]} ]. diff --git a/include/emqtt.hrl b/include/emqtt.hrl index fc5f70e56..971238c02 100644 --- a/include/emqtt.hrl +++ b/include/emqtt.hrl @@ -1,10 +1,11 @@ +-define(COPYRIGHT, "Copyright (C) 2007-2012 VMware, Inc."). +-define(LICENSE_MESSAGE, "Licensed under the MPL."). +-define(PROTOCOL_VERSION, "MQTT/3.1"). +-define(ERTS_MINIMUM, "5.6.3"). --record(direct_topic, {name, node}). - --record(wildcard_topic, {words, node}). +-record(topic, {words, path}). -record(subscriber, {topic, pid}). - diff --git a/include/emqtt_frame.hrl b/include/emqtt_frame.hrl new file mode 100644 index 000000000..fb6eebe29 --- /dev/null +++ b/include/emqtt_frame.hrl @@ -0,0 +1,100 @@ +% +% NOTICE: copy from rabbitmq mqtt-adaper +% + + +%% The contents of this file are subject to the Mozilla Public License +%% Version 1.1 (the "License"); you may not use this file except in +%% compliance with the License. You may obtain a copy of the License +%% at http://www.mozilla.org/MPL/ +%% +%% Software distributed under the License is distributed on an "AS IS" +%% basis, WITHOUT WARRANTY OF ANY KIND, either express or implied. See +%% the License for the specific language governing rights and +%% limitations under the License. +%% +%% The Original Code is RabbitMQ. +%% +%% The Initial Developer of the Original Code is VMware, Inc. +%% Copyright (c) 2007-2012 VMware, Inc. All rights reserved. +%% + +-define(MQTT_PROTO_MAJOR, 3). +-define(MQTT_PROTO_MINOR, 1). + +%% frame types + +-define(CONNECT, 1). +-define(CONNACK, 2). +-define(PUBLISH, 3). +-define(PUBACK, 4). +-define(PUBREC, 5). +-define(PUBREL, 6). +-define(PUBCOMP, 7). +-define(SUBSCRIBE, 8). +-define(SUBACK, 9). +-define(UNSUBSCRIBE, 10). +-define(UNSUBACK, 11). +-define(PINGREQ, 12). +-define(PINGRESP, 13). +-define(DISCONNECT, 14). + +%% connect return codes + +-define(CONNACK_ACCEPT, 0). +-define(CONNACK_PROTO_VER, 1). %% unacceptable protocol version +-define(CONNACK_INVALID_ID, 2). %% identifier rejected +-define(CONNACK_SERVER, 3). %% server unavailable +-define(CONNACK_CREDENTIALS, 4). %% bad user name or password +-define(CONNACK_AUTH, 5). %% not authorized + +%% qos levels + +-define(QOS_0, 0). +-define(QOS_1, 1). +-define(QOS_2, 2). + +-record(mqtt_frame, {fixed, + variable, + payload}). + +-record(mqtt_frame_fixed, {type = 0, + dup = 0, + qos = 0, + retain = 0}). + +-record(mqtt_frame_connect, {proto_ver, + will_retain, + will_qos, + will_flag, + clean_sess, + keep_alive, + client_id, + will_topic, + will_msg, + username, + password}). + +-record(mqtt_frame_connack, {return_code}). + +-record(mqtt_frame_publish, {topic_name, + message_id}). + +-record(mqtt_frame_subscribe,{message_id, + topic_table}). + +-record(mqtt_frame_suback, {message_id, + qos_table = []}). + +-record(mqtt_topic, {name, + qos}). + +-record(mqtt_frame_other, {other}). + +-record(mqtt_msg, {retain, + qos, + topic, + dup, + message_id, + payload}). + diff --git a/rebar.config b/rebar.config new file mode 100644 index 000000000..c50c864af --- /dev/null +++ b/rebar.config @@ -0,0 +1,8 @@ + +{lib_dirs,["lib"]}. + +{deps_dir, ["lib"]}. + +{deps, [ + {'rabbitlib', ".*", {git, "git://github.com/emqtt/rabbitlib.git", {branch, "master"}}} +]}. diff --git a/src/.emqtt_topic.erl.swp b/src/.emqtt_topic.erl.swp deleted file mode 100644 index a5e73750e..000000000 Binary files a/src/.emqtt_topic.erl.swp and /dev/null differ diff --git a/src/emqtt_app.erl b/src/emqtt_app.erl index efd7cc6dd..e29320ba2 100644 --- a/src/emqtt_app.erl +++ b/src/emqtt_app.erl @@ -19,13 +19,14 @@ start_app(mnesia) -> start_app(App) -> application:start(App). - %% =================================================================== %% Application callbacks %% =================================================================== start(_StartType, _StartArgs) -> - emqtt_sup:start_link(). + {ok, Sup} = emqtt_sup:start_link(), + emqtt_networking:boot(), + {ok, Sup}. stop(_State) -> ok. diff --git a/src/emqtt_client.erl b/src/emqtt_client.erl index bd83b5f7b..01ffe0897 100644 --- a/src/emqtt_client.erl +++ b/src/emqtt_client.erl @@ -1,5 +1,40 @@ -%simulate a mqtt connection -module(emqtt_client). +-behaviour(gen_server2). +-export([start_link/0, go/2]). +-export([init/1, + handle_call/3, + handle_cast/2, + handle_info/2, + code_change/3, + terminate/2]). + +-include("emqtt.hrl"). + +go(Pid, Sock) -> + gen_server2:call(Pid, {go, Sock}). + +start_link() -> + gen_server2:start_link(?MODULE, [], []). + +init([]) -> + {ok, undefined, hibernate, {backoff, 1000, 1000, 10000}}. + +handle_call({go, Sock}, _From, State) -> + error_logger:info_msg("go.... sock: ~p", [Sock]), + {reply, ok, State}. + +handle_cast(Msg, State) -> + {stop, {badmsg, Msg}, State}. + +handle_info(Info, State) -> + {stop, {badinfo, Info}, State}. + +terminate(_Reason, _State) -> + ok. + +code_change(_OldVsn, State, _Extra) -> + {ok, State}. + diff --git a/src/emqtt_client_sup.erl b/src/emqtt_client_sup.erl index ce0f47441..be21d86e9 100644 --- a/src/emqtt_client_sup.erl +++ b/src/emqtt_client_sup.erl @@ -1 +1,16 @@ -module(emqtt_client_sup). + +-export([start_link/0]). + +-behaviour(supervisor2). + +-export([init/1]). + +start_link() -> + supervisor2:start_link({local, ?MODULE}, ?MODULE, []). + +init([]) -> + {ok, {{simple_one_for_one_terminate, 0, 1}, + [{client, {emqtt_client, start_link, []}, + temporary, 5000, worker, [emqtt_client]}]}}. + diff --git a/src/emqtt_frame.erl b/src/emqtt_frame.erl index 102310f9a..5b9a76c87 100644 --- a/src/emqtt_frame.erl +++ b/src/emqtt_frame.erl @@ -1 +1,233 @@ +%% This file is a copy of `rabbitmq_mqtt_frame.erl' from rabbitmq. +%% License: +%% The contents of this file are subject to the Mozilla Public License +%% Version 1.1 (the "License"); you may not use this file except in +%% compliance with the License. You may obtain a copy of the License +%% at http://www.mozilla.org/MPL/ +%% +%% Software distributed under the License is distributed on an "AS IS" +%% basis, WITHOUT WARRANTY OF ANY KIND, either express or implied. See +%% the License for the specific language governing rights and +%% limitations under the License. +%% +%% The Original Code is RabbitMQ. +%% +%% The Initial Developer of the Original Code is VMware, Inc. +%% Copyright (c) 2007-2012 VMware, Inc. All rights reserved. +%% -module(emqtt_frame). + +-include("emqtt_frame.hrl"). + +-export([parse/2, initial_state/0]). +-export([serialise/1]). + +-define(RESERVED, 0). +-define(PROTOCOL_MAGIC, "MQIsdp"). +-define(MAX_LEN, 16#fffffff). +-define(HIGHBIT, 2#10000000). +-define(LOWBITS, 2#01111111). + +initial_state() -> none. + +parse(<<>>, none) -> + {more, fun(Bin) -> parse(Bin, none) end}; +parse(<>, none) -> + parse_remaining_len(Rest, #mqtt_frame_fixed{ type = MessageType, + dup = bool(Dup), + qos = QoS, + retain = bool(Retain) }); +parse(Bin, Cont) -> Cont(Bin). + +parse_remaining_len(<<>>, Fixed) -> + {more, fun(Bin) -> parse_remaining_len(Bin, Fixed) end}; +parse_remaining_len(Rest, Fixed) -> + parse_remaining_len(Rest, Fixed, 1, 0). + +parse_remaining_len(_Bin, _Fixed, _Multiplier, Length) + when Length > ?MAX_LEN -> + {error, invalid_mqtt_frame_len}; +parse_remaining_len(<<>>, Fixed, Multiplier, Length) -> + {more, fun(Bin) -> parse_remaining_len(Bin, Fixed, Multiplier, Length) end}; +parse_remaining_len(<<1:1, Len:7, Rest/binary>>, Fixed, Multiplier, Value) -> + parse_remaining_len(Rest, Fixed, Multiplier * ?HIGHBIT, Value + Len * Multiplier); +parse_remaining_len(<<0:1, Len:7, Rest/binary>>, Fixed, Multiplier, Value) -> + parse_frame(Rest, Fixed, Value + Len * Multiplier). + +parse_frame(Bin, #mqtt_frame_fixed{ type = Type, + qos = Qos } = Fixed, Length) -> + case {Type, Bin} of + {?CONNECT, <>} -> + {ProtocolMagic, Rest1} = parse_utf(FrameBin), + <> = Rest1, + <> = Rest2, + {ClientId, Rest4} = parse_utf(Rest3), + {WillTopic, Rest5} = parse_utf(Rest4, WillFlag), + {WillMsg, Rest6} = parse_msg(Rest5, WillFlag), + {UserName, Rest7} = parse_utf(Rest6, UsernameFlag), + {PasssWord, <<>>} = parse_utf(Rest7, PasswordFlag), + case ProtocolMagic == ?PROTOCOL_MAGIC of + true -> + wrap(Fixed, + #mqtt_frame_connect{ + proto_ver = ProtoVersion, + will_retain = bool(WillRetain), + will_qos = WillQos, + will_flag = bool(WillFlag), + clean_sess = bool(CleanSession), + keep_alive = KeepAlive, + client_id = ClientId, + will_topic = WillTopic, + will_msg = WillMsg, + username = UserName, + password = PasssWord}, Rest); + false -> + {error, protocol_header_corrupt} + end; + {?PUBLISH, <>} -> + {TopicName, Rest1} = parse_utf(FrameBin), + {MessageId, Payload} = case Qos of + 0 -> {undefined, Rest1}; + _ -> <> = Rest1, + {M, R} + end, + wrap(Fixed, #mqtt_frame_publish { topic_name = TopicName, + message_id = MessageId }, + Payload, Rest); + {?PUBACK, <>} -> + <> = FrameBin, + wrap(Fixed, #mqtt_frame_publish { message_id = MessageId }, Rest); + {Subs, <>} + when Subs =:= ?SUBSCRIBE orelse Subs =:= ?UNSUBSCRIBE -> + 1 = Qos, + <> = FrameBin, + Topics = parse_topics(Subs, Rest1, []), + wrap(Fixed, #mqtt_frame_subscribe { message_id = MessageId, + topic_table = Topics }, Rest); + {Minimal, Rest} + when Minimal =:= ?DISCONNECT orelse Minimal =:= ?PINGREQ -> + Length = 0, + wrap(Fixed, Rest); + {_, TooShortBin} -> + {more, fun(BinMore) -> + parse_frame(<>, + Fixed, Length) + end} + end. + +parse_topics(_, <<>>, Topics) -> + Topics; +parse_topics(?SUBSCRIBE = Sub, Bin, Topics) -> + {Name, <<_:6, QoS:2, Rest/binary>>} = parse_utf(Bin), + parse_topics(Sub, Rest, [#mqtt_topic { name = Name, qos = QoS } | Topics]); +parse_topics(?UNSUBSCRIBE = Sub, Bin, Topics) -> + {Name, <>} = parse_utf(Bin), + parse_topics(Sub, Rest, [#mqtt_topic { name = Name } | Topics]). + +wrap(Fixed, Variable, Payload, Rest) -> + {ok, #mqtt_frame { variable = Variable, fixed = Fixed, payload = Payload }, Rest}. +wrap(Fixed, Variable, Rest) -> + {ok, #mqtt_frame { variable = Variable, fixed = Fixed }, Rest}. +wrap(Fixed, Rest) -> + {ok, #mqtt_frame { fixed = Fixed }, Rest}. + +parse_utf(Bin, 0) -> + {undefined, Bin}; +parse_utf(Bin, _) -> + parse_utf(Bin). + +parse_utf(<>) -> + {binary_to_list(Str), Rest}. + +parse_msg(Bin, 0) -> + {undefined, Bin}; +parse_msg(<>, _) -> + {Msg, Rest}. + +bool(0) -> false; +bool(1) -> true. + +%% serialisation + +serialise(#mqtt_frame{ fixed = Fixed, + variable = Variable, + payload = Payload }) -> + serialise_variable(Fixed, Variable, serialise_payload(Payload)). + +serialise_payload(undefined) -> <<>>; +serialise_payload(B) when is_binary(B) -> B. + +serialise_variable(#mqtt_frame_fixed { type = ?CONNACK } = Fixed, + #mqtt_frame_connack { return_code = ReturnCode }, + <<>> = PayloadBin) -> + VariableBin = <>, + serialise_fixed(Fixed, VariableBin, PayloadBin); + +serialise_variable(#mqtt_frame_fixed { type = SubAck } = Fixed, + #mqtt_frame_suback { message_id = MessageId, + qos_table = Qos }, + <<>> = _PayloadBin) + when SubAck =:= ?SUBACK orelse SubAck =:= ?UNSUBACK -> + VariableBin = <>, + QosBin = << <> || Q <- Qos >>, + serialise_fixed(Fixed, VariableBin, QosBin); + +serialise_variable(#mqtt_frame_fixed { type = ?PUBLISH, + qos = Qos } = Fixed, + #mqtt_frame_publish { topic_name = TopicName, + message_id = MessageId }, + PayloadBin) -> + TopicBin = serialise_utf(TopicName), + MessageIdBin = case Qos of + 0 -> <<>>; + 1 -> <> + end, + serialise_fixed(Fixed, <>, PayloadBin); + +serialise_variable(#mqtt_frame_fixed { type = ?PUBACK } = Fixed, + #mqtt_frame_publish { message_id = MessageId }, + PayloadBin) -> + MessageIdBin = <>, + serialise_fixed(Fixed, MessageIdBin, PayloadBin); + +serialise_variable(#mqtt_frame_fixed {} = Fixed, + undefined, + <<>> = _PayloadBin) -> + serialise_fixed(Fixed, <<>>, <<>>). + +serialise_fixed(#mqtt_frame_fixed{ type = Type, + dup = Dup, + qos = Qos, + retain = Retain }, VariableBin, PayloadBin) + when is_integer(Type) andalso ?CONNECT =< Type andalso Type =< ?DISCONNECT -> + Len = size(VariableBin) + size(PayloadBin), + true = (Len =< ?MAX_LEN), + LenBin = serialise_len(Len), + <>. + +serialise_utf(String) -> + StringBin = unicode:characters_to_binary(String), + Len = size(StringBin), + true = (Len =< 16#ffff), + <>. + +serialise_len(N) when N =< ?LOWBITS -> + <<0:1, N:7>>; +serialise_len(N) -> + <<1:1, (N rem ?HIGHBIT):7, (serialise_len(N div ?HIGHBIT))/binary>>. + +opt(undefined) -> ?RESERVED; +opt(false) -> 0; +opt(true) -> 1; +opt(X) when is_integer(X) -> X. + + diff --git a/src/emqtt_networking.erl b/src/emqtt_networking.erl new file mode 100644 index 000000000..a8493f9c0 --- /dev/null +++ b/src/emqtt_networking.erl @@ -0,0 +1,253 @@ +-module(emqtt_networking). + +-export([boot/0]). + +-export([start_tcp_listener/1, stop_tcp_listener/1, tcp_host/1, ntoab/1]). + +%callback. + +-export([tcp_listener_started/3, tcp_listener_stopped/3, start_client/1]). + +-include_lib("kernel/include/inet.hrl"). + +-define(FIRST_TEST_BIND_PORT, 10000). + +boot() -> + {ok, TcpListeners} = application:get_env(tcp_listeners), + [ok = start_tcp_listener(Listener) || Listener <- TcpListeners]. + +start_tcp_listener(Listener) -> + start_listener(Listener, emqtt, "TCP Listener", + {?MODULE, start_client, []}). + +start_listener(Listener, Protocol, Label, OnConnect) -> + [start_listener0(Address, Protocol, Label, OnConnect) || + Address <- tcp_listener_addresses(Listener)], + ok. + +start_listener0(Address, Protocol, Label, OnConnect) -> + Spec = tcp_listener_spec(emqtt_tcp_listener_sup, Address, tcp_opts(), + Protocol, Label, OnConnect), + case supervisor:start_child(emqtt_sup, Spec) of + {ok, _} -> ok; + {error, {shutdown, _}} -> {IPAddress, Port, _Family} = Address, + exit({could_not_start_tcp_listener, + {ntoa(IPAddress), Port}}) + end. + +stop_tcp_listener(Listener) -> + [stop_tcp_listener0(Address) || + Address <- tcp_listener_addresses(Listener)], + ok. + +stop_tcp_listener0({IPAddress, Port, _Family}) -> + Name = tcp_name(emqtt_tcp_listener_sup, IPAddress, Port), + ok = supervisor:terminate_child(emqtt_sup, Name), + ok = supervisor:delete_child(emqtt_sup, Name). + +tcp_listener_addresses(Port) when is_integer(Port) -> + tcp_listener_addresses_auto(Port); +tcp_listener_addresses({"auto", Port}) -> + %% Variant to prevent lots of hacking around in bash and batch files + tcp_listener_addresses_auto(Port); +tcp_listener_addresses({Host, Port}) -> + %% auto: determine family IPv4 / IPv6 after converting to IP address + tcp_listener_addresses({Host, Port, auto}); +tcp_listener_addresses({Host, Port, Family0}) + when is_integer(Port) andalso (Port >= 0) andalso (Port =< 65535) -> + [{IPAddress, Port, Family} || + {IPAddress, Family} <- getaddr(Host, Family0)]; +tcp_listener_addresses({_Host, Port, _Family0}) -> + error_logger:error_msg("invalid port ~p - not 0..65535~n", [Port]), + throw({error, {invalid_port, Port}}). + +tcp_listener_addresses_auto(Port) -> + lists:append([tcp_listener_addresses(Listener) || + Listener <- port_to_listeners(Port)]). + +tcp_listener_spec(NamePrefix, {IPAddress, Port, Family}, SocketOpts, + Protocol, Label, OnConnect) -> + {tcp_name(NamePrefix, IPAddress, Port), + {tcp_listener_sup, start_link, + [IPAddress, Port, [Family | SocketOpts], + {?MODULE, tcp_listener_started, [Protocol]}, + {?MODULE, tcp_listener_stopped, [Protocol]}, + OnConnect, Label]}, + transient, infinity, supervisor, [tcp_listener_sup]}. + + +tcp_listener_started(Protocol, IPAddress, Port) -> + %% We need the ip to distinguish e.g. 0.0.0.0 and 127.0.0.1 + %% We need the host so we can distinguish multiple instances of the above + %% in a cluster. + error_logger:info_msg("tcp listener started: ~p ~p:~p", [Protocol, IPAddress, Port]). + +tcp_listener_stopped(Protocol, IPAddress, Port) -> + error_logger:info_msg("tcp listener stopped: ~p ~p:~p", [Protocol, IPAddress, Port]). + +start_client(Sock) -> + {ok, Client} = supervisor:start_child(emqtt_client_sup, []), + ok = gen_tcp:controlling_process(Sock, Client), + emqtt_client:go(Client, Sock), + Client. + +%%-------------------------------------------------------------------- +tcp_host({0,0,0,0}) -> + hostname(); + +tcp_host({0,0,0,0,0,0,0,0}) -> + hostname(); + +tcp_host(IPAddress) -> + case inet:gethostbyaddr(IPAddress) of + {ok, #hostent{h_name = Name}} -> Name; + {error, _Reason} -> ntoa(IPAddress) + end. + +hostname() -> + {ok, Hostname} = inet:gethostname(), + case inet:gethostbyname(Hostname) of + {ok, #hostent{h_name = Name}} -> Name; + {error, _Reason} -> Hostname + end. + +tcp_opts() -> + {ok, Opts} = application:get_env(emqtt, tcp_listen_options), + Opts. + +%% inet_parse:address takes care of ip string, like "0.0.0.0" +%% inet:getaddr returns immediately for ip tuple {0,0,0,0}, +%% and runs 'inet_gethost' port process for dns lookups. +%% On Windows inet:getaddr runs dns resolver for ip string, which may fail. +getaddr(Host, Family) -> + case inet_parse:address(Host) of + {ok, IPAddress} -> [{IPAddress, resolve_family(IPAddress, Family)}]; + {error, _} -> gethostaddr(Host, Family) + end. + +gethostaddr(Host, auto) -> + Lookups = [{Family, inet:getaddr(Host, Family)} || Family <- [inet, inet6]], + case [{IP, Family} || {Family, {ok, IP}} <- Lookups] of + [] -> host_lookup_error(Host, Lookups); + IPs -> IPs + end; + +gethostaddr(Host, Family) -> + case inet:getaddr(Host, Family) of + {ok, IPAddress} -> [{IPAddress, Family}]; + {error, Reason} -> host_lookup_error(Host, Reason) + end. + +host_lookup_error(Host, Reason) -> + error_logger:error_msg("invalid host ~p - ~p~n", [Host, Reason]), + throw({error, {invalid_host, Host, Reason}}). + +resolve_family({_,_,_,_}, auto) -> inet; +resolve_family({_,_,_,_,_,_,_,_}, auto) -> inet6; +resolve_family(IP, auto) -> throw({error, {strange_family, IP}}); +resolve_family(_, F) -> F. + +%%-------------------------------------------------------------------- + +%% There are three kinds of machine (for our purposes). +%% +%% * Those which treat IPv4 addresses as a special kind of IPv6 address +%% ("Single stack") +%% - Linux by default, Windows Vista and later +%% - We also treat any (hypothetical?) IPv6-only machine the same way +%% * Those which consider IPv6 and IPv4 to be completely separate things +%% ("Dual stack") +%% - OpenBSD, Windows XP / 2003, Linux if so configured +%% * Those which do not support IPv6. +%% - Ancient/weird OSes, Linux if so configured +%% +%% How to reconfigure Linux to test this: +%% Single stack (default): +%% echo 0 > /proc/sys/net/ipv6/bindv6only +%% Dual stack: +%% echo 1 > /proc/sys/net/ipv6/bindv6only +%% IPv4 only: +%% add ipv6.disable=1 to GRUB_CMDLINE_LINUX_DEFAULT in /etc/default/grub then +%% sudo update-grub && sudo reboot +%% +%% This matters in (and only in) the case where the sysadmin (or the +%% app descriptor) has only supplied a port and we wish to bind to +%% "all addresses". This means different things depending on whether +%% we're single or dual stack. On single stack binding to "::" +%% implicitly includes all IPv4 addresses, and subsequently attempting +%% to bind to "0.0.0.0" will fail. On dual stack, binding to "::" will +%% only bind to IPv6 addresses, and we need another listener bound to +%% "0.0.0.0" for IPv4. Finally, on IPv4-only systems we of course only +%% want to bind to "0.0.0.0". +%% +%% Unfortunately it seems there is no way to detect single vs dual stack +%% apart from attempting to bind to the port. +port_to_listeners(Port) -> + IPv4 = {"0.0.0.0", Port, inet}, + IPv6 = {"::", Port, inet6}, + case ipv6_status(?FIRST_TEST_BIND_PORT) of + single_stack -> [IPv6]; + ipv6_only -> [IPv6]; + dual_stack -> [IPv6, IPv4]; + ipv4_only -> [IPv4] + end. + +ipv6_status(TestPort) -> + IPv4 = [inet, {ip, {0,0,0,0}}], + IPv6 = [inet6, {ip, {0,0,0,0,0,0,0,0}}], + case gen_tcp:listen(TestPort, IPv6) of + {ok, LSock6} -> + case gen_tcp:listen(TestPort, IPv4) of + {ok, LSock4} -> + %% Dual stack + gen_tcp:close(LSock6), + gen_tcp:close(LSock4), + dual_stack; + %% Checking the error here would only let us + %% distinguish single stack IPv6 / IPv4 vs IPv6 only, + %% which we figure out below anyway. + {error, _} -> + gen_tcp:close(LSock6), + case gen_tcp:listen(TestPort, IPv4) of + %% Single stack + {ok, LSock4} -> gen_tcp:close(LSock4), + single_stack; + %% IPv6-only machine. Welcome to the future. + {error, eafnosupport} -> ipv6_only; %% Linux + {error, eprotonosupport}-> ipv6_only; %% FreeBSD + %% Dual stack machine with something already + %% on IPv4. + {error, _} -> ipv6_status(TestPort + 1) + end + end; + %% IPv4-only machine. Welcome to the 90s. + {error, eafnosupport} -> %% Linux + ipv4_only; + {error, eprotonosupport} -> %% FreeBSD + ipv4_only; + %% Port in use + {error, _} -> + ipv6_status(TestPort + 1) + end. + +ntoa({0,0,0,0,0,16#ffff,AB,CD}) -> + inet_parse:ntoa({AB bsr 8, AB rem 256, CD bsr 8, CD rem 256}); +ntoa(IP) -> + inet_parse:ntoa(IP). + +ntoab(IP) -> + Str = ntoa(IP), + case string:str(Str, ":") of + 0 -> Str; + _ -> "[" ++ Str ++ "]" + end. + +tcp_name(Prefix, IPAddress, Port) + when is_atom(Prefix) andalso is_number(Port) -> + list_to_atom( + format("~w_~s:~w", [Prefix, inet_parse:ntoa(IPAddress), Port])). + +format(Fmt, Args) -> lists:flatten(io_lib:format(Fmt, Args)). + + + diff --git a/src/emqtt_reader.erl b/src/emqtt_reader.erl deleted file mode 100644 index a91f0dc82..000000000 --- a/src/emqtt_reader.erl +++ /dev/null @@ -1,3 +0,0 @@ -%tcp data reader --module(emqtt_reader). - diff --git a/src/emqtt_router.erl b/src/emqtt_router.erl index ae262a5c2..4d59458fe 100644 --- a/src/emqtt_router.erl +++ b/src/emqtt_router.erl @@ -4,6 +4,10 @@ -export([start_link/0]). +-export([route/2, + insert/1, + delete/1]). + -behaviour(gen_server). -export([init/1, @@ -18,12 +22,31 @@ start_link() -> gen_server:start_link({local, ?MODULE}, ?MODULE, [], []). +route(Topic, Msg) -> + [ Pid ! {route, Msg} || #subscriber{pid=Pid} <- ets:lookup(subscriber, Topic) ]. + +insert(Sub) when is_record(Sub, subscriber) -> + gen_server:call(?MODULE, {insert, Sub}). + +delete(Sub) when is_record(Sub, subscriber) -> + gen_server:cast(?MODULE, {delete, Sub}). + init([]) -> + ets:new(subscriber, [bag, protected, {keypos, 2}]), + error_logger:info_msg("emqtt_router is started."), {ok, #state{}}. +handle_call({insert, Sub}, _From, State) -> + ets:insert(subscriber, Sub), + {reply, ok, State}; + handle_call(Req, _From, State) -> {stop, {badreq, Req}, State}. +handle_cast({delete, Sub}, State) -> + ets:delete_object(subscriber, Sub), + {noreply, State}; + handle_cast(Msg, State) -> {stop, {badmsg, Msg}, State}. diff --git a/src/emqtt_subscriber.erl b/src/emqtt_subscriber.erl deleted file mode 100644 index 5dfb74414..000000000 --- a/src/emqtt_subscriber.erl +++ /dev/null @@ -1,42 +0,0 @@ --module(emqtt_subscriber). - --include("emqtt.hrl"). - --export([start_link/0]). - --behaviour(gen_server). - --export([init/1, - handle_call/3, - handle_cast/2, - handle_info/2, - terminate/2, - code_change/3]). - --record(state,{}). - -start_link() -> - gen_server:start_link({local, ?MODULE}, ?MODULE, [], []). - -init([]) -> - ets:new(subscriber, [bag, protected, {keypos, 2}]), - {ok, #state{}}. - -handle_call(Req, _From, State) -> - {stop, {badreq, Req}, State}. - -handle_cast(Msg, State) -> - {stop, {badmsg, Msg}, State}. - -handle_info(Info, State) -> - {stop, {badinfo, Info}, State}. - -terminate(_Reason, _State) -> - ok. - -code_change(_OldVsn, _State, _Extra) -> - ok. - - - - diff --git a/src/emqtt_sup.erl b/src/emqtt_sup.erl index c00618cda..89d61eab3 100644 --- a/src/emqtt_sup.erl +++ b/src/emqtt_sup.erl @@ -27,6 +27,7 @@ start_link() -> init([]) -> {ok, { {one_for_all, 5, 10}, [ ?CHILD(emqtt_topic, worker), - ?CHILD(emqtt_router, worker) + ?CHILD(emqtt_router, worker), + ?CHILD(emqtt_client_sup, supervisor) ]} }. diff --git a/src/emqtt_topic.erl b/src/emqtt_topic.erl index 9cdbedecd..2a6bb8d51 100644 --- a/src/emqtt_topic.erl +++ b/src/emqtt_topic.erl @@ -1,5 +1,4 @@ - -module(emqtt_topic). -include("emqtt.hrl"). @@ -24,15 +23,15 @@ start_link() -> gen_server:start_link({local, ?MODULE}, ?MODULE, [], []). match(Topic) when is_binary(Topic) -> - DirectMatches = mnesia:dirty_read(direct_topic, Topic), Words = topic_split(Topic), - WildcardMatches = lists:append([ + DirectMatches = mnesia:dirty_read(direct_topic, Words), + WildcardMatches = lists:append([ mnesia:dirty_read(wildcard_topic, Key) || - Key <- mnesia:dirty_all_keys(wildcard_topic), topic_match(Words, Key) + Key <- mnesia:dirty_all_keys(wildcard_topic), + topic_match(Words, Key) ]), DirectMatches ++ WildcardMatches. - insert(Topic) when is_binary(Topic) -> gen_server:call(?MODULE, {insert, Topic}). @@ -42,12 +41,15 @@ delete(Topic) when is_binary(Topic) -> init([]) -> {atomic, ok} = mnesia:create_table( direct_topic, [ + {record_name, topic}, {ram_copies, [node()]}, - {attributes, record_info(fields, direct_topic)}]), + {attributes, record_info(fields, topic)}]), {atomic, ok} = mnesia:create_table( wildcard_topic, [ + {record_name, topic}, {ram_copies, [node()]}, - {attributes, record_info(fields, wildcard_topic)}]), + {attributes, record_info(fields, topic)}]), + error_logger:info_msg("emqtt_topic is started."), {ok, #state{}}. handle_call({insert, Topic}, _From, State) -> @@ -55,9 +57,9 @@ handle_call({insert, Topic}, _From, State) -> Reply = case topic_type(Words) of direct -> - mnesia:dirty_write(#direct_topic{name=Topic}); + mnesia:dirty_write(direct_topic, #topic{words=Words, path=Topic}); wildcard -> - mnesia:dirty_write(#wildcard_topic{words=Words}) + mnesia:dirty_write(wildcard_topic, #topic{words=Words, path=Topic}) end, {reply, Reply, State}; @@ -68,9 +70,9 @@ handle_cast({delete, Topic}, State) -> Words = topic_split(Topic), case topic_type(Words) of direct -> - mnesia:dirty_delete(direct_topic, Topic); + mnesia:dirty_delete(direct_topic, #topic{words=Words, path=Topic}); wildcard -> - mnesia:direct_delete(wildcard_topic, Words) + mnesia:direct_delete(wildcard_topic, #topic{words=Words, path=Topic}) end, {noreply, State}; @@ -113,3 +115,4 @@ topic_match([], [_H|_T2]) -> topic_split(S) -> binary:split(S, [<<"/">>], [global]). +