Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/close_sparks_sessions' into clos…
Browse files Browse the repository at this point in the history
…e_sparks_sessions
  • Loading branch information
gibchikafa committed Jul 10, 2024
2 parents 5746523 + 043e39b commit 4bf3396
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 5 deletions.
2 changes: 1 addition & 1 deletion python/hsfs/constructor/join.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def __init__(
self._on = util.parse_features(on)
self._left_on = util.parse_features(left_on)
self._right_on = util.parse_features(right_on)
self._join_type = join_type or self.INNER
self._join_type = join_type or self.LEFT
self._prefix = prefix

def to_dict(self) -> Dict[str, Any]:
Expand Down
16 changes: 14 additions & 2 deletions python/hsfs/constructor/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def join(
on: Optional[List[str]] = None,
left_on: Optional[List[str]] = None,
right_on: Optional[List[str]] = None,
join_type: Optional[str] = "inner",
join_type: Optional[str] = "left",
prefix: Optional[str] = None,
) -> "Query":
"""Join Query with another Query.
Expand Down Expand Up @@ -769,7 +769,7 @@ def featuregroups(
"""List of feature groups used in the query"""
featuregroups = {self._left_feature_group}
for join_obj in self.joins:
featuregroups.add(join_obj.query._left_feature_group)
self._fg_rec_add(join_obj, featuregroups)
return list(featuregroups)

@property
Expand Down Expand Up @@ -809,6 +809,18 @@ def get_feature(self, feature_name: str) -> "Feature":
"""
return self._get_feature_by_name(feature_name)[0]

def _fg_rec_add(self, join_object, featuregroups):
"""
Recursively get a feature groups from nested join and add to featuregroups list.
# Arguments
join_object: `Join object`.
"""
if len(join_object.query.joins) > 0:
for nested_join in join_object.query.joins:
self._fg_rec_add(nested_join, featuregroups)
featuregroups.add(join_object.query._left_feature_group)

def __getattr__(self, name: str) -> Any:
try:
return self.__getitem__(name)
Expand Down
2 changes: 1 addition & 1 deletion python/hsfs/engine/spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -1190,7 +1190,7 @@ def add_cols_to_delta_table(self, feature_group, new_features):
"spark.databricks.delta.schema.autoMerge.enabled", "true"
).save(feature_group.location)

def _apply_transformation_function(self, transformation_functions, dataset):
def _apply_transformation_function(self, transformation_functions, dataset, **kwargs):
# generate transformation function expressions
transformed_feature_names = []
transformation_fn_expressions = []
Expand Down
2 changes: 1 addition & 1 deletion python/tests/constructor/test_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def test_from_response_json_basic_info(self, mocker, backend_fixtures):
assert len(j._on) == 0
assert len(j._left_on) == 0
assert len(j._right_on) == 0
assert j._join_type == "INNER"
assert j._join_type == "LEFT"
assert j._prefix is None

def test_from_response_json_left_join(self, mocker, backend_fixtures):
Expand Down

0 comments on commit 4bf3396

Please sign in to comment.