Skip to content

Commit

Permalink
Consolidate more code
Browse files Browse the repository at this point in the history
  • Loading branch information
nnarayen committed Jan 18, 2025
1 parent 5a406f2 commit 6f78f3e
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 73 deletions.
2 changes: 1 addition & 1 deletion truss-chains/tests/traditional_truss/truss_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class PassthroughModel(chains.ModelBase):
),
)

def __init__(self, **kwargs):
def __init__(self):
self._call_count = 0

async def run_remote(self, call_count_increment: int) -> int:
Expand Down
10 changes: 0 additions & 10 deletions truss-chains/truss_chains/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,10 +510,6 @@ class ABCChainlet(abc.ABC):
def has_custom_init(cls) -> bool:
return cls.__init__ is not object.__init__

@classmethod
def truss_type(cls) -> str:
return "Chainlet"

@classproperty
@classmethod
def name(cls) -> str:
Expand All @@ -531,12 +527,6 @@ def display_name(cls) -> str:
# ...


class ABCModel(ABCChainlet):
@classmethod
def truss_type(cls) -> str:
return "Chainlet"


class TypeDescriptor(SafeModelNonSerializable):
"""For describing I/O types of Chainlets."""

Expand Down
67 changes: 7 additions & 60 deletions truss-chains/truss_chains/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
Callable,
Iterable,
Iterator,
Literal,
Mapping,
MutableMapping,
Optional,
Expand Down Expand Up @@ -136,9 +135,7 @@ def _collect_error(msg: str, kind: _ErrorKind, location: _ErrorLocation):
)


def raise_validation_errors(
truss_type: Literal["Chainlet", "Model"] = "Chainlet",
) -> None:
def raise_validation_errors() -> None:
"""Raises validation errors as combined ``ChainsUsageError``"""
if _global_error_collector.has_errors:
error_msg = _global_error_collector.format_errors()
Expand All @@ -149,7 +146,7 @@ def raise_validation_errors(
)
_global_error_collector.clear() # Clear errors so `atexit` won't display them
raise definitions.ChainsUsageError(
f"The {truss_type} definitions contain {errors_count}:\n{error_msg}"
f"The Chainlet definitions contain {errors_count}:\n{error_msg}"
)


Expand Down Expand Up @@ -740,7 +737,7 @@ def _validate_remote_config(
definitions.RemoteConfig,
):
_collect_error(
f"{cls.truss_type}s must have a `{definitions.REMOTE_CONFIG_NAME}` class variable "
f"Chainlets must have a `{definitions.REMOTE_CONFIG_NAME}` class variable "
f"of type `{definitions.RemoteConfig}`. Got `{type(remote_config)}` "
f"for `{cls}`.",
_ErrorKind.TYPE_ERROR,
Expand Down Expand Up @@ -772,22 +769,6 @@ def validate_and_register_chain(cls: Type[definitions.ABCChainlet]) -> None:
_global_chainlet_registry.register_chainlet(chainlet_descriptor)


def validate_base_model(cls: Type[definitions.ABCModel]) -> None:
src_path = os.path.abspath(inspect.getfile(cls))
line = inspect.getsourcelines(cls)[1]
location = _ErrorLocation(src_path=src_path, line=line)
_validate_remote_config(cls, location)

base_model_descriptor = definitions.ChainletAPIDescriptor(
chainlet_cls=cls,
dependencies={},
has_context=False,
endpoint=_validate_and_describe_endpoint(cls, location),
src_path=src_path,
)
_global_chainlet_registry.register_chainlet(base_model_descriptor)


# Dependency-Injection / Registry ######################################################


Expand Down Expand Up @@ -1181,10 +1162,10 @@ def _get_entrypoint_chainlets(symbols) -> set[Type[definitions.ABCChainlet]]:

@contextlib.contextmanager
def import_target(
module_path: pathlib.Path, target_name: Optional[str]
module_path: pathlib.Path, target_name: Optional[str] = None
) -> Iterator[Type[definitions.ABCChainlet]]:
resolved_module_path = pathlib.Path(module_path).resolve()
module, loader = _load_module(module_path, "Chainlet")
module, loader = _load_module(module_path)
modules_before = set(sys.modules.keys())
modules_after = set()

Expand Down Expand Up @@ -1238,42 +1219,8 @@ def import_target(
_global_chainlet_registry.unregister_chainlet(chainlet_name)


@contextlib.contextmanager
def import_model_target(
module_path: pathlib.Path,
) -> Iterator[Type[definitions.ABCModel]]:
resolved_module_path = pathlib.Path(module_path).resolve()
module, loader = _load_module(resolved_module_path, "Model")
modules_before = set(sys.modules.keys())
modules_after = set()
try:
try:
loader.exec_module(module)
raise_validation_errors("Model")
finally:
modules_after = set(sys.modules.keys())

module_vars = (getattr(module, name) for name in dir(module))
models: set[Type[definitions.ABCModel]] = {
sym
for sym in module_vars
if utils.issubclass_safe(sym, definitions.ABCModel)
}
if len(models) == 0:
raise ValueError(f"No class in `{module_path}` extends `ModelBase`.")

target_cls = utils.expect_one(models)
if not utils.issubclass_safe(target_cls, definitions.ABCModel):
raise TypeError(f"Target `{target_cls}` is not a {definitions.ABCModel}.")

yield target_cls
finally:
_cleanup_module_imports(modules_before, modules_after, resolved_module_path)


def _load_module(
module_path: pathlib.Path,
truss_type: Literal["Chainlet", "Model"],
) -> tuple[types.ModuleType, Loader]:
"""The context manager ensures that modules imported by the Model/Chain
are removed upon exit.
Expand All @@ -1284,7 +1231,7 @@ def _load_module(
if not os.path.isfile(module_path):
raise ImportError(
f"`{module_path}` is not a file. You must point to a python file where "
f"the entrypoint {truss_type} is defined."
f"the entrypoint Chainlet is defined."
)

import_error_msg = f"Could not import `{module_path}`. Check path."
Expand Down Expand Up @@ -1340,7 +1287,7 @@ def truss_handle_from_code_config(model_file: pathlib.Path) -> TrussHandle:
# TODO(nikhil): Improve detection of directory structure, since right now
# we assume a flat structure
root_dir = model_file.absolute().parent
with import_model_target(model_file) as entrypoint_cls:
with import_target(model_file) as entrypoint_cls:
descriptor = _global_chainlet_registry.get_descriptor(entrypoint_cls)
generated_dir = code_gen.gen_truss_chainlet(
chain_root=root_dir,
Expand Down
4 changes: 2 additions & 2 deletions truss-chains/truss_chains/public_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def __init_with_arg_check__(self, *args, **kwargs):
cls.__init__ = __init_with_arg_check__ # type: ignore[method-assign]


class ModelBase(definitions.ABCModel):
class ModelBase(definitions.ABCChainlet):
"""Base class for all singular truss models.
Inheriting from this class adds validations to make sure subclasses adhere to the
Expand All @@ -132,7 +132,7 @@ class ModelBase(definitions.ABCModel):
def __init_subclass__(cls, **kwargs) -> None:
super().__init_subclass__(**kwargs)
cls.meta_data = definitions.ChainletMetadata(is_entrypoint=True)
framework.validate_base_model(cls)
framework.validate_and_register_chain(cls)


@overload
Expand Down

0 comments on commit 6f78f3e

Please sign in to comment.