"""
Server for OctoAI endpoints created with the ``octoai`` CLI.
Developers that want to create an endpoint should not use
this module directly. Instead, they should use the ``octoai``
command-line interface, which directs them to implement the
``octoai.service.Service`` class and use the ``octoai`` CLI to help
build and deploy their endpoint.
"""
import asyncio
import dataclasses
import enum
import importlib
import inspect
import json
import logging
import multiprocessing
import os
import signal
import sys
import time
from contextlib import asynccontextmanager
from http import HTTPStatus
from multiprocessing import Pipe, Process, Queue
from multiprocessing.connection import Connection
from typing import Any, Dict, NamedTuple, Optional, Type
import chevron
import click
import uvicorn
from fastapi import FastAPI
from fastapi.responses import JSONResponse, Response
from pydantic import BaseModel
from octoai import documenter
from octoai.documenter import MustacheGenerationResult
from .service import (
STORE_ASSETS_NOT_OVERRIDDEN,
VOLUME_ENVIRONMENT,
ResponseAnalytics,
Service,
find_additional_endpoints,
implements_form_data,
inspect_input_types,
inspect_output_types,
transform_form_data_signature,
)
_LOG = logging.getLogger(__name__)
_OCTOAI_SERVICE_MODULE = "octoai.service"
_OCTOAI_BASE_SERVICE_CLASS = "Service"
_PREDICT_LOOP_WATCHDOG_SECONDS = 2
"""Delay in seconds between checking if the predict loop is running."""
_process_mutex = multiprocessing.Lock()
"""Lock for spawning predict loop."""
_TERMINATE_PREDICT_LOOP_REQUEST = "terminate"
"""Sent to predict loop process to terminate it."""
_LOAD_MODEL_MESSAGE = "load_model"
"""Sent from predict loop when it is loading the model."""
_MODEL_LOADED_MESSAGE = "model_loaded"
"""Sent from predict loop when it is ready for inference."""
_STORE_ASSET_NEEDED_CODE = 10
"""Exit status code when store asset is required."""
_STORE_ASSET_NOT_NEEDED_CODE = 11
"""Exit status code when store asset is not required."""
[docs]class ServiceMethod(enum.Enum):
"""Class for distinguishing route implementations in the request queue."""
INFER_JSON = "infer"
INFER_FORM_DATA = "infer_form_data"
[docs]class InferenceRequest(NamedTuple):
"""Class for returning inference results."""
response_pipe: Connection
method: str
inputs: Any
[docs]class InferenceResponse(NamedTuple):
"""Class for returning inference results."""
inference_time_ms: float
outputs: Any
[docs]def is_store_assets_needed(ctx) -> bool:
"""Check if store-assets step is required."""
def check_store_assets_overriden(queue, ctx, dummy) -> None:
service = load_service(
ctx.parent.params["service_module"],
class_name=ctx.parent.params["service_class"],
)
queue.put(not hasattr(service.store_assets, STORE_ASSETS_NOT_OVERRIDDEN))
queue: Queue[bool] = Queue()
p = Process(target=check_store_assets_overriden, args=(queue, ctx, 0))
p.start()
p.join(timeout=10)
result = queue.get()
return result
[docs]def maybe_set_volume_environment_variables(ctx) -> None:
"""Set volume environment variables if needed."""
if not is_store_assets_needed(ctx):
return
for key, value in VOLUME_ENVIRONMENT.items():
os.environ[key] = value
def _predict_loop(service, _request_queue):
"""Loop which handles prediction requests.
This loop runs for the duration of the server and receives prediction
requests posted to the _REQUEST_QUEUE. When the request is done processing
the results are posted to the response_pipe where they are handled by
the main /predict endpoint.
"""
startup_pipe = _request_queue.get()
startup_pipe.send(_LOAD_MODEL_MESSAGE)
try:
service.setup()
except Exception as e:
_LOG.error("_predict_loop: model setup failed with {e}", exc_info=1)
startup_pipe.send(e)
startup_pipe.close()
return
startup_pipe.send(_MODEL_LOADED_MESSAGE)
startup_pipe.close()
def signal_handler(_signum, _frame):
# This will only kill the _predict_loop process, not the parent
sys.exit()
signal.signal(signal.SIGINT, signal_handler)
while True:
try:
inference_request = _request_queue.get()
if (
isinstance(inference_request, type(_TERMINATE_PREDICT_LOOP_REQUEST))
and inference_request == _TERMINATE_PREDICT_LOOP_REQUEST
):
sys.exit()
try:
start_time = time.perf_counter_ns()
infer_fn = getattr(service, inference_request.method)
results = infer_fn(**inference_request.inputs)
stop_time = time.perf_counter_ns()
response = InferenceResponse((stop_time - start_time) / 1e9, results)
except Exception as e:
_LOG.error("infer() raised Exception", exc_info=1)
response = e
inference_request.response_pipe.send(response)
inference_request.response_pipe.close()
except Exception:
# We only end up here if something went wrong outside the predict call
# continue loop
pass
[docs]class Server:
"""
Server for OctoAI endpoints created with the ``octoai`` CLI.
Developers that want to create an endpoint should not use
this class directly. Instead, they should use the ``octoai``
command-line interface, which directs them to implement the
``octoai.service.Service`` class and use the ``octoai`` CLI to
help build and deploy their endpoint.
"""
[docs] class State(enum.Enum):
"""Describes the states of Server."""
UNINITIALIZED = "UNINITIALIZED"
LAUNCH_PREDICT_LOOP = "LAUNCH_PREDICT_LOOP"
SETUP_SERVICE = "SETUP_SERVICE"
RUNNING = "RUNNING"
SHUTTING_DOWN = "SHUTTING_DOWN"
STOPPED = "STOPPED"
def __init__(self, service: Service, async_enable: bool = True):
self.app = FastAPI(lifespan=lambda _: self.prepare_for_serving())
self.service: Service = service
self._state = self.State.UNINITIALIZED
self.is_async = async_enable
self._request_queue: multiprocessing.Queue[Any] = None
self._predict_loop_watchdog_task: asyncio.Task = None
self.response_headers = {
"OCTOAI_REPLICA_NAME": os.environ.get("OCTOAI_REPLICA_NAME", ""),
}
# Build Pydantic models for input/ouput for /infer route
Input = inspect_input_types(service.infer)
Output = inspect_output_types(service.infer)
# Build Pydantic models for input/output for additional routes
Inputs, Outputs = {}, {}
additional_endpoints = find_additional_endpoints(service)
for method_name, method_info in additional_endpoints.items():
method, _ = method_info
Inputs[method_name] = inspect_input_types(method)
Outputs[method_name] = inspect_output_types(method)
# Used to read from the request queue in async mode
async def _pipe_reader(read: Connection):
"""Async multiprocessing.Pipe reader.
:param read: pipe file handle to read from.
:return: the contents of the pipe when read.
"""
data_available = asyncio.Event()
asyncio.get_event_loop().add_reader(read.fileno(), data_available.set)
if not read.poll():
await data_available.wait()
result = read.recv()
data_available.clear()
asyncio.get_event_loop().remove_reader(read.fileno())
return result
# Async mode: Put request in the queue for predict loop
# Sync mode: Call predict method directly
async def _infer_common(
service_method: str,
method_args: Dict[str, Any],
output_type: Type[BaseModel],
) -> Response:
if not self.is_running:
return JSONResponse(
status_code=HTTPStatus.SERVICE_UNAVAILABLE,
content={"status": self.state.name},
)
if self.is_async:
read_conn, write_conn = Pipe()
start_perf = time.perf_counter_ns()
request = InferenceRequest(write_conn, service_method, method_args)
self._request_queue.put(request)
response = await _pipe_reader(read_conn)
performance_time_ms = (time.perf_counter_ns() - start_perf) / 1e6
if isinstance(response, Exception):
raise response
prediction = response.outputs
inference_time_ms = response.inference_time_ms
else:
# track time elapsed in nanoseconds only while app is not asleep
start_process = time.process_time_ns()
# track time elapsed in nanoseconds including any sleep time
start_perf = time.perf_counter_ns()
infer_fn = getattr(service, service_method)
prediction = infer_fn(**method_args)
inference_time_ms = (time.process_time_ns() - start_process) / 1e6
performance_time_ms = (time.perf_counter_ns() - start_perf) / 1e6
return Response(
status_code=HTTPStatus.OK,
headers=self.response_headers,
media_type="application/json",
content=output_type(
output=prediction,
analytics=ResponseAnalytics(
inference_time_ms=inference_time_ms,
performance_time_ms=performance_time_ms,
),
).model_dump_json(),
)
# Implementation for form data route.
# This function signature is dynamically redefined later based on that of
# Service.infer_form_data() if it is implemented, so that FastAPI
# can know the parameters and their types when registering the route.
async def infer_form_data(**kwargs):
return await _infer_common(
service_method=ServiceMethod.INFER_FORM_DATA.value,
method_args=kwargs,
output_type=OutputFormData,
)
# Add form data route to FastAPI if implemented.
if implements_form_data(service):
infer_form_data.__signature__ = transform_form_data_signature(service)
OutputFormData = inspect_output_types(service.infer_form_data)
_LOG.info(
"adding endpoint infer_form_data() as %s",
service.infer_form_data.__path__,
)
self.app.add_api_route(
path=service.infer_form_data.__path__,
endpoint=infer_form_data,
methods=["POST"],
response_model=OutputFormData,
)
@self.app.get("/healthcheck")
def health() -> JSONResponse:
return JSONResponse(
status_code=HTTPStatus.OK
if self.is_running
else HTTPStatus.SERVICE_UNAVAILABLE,
content={"status": self.state.name, "async_enable": self.is_async},
)
@self.app.get("/")
def root() -> JSONResponse:
return JSONResponse(
status_code=HTTPStatus.OK,
content={
"docs": "/docs",
"openapi": "/openapi.json",
},
)
@self.app.post(
"/infer",
response_model=Output,
)
async def infer(request: Input) -> Response:
return await _infer_common(
service_method=ServiceMethod.INFER_JSON.value,
method_args={k: v for k, v in request},
output_type=Output,
)
# Create endpoint function for additional endpoints
def _get_infer_endpoint_fn(method_name: str):
async def _infer_endpoint(request: Inputs[method_name]) -> Response:
return await _infer_common(
service_method=method_name,
method_args={k: v for k, v in request},
output_type=Outputs[method_name],
)
return _infer_endpoint
# Add additional endpoints to FastAPI
for method_name, method_info in additional_endpoints.items():
method, method_path = method_info
method_path = method_path or f"/{method_name}".replace("_", "-")
_LOG.info("adding endpoint %s() as %s", method_name, method_path)
self.app.add_api_route(
path=method_path,
endpoint=_get_infer_endpoint_fn(method_name),
methods=["POST"],
response_model=Outputs[method_name],
)
@property
def is_running(self):
"""True when this server instance can serve a request."""
return self.state == self.State.RUNNING
@property
def state(self):
"""Get the status of this server."""
return self._state
@state.setter
def state(self, new_state):
"""Set the status of the server, and log transition."""
_LOG.info("status: %s -> %s", self._state, new_state)
self._state = new_state
[docs] @asynccontextmanager
async def prepare_for_serving(self):
"""Context manager that should surround all serving.
This is intended to be used as an ASGI application's lifetime handler.
"""
assert self.state in (
self.State.UNINITIALIZED,
self.State.STOPPED,
), f"prepare_for_serving: status not UNINITIALIZED or STOPPED: {self.state}"
_LOG.info("lifecycle: on_server_startup")
self.service.on_server_startup()
if self.is_async:
self.state = self.State.LAUNCH_PREDICT_LOOP
# load_into_memory is handled in predict loop so that subprocess
# does all GPU access.
self._start_predict_loop()
else:
self.state = self.State.SETUP_SERVICE
self.service.setup()
self.state = self.State.RUNNING
yield
self.state = self.State.SHUTTING_DOWN
if self.is_async:
self._stop_predict_loop()
_LOG.info("lifecycle: on_server_shutdown")
self.service.on_server_shutdown()
self.state = self.State.STOPPED
async def _check_predict_loop(self):
if not self._predict_process.is_alive():
self._start_predict_loop()
def _start_predict_loop(self):
context = multiprocessing.get_context("spawn")
if not self._request_queue:
# Only need to create this queue once. This function may be called
# multiple times if the predict loop dies.
self._request_queue = context.Queue()
self._predict_process = context.Process(
target=_predict_loop,
name="_predict_loop",
args=(self.service, self._request_queue),
)
self._predict_process.start()
read_conn, write_conn = Pipe()
def _remove_reader():
asyncio.get_event_loop().remove_reader(read_conn)
read_conn.close()
def read_startup_message():
try:
value = read_conn.recv()
except EOFError:
asyncio.get_event_loop().remove_reader(read_conn)
return
if isinstance(value, Exception):
_LOG.error("predict loop died during setup", exc_info=value)
# Though unconventional, this appears the be the proper way to cleanly
# terminate uvicorn. It's not easy to get ahold of
# uvicorn.Server.handle_exit() from here.
self._stop_predict_loop()
os.kill(os.getpid(), signal.SIGTERM)
if value == _LOAD_MODEL_MESSAGE:
if self.state != self.State.LAUNCH_PREDICT_LOOP:
_LOG.error(
"_read_predict_loop_startup_message: %s message: "
"not in expected state: %s",
value,
self.state,
)
_remove_reader()
return
self.state = self.State.SETUP_SERVICE
elif value == _MODEL_LOADED_MESSAGE:
if self.state != self.State.SETUP_SERVICE:
_LOG.error(
"_read_predict_loop_startup_message: %s message: "
"not in expected state: %s",
value,
self.state,
)
_remove_reader()
return
self.state = self.State.RUNNING
self._predict_loop_watchdog_task = asyncio.create_task(
self._predict_loop_watchdog( # noqa
_PREDICT_LOOP_WATCHDOG_SECONDS, self._check_predict_loop # noqa
)
)
_remove_reader()
asyncio.get_event_loop().add_reader(read_conn.fileno(), read_startup_message)
self._request_queue.put(write_conn)
def _stop_predict_loop(self):
if not self._request_queue:
return
if self._predict_loop_watchdog_task is not None:
self._predict_loop_watchdog_task.cancel()
self._request_queue.put(_TERMINATE_PREDICT_LOOP_REQUEST)
self._predict_process.join()
self._request_queue = None
async def _predict_loop_watchdog(self, interval, periodic_function):
while True:
await asyncio.gather(
asyncio.sleep(interval),
periodic_function(),
)
[docs] def get_api_schema(self) -> Dict[str, Any]:
"""Return the Open API schema for the underlying service."""
return self.app.openapi()
[docs] def get_usage_examples(
self, format: documenter.DocumentationFormat
) -> MustacheGenerationResult:
"""Return the Mustache generation result for the underlying service."""
return documenter.generate_usage_examples(self.service, format)
[docs] def store_assets(self) -> None:
"""Run service store assets."""
self.service.store_assets()
[docs] def run(self, port: int, timeout_keep_alive: int):
"""Run the server exposing the underlying service."""
uvicorn.run(
self.app,
host="0.0.0.0",
port=port,
timeout_keep_alive=timeout_keep_alive,
lifespan="on",
)
[docs]def load_service(module_name: str, class_name: Optional[str] = None) -> Service:
"""Load a class from service implementation."""
try:
module = importlib.import_module(module_name)
if class_name is not None:
# if service class is provided, instantiate it
class_ = getattr(module, class_name)
else:
# if service class not provided, look for it
class_ = None
for name, class_obj in inspect.getmembers(module, inspect.isclass):
for class_base in class_obj.__bases__:
if (
class_base.__module__ == _OCTOAI_SERVICE_MODULE
and class_base.__name__ == _OCTOAI_BASE_SERVICE_CLASS
):
class_ = class_obj
break
if class_ is None:
raise ValueError(
f"Module '{module_name}' contains no classes extending "
f"base '{_OCTOAI_SERVICE_MODULE}.{_OCTOAI_BASE_SERVICE_CLASS}'"
)
_LOG.info(f"Using service in {module_name}.{class_.__name__}.")
return class_()
except ModuleNotFoundError:
error_msg = f"Module '{module_name}' not found. "
if module_name == "service":
error_msg += "Ensure your service is defined in service.py."
raise ValueError(error_msg)
def _load_server(ctx, async_enable: bool = True) -> Server:
service = load_service(
ctx.parent.params["service_module"],
class_name=ctx.parent.params["service_class"],
)
return Server(service, async_enable=async_enable)
@click.group(name="server")
@click.option(
"--log-level",
type=click.Choice(["ERROR", "WARNING", "INFO", "DEBUG"], case_sensitive=False),
default="INFO",
envvar="OCTOAI_LOG_LEVEL",
)
@click.option("--service-module", default="service")
@click.option("--service-class", default=None)
@click.pass_context
def server(ctx, log_level, service_module, service_class):
"""CLI for OctoAI server."""
logging.basicConfig(
level=log_level,
stream=sys.stderr,
format="%(asctime)s %(levelname)s %(filename)s:%(lineno)d %(message)s",
)
click.echo("octoai server")
ctx.ensure_object(dict)
@server.command()
@click.option("--output-file", default=None)
@click.pass_context
def api_schema(ctx, output_file):
"""Generate OpenAPI schema for the given service."""
_LOG.info("api-schema")
server: Server = _load_server(ctx)
schema = server.get_api_schema()
if output_file:
with open(output_file, "w") as f:
json.dump(schema, f, indent=2)
else:
click.echo(json.dumps(schema, indent=2))
@server.command(hidden=True)
@click.option("--no-render", is_flag=True, default=False)
@click.option(
"--format",
required=True,
type=click.Choice(["curl", "python"], case_sensitive=False),
)
@click.option("--output-file", default=None)
@click.pass_context
def generate_usage_examples(ctx, no_render, format, output_file):
"""Generate client usage examples for the given service."""
_LOG.info("generate-usage-examples")
def convert_format(f: str) -> documenter.DocumentationFormat:
if f.lower() == "curl":
return documenter.DocumentationFormat.CURL
elif f.lower() == "python":
return documenter.DocumentationFormat.PYTHON
else:
raise ValueError(f"format '{f}' is not supported")
server: Server = _load_server(ctx)
generation_result = server.get_usage_examples(convert_format(format))
if not no_render:
rendered_result = chevron.render(
generation_result.generated_template, generation_result.template_hash
)
if output_file:
with open(output_file, "w") as f:
f.write(rendered_result)
else:
click.echo(rendered_result)
else:
if output_file:
with open(output_file, "w") as f:
json.dump(dataclasses.asdict(generation_result), f, indent=2)
else:
click.echo(json.dumps(dataclasses.asdict(generation_result), indent=2))
@server.command()
@click.pass_context
def setup_service(ctx):
"""Run the setup code for the given service."""
_LOG.info("setup-service")
server: Server = _load_server(ctx, async_enable=False)
server.service.setup()
@server.command()
@click.pass_context
@click.option("--port", type=int, default=8080)
@click.option("--async-enable", default=True)
@click.option(
"--timeout-keep-alive",
type=int,
default=900,
help="Connection keep alive timeout in seconds",
)
def run(ctx, port, async_enable, timeout_keep_alive):
"""Run the server for the given service."""
_LOG.info("run")
maybe_set_volume_environment_variables(ctx)
server: Server = _load_server(ctx, async_enable=async_enable)
server.run(port, timeout_keep_alive)
@server.command()
@click.pass_context
@click.option("--check-is-needed", is_flag=True)
def store_assets(ctx, check_is_needed):
"""Run the store_assets code for the given model."""
_LOG.info("store_assets")
if check_is_needed:
sys.exit(
_STORE_ASSET_NEEDED_CODE
if is_store_assets_needed(ctx)
else _STORE_ASSET_NOT_NEEDED_CODE
)
maybe_set_volume_environment_variables(ctx)
server: Server = _load_server(ctx, async_enable=False)
server.store_assets()
if __name__ == "__main__":
server(obj={})