Skip to content

Commit

Permalink
Merge pull request #1224 from basetenlabs/bump-version-0.9.48
Browse files Browse the repository at this point in the history
Release 0.9.48
  • Loading branch information
marius-baseten authored Nov 1, 2024
2 parents e84e615 + 9c06114 commit 6272b0e
Show file tree
Hide file tree
Showing 12 changed files with 352 additions and 237 deletions.
442 changes: 228 additions & 214 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "truss"
version = "0.9.47"
version = "0.9.48"
description = "A seamless bridge from model development to model delivery"
license = "MIT"
readme = "README.md"
Expand Down
6 changes: 5 additions & 1 deletion truss-chains/tests/chains_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,11 @@ def test_chain():
url = service.run_remote_url.replace("host.docker.internal", "localhost")

# Call without providing values for default arguments.
response = requests.post(url, json={"length": 30, "num_partitions": 3})
response = requests.post(
url,
json={"length": 30, "num_partitions": 3},
headers={"traceparent": "TEST TEST TEST"},
)
print(response.content)
assert response.status_code == 200
assert response.json() == [
Expand Down
15 changes: 8 additions & 7 deletions truss-chains/truss_chains/code_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,14 +404,16 @@ def _gen_predict_src(chainlet_descriptor: definitions.ChainletAPIDescriptor) ->
def_str = "async def" if chainlet_descriptor.endpoint.is_async else "def"
input_model_name = _get_input_model_name(chainlet_descriptor.name)
output_model_name = _get_output_model_name(chainlet_descriptor.name)
imports.add("import starlette.requests")
imports.add("from truss_chains import stub")
parts.append(
f"{def_str} predict(self, inputs: {input_model_name}) "
f"-> {output_model_name}:"
f"{def_str} predict(self, inputs: {input_model_name}, "
f"request: starlette.requests.Request) -> {output_model_name}:"
)
# Add error handling context manager:
parts.append(
_indent(
f"with utils.exception_to_http_error("
f"with stub.trace_parent(request), utils.exception_to_http_error("
f'include_stack=True, chainlet_name="{chainlet_descriptor.name}"):'
)
)
Expand Down Expand Up @@ -448,10 +450,6 @@ def _gen_truss_chainlet_model(
for stmt in node.body
)
)

imports.add("import logging")
imports.add("from truss_chains import utils")

