Contains the base interface that OctoAI endpoints should implement.
Developers that want to create an endpoint should implement the
``Service`` class in this module as directed by the ``octoai`` command-line
import functools
import inspect
import os
from abc import ABC, abstractmethod
from collections import OrderedDict
from typing import Any, Callable, Mapping, Tuple, Type
import pydantic_core
from fastapi import Form
from pydantic import BaseModel, Field, create_model
from .types import File
DEFAULT_VOLUME_PATH = "/octoai/cache"
[docs]def volume_path() -> str:
"""Get mounted volume path in docker.
:return: Docker path.
docker_path = os.environ.get("OCTOAI_VOLUME_PATH", None)
if docker_path:
return docker_path
"HUGGINGFACE_HUB_CACHE": os.path.join(volume_path(), "huggingface_hub_cache"),
"TORCH_HOME": os.path.join(volume_path(), "torch_home"),
[docs]class Service(ABC):
The base interface that OctoAI endpoints should implement.
Developers that want to create an endpoint should implement this
class as directed by the ``octoai`` command-line interface.
[docs] def setup(self) -> None:
Perform service initialization.
A common operation to include here is loading weights and making
those available to the ``infer()`` method in a member variable.
[docs] def store_assets(self) -> None:
"""Download model assets."""
[docs] def on_server_startup(self) -> None:
Perform any necessary initialization when the server starts.
This method is separate from setup() because setup() can be called
outside the serving context to include weights when building the image.
[docs] def on_server_shutdown(self) -> None:
"""Perform any necessary cleanup when the server stops."""
[docs] @abstractmethod
def infer(self, **kwargs: Any) -> Any:
"""Perform inference."""
setattr(store_assets, STORE_ASSETS_NOT_OVERRIDDEN, True)
[docs]class ResponseAnalytics(BaseModel):
"""Additional analytics metadata."""
inference_time_ms: float = Field(
description="Inference execution time (without pauses)"
performance_time_ms: float = Field(
description="Inference execution time (including pauses)"
[docs]def inspect_output_types(method: Callable) -> Type[BaseModel]:
"""Create Pydantic output model from inference method signature."""
signature = inspect.signature(method)
if signature.return_annotation == inspect._empty:
raise ValueError(f"{method.__name__}() requires a return type annotation")
rets = OrderedDict()
rets["output"] = (signature.return_annotation, None)
rets["analytics"] = (ResponseAnalytics, None)
return create_model(
[docs]def find_additional_endpoints(service: Service) -> Mapping[str, Tuple[Callable, str]]:
"""Find additional methods that should be exposed as endpoints."""
methods = {}
reserved_methods = [attr for attr in dir(Service) if not attr.startswith("_")]
for name, method in inspect.getmembers(service, predicate=inspect.ismethod):
if name not in reserved_methods and not name.startswith("_"):
methods[name] = (method, getattr(method, "__path__", None))
return methods
[docs]def path(path: str):
"""Specify the path for a service method."""
def wrapped(fn: Callable):
def wrapped_fn(*args, **kwargs):
return fn(*args, **kwargs)
wrapped_fn.__path__ = path # type: ignore[attr-defined]
return wrapped_fn
return wrapped