Skip to content

Commit e335569

Browse files
authored
feat: Add support for %dpip magic to install packages on Spark session (#176)
1 parent 3013822 commit e335569

File tree

6 files changed

+409
-0
lines changed

6 files changed

+409
-0
lines changed
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from .magics import DataprocMagics
16+
17+
18+
def load_ipython_extension(ipython):
19+
ipython.register_magics(DataprocMagics)
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Dataproc magic implementations."""
16+
17+
import shlex
18+
from IPython.core.magic import (Magics, magics_class, line_magic)
19+
from google.cloud.dataproc_spark_connect import DataprocSparkSession
20+
21+
22+
@magics_class
23+
class DataprocMagics(Magics):
24+
25+
def __init__(
26+
self,
27+
shell,
28+
**kwargs,
29+
):
30+
super().__init__(shell, **kwargs)
31+
32+
@line_magic
33+
def dpip(self, line):
34+
"""
35+
Custom magic to install pip packages as Spark Connect artifacts.
36+
Usage: %dpip install pandas numpy
37+
"""
38+
try:
39+
args = shlex.split(line)
40+
41+
if not args or args[0] != "install":
42+
raise RuntimeError(
43+
"Usage: %dpip install <package1> <package2> ..."
44+
)
45+
46+
packages = args[1:] # remove `install`
47+
48+
if not packages:
49+
raise RuntimeError("Error: No packages specified.")
50+
51+
if any(pkg.startswith("-") for pkg in packages):
52+
raise RuntimeError("Error: Flags are not currently supported.")
53+
54+
sessions = [
55+
(key, value)
56+
for key, value in self.shell.user_ns.items()
57+
if isinstance(value, DataprocSparkSession)
58+
]
59+
60+
if not sessions:
61+
raise RuntimeError(
62+
"Error: No active Dataproc Spark Session found. Please create one first."
63+
)
64+
if len(sessions) > 1:
65+
raise RuntimeError(
66+
"Error: Found more than one active Dataproc Spark Sessions."
67+
)
68+
69+
((name, session),) = sessions
70+
print(f"Active session found: {name}")
71+
print(f"Installing packages: {packages}")
72+
session.addArtifacts(*packages, pypi=True)
73+
74+
print("Finished installing packages.")
75+
except Exception as e:
76+
raise RuntimeError(f"Failed to install packages: {e}") from e

tests/integration/dataproc_magics/__init__.py

Whitespace-only changes.
Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import os
15+
import pytest
16+
import certifi
17+
from unittest import mock
18+
19+
from google.cloud.dataproc_spark_connect import DataprocSparkSession
20+
21+
22+
_SERVICE_ACCOUNT_KEY_FILE_ = "service_account_key.json"
23+
24+
25+
@pytest.fixture(params=[None, "3.0"])
26+
def image_version(request):
27+
return request.param
28+
29+
30+
@pytest.fixture
31+
def test_project():
32+
return os.getenv("GOOGLE_CLOUD_PROJECT")
33+
34+
35+
@pytest.fixture
36+
def test_region():
37+
return os.getenv("GOOGLE_CLOUD_REGION")
38+
39+
40+
def is_ci_environment():
41+
"""Detect if running in CI environment."""
42+
return os.getenv("CI") == "true" or os.getenv("GITHUB_ACTIONS") == "true"
43+
44+
45+
@pytest.fixture
46+
def auth_type(request):
47+
"""Auto-detect authentication type based on environment.
48+
49+
CI environment (CI=true or GITHUB_ACTIONS=true): Uses SERVICE_ACCOUNT
50+
Local environment: Uses END_USER_CREDENTIALS
51+
Test parametrization can still override this default.
52+
"""
53+
# Allow test parametrization to override
54+
if hasattr(request, "param"):
55+
return request.param
56+
57+
# Auto-detect based on environment
58+
if is_ci_environment():
59+
return "SERVICE_ACCOUNT"
60+
else:
61+
return "END_USER_CREDENTIALS"
62+
63+
64+
@pytest.fixture
65+
def test_subnet():
66+
return os.getenv("DATAPROC_SPARK_CONNECT_SUBNET")
67+
68+
69+
@pytest.fixture
70+
def test_subnetwork_uri(test_subnet):
71+
# Make DATAPROC_SPARK_CONNECT_SUBNET the full URI
72+
# to align with how user would specify it in the project
73+
return test_subnet
74+
75+
76+
@pytest.fixture
77+
def os_environment(auth_type, image_version, test_project, test_region):
78+
original_environment = dict(os.environ)
79+
if os.path.isfile(_SERVICE_ACCOUNT_KEY_FILE_):
80+
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = (
81+
_SERVICE_ACCOUNT_KEY_FILE_
82+
)
83+
os.environ["DATAPROC_SPARK_CONNECT_AUTH_TYPE"] = auth_type
84+
if auth_type == "END_USER_CREDENTIALS":
85+
os.environ.pop("DATAPROC_SPARK_CONNECT_SERVICE_ACCOUNT", None)
86+
# Add SSL certificate fix
87+
os.environ["SSL_CERT_FILE"] = certifi.where()
88+
os.environ["REQUESTS_CA_BUNDLE"] = certifi.where()
89+
yield os.environ
90+
os.environ.clear()
91+
os.environ.update(original_environment)
92+
93+
94+
@pytest.fixture
95+
def connect_session(test_project, test_region, os_environment):
96+
session = (
97+
DataprocSparkSession.builder.projectId(test_project)
98+
.location(test_region)
99+
.getOrCreate()
100+
)
101+
yield session
102+
# Clean up the session after each test to prevent resource conflicts
103+
try:
104+
session.stop()
105+
except Exception:
106+
# Ignore cleanup errors to avoid masking the actual test failure
107+
pass
108+
109+
110+
@pytest.fixture
111+
def ipython_shell(connect_session):
112+
"""Provides an IPython shell with a DataprocSparkSession in user_ns."""
113+
try:
114+
from IPython.terminal.interactiveshell import TerminalInteractiveShell
115+
from google.cloud import dataproc_magics
116+
117+
shell = TerminalInteractiveShell.instance()
118+
shell.user_ns = {"spark": connect_session}
119+
120+
# Load magics
121+
dataproc_magics.load_ipython_extension(shell)
122+
123+
yield shell
124+
finally:
125+
from IPython.terminal.interactiveshell import TerminalInteractiveShell
126+
127+
TerminalInteractiveShell.clear_instance()
128+
129+
130+
# Tests for magics.py
131+
def test_dpip_magic_loads(ipython_shell):
132+
"""Test that %dpip magic is registered."""
133+
assert "dpip" in ipython_shell.magics_manager.magics["line"]
134+
135+
136+
def test_dpip_install_success(connect_session, ipython_shell, capsys):
137+
"""Test installing a single package with %dpip."""
138+
ipython_shell.run_line_magic("dpip", "install roman numpy")
139+
captured = capsys.readouterr()
140+
assert "Active session found:" in captured.out
141+
assert "Installing packages:" in captured.out
142+
assert "Finished installing packages." in captured.out
143+
144+
from pyspark.sql.connect.functions import udf
145+
from pyspark.sql.types import StringType
146+
147+
df = connect_session.createDataFrame([(1666,)], ["number"])
148+
149+
def to_roman(number):
150+
import roman
151+
152+
return roman.toRoman(number)
153+
154+
df_result = df.withColumn(
155+
"roman", udf(to_roman, StringType())("number")
156+
).collect()
157+
158+
assert df_result[0]["roman"] == "MDCLXVI"
159+
160+
connect_session.stop()
161+
162+
163+
def test_dpip_no_install_command(ipython_shell):
164+
"""Test usage message when 'install' is missing."""
165+
with pytest.raises(
166+
RuntimeError, match="Usage: %dpip install <package1> <package2>..."
167+
):
168+
ipython_shell.run_line_magic("dpip", "pandas")
169+
170+
171+
def test_dpip_no_packages(ipython_shell):
172+
"""Test message when no packages are specified."""
173+
with pytest.raises(RuntimeError, match="Error: No packages specified."):
174+
ipython_shell.run_line_magic("dpip", "install")
175+
176+
177+
def test_dpip_with_flags(ipython_shell):
178+
"""Test installing multiple packages with flags like -U."""
179+
with pytest.raises(
180+
RuntimeError, match="Error: Flags are not currently supported."
181+
):
182+
ipython_shell.run_line_magic("dpip", "install -U numpy scikit-learn")
183+
184+
185+
def test_dpip_no_session(ipython_shell):
186+
"""Test message when no Spark session is active."""
187+
ipython_shell.user_ns = {} # Remove spark session from namespace
188+
with pytest.raises(
189+
RuntimeError, match="No active Dataproc Spark Session found."
190+
):
191+
ipython_shell.run_line_magic("dpip", "install pandas")
192+
193+
194+
def test_dpip_install_failure(ipython_shell):
195+
"""Test error message on installation failure."""
196+
with pytest.raises(
197+
RuntimeError,
198+
match="No matching distribution found",
199+
):
200+
ipython_shell.run_line_magic("dpip", "install dp-non-existent-package")
201+
202+
203+
def test_dpip_multiple_sessions(ipython_shell, connect_session):
204+
"""Test error message when multiple Spark sessions found."""
205+
ipython_shell.user_ns["sparksession"] = connect_session
206+
ipython_shell.user_ns["sparkanother"] = connect_session
207+
with pytest.raises(
208+
RuntimeError,
209+
match="Error: Found more than one active Dataproc Spark Sessions.",
210+
):
211+
ipython_shell.run_line_magic("dpip", "install pandas")

tests/unit/dataproc_magics/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)