class_definition: libcst.ClassDef = utils.expect_one(
node
for node in skeleton_tree.body
Expand Down Expand Up @@ -499,6 +497,7 @@ def _gen_truss_chainlet_file(
(chainlet_dir / truss_config.DEFAULT_MODEL_MODULE_DIR / "__init__.py").touch()
imports: set[str] = set()
src_parts: list[str] = []

if maybe_stub_src := _gen_stub_src_for_deps(dependencies):
_update_src(maybe_stub_src, src_parts, imports)

Expand Down Expand Up @@ -578,6 +577,8 @@ def _make_truss_config(
config.model_name = model_name
config.model_class_filename = _MODEL_FILENAME
config.model_class_name = _MODEL_CLS_NAME
config.runtime.enable_tracing_data = chains_config.options.enable_b10_tracing
config.environment_variables = dict(chains_config.options.env_variables)
# Compute.
compute = chains_config.get_compute_spec()
config.resources.cpu = str(compute.cpu_count)
Expand Down
25 changes: 23 additions & 2 deletions truss-chains/truss_chains/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
TRUSS_CONFIG_CHAINS_KEY = "chains_metadata"
GENERATED_CODE_DIR = ".chains_generated"
DYNAMIC_CHAINLET_CONFIG_KEY = "dynamic_chainlet_config"

OTEL_TRACE_PARENT_HEADER_KEY = "traceparent"
# Below arg names must correspond to `definitions.ABCChainlet`.
ENDPOINT_METHOD_NAME = "run_remote" # Chainlet method name exposed as endpoint.
CONTEXT_ARG_NAME = "context" # Referring to Chainlets `__init__` signature.
Expand Down Expand Up @@ -340,6 +340,20 @@ def get_spec(self) -> AssetSpec:
return self._spec.copy(deep=True)


class ChainletOptions(SafeModelNonSerializable):
"""
Args:
enable_b10_tracing: enables baseten-internal trace data collection. This
helps baseten engineers better analyze chain performance in case of issues.
It is independent of a potentially user-configured tracing instrumentation.
Turning this on, could add performance overhead.
env_variables: static environment variables available to the deployed chainlet.
"""

enable_b10_tracing: bool = False
env_variables: Mapping[str, str] = {}


class RemoteConfig(SafeModelNonSerializable):
"""Bundles config values needed to deploy a chainlet remotely.
Expand All @@ -363,6 +377,7 @@ class MyChainlet(chains.ChainletBase):
compute: Compute = Compute()
assets: Assets = Assets()
name: Optional[str] = None
options: ChainletOptions = ChainletOptions()

def get_compute_spec(self) -> ComputeSpec:
return self.compute.get_spec()
Expand Down Expand Up @@ -391,7 +406,11 @@ class ServiceDescriptor(SafeModel):


class Environment(SafeModel):
"""The environment in which the chainlet is deployed."""
"""The environment the chainlet is deployed in.
Args:
name: The name of the environment.
"""

name: str
# can add more fields here as we add them to dynamic_config configmap
Expand All @@ -417,6 +436,8 @@ class DeploymentContext(SafeModelNonSerializable, Generic[UserConfigT]):
user_env: These values can be provided to
the deploy command and customize the behavior of deployed chainlets. E.g.
for differentiating between prod and dev version of the same chain.
environment: The environment that the chainlet is deployed in.
None if the chainlet is not associated with an environment.
"""

data_dir: Optional[pathlib.Path] = None
Expand Down
10 changes: 8 additions & 2 deletions truss-chains/truss_chains/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@


def _push_to_baseten(
truss_dir: pathlib.Path, options: definitions.PushOptionsBaseten
truss_dir: pathlib.Path, options: definitions.PushOptionsBaseten, chainlet_name: str
) -> b10_service.BasetenService:
truss_handle = truss.load(str(truss_dir))
model_name = truss_handle.spec.config.model_name
Expand All @@ -59,6 +59,9 @@ def _push_to_baseten(
trusted=True,
publish=options.publish,
origin=b10_types.ModelOrigin.CHAINS,
chain_environment=options.environment,
chainlet_name=chainlet_name,
chain_name=options.chain_name,
)
return cast(b10_service.BasetenService, service)

Expand Down Expand Up @@ -124,7 +127,10 @@ def _push_service(
)
elif isinstance(options, definitions.PushOptionsBaseten):
with utils.log_level(logging.INFO):
service = _push_to_baseten(truss_dir, options)
# We send the display_name of the chainlet in subsequent steps.
service = _push_to_baseten(
truss_dir, options, chainlet_descriptor.display_name
)
else:
raise NotImplementedError(options)

Expand Down
32 changes: 29 additions & 3 deletions truss-chains/truss_chains/stub.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,38 @@
import abc
import asyncio
import contextlib
import contextvars
import logging
import ssl
import threading
import time
from typing import Any, ClassVar, Mapping, Optional, Type, TypeVar, final
from typing import Any, ClassVar, Iterator, Mapping, Optional, Type, TypeVar, final

import aiohttp
import httpx
import starlette.requests
import tenacity

from truss_chains import definitions, utils

DEFAULT_MAX_CONNECTIONS = 1000
DEFAULT_MAX_KEEPALIVE_CONNECTIONS = 400

_trace_parent_context: contextvars.ContextVar[str] = contextvars.ContextVar(
"trace_parent"
)


@contextlib.contextmanager
def trace_parent(request: starlette.requests.Request) -> Iterator[None]:
token = _trace_parent_context.set(
request.headers.get(definitions.OTEL_TRACE_PARENT_HEADER_KEY, "")
)
try:
yield
finally:
_trace_parent_context.reset(token)


class BasetenSession:
"""Helper to invoke predict method on Baseten deployments."""
Expand Down Expand Up @@ -122,7 +140,11 @@ def predict_sync(self, json_payload):
with self._sync_num_requests as num_requests:
self._maybe_warn_for_overload(num_requests)
resp = self._client_sync().post(
self._service_descriptor.predict_url, json=json_payload
self._service_descriptor.predict_url,
json=json_payload,
headers={
definitions.OTEL_TRACE_PARENT_HEADER_KEY: _trace_parent_context.get()
},
)
return utils.handle_response(resp, self.name)
# As a special case we invalidate the client in case of certificate
Expand All @@ -146,7 +168,11 @@ async def predict_async(self, json_payload):
async with self._async_num_requests as num_requests:
self._maybe_warn_for_overload(num_requests)
resp = await client.post(
self._service_descriptor.predict_url, json=json_payload
self._service_descriptor.predict_url,
json=json_payload,
headers={
definitions.OTEL_TRACE_PARENT_HEADER_KEY: _trace_parent_context.get()
},
)
return await utils.handle_async_response(resp, self.name)
# As a special case we invalidate the client in case of certificate
Expand Down
6 changes: 6 additions & 0 deletions truss/remote/baseten/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,9 @@ def create_model_from_truss(
is_trusted: bool,
deployment_name: Optional[str] = None,
origin: Optional[b10_types.ModelOrigin] = None,
chain_environment: Optional[str] = None,
chainlet_name: Optional[str] = None,
chain_name: Optional[str] = None,
):
query_string = f"""
mutation {{
Expand All @@ -121,6 +124,9 @@ def create_model_from_truss(
is_trusted: {'true' if is_trusted else 'false'},
{f'version_name: "{deployment_name}"' if deployment_name else ""}
{f'model_origin: {origin.value}' if origin else ""}
{f'chain_environment: "{chain_environment}"' if chain_environment else ""}
{f'chainlet_name: "{chainlet_name}"' if chainlet_name else ""}
{f'chain_name: "{chain_name}"' if chain_name else ""}
) {{
id,
name,
Expand Down
6 changes: 6 additions & 0 deletions truss/remote/baseten/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,9 @@ def create_truss_service(
deployment_name: Optional[str] = None,
origin: Optional[b10_types.ModelOrigin] = None,
environment: Optional[str] = None,
chain_environment: Optional[str] = None,
chainlet_name: Optional[str] = None,
chain_name: Optional[str] = None,
) -> Tuple[str, str]:
"""
Create a model in the Baseten remote.
Expand Down Expand Up @@ -291,6 +294,9 @@ def create_truss_service(
is_trusted=is_trusted,
deployment_name=deployment_name,
origin=origin,
chain_environment=chain_environment,
chainlet_name=chainlet_name,
chain_name=chain_name,
)
return model_version_json["id"], model_version_json["version_id"]

Expand Down
6 changes: 6 additions & 0 deletions truss/remote/baseten/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,9 @@ def push( # type: ignore
deployment_name: Optional[str] = None,
origin: Optional[custom_types.ModelOrigin] = None,
environment: Optional[str] = None,
chain_environment: Optional[str] = None,
chainlet_name: Optional[str] = None,
chain_name: Optional[str] = None,
) -> BasetenService:
if model_name.isspace():
raise ValueError("Model name cannot be empty")
Expand Down Expand Up @@ -179,6 +182,9 @@ def push( # type: ignore
deployment_name=deployment_name,
origin=origin,
environment=environment,
chain_environment=chain_environment,
chainlet_name=chainlet_name,
chain_name=chain_name,
)

return BasetenService(
Expand Down
11 changes: 4 additions & 7 deletions truss/templates/server/common/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,7 @@ def serialize(self) -> dict:

def _parse_input_type(input_parameters: MappingProxyType) -> Optional[type]:
parameter_types = list(input_parameters.values())

if len(parameter_types) > 1:
return None

# In `ArgConfig.from_signature` the arguments are validated.
input_type = parameter_types[0].annotation

if _annotation_is_pydantic_model(input_type):
Expand Down Expand Up @@ -158,11 +155,11 @@ def _extract_pydantic_base_models(union_args: tuple) -> List[Type[BaseModel]]:
2. Union[Awaitable[PydanticBaseModel], AsyncGenerator]
So for Awaitables, we need to extract the base class from the Awaitable type
"""

return [
retrieve_base_class_from_awaitable(arg) if _is_awaitable_type(arg) else arg
retrieve_base_class_from_awaitable(arg) if _is_awaitable_type(arg) else arg # type: ignore[misc] # Types are ok per filter condition.
for arg in union_args
if _is_awaitable_type(arg)
or (isinstance(arg, type) and issubclass(arg, BaseModel))
if _is_awaitable_type(arg) or _annotation_is_pydantic_model(arg)
]


Expand Down
28 changes: 28 additions & 0 deletions truss/tests/remote/baseten/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,34 @@ def test_create_model_from_truss(mock_post, baseten_api):
assert 'version_name: "deployment_name"' in gql_mutation


@mock.patch("requests.post", return_value=mock_create_model_response())
def test_create_model_from_truss_forwards_chainlet_data(mock_post, baseten_api):
baseten_api.create_model_from_truss(
"model_name",
"s3key",
"config_str",
"semver_bump",
"client_version",
False,
"deployment_name",
chain_environment="chainstaging",
chain_name="chainchain",
chainlet_name="chainlet-1",
)

gql_mutation = mock_post.call_args[1]["data"]["query"]
assert 'name: "model_name"' in gql_mutation
assert 's3_key: "s3key"' in gql_mutation
assert 'config: "config_str"' in gql_mutation
assert 'semver_bump: "semver_bump"' in gql_mutation
assert 'client_version: "client_version"' in gql_mutation
assert "is_trusted: false" in gql_mutation
assert 'version_name: "deployment_name"' in gql_mutation
assert 'chain_environment: "chainstaging"' in gql_mutation
assert 'chain_name: "chainchain"' in gql_mutation
assert 'chainlet_name: "chainlet-1"' in gql_mutation


@mock.patch("requests.post", return_value=mock_create_model_response())
def test_create_model_from_truss_does_not_send_deployment_name_if_not_specified(
mock_post, baseten_api
Expand Down

0 comments on commit 6272b0e

Please sign in to comment.