Skip to content

Commit

Permalink
Options to get_kafka_config with spark options
Browse files Browse the repository at this point in the history
  • Loading branch information
vatj committed Jul 4, 2024
1 parent cd2e8b5 commit 751e16a
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 38 deletions.
11 changes: 8 additions & 3 deletions python/hsfs/core/kafka_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import json
from io import BytesIO
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, Literal, Optional, Tuple, Union

from hsfs import client
from hsfs.client import hopsworks
Expand Down Expand Up @@ -180,7 +180,9 @@ def get_encoder_func(writer_schema: str) -> callable:


def get_kafka_config(
feature_store_id: int, write_options: Optional[Dict[str, Any]] = None
feature_store_id: int,
write_options: Optional[Dict[str, Any]] = None,
engine: Literal["spark", "confluent"] = "confluent",
) -> Dict[str, Any]:
if write_options is None:
write_options = {}
Expand All @@ -193,7 +195,10 @@ def get_kafka_config(
feature_store_id, external
)

config = storage_connector.confluent_options()
if engine == "spark":
config = storage_connector.spark_options()
elif engine == "confluent":
config = storage_connector.confluent_options()
config.update(write_options.get("kafka_producer_config", {}))
return config

Expand Down
62 changes: 27 additions & 35 deletions python/hsfs/engine/spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,14 @@
import shutil
import warnings
from datetime import date, datetime, timezone
from typing import TYPE_CHECKING, Any, List, Optional, TypeVar, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, TypeVar, Union


if TYPE_CHECKING:
import great_expectations
from pyspark.rdd import RDD
from pyspark.sql import DataFrame

import avro
import numpy as np
import pandas as pd
import tzlocal
Expand Down Expand Up @@ -82,17 +83,16 @@ def iteritems(self):

from hsfs import client, feature, training_dataset_feature, util
from hsfs import feature_group as fg_mod
from hsfs.client import hopsworks
from hsfs.client.exceptions import FeatureStoreException
from hsfs.constructor import query
from hsfs.core import (
dataset_api,
delta_engine,
hudi_engine,
storage_connector_api,
kafka_engine,
transformation_function_engine,
)
from hsfs.core.constants import HAS_GREAT_EXPECTATIONS
from hsfs.core.constants import HAS_AVRO, HAS_GREAT_EXPECTATIONS
from hsfs.decorators import uses_great_expectations
from hsfs.storage_connector import StorageConnector
from hsfs.training_dataset_split import TrainingDatasetSplit
Expand All @@ -101,6 +101,9 @@ def iteritems(self):
if HAS_GREAT_EXPECTATIONS:
import great_expectations

if HAS_AVRO:
import avro


class Engine:
HIVE_FORMAT = "hive"
Expand All @@ -123,7 +126,6 @@ def __init__(self):
if importlib.util.find_spec("pydoop"):
# If we are on Databricks don't setup Pydoop as it's not available and cannot be easily installed.
util.setup_pydoop()
self._storage_connector_api = storage_connector_api.StorageConnectorApi()
self._dataset_api = dataset_api.DatasetApi()

def sql(
Expand Down Expand Up @@ -382,17 +384,17 @@ def save_dataframe(

def save_stream_dataframe(
self,
feature_group,
feature_group: Union[fg_mod.FeatureGroup, fg_mod.ExternalFeatureGroup],
dataframe,
query_name,
output_mode,
await_termination,
await_termination: bool,
timeout,
checkpoint_dir,
write_options,
checkpoint_dir: Optional[str],
write_options: Optional[Dict[str, Any]],
):
write_options = self._get_kafka_config(
feature_group.feature_store_id, write_options
write_options = kafka_engine.get_kafka_config(
feature_group.feature_store_id, write_options, "spark"
)
serialized_df = self._online_fg_to_avro(
feature_group, self._encode_complex_features(feature_group, dataframe)
Expand Down Expand Up @@ -485,8 +487,8 @@ def _save_offline_dataframe(
).saveAsTable(feature_group._get_table_name())

def _save_online_dataframe(self, feature_group, dataframe, write_options):
write_options = self._get_kafka_config(
feature_group.feature_store_id, write_options
write_options = kafka_engine.get_kafka_config(
feature_group.feature_store_id, write_options, "spark"
)

serialized_df = self._online_fg_to_avro(
Expand All @@ -511,7 +513,11 @@ def _save_online_dataframe(self, feature_group, dataframe, write_options):
"topic", feature_group._online_topic_name
).save()

def _encode_complex_features(self, feature_group, dataframe):
def _encode_complex_features(
self,
feature_group: Union[fg_mod.FeatureGroup, fg_mod.ExternalFeatureGroup],
dataframe: Union[RDD, DataFrame],
):
"""Encodes all complex type features to binary using their avro type as schema."""
return dataframe.select(
[
Expand All @@ -524,7 +530,11 @@ def _encode_complex_features(self, feature_group, dataframe):
]
)

def _online_fg_to_avro(self, feature_group, dataframe):
def _online_fg_to_avro(
self,
feature_group: Union[fg_mod.FeatureGroup, fg_mod.ExternalFeatureGroup],
dataframe: Union[DataFrame, RDD],
):
"""Packs all features into named struct to be serialized to single avro/binary
column. And packs primary key into arry to be serialized for partitioning.
"""
Expand Down Expand Up @@ -976,7 +986,7 @@ def profile(
@uses_great_expectations
def validate_with_great_expectations(
self,
dataframe: TypeVar("pyspark.sql.DataFrame"), # noqa: F821
dataframe: DataFrame, # noqa: F821
expectation_suite: great_expectations.core.ExpectationSuite, # noqa: F821
ge_validate_kwargs: Optional[dict],
):
Expand Down Expand Up @@ -1388,24 +1398,6 @@ def cast_columns(df, schema, online=False):
df = df.withColumn(_feat, col(_feat).cast(pyspark_schema[_feat]))
return df

def _get_kafka_config(
self, feature_store_id: int, write_options: dict = None
) -> dict:
if write_options is None:
write_options = {}
external = not (
isinstance(client.get_instance(), hopsworks.Client)
or write_options.get("internal_kafka", False)
)

storage_connector = self._storage_connector_api.get_kafka_connector(
feature_store_id, external
)

config = storage_connector.spark_options()
config.update(write_options)
return config

@staticmethod
def is_connector_type_supported(type):
return True
Expand Down

0 comments on commit 751e16a

Please sign in to comment.