Skip to content

Commit 99c451e

Browse files
committed
apacheGH-39865: [C++] Strip extension metadata when importing a registered extension
1 parent 787afa1 commit 99c451e

File tree

4 files changed

+44
-29
lines changed

4 files changed

+44
-29
lines changed

cpp/src/arrow/c/bridge.cc

+6
Original file line numberDiff line numberDiff line change
@@ -914,6 +914,8 @@ struct DecodedMetadata {
914914
std::shared_ptr<KeyValueMetadata> metadata;
915915
std::string extension_name;
916916
std::string extension_serialized;
917+
int extension_name_index = -1; // index of extension_name in metadata
918+
int extension_serialized_index = -1; // index of extension_serialized in metadata
917919
};
918920

919921
Result<DecodedMetadata> DecodeMetadata(const char* metadata) {
@@ -956,8 +958,10 @@ Result<DecodedMetadata> DecodeMetadata(const char* metadata) {
956958
RETURN_NOT_OK(read_string(&values[i]));
957959
if (keys[i] == kExtensionTypeKeyName) {
958960
decoded.extension_name = values[i];
961+
decoded.extension_name_index = i;
959962
} else if (keys[i] == kExtensionMetadataKeyName) {
960963
decoded.extension_serialized = values[i];
964+
decoded.extension_serialized_index = i;
961965
}
962966
}
963967
decoded.metadata = key_value_metadata(std::move(keys), std::move(values));
@@ -1046,6 +1050,8 @@ struct SchemaImporter {
10461050
ARROW_ASSIGN_OR_RAISE(
10471051
type_, registered_ext_type->Deserialize(std::move(type_),
10481052
metadata_.extension_serialized));
1053+
RETURN_NOT_OK(metadata_.metadata->DeleteMany(
1054+
{metadata_.extension_name_index, metadata_.extension_serialized_index}));
10491055
}
10501056
}
10511057

cpp/src/arrow/c/bridge_test.cc

+24-14
Original file line numberDiff line numberDiff line change
@@ -1870,7 +1870,7 @@ class TestSchemaImport : public ::testing::Test, public SchemaStructBuilder {
18701870
ASSERT_TRUE(ArrowSchemaIsReleased(&c_struct_));
18711871
Reset(); // for further tests
18721872
cb.AssertCalled(); // was released
1873-
AssertTypeEqual(*expected, *type);
1873+
AssertTypeEqual(*expected, *type, /*check_metadata=*/true);
18741874
}
18751875

