Skip to content

Commit b319014

Browse files
authored
[CPU]PageAttn with 4bit-quantization (#27992)
### Details: - *Add new hint to set group_size for key/value cache* - *Add grouped 4bit sym/asym quantization support for PageAttentionNode* - *Add grouped quantization for U8 quantization for PageAttentionNode* ### Tickets: - *CVS-151586* --------- Signed-off-by: [email protected] <[email protected]> Signed-off-by: Zhang Yi3 <[email protected]> Signed-off-by: Zhang Yi <[email protected]>
1 parent 345163f commit b319014

File tree

18 files changed

+1413
-411
lines changed

18 files changed

+1413
-411
lines changed

src/bindings/python/src/openvino/runtime/properties/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@
2828
from openvino._pyopenvino.properties import loaded_from_cache
2929
from openvino._pyopenvino.properties import cache_encryption_callbacks
3030
from openvino._pyopenvino.properties import weights_path
31+
from openvino._pyopenvino.properties import key_cache_precision
32+
from openvino._pyopenvino.properties import value_cache_precision
33+
from openvino._pyopenvino.properties import key_cache_group_size
34+
from openvino._pyopenvino.properties import value_cache_group_size
3135

3236
# Submodules
3337
from openvino.runtime.properties import hint

src/bindings/python/src/pyopenvino/core/properties/properties.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@ void regmodule_properties(py::module m) {
3434
wrap_property_RW(m_properties, ov::force_tbb_terminate, "force_tbb_terminate");
3535
wrap_property_RW(m_properties, ov::enable_mmap, "enable_mmap");
3636
wrap_property_RW(m_properties, ov::weights_path, "weights_path");
37+
wrap_property_RW(m_properties, ov::key_cache_precision, "key_cache_precision");
38+
wrap_property_RW(m_properties, ov::value_cache_precision, "value_cache_precision");
39+
wrap_property_RW(m_properties, ov::key_cache_group_size, "key_cache_group_size");
40+
wrap_property_RW(m_properties, ov::value_cache_group_size, "value_cache_group_size");
3741

3842
wrap_property_RO(m_properties, ov::supported_properties, "supported_properties");
3943
wrap_property_RO(m_properties, ov::available_devices, "available_devices");

src/bindings/python/tests/test_runtime/test_properties.py

+12
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,18 @@ def test_properties_ro(ov_property_ro, expected_value):
257257
"WEIGHTS_PATH",
258258
(("./model.bin", "./model.bin"),),
259259
),
260+
(
261+
props.key_cache_group_size,
262+
"KEY_CACHE_GROUP_SIZE",
263+
((64, 64),),
264+
),
265+
(
266+
props.value_cache_group_size,
267+
"VALUE_CACHE_GROUP_SIZE",
268+
((64, 64),),
269+
),
270+
(props.key_cache_precision, "KEY_CACHE_PRECISION", ((Type.f32, Type.f32),)),
271+
(props.value_cache_precision, "VALUE_CACHE_PRECISION", ((Type.f32, Type.f32),)),
260272
(hints.inference_precision, "INFERENCE_PRECISION_HINT", ((Type.f32, Type.f32),)),
261273
(
262274
hints.model_priority,

src/inference/include/openvino/runtime/properties.hpp

+24
Original file line numberDiff line numberDiff line change
@@ -1301,4 +1301,28 @@ static constexpr Property<std::vector<std::string>, PropertyMutability::RO> exec
13011301
* @note This property is used for weightless caching. Only used when ov::CacheMode Property is set to "OPTIMIZE_SIZE".
13021302
*/
13031303
static constexpr Property<std::string, PropertyMutability::RW> weights_path{"WEIGHTS_PATH"};
1304+
1305+
/**
1306+
* @brief The precision of key cache compression
1307+
* @ingroup ov_runtime_cpp_prop_api
1308+
*/
1309+
static constexpr Property<element::Type, PropertyMutability::RW> key_cache_precision{"KEY_CACHE_PRECISION"};
1310+
1311+
/**
1312+
* @brief The precision of value cache compression
1313+
* @ingroup ov_runtime_cpp_prop_api
1314+
*/
1315+
static constexpr Property<element::Type, PropertyMutability::RW> value_cache_precision{"VALUE_CACHE_PRECISION"};
1316+
1317+
/**
1318+
* @brief The group_size of key cache compression
1319+
* @ingroup ov_runtime_cpp_prop_api
1320+
*/
1321+
static constexpr Property<uint64_t, PropertyMutability::RW> key_cache_group_size{"KEY_CACHE_GROUP_SIZE"};
1322+
1323+
/**
1324+
* @brief The group_size of value cache compression
1325+
* @ingroup ov_runtime_cpp_prop_api
1326+
*/
1327+
static constexpr Property<uint64_t, PropertyMutability::RW> value_cache_group_size{"VALUE_CACHE_GROUP_SIZE"};
13041328
} // namespace ov

src/plugins/intel_cpu/src/compiled_model.cpp

+12
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,10 @@ ov::Any CompiledModel::get_property(const std::string& name) const {
256256
RO_property(ov::intel_cpu::sparse_weights_decompression_rate.name()),
257257
RO_property(ov::hint::dynamic_quantization_group_size.name()),
258258
RO_property(ov::hint::kv_cache_precision.name()),
259+
RO_property(ov::key_cache_precision.name()),
260+
RO_property(ov::value_cache_precision.name()),
261+
RO_property(ov::key_cache_group_size.name()),
262+
RO_property(ov::value_cache_group_size.name()),
259263
};
260264

261265
return ro_properties;
@@ -313,6 +317,14 @@ ov::Any CompiledModel::get_property(const std::string& name) const {
313317
return decltype(ov::hint::dynamic_quantization_group_size)::value_type(config.fcDynamicQuantizationGroupSize);
314318
} else if (name == ov::hint::kv_cache_precision) {
315319
return decltype(ov::hint::kv_cache_precision)::value_type(config.kvCachePrecision);
320+
} else if (name == ov::key_cache_precision) {
321+
return decltype(ov::key_cache_precision)::value_type(config.keyCachePrecision);
322+
} else if (name == ov::value_cache_precision) {
323+
return decltype(ov::value_cache_precision)::value_type(config.valueCachePrecision);
324+
} else if (name == ov::key_cache_group_size) {
325+
return decltype(ov::key_cache_group_size)::value_type(config.keyCacheGroupSize);
326+
} else if (name == ov::value_cache_group_size) {
327+
return decltype(ov::value_cache_group_size)::value_type(config.valueCacheGroupSize);
316328
}
317329
OPENVINO_THROW("Unsupported property: ", name);
318330
}

