diff --git a/apps/emqx_s3/src/emqx_s3_client.erl b/apps/emqx_s3/src/emqx_s3_client.erl index 99e2de4da..8e54d90f9 100644 --- a/apps/emqx_s3/src/emqx_s3_client.erl +++ b/apps/emqx_s3/src/emqx_s3_client.erl @@ -12,8 +12,10 @@ create/1, put_object/3, + put_object/4, start_multipart/2, + start_multipart/3, upload_part/5, complete_multipart/4, abort_multipart/3, @@ -84,12 +86,18 @@ create(Config) -> }. -spec put_object(client(), key(), iodata()) -> ok_or_error(term()). +put_object(Client, Key, Value) -> + put_object(Client, #{}, Key, Value). + +-spec put_object(client(), headers(), key(), iodata()) -> ok_or_error(term()). put_object( #{bucket := Bucket, upload_options := Options, headers := Headers, aws_config := AwsConfig}, + SpecialHeaders, Key, Value ) -> - try erlcloud_s3:put_object(Bucket, key(Key), Value, Options, Headers, AwsConfig) of + AllHeaders = join_headers(Headers, SpecialHeaders), + try erlcloud_s3:put_object(Bucket, key(Key), Value, Options, AllHeaders, AwsConfig) of Props when is_list(Props) -> ok catch @@ -99,11 +107,17 @@ put_object( end. -spec start_multipart(client(), key()) -> ok_or_error(upload_id(), term()). +start_multipart(Client, Key) -> + start_multipart(Client, #{}, Key). + +-spec start_multipart(client(), headers(), key()) -> ok_or_error(upload_id(), term()). start_multipart( #{bucket := Bucket, upload_options := Options, headers := Headers, aws_config := AwsConfig}, + SpecialHeaders, Key ) -> - case erlcloud_s3:start_multipart(Bucket, key(Key), Options, Headers, AwsConfig) of + AllHeaders = join_headers(Headers, SpecialHeaders), + case erlcloud_s3:start_multipart(Bucket, key(Key), Options, AllHeaders, AwsConfig) of {ok, Props} -> {ok, proplists:get_value('uploadId', Props)}; {error, Reason} -> @@ -300,6 +314,9 @@ erlcloud_string_headers(Headers) -> binary_headers(Headers) -> [{to_binary(K), V} || {K, V} <- Headers]. +join_headers(Headers, SpecialHeaders) -> + Headers ++ string_headers(maps:to_list(SpecialHeaders)). + to_binary(Val) when is_list(Val) -> list_to_binary(Val); to_binary(Val) when is_binary(Val) -> Val. diff --git a/apps/emqx_s3/src/emqx_s3_uploader.erl b/apps/emqx_s3/src/emqx_s3_uploader.erl index 07ae5bd19..f6414669d 100644 --- a/apps/emqx_s3/src/emqx_s3_uploader.erl +++ b/apps/emqx_s3/src/emqx_s3_uploader.erl @@ -93,10 +93,11 @@ callback_mode() -> handle_event_function. init([ProfileId, #{key := Key} = Opts]) -> process_flag(trap_exit, true), {ok, ClientConfig, UploaderConfig} = emqx_s3_profile_conf:checkout_config(ProfileId), - Client = client(ClientConfig, Opts), + Client = client(ClientConfig), {ok, upload_not_started, #{ profile_id => ProfileId, client => Client, + headers => maps:get(headers, Opts, #{}), key => Key, buffer => [], buffer_size => 0, @@ -205,8 +206,8 @@ maybe_start_upload(#{buffer_size := BufferSize, min_part_size := MinPartSize} = end. -spec start_upload(data()) -> {started, data()} | {error, term()}. -start_upload(#{client := Client, key := Key} = Data) -> - case emqx_s3_client:start_multipart(Client, Key) of +start_upload(#{client := Client, key := Key, headers := Headers} = Data) -> + case emqx_s3_client:start_multipart(Client, Headers, Key) of {ok, UploadId} -> NewData = Data#{upload_id => UploadId}, {started, NewData}; @@ -293,10 +294,11 @@ put_object( #{ client := Client, key := Key, - buffer := Buffer + buffer := Buffer, + headers := Headers } ) -> - case emqx_s3_client:put_object(Client, Key, lists:reverse(Buffer)) of + case emqx_s3_client:put_object(Client, Headers, Key, lists:reverse(Buffer)) of ok -> ok; {error, _} = Error -> @@ -320,6 +322,5 @@ unwrap(WrappedData) -> is_valid_part(WriteData, #{max_part_size := MaxPartSize, buffer_size := BufferSize}) -> BufferSize + iolist_size(WriteData) =< MaxPartSize. -client(Config, Opts) -> - Headers = maps:get(headers, Opts, #{}), - emqx_s3_client:create(Config#{headers => Headers}). +client(Config) -> + emqx_s3_client:create(Config). diff --git a/apps/emqx_s3/test/emqx_s3_client_SUITE.erl b/apps/emqx_s3/test/emqx_s3_client_SUITE.erl index f5a507653..ec7d5ebcf 100644 --- a/apps/emqx_s3/test/emqx_s3_client_SUITE.erl +++ b/apps/emqx_s3/test/emqx_s3_client_SUITE.erl @@ -109,7 +109,7 @@ t_url(Config) -> Client = client(Config), ok = emqx_s3_client:put_object(Client, Key, <<"data">>), - Url = emqx_s3_client:url(Client, Key), + Url = emqx_s3_client:uri(Client, Key), ?assertMatch( {ok, {{_StatusLine, 200, "OK"}, _Headers, "data"}},