Skip to content

Commit fa562d5

Browse files
Daniel Ohayonfacebook-github-bot
Daniel Ohayon
authored andcommitted
handle components whose implementation lives in a different file (pytorch#1075)
Summary: Add support for cases like: ```lang=python # some_file.py # ==================== def my_component(...) -> specs.AppDef: ... # other_file.py # ==================== from some_file import my_component ``` where the component is invoked with `torchx run ... other_file.py:my_component` This was currently failing with a validation error because in the step where we inspect the AST of the component, we assume that the file where the component is being looked up is the same as the file where it is implemented. Reviewed By: kiukchung Differential Revision: D75496839
1 parent df5d30c commit fa562d5

File tree

2 files changed

+27
-5
lines changed

2 files changed

+27
-5
lines changed

torchx/specs/finder.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -274,12 +274,23 @@ def _get_validation_errors(
274274
linter_errors = validate(path, function_name, validators)
275275
return [linter_error.description for linter_error in linter_errors]
276276

277+
def _get_path_to_function_decl(
278+
self, function: Callable[..., Any] # pyre-ignore[2]
279+
) -> str:
280+
"""
281+
Attempts to return the path to the file where the function is implemented.
282+
This can be different from the path where the function is looked up, for example if we have:
283+
my_component defined in some_file.py, imported in other_file.py
284+
and the component is invoked as other_file.py:my_component
285+
"""
286+
path_to_function_decl = inspect.getabsfile(function)
287+
if path_to_function_decl is None or not os.path.isfile(path_to_function_decl):
288+
return self._filepath
289+
return path_to_function_decl
290+
277291
def find(
278292
self, validators: Optional[List[TorchxFunctionValidator]]
279293
) -> List[_Component]:
280-
validation_errors = self._get_validation_errors(
281-
self._filepath, self._function_name, validators
282-
)
283294

284295
file_source = read_conf_file(self._filepath)
285296
namespace = copy.copy(globals())
@@ -292,6 +303,12 @@ def find(
292303
)
293304
app_fn = namespace[self._function_name]
294305
fn_desc, _ = get_fn_docstring(app_fn)
306+
307+
func_path = self._get_path_to_function_decl(app_fn)
308+
validation_errors = self._get_validation_errors(
309+
func_path, self._function_name, validators
310+
)
311+
295312
return [
296313
_Component(
297314
name=f"{self._filepath}:{self._function_name}",

torchx/specs/test/finder_test.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@
1414
from pathlib import Path
1515
from unittest.mock import MagicMock, patch
1616

17-
import torchx.specs.finder as finder
18-
1917
from importlib_metadata import EntryPoints
18+
19+
import torchx.specs.finder as finder
2020
from torchx.runner import get_runner
2121
from torchx.runtime.tracking import FsspecResultTracker
2222
from torchx.specs.api import AppDef, AppState, Role
@@ -29,6 +29,7 @@
2929
get_components,
3030
ModuleComponentsFinder,
3131
)
32+
from torchx.specs.test.components.a import comp_a
3233
from torchx.util.test.entrypoints_test import EntryPoint_from_text
3334
from torchx.util.types import none_throws
3435

@@ -238,6 +239,10 @@ def test_get_component_invalid(self) -> None:
238239
with self.assertRaises(ComponentValidationException):
239240
get_component(f"{current_file_path()}:invalid_component")
240241

242+
def test_get_component_imported_from_other_file(self) -> None:
243+
component = get_component(f"{current_file_path()}:comp_a")
244+
self.assertListEqual([], component.validation_errors)
245+
241246

242247
class GetBuiltinSourceTest(unittest.TestCase):
243248
def setUp(self) -> None:

0 commit comments

Comments
 (0)