src/plugins/intel_cpu/src/config.cpp

+85-1
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,60 @@ void Config::readProperties(const ov::AnyMap& prop, const ModelType modelType) {
309309
ov::hint::kv_cache_precision.name(),
310310
". Supported values: u8, bf16, f16, f32");
311311
}
312+
} else if (key == ov::key_cache_precision.name()) {
313+
try {
314+
keyCachePrecisionSetExplicitly = true;
315+
auto const prec = val.as<ov::element::Type>();
316+
if (one_of(prec, ov::element::f32, ov::element::f16, ov::element::bf16, ov::element::u8)) {
317+
keyCachePrecision = prec;
318+
} else {
319+
OPENVINO_THROW("keyCachePrecision doesn't support value ", prec);
320+
}
321+
} catch (ov::Exception&) {
322+
OPENVINO_THROW("Wrong value ",
323+
val.as<std::string>(),
324+
" for property key ",
325+
ov::key_cache_precision.name(),
326+
". Supported values: u8, bf16, f16, f32");
327+
}
328+
} else if (key == ov::value_cache_precision.name()) {
329+
try {
330+
valueCachePrecisionSetExplicitly = true;
331+
auto const prec = val.as<ov::element::Type>();
332+
if (one_of(prec,
333+
ov::element::f32,
334+
ov::element::f16,
335+
ov::element::bf16,
336+
ov::element::u8,
337+
ov::element::u4)) {
338+
valueCachePrecision = prec;
339+
} else {
340+
OPENVINO_THROW("valueCachePrecision doesn't support value ", prec);
341+
}
342+
} catch (ov::Exception&) {
343+
OPENVINO_THROW("Wrong value ",
344+
val.as<std::string>(),
345+
" for property key ",
346+
ov::value_cache_precision.name(),
347+
". Supported values: u4, u8, bf16, f16, f32");
348+
}
349+
} else if (key == ov::key_cache_group_size.name() || key == ov::value_cache_group_size.name()) {
350+
try {
351+
auto const groupSize = val.as<uint64_t>();
352+
if (key == ov::key_cache_group_size.name()) {
353+
keyCacheGroupSizeSetExplicitly = true;
354+
keyCacheGroupSize = groupSize;
355+
} else {
356+
valueCacheGroupSizeSetExplicitly = true;
357+
valueCacheGroupSize = groupSize;
358+
}
359+
} catch (ov::Exception&) {
360+
OPENVINO_THROW("Wrong value ",
361+
val.as<std::string>(),
362+
" for property key ",
363+
key,
364+
". Expected only unsinged integer numbers");
365+
}
312366
} else if (key == ov::cache_encryption_callbacks.name()) {
313367
try {
314368
const auto& encryption_callbacks = val.as<EncryptionCallbacks>();
@@ -344,6 +398,13 @@ void Config::readProperties(const ov::AnyMap& prop, const ModelType modelType) {
344398
aclFastMath = true;
345399
}
346400
#endif
401+
// key/value cache precision has higher priority, if not defined use kvCachePrecision
402+
if (!keyCachePrecisionSetExplicitly && kvCachePrecisionSetExplicitly) {
403+
keyCachePrecision = kvCachePrecision;
404+
}
405+
if (!valueCachePrecisionSetExplicitly && kvCachePrecisionSetExplicitly) {
406+
valueCachePrecision = kvCachePrecision;
407+
}
347408
// disable dynamic quantization and kv quantization for best accuracy
348409
if (executionMode == ov::hint::ExecutionMode::ACCURACY) {
349410
if (!fcDynamicQuantizationGroupSizeSetExplicitly) {
@@ -352,6 +413,12 @@ void Config::readProperties(const ov::AnyMap& prop, const ModelType modelType) {
352413
if (!kvCachePrecisionSetExplicitly) {
353414
kvCachePrecision = ov::element::f32;
354415
}
416+
if (!keyCachePrecisionSetExplicitly) {
417+
keyCachePrecision = ov::element::f32;
418+
}
419+
if (!valueCachePrecisionSetExplicitly) {
420+
valueCachePrecision = ov::element::f32;
421+
}
355422
}
356423

357424
if (!prop.empty())
@@ -398,14 +465,31 @@ void Config::applyRtInfo(const std::shared_ptr<const ov::Model>& model) {
398465
// if user sets explicitly, it will be higher priority than rt_info
399466
if (!kvCachePrecisionSetExplicitly &&
400467
model->has_rt_info({"runtime_options", ov::hint::kv_cache_precision.name()})) {
401-
this->kvCachePrecision =
468+
this->kvCachePrecision = this->keyCachePrecision = this->valueCachePrecision =
402469
model->get_rt_info<ov::element::Type>({"runtime_options", ov::hint::kv_cache_precision.name()});
403470
}
404471
if (!fcDynamicQuantizationGroupSizeSetExplicitly &&
405472
model->has_rt_info({"runtime_options", ov::hint::dynamic_quantization_group_size.name()})) {
406473
this->fcDynamicQuantizationGroupSize =
407474
model->get_rt_info<uint64_t>({"runtime_options", ov::hint::dynamic_quantization_group_size.name()});
408475
}
476+
if (!keyCachePrecisionSetExplicitly && model->has_rt_info({"runtime_options", ov::key_cache_precision.name()})) {
477+
this->keyCachePrecision =
478+
model->get_rt_info<ov::element::Type>({"runtime_options", ov::key_cache_precision.name()});
479+
}
480+
if (!valueCachePrecisionSetExplicitly &&
481+
model->has_rt_info({"runtime_options", ov::value_cache_precision.name()})) {
482+
this->valueCachePrecision =
483+
model->get_rt_info<ov::element::Type>({"runtime_options", ov::value_cache_precision.name()});
484+
}
485+
if (!keyCacheGroupSizeSetExplicitly && model->has_rt_info({"runtime_options", ov::key_cache_group_size.name()})) {
486+
this->keyCacheGroupSize = model->get_rt_info<uint64_t>({"runtime_options", ov::key_cache_group_size.name()});
487+
}
488+
if (!valueCacheGroupSizeSetExplicitly &&
489+
model->has_rt_info({"runtime_options", ov::value_cache_group_size.name()})) {
490+
this->valueCacheGroupSize =
491+
model->get_rt_info<uint64_t>({"runtime_options", ov::value_cache_group_size.name()});
492+
}
409493
}
410494

411495
} // namespace intel_cpu

src/plugins/intel_cpu/src/config.h

+10
Original file line numberDiff line numberDiff line change
@@ -48,17 +48,27 @@ struct Config {
4848
uint64_t fcDynamicQuantizationGroupSize = 32;
4949
bool fcDynamicQuantizationGroupSizeSetExplicitly = false;
5050
bool kvCachePrecisionSetExplicitly = false;
51+
bool keyCachePrecisionSetExplicitly = false;
52+
bool valueCachePrecisionSetExplicitly = false;
53+
bool keyCacheGroupSizeSetExplicitly = false;
54+
bool valueCacheGroupSizeSetExplicitly = false;
5155
#if defined(OV_CPU_WITH_ACL)
5256
bool aclFastMath = false;
5357
#endif
5458
#if defined(OPENVINO_ARCH_X86_64)
5559
ov::element::Type kvCachePrecision = ov::element::u8;
60+
ov::element::Type keyCachePrecision = ov::element::u8;
61+
ov::element::Type valueCachePrecision = ov::element::u8;
5662
size_t rtCacheCapacity = 5000ul;
5763
#else
5864
ov::element::Type kvCachePrecision = ov::element::f16;
65+
ov::element::Type keyCachePrecision = ov::element::f16;
66+
ov::element::Type valueCachePrecision = ov::element::f16;
5967
// TODO: Executor cache may leads to incorrect behavior on oneDNN ACL primitives
6068
size_t rtCacheCapacity = 0ul;
6169
#endif
70+
size_t keyCacheGroupSize = 0ul;
71+
size_t valueCacheGroupSize = 0ul;
6272
ov::threading::IStreamsExecutor::Config streamExecutorConfig;
6373
int streams = 1;
6474
bool streamsChanged = false;

0 commit comments

Comments
 (0)