Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 99 additions & 55 deletions music_assistant/controllers/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import base64
import logging
import os
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, cast
from uuid import uuid4

import aiofiles
Expand Down Expand Up @@ -69,6 +69,7 @@
from music_assistant.helpers.json import JSON_DECODE_EXCEPTIONS, async_json_dumps, async_json_loads
from music_assistant.helpers.util import load_provider_module
from music_assistant.models import ProviderModuleType
from music_assistant.models.music_provider import MusicProvider

if TYPE_CHECKING:
import asyncio
Expand Down Expand Up @@ -117,7 +118,7 @@ async def setup(self) -> None:
@property
def onboard_done(self) -> bool:
"""Return True if onboarding is done."""
return self.get(CONF_ONBOARD_DONE, False)
return bool(self.get(CONF_ONBOARD_DONE, False))

async def close(self) -> None:
"""Handle logic on server stop."""
Expand Down Expand Up @@ -196,12 +197,12 @@ async def get_provider_configs(
include_values: bool = False,
) -> list[ProviderConfig]:
"""Return all known provider configurations, optionally filtered by ProviderType."""
raw_values: dict[str, dict] = self.get(CONF_PROVIDERS, {})
raw_values = self.get(CONF_PROVIDERS, {})
prov_entries = {x.domain for x in self.mass.get_provider_manifests()}
return [
await self.get_provider_config(prov_conf["instance_id"])
if include_values
else ProviderConfig.parse([], prov_conf)
else cast("ProviderConfig", ProviderConfig.parse([], prov_conf))
for prov_conf in raw_values.values()
if (provider_type is None or prov_conf["type"] == provider_type)
and (provider_domain is None or prov_conf["domain"] == provider_domain)
Expand All @@ -224,7 +225,7 @@ async def get_provider_config(self, instance_id: str) -> ProviderConfig:
else:
msg = f"Unknown provider domain: {raw_conf['domain']}"
raise KeyError(msg)
return ProviderConfig.parse(config_entries, raw_conf)
return cast("ProviderConfig", ProviderConfig.parse(config_entries, raw_conf))
msg = f"No config found for provider id {instance_id}"
raise KeyError(msg)

Expand Down Expand Up @@ -284,23 +285,29 @@ async def get_provider_config_entries( # noqa: PLR0915
supported_features = provider.supported_features
else:
provider = None
supported_features: set[ProviderFeature] = getattr(
prov_mod, "SUPPORTED_FEATURES", set()
)
supported_features = getattr(prov_mod, "SUPPORTED_FEATURES", set())
extra_entries: list[ConfigEntry] = []
if manifest.type == ProviderType.MUSIC:
# library sync settings
if ProviderFeature.LIBRARY_ARTISTS in supported_features:
extra_entries.append(CONF_ENTRY_LIBRARY_SYNC_ARTISTS)
if ProviderFeature.LIBRARY_ALBUMS in supported_features:
extra_entries.append(CONF_ENTRY_LIBRARY_SYNC_ALBUMS)
if provider and provider.is_streaming_provider:
if (
provider
and isinstance(provider, MusicProvider)
and provider.is_streaming_provider
):
extra_entries.append(CONF_ENTRY_LIBRARY_SYNC_ALBUM_TRACKS)
if ProviderFeature.LIBRARY_TRACKS in supported_features:
extra_entries.append(CONF_ENTRY_LIBRARY_SYNC_TRACKS)
if ProviderFeature.LIBRARY_PLAYLISTS in supported_features:
extra_entries.append(CONF_ENTRY_LIBRARY_SYNC_PLAYLISTS)
if provider and provider.is_streaming_provider:
if (
provider
and isinstance(provider, MusicProvider)
and provider.is_streaming_provider
):
extra_entries.append(CONF_ENTRY_LIBRARY_SYNC_PLAYLIST_TRACKS)
if ProviderFeature.LIBRARY_AUDIOBOOKS in supported_features:
extra_entries.append(CONF_ENTRY_LIBRARY_SYNC_AUDIOBOOKS)
Expand Down Expand Up @@ -362,7 +369,11 @@ async def save_provider_config(
if instance_id is not None:
config = await self._update_provider_config(instance_id, values)
else:
config = await self._add_provider_config(provider_domain, values)
result = await self._add_provider_config(provider_domain, values)
if isinstance(result, list):
msg = "Unexpected return type from _add_provider_config"
raise TypeError(msg)
config = result
# mark onboard done whenever the (first) provider is added
# this will be replaced later by a more sophisticated onboarding process
self.set(CONF_ONBOARD_DONE, True)
Expand Down Expand Up @@ -413,7 +424,7 @@ async def get_player_configs(
return [
await self.get_player_config(raw_conf["player_id"])
if include_values
else PlayerConfig.parse([], raw_conf)
else cast("PlayerConfig", PlayerConfig.parse([], raw_conf))
for raw_conf in list(self.get(CONF_PLAYERS, {}).values())
# filter out unavailable providers (only if we requested the full info)
if (
Expand Down Expand Up @@ -447,7 +458,7 @@ async def get_player_config(
raw_conf["available"] = False
raw_conf["name"] = raw_conf.get("name")
raw_conf["default_name"] = raw_conf.get("default_name") or raw_conf["player_id"]
return PlayerConfig.parse(conf_entries, raw_conf)
return cast("PlayerConfig", PlayerConfig.parse(conf_entries, raw_conf))
msg = f"No config found for player id {player_id}"
raise KeyError(msg)

Expand Down Expand Up @@ -480,7 +491,7 @@ async def get_player_config_value(
player_id: str,
key: str,
unpack_splitted_values: bool = False,
) -> ConfigValueType:
) -> ConfigValueType | tuple[str, ...] | list[tuple[str, ...]]:
"""Return single configentry value for a player."""
conf = await self.get_player_config(player_id)
if unpack_splitted_values:
Expand All @@ -499,9 +510,12 @@ def get_raw_player_config_value(

Note that this only returns the stored value without any validation or default.
"""
return self.get(
f"{CONF_PLAYERS}/{player_id}/values/{key}",
self.get(f"{CONF_PLAYERS}/{player_id}/{key}", default),
return cast(
"ConfigValueType",
self.get(
f"{CONF_PLAYERS}/{player_id}/values/{key}",
self.get(f"{CONF_PLAYERS}/{player_id}/{key}", default),
),
)

def get_base_player_config(self, player_id: str, provider: str) -> PlayerConfig:
Expand All @@ -516,7 +530,7 @@ def get_base_player_config(self, player_id: str, provider: str) -> PlayerConfig:
"player_id": player_id,
"provider": provider,
}
return PlayerConfig.parse([], raw_conf)
return cast("PlayerConfig", PlayerConfig.parse([], raw_conf))

@api_command("config/players/save")
async def save_player_config(
Expand All @@ -527,7 +541,7 @@ async def save_player_config(
changed_keys = config.update(values)
if not changed_keys:
# no changes
return None
return config
# validate/handle the update in the player manager
await self.mass.players.on_player_config_change(config, changed_keys)
# actually store changes (if the above did not raise)
Expand Down Expand Up @@ -602,9 +616,15 @@ def get_player_dsp_config(self, player_id: str) -> DSPConfig:
dsp_config.filters.append(
ToneControlFilter(
enabled=True,
bass_level=deprecated_eq_bass,
mid_level=deprecated_eq_mid,
treble_level=deprecated_eq_treble,
bass_level=float(deprecated_eq_bass)
if isinstance(deprecated_eq_bass, (int, float, str))
else 0.0,
mid_level=float(deprecated_eq_mid)
if isinstance(deprecated_eq_mid, (int, float, str))
else 0.0,
treble_level=float(deprecated_eq_treble)
if isinstance(deprecated_eq_treble, (int, float, str))
else 0.0,
)
)

Expand Down Expand Up @@ -748,17 +768,20 @@ async def create_builtin_provider_config(self, provider_domain: str) -> None:
instance_id = f"{manifest.domain}--{shortuuid.random(8)}"
else:
instance_id = manifest.domain
default_config: ProviderConfig = ProviderConfig.parse(
config_entries,
{
"type": manifest.type.value,
"domain": manifest.domain,
"instance_id": instance_id,
"name": manifest.name,
# note: this will only work for providers that do
# not have any required config entries or provide defaults
"values": {},
},
default_config = cast(
"ProviderConfig",
ProviderConfig.parse(
config_entries,
{
"type": manifest.type.value,
"domain": manifest.domain,
"instance_id": instance_id,
"name": manifest.name,
# note: this will only work for providers that do
# not have any required config entries or provide defaults
"values": {},
},
),
)
default_config.validate()
conf_key = f"{CONF_PROVIDERS}/{default_config.instance_id}"
Expand All @@ -770,9 +793,12 @@ async def get_core_configs(self, include_values: bool = False) -> list[CoreConfi
return [
await self.get_core_config(core_controller)
if include_values
else CoreConfig.parse(
[],
self.get(f"{CONF_CORE}/{core_controller}", {"domain": core_controller}),
else cast(
"CoreConfig",
CoreConfig.parse(
[],
self.get(f"{CONF_CORE}/{core_controller}", {"domain": core_controller}),
),
)
for core_controller in CONFIGURABLE_CORE_CONTROLLERS
]
Expand All @@ -782,7 +808,7 @@ async def get_core_config(self, domain: str) -> CoreConfig:
"""Return configuration for a single core controller."""
raw_conf = self.get(f"{CONF_CORE}/{domain}", {"domain": domain})
config_entries = await self.get_core_config_entries(domain)
return CoreConfig.parse(config_entries, raw_conf)
return cast("CoreConfig", CoreConfig.parse(config_entries, raw_conf))

@api_command("config/core/get_value")
async def get_core_config_value(self, domain: str, key: str) -> ConfigValueType:
Expand Down Expand Up @@ -848,9 +874,12 @@ def get_raw_core_config_value(

Note that this only returns the stored value without any validation or default.
"""
return self.get(
f"{CONF_CORE}/{core_module}/values/{key}",
self.get(f"{CONF_CORE}/{core_module}/{key}", default),
return cast(
"ConfigValueType",
self.get(
f"{CONF_CORE}/{core_module}/values/{key}",
self.get(f"{CONF_CORE}/{core_module}/{key}", default),
),
)

def get_raw_provider_config_value(
Expand All @@ -861,9 +890,12 @@ def get_raw_provider_config_value(

Note that this only returns the stored value without any validation or default.
"""
return self.get(
f"{CONF_PROVIDERS}/{provider_instance}/values/{key}",
self.get(f"{CONF_PROVIDERS}/{provider_instance}/{key}", default),
return cast(
"ConfigValueType",
self.get(
f"{CONF_PROVIDERS}/{provider_instance}/values/{key}",
self.get(f"{CONF_PROVIDERS}/{provider_instance}/{key}", default),
),
)

def set_raw_provider_config_value(
Expand All @@ -883,6 +915,9 @@ def set_raw_provider_config_value(
msg = f"Invalid provider_instance: {provider_instance}"
raise KeyError(msg)
if encrypted:
if not isinstance(value, str):
msg = f"Cannot encrypt non-string value for key {key}"
raise ValueError(msg)
value = self.encrypt_string(value)
if key in BASE_KEYS:
self.set(f"{CONF_PROVIDERS}/{provider_instance}/{key}", value)
Expand Down Expand Up @@ -934,6 +969,7 @@ def encrypt_string(self, str_value: str) -> str:
"""Encrypt a (password)string with Fernet."""
if str_value.startswith(ENCRYPT_SUFFIX):
return str_value
assert self._fernet is not None
return ENCRYPT_SUFFIX + self._fernet.encrypt(str_value.encode()).decode()

def decrypt_string(self, encrypted_str: str) -> str:
Expand All @@ -942,6 +978,7 @@ def decrypt_string(self, encrypted_str: str) -> str:
return encrypted_str
if not encrypted_str.startswith(ENCRYPT_SUFFIX):
return encrypted_str
assert self._fernet is not None
try:
return self._fernet.decrypt(encrypted_str.replace(ENCRYPT_SUFFIX, "").encode()).decode()
except InvalidToken as err:
Expand Down Expand Up @@ -972,7 +1009,6 @@ async def _migrate(self) -> None: # noqa: PLR0915
instance_id: str
provider_config: dict[str, Any]
player_config: dict[str, Any]
values: dict[str, ConfigValueType]

# Older versions of MA can create corrupt entries with no domain if retrying
# logic runs after a provider has been removed. Remove those corrupt entries.
Expand Down Expand Up @@ -1020,7 +1056,12 @@ async def _migrate(self) -> None: # noqa: PLR0915
# migrate player_group entries
ugp_found = False
for player_config in self._data.get(CONF_PLAYERS, {}).values():
if not player_config.get("provider").startswith("player_group"):
provider = player_config.get("provider")
if (
not provider
or not isinstance(provider, str)
or not provider.startswith("player_group")
):
continue
if not (values := player_config.get("values")):
continue
Expand Down Expand Up @@ -1133,7 +1174,7 @@ async def _add_provider_config(
self,
provider_domain: str,
values: dict[str, ConfigValueType],
) -> list[ConfigEntry] | ProviderConfig:
) -> ProviderConfig:
"""
Add new Provider (instance).

Expand Down Expand Up @@ -1170,15 +1211,18 @@ async def _add_provider_config(
config_entries = await self.get_provider_config_entries(
provider_domain=provider_domain, instance_id=instance_id, values=values
)
config: ProviderConfig = ProviderConfig.parse(
config_entries,
{
"type": manifest.type.value,
"domain": manifest.domain,
"instance_id": instance_id,
"default_name": manifest.name,
"values": values,
},
config = cast(
"ProviderConfig",
ProviderConfig.parse(
config_entries,
{
"type": manifest.type.value,
"domain": manifest.domain,
"instance_id": instance_id,
"default_name": manifest.name,
"values": values,
},
),
)
# validate the new config
config.validate()
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,6 @@ enable_error_code = [
]
exclude = [
'^music_assistant/controllers/cache.py$',
'^music_assistant/controllers/config.py$',
'^music_assistant/controllers/media/.*$',
'^music_assistant/controllers/music.py$',
'^music_assistant/controllers/player_queues.py$',
Expand Down