Skip to content

Fix render_labels color parameter KeyError #456

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
261 changes: 209 additions & 52 deletions src/spatialdata_plot/pl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -733,12 +733,17 @@ def _set_color_source_vec(
table_name: str | None = None,
table_layer: str | None = None,
render_type: Literal["points"] | None = None,
) -> tuple[ArrayLike | pd.Series | None, ArrayLike, bool]:
) -> tuple[pd.Categorical | None, ArrayLike, bool]:
if value_to_plot is None and element is not None:
color = np.full(len(element), na_color)
return color, color, False
return None, color, False

# First check if value_to_plot is likely a color specification rather than a column name
if value_to_plot is not None and _is_color_like(value_to_plot) and element is not None:
# User passed a color, not a column name
color = np.full(len(element), value_to_plot)
return None, color, False

# Figure out where to get the color from
origins = _locate_value(
value_key=value_to_plot,
sdata=sdata,
Expand All @@ -760,27 +765,55 @@ def _set_color_source_vec(
table_layer=table_layer,
)[value_to_plot]

# numerical case, return early
# TODO temporary split until refactor is complete
if color_source_vector is not None and not isinstance(color_source_vector.dtype, pd.CategoricalDtype):
if (
not isinstance(element, GeoDataFrame)
and isinstance(palette, list)
and palette[0] is not None
or isinstance(element, GeoDataFrame)
and isinstance(palette, list)
):
logger.warning(
"Ignoring categorical palette which is given for a continuous variable. "
"Consider using `cmap` to pass a ColorMap."
)
return None, color_source_vector, False

color_source_vector = pd.Categorical(color_source_vector) # convert, e.g., `pd.Series`
# Convert to categorical if not already
if not isinstance(color_source_vector, pd.Categorical):
try:
color_source_vector = pd.Categorical(color_source_vector)
except (ValueError, TypeError) as e:
logger.warning(f"Could not convert '{value_to_plot}' to categorical: {e}")
# For numeric data, return None to indicate non-categorical
if pd.api.types.is_numeric_dtype(color_source_vector):
if (
not isinstance(element, GeoDataFrame)
and isinstance(palette, list)
and palette[0] is not None
or isinstance(element, GeoDataFrame)
and isinstance(palette, list)
):
logger.warning(
"Ignoring categorical palette which is given for a continuous variable. "
"Consider using `cmap` to pass a ColorMap."
)
return None, color_source_vector, False
# For other types, try to use as is
return None, color_source_vector, False

# At this point color_source_vector should be categorical
adata_with_colors = None
cluster_key = value_to_plot

# First check if the table_name is specified
if table_name is not None and table_name in sdata.tables:
adata_with_colors = sdata.tables[table_name]
adata_with_colors.uns["spatialdata_key"] = table_name

# If not, but the element is annotated by any table, use that
elif element_name is not None:
annotator_tables = get_element_annotators(sdata, element_name)
if len(annotator_tables) > 0:
# Use the first table that annotates this element
first_table = next(iter(annotator_tables))
adata_with_colors = sdata.tables[first_table]
adata_with_colors.uns["spatialdata_key"] = first_table

# If no specific table is found, try using the default table
elif sdata.table is not None:
adata_with_colors = sdata.table
adata_with_colors.uns["spatialdata_key"] = "default_table"

color_mapping = _get_categorical_color_mapping(
adata=sdata.table,
cluster_key=value_to_plot,
adata=adata_with_colors,
cluster_key=cluster_key,
color_source_vector=color_source_vector,
cmap_params=cmap_params,
alpha=alpha,
Expand All @@ -790,18 +823,28 @@ def _set_color_source_vec(
render_type=render_type,
)

# Set categories to match the mapping keys
color_source_vector = color_source_vector.set_categories(color_mapping.keys())
if color_mapping is None:
raise ValueError("Unable to create color palette.")

# do not rename categories, as colors need not be unique
color_vector = color_source_vector.map(color_mapping)
# Map categorical values to colors
try:
color_vector = color_source_vector.map(color_mapping)
except (KeyError, TypeError, ValueError) as e:
logger.warning(f"Error mapping colors: {e}. Attempting alternate approach.")
# Try mapping with string conversion
str_mapping = {str(k): v for k, v in color_mapping.items()}
color_vector = pd.Series(
[str_mapping.get(str(x), color_mapping.get("NaN", "#d3d3d3")) for x in color_source_vector],
index=color_source_vector.index,
)

return color_source_vector, color_vector, True

logger.warning(f"Color key '{value_to_plot}' for element '{element_name}' not been found, using default colors.")
logger.warning(f"Color key '{value_to_plot}' for element '{element_name}' not found, using default colors.")
color = np.full(sdata[table_name].n_obs, to_hex(na_color))
return color, color, False
return None, color, False


def _map_color_seg(
Expand All @@ -817,20 +860,34 @@ def _map_color_seg(
) -> ArrayLike:
cell_id = np.array(cell_id)

if pd.api.types.is_categorical_dtype(color_vector.dtype):
# Case A: users wants to plot a categorical column
is_categorical = pd.api.types.is_categorical_dtype(getattr(color_vector, "dtype", None))
is_numeric = pd.api.types.is_numeric_dtype(getattr(color_vector, "dtype", None))
is_pandas_series = isinstance(color_vector, pd.Series)

# Case A: categorical column
if is_categorical:
if np.any(color_source_vector.isna()):
cell_id[color_source_vector.isna()] = 0
val_im: ArrayLike = map_array(seg.copy(), cell_id, color_vector.codes + 1)
cols = colors.to_rgba_array(color_vector.categories)
elif pd.api.types.is_numeric_dtype(color_vector.dtype):
# Case B: user wants to plot a continous column
if isinstance(color_vector, pd.Series):

# Case B: continuous column
elif is_numeric:
if is_pandas_series:
color_vector = color_vector.to_numpy()
cols = cmap_params.cmap(cmap_params.norm(color_vector))
val_im = map_array(seg.copy(), cell_id, cell_id)

# Case C & D: Other cases (could be strings, or hex colors)
else:
# Case C: User didn't specify any colors
# Get the first color safely, regardless of index structure
first_color = None
if is_pandas_series and len(color_vector) > 0:
first_color = color_vector.iloc[0]
elif not is_pandas_series and len(color_vector) > 0:
first_color = color_vector[0]

# Case C: Using default colors with random generation
if color_source_vector is not None and (
set(color_vector) == set(color_source_vector)
and len(set(color_vector)) == 1
Expand All @@ -840,14 +897,31 @@ def _map_color_seg(
val_im = map_array(seg.copy(), cell_id, cell_id)
RNG = default_rng(42)
cols = RNG.random((len(color_vector), 3))

# Case D: User specified explicit colors or we're using defaults
else:
# Case D: User didn't specify a column to color by, but modified the na_color
val_im = map_array(seg.copy(), cell_id, cell_id)
if "#" in str(color_vector[0]):
# we have hex colors
assert all(_is_color_like(c) for c in color_vector), "Not all values are color-like."
cols = colors.to_rgba_array(color_vector)

# Check if we're dealing with hex colors
if first_color is not None and isinstance(first_color, str) and "#" in first_color:
# We have hex colors
all_is_color = True
for c in color_vector:
if not _is_color_like(c):
all_is_color = False
break

if all_is_color:
try:
cols = colors.to_rgba_array(color_vector)
except ValueError as e:
logger.warning(f"Error converting colors: {e}, falling back to default colormap")
cols = cmap_params.cmap(cmap_params.norm(np.arange(len(color_vector))))
else:
# Fall back to colormap
cols = cmap_params.cmap(cmap_params.norm(color_vector))
else:
# Use the colormap
cols = cmap_params.cmap(cmap_params.norm(color_vector))

if seg_erosionpx is not None:
Expand Down Expand Up @@ -879,20 +953,93 @@ def _generate_base_categorial_color_mapping(
na_color: ColorLike,
cmap_params: CmapParams | None = None,
) -> Mapping[str, str]:
if adata is not None and cluster_key in adata.uns and f"{cluster_key}_colors" in adata.uns:
colors = adata.uns[f"{cluster_key}_colors"]
categories = color_source_vector.categories.tolist() + ["NaN"]
if "#" not in na_color:
# should be unreachable, but just for safety
raise ValueError("Expected `na_color` to be a hex color, but got a non-hex color.")

colors = [to_hex(to_rgba(color)[:3]) for color in colors]
na_color = to_hex(to_rgba(na_color)[:3])

if na_color and len(categories) > len(colors):
return dict(zip(categories, colors + [na_color], strict=True))
color_key = f"{cluster_key}_colors"
color_found_in_uns_msg_template = (
"Using colors from '{cluster}_colors' in .uns slot of table '{table}' for plotting. "
"If this is unexpected, please delete the column from your AnnData object."
)

return dict(zip(categories, colors, strict=True))
if adata is not None and cluster_key is not None:
if cluster_key in adata.uns and isinstance(adata.uns[cluster_key], dict):
# We have a direct color mapping dictionary
color_dict = adata.uns[cluster_key]
table_name = getattr(adata, "uns", {}).get("spatialdata_key", "")
if table_name:
logger.info(color_found_in_uns_msg_template.format(cluster=cluster_key, table=table_name))

# Ensure all values are hex colors
for k, v in color_dict.items():
if isinstance(v, str) and not v.startswith("#"):
color_dict[k] = to_hex(to_rgba(v))

categories = color_source_vector.categories.tolist()
na_color_hex = to_hex(to_rgba(na_color)[:3])

return {cat: color_dict.get(str(cat), color_dict.get(cat, na_color_hex)) for cat in categories}

if color_key in adata.uns:
colors = adata.uns[color_key]
table_name = getattr(adata, "uns", {}).get("spatialdata_key", "")
if table_name:
logger.info(color_found_in_uns_msg_template.format(cluster=cluster_key, table=table_name))

if isinstance(colors, list):
colors = [to_hex(to_rgba(color)[:3]) for color in colors]
categories = color_source_vector.categories.tolist()

na_color_hex = to_hex(to_rgba(na_color)[:3])
if "NaN" not in categories:
categories.append("NaN")

if len(colors) < len(categories) - 1: # -1 for NaN
logger.warning(
f"Not enough colors in {color_key} ({len(colors)}) for all categories ({len(categories) - 1}). "
"Some categories will use default colors."
)
# Extend with default colors or duplicate the last color
colors.extend([na_color_hex] * (len(categories) - 1 - len(colors)))

return dict(zip(categories, colors + [na_color_hex], strict=False))

if isinstance(colors, np.ndarray):
colors = [to_hex(to_rgba(color)[:3]) for color in colors]
categories = color_source_vector.categories.tolist()

na_color_hex = to_hex(to_rgba(na_color)[:3])
if "NaN" not in categories:
categories.append("NaN")

if len(colors) < len(categories) - 1: # -1 for NaN
logger.warning(
f"Not enough colors in {color_key} ({len(colors)}) for all categories ({len(categories) - 1}). "
"Some categories will use default colors."
)
colors.extend([na_color_hex] * (len(categories) - 1 - len(colors)))

return dict(zip(categories, colors + [na_color_hex], strict=False))

if isinstance(colors, dict):
# Ensure all values are hex colors
for k, v in colors.items():
if isinstance(v, str) and not v.startswith("#"):
colors[k] = to_hex(to_rgba(v))

categories = color_source_vector.categories.tolist()
na_color_hex = to_hex(to_rgba(na_color)[:3])

# Try to match color keys to categories, accounting for string/categorical differences
result = {}
for cat in categories:
# Try direct match first
if cat in colors:
result[cat] = colors[cat]
# Then try string conversion - handles int/string mismatches
elif str(cat) in colors:
result[cat] = colors[str(cat)]
else:
result[cat] = na_color_hex

return result

return _get_default_categorial_color_mapping(color_source_vector=color_source_vector, cmap_params=cmap_params)

Expand Down Expand Up @@ -1007,13 +1154,23 @@ def _maybe_set_colors(
try:
if palette is not None:
raise KeyError("Unable to copy the palette when there was other explicitly specified.")
target.uns[color_key] = source.uns[color_key]

# First check if source has the colors
if color_key in source.uns:
logger.info(f"Copying color information for '{key}' from source to target AnnData")
target.uns[color_key] = source.uns[color_key]
# Then check if the base key has colors (direct dict mapping)
elif key in source.uns and isinstance(source.uns[key], dict):
logger.info(f"Copying direct color mappings for '{key}' from source to target AnnData")
target.uns[key] = source.uns[key]
else:
raise KeyError(f"No color information found for '{key}' in source AnnData")

except KeyError:
if isinstance(palette, str):
palette = ListedColormap([palette])
if isinstance(palette, ListedColormap): # `scanpy` requires it
palette = cycler(color=palette.colors)
palette = None
add_colors_for_categorical_sample_annotation(target, key=key, force_update_colors=True, palette=palette)


Expand Down
Loading