18761876
void CheckImport(const std::shared_ptr<Field>& expected) {
@@ -1890,7 +1890,7 @@ class TestSchemaImport : public ::testing::Test, public SchemaStructBuilder {
18901890
ASSERT_TRUE(ArrowSchemaIsReleased(&c_struct_));
18911891
Reset(); // for further tests
18921892
cb.AssertCalled(); // was released
1893-
AssertSchemaEqual(*expected, *schema);
1893+
AssertSchemaEqual(*expected, *schema, /*check_metadata=*/true);
18941894
}
18951895

18961896
void CheckImportError() {
@@ -3569,7 +3569,7 @@ class TestSchemaRoundtrip : public ::testing::Test {
35693569
// Recreate the type
35703570
ASSERT_OK_AND_ASSIGN(actual, ImportType(&c_schema));
35713571
type = factory_expected();
3572-
AssertTypeEqual(*type, *actual);
3572+
AssertTypeEqual(*type, *actual, /*check_metadata=*/true);
35733573
type.reset();
35743574
actual.reset();
35753575

@@ -3600,7 +3600,7 @@ class TestSchemaRoundtrip : public ::testing::Test {
36003600
// Recreate the schema
36013601
ASSERT_OK_AND_ASSIGN(actual, ImportSchema(&c_schema));
36023602
schema = factory();
3603-
AssertSchemaEqual(*schema, *actual);
3603+
AssertSchemaEqual(*schema, *actual, /*check_metadata=*/true);
36043604
schema.reset();
36053605
actual.reset();
36063606

@@ -3693,13 +3693,23 @@ TEST_F(TestSchemaRoundtrip, Dictionary) {
36933693
}
36943694
}
36953695

3696+
std::shared_ptr<Field> GetStorageWithMetadata(const std::string& field_name,
3697+
const std::shared_ptr<DataType>& type) {
3698+
const auto& ext_type = checked_cast<const ExtensionType&>(*type);
3699+
auto storage_type = ext_type.storage_type();
3700+
auto md = KeyValueMetadata::Make({kExtensionTypeKeyName, kExtensionMetadataKeyName},
3701+
{ext_type.extension_name(), ext_type.Serialize()});
3702+
return field(field_name, storage_type, /*nullable=*/true, md);
3703+
}
3704+
36963705
TEST_F(TestSchemaRoundtrip, UnregisteredExtension) {
36973706
TestWithTypeFactory(uuid, []() { return fixed_size_binary(16); });
36983707
TestWithTypeFactory(dict_extension_type, []() { return dictionary(int8(), utf8()); });
36993708

37003709
// Inside nested type
3701-
TestWithTypeFactory([]() { return list(dict_extension_type()); },
3702-
[]() { return list(dictionary(int8(), utf8())); });
3710+
TestWithTypeFactory(
3711+
[]() { return list(dict_extension_type()); },
3712+
[]() { return list(GetStorageWithMetadata("item", dict_extension_type())); });
37033713
}
37043714

37053715
TEST_F(TestSchemaRoundtrip, RegisteredExtension) {
@@ -3808,7 +3818,7 @@ class TestArrayRoundtrip : public ::testing::Test {
38083818
{
38093819
std::shared_ptr<Array> expected;
38103820
ASSERT_OK_AND_ASSIGN(expected, ToResult(factory_expected()));
3811-
AssertTypeEqual(*expected->type(), *array->type());
3821+
AssertTypeEqual(*expected->type(), *array->type(), /*check_metadata=*/true);
38123822
AssertArraysEqual(*expected, *array, true);
38133823
}
38143824
array.reset();
@@ -3848,7 +3858,7 @@ class TestArrayRoundtrip : public ::testing::Test {
38483858
{
38493859
std::shared_ptr<RecordBatch> expected;
38503860
ASSERT_OK_AND_ASSIGN(expected, ToResult(factory()));
3851-
AssertSchemaEqual(*expected->schema(), *batch->schema());
3861+
AssertSchemaEqual(*expected->schema(), *batch->schema(), /*check_metadata=*/true);
38523862
AssertBatchesEqual(*expected, *batch);
38533863
}
38543864
batch.reset();
@@ -4228,7 +4238,7 @@ class TestDeviceArrayRoundtrip : public ::testing::Test {
42284238
{
42294239
std::shared_ptr<Array> expected;
42304240
ASSERT_OK_AND_ASSIGN(expected, ToResult(factory_expected()));
4231-
AssertTypeEqual(*expected->type(), *array->type());
4241+
AssertTypeEqual(*expected->type(), *array->type(), /*check_metadata=*/true);
42324242
AssertArraysEqual(*expected, *array, true);
42334243
}
42344244
array.reset();
@@ -4274,7 +4284,7 @@ class TestDeviceArrayRoundtrip : public ::testing::Test {
42744284
{
42754285
std::shared_ptr<RecordBatch> expected;
42764286
ASSERT_OK_AND_ASSIGN(expected, ToResult(factory()));
4277-
AssertSchemaEqual(*expected->schema(), *batch->schema());
4287+
AssertSchemaEqual(*expected->schema(), *batch->schema(), /*check_metadata=*/true);
42784288
AssertBatchesEqual(*expected, *batch);
42794289
}
42804290
batch.reset();
@@ -4351,7 +4361,7 @@ class TestArrayStreamExport : public BaseArrayStreamTest {
43514361
SchemaExportGuard schema_guard(&c_schema);
43524362
ASSERT_FALSE(ArrowSchemaIsReleased(&c_schema));
43534363
ASSERT_OK_AND_ASSIGN(auto schema, ImportSchema(&c_schema));
4354-
AssertSchemaEqual(expected, *schema);
4364+
AssertSchemaEqual(expected, *schema, /*check_metadata=*/true);
43554365
}
43564366

43574367
void AssertStreamEnd(struct ArrowArrayStream* c_stream) {
@@ -4435,7 +4445,7 @@ TEST_F(TestArrayStreamExport, ArrayLifetime) {
44354445
{
44364446
SchemaExportGuard schema_guard(&c_schema);
44374447
ASSERT_OK_AND_ASSIGN(auto got_schema, ImportSchema(&c_schema));
4438-
AssertSchemaEqual(*schema, *got_schema);
4448+
AssertSchemaEqual(*schema, *got_schema, /*check_metadata=*/true);
44394449
}
44404450

44414451
ASSERT_GT(pool_->bytes_allocated(), orig_allocated_);
@@ -4460,7 +4470,7 @@ TEST_F(TestArrayStreamExport, Errors) {
44604470
{
44614471
SchemaExportGuard schema_guard(&c_schema);
44624472
ASSERT_OK_AND_ASSIGN(auto schema, ImportSchema(&c_schema));
4463-
AssertSchemaEqual(schema, arrow::schema({}));
4473+
AssertSchemaEqual(schema, arrow::schema({}), /*check_metadata=*/true);
44644474
}
44654475

44664476
struct ArrowArray c_array;
@@ -4537,7 +4547,7 @@ TEST_F(TestArrayStreamRoundtrip, Simple) {
45374547
ASSERT_OK_AND_ASSIGN(auto reader, RecordBatchReader::Make(batches, orig_schema));
45384548

45394549
Roundtrip(std::move(reader), [&](const std::shared_ptr<RecordBatchReader>& reader) {
4540-
AssertSchemaEqual(*orig_schema, *reader->schema());
4550+
AssertSchemaEqual(*orig_schema, *reader->schema(), /*check_metadata=*/true);
45414551
AssertReaderNext(reader, *batches[0]);
45424552
AssertReaderNext(reader, *batches[1]);
45434553
AssertReaderEnd(reader);

cpp/src/arrow/util/key_value_metadata.cc

+8-10
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ void KeyValueMetadata::Append(std::string key, std::string value) {
9090
values_.push_back(std::move(value));
9191
}
9292

93-
Result<std::string> KeyValueMetadata::Get(const std::string& key) const {
93+
Result<std::string> KeyValueMetadata::Get(std::string_view key) const {
9494
auto index = FindKey(key);
9595
if (index < 0) {
9696
return Status::KeyError(key);
@@ -129,7 +129,7 @@ Status KeyValueMetadata::DeleteMany(std::vector<int64_t> indices) {
129129
return Status::OK();
130130
}
131131

132-
Status KeyValueMetadata::Delete(const std::string& key) {
132+
Status KeyValueMetadata::Delete(std::string_view key) {
133133
auto index = FindKey(key);
134134
if (index < 0) {
135135
return Status::KeyError(key);
@@ -138,20 +138,18 @@ Status KeyValueMetadata::Delete(const std::string& key) {
138138
}
139139
}
140140

141-
Status KeyValueMetadata::Set(const std::string& key, const std::string& value) {
141+
Status KeyValueMetadata::Set(std::string key, std::string value) {
142142
auto index = FindKey(key);
143143
if (index < 0) {
144-
Append(key, value);
144+
Append(std::move(key), std::move(value));
145145
} else {
146-
keys_[index] = key;
147-
values_[index] = value;
146+
keys_[index] = std::move(key);
147+
values_[index] = std::move(value);
148148
}
149149
return Status::OK();
150150
}
151151

152-
bool KeyValueMetadata::Contains(const std::string& key) const {
153-
return FindKey(key) >= 0;
154-
}
152+
bool KeyValueMetadata::Contains(std::string_view key) const { return FindKey(key) >= 0; }
155153

156154
void KeyValueMetadata::reserve(int64_t n) {
157155
DCHECK_GE(n, 0);
@@ -188,7 +186,7 @@ std::vector<std::pair<std::string, std::string>> KeyValueMetadata::sorted_pairs(
188186
return pairs;
189187
}
190188

191-
int KeyValueMetadata::FindKey(const std::string& key) const {
189+
int KeyValueMetadata::FindKey(std::string_view key) const {
192190
for (size_t i = 0; i < keys_.size(); ++i) {
193191
if (keys_[i] == key) {
194192
return static_cast<int>(i);

cpp/src/arrow/util/key_value_metadata.h

+6-5
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include <cstdint>
2121
#include <memory>
2222
#include <string>
23+
#include <string_view>
2324
#include <unordered_map>
2425
#include <utility>
2526
#include <vector>
@@ -44,13 +45,13 @@ class ARROW_EXPORT KeyValueMetadata {
4445
void ToUnorderedMap(std::unordered_map<std::string, std::string>* out) const;
4546
void Append(std::string key, std::string value);
4647

47-
Result<std::string> Get(const std::string& key) const;
48-
bool Contains(const std::string& key) const;
48+
Result<std::string> Get(std::string_view key) const;
49+
bool Contains(std::string_view key) const;
4950
// Note that deleting may invalidate known indices
50-
Status Delete(const std::string& key);
51+
Status Delete(std::string_view key);
5152
Status Delete(int64_t index);
5253
Status DeleteMany(std::vector<int64_t> indices);
53-
Status Set(const std::string& key, const std::string& value);
54+
Status Set(std::string key, std::string value);
5455

5556
void reserve(int64_t n);
5657

@@ -63,7 +64,7 @@ class ARROW_EXPORT KeyValueMetadata {
6364
std::vector<std::pair<std::string, std::string>> sorted_pairs() const;
6465

6566
/// \brief Perform linear search for key, returning -1 if not found
66-
int FindKey(const std::string& key) const;
67+
int FindKey(std::string_view key) const;
6768

6869
std::shared_ptr<KeyValueMetadata> Copy() const;
6970

0 commit comments

Comments
 (0)