"""
Entities and APIs for working with text generation models.
Instead of using these classes directly, developers should
use the octoai.client.Client class. For example:
client = octoai.client.Client()
completion = client.chat.completions.create(...)
"""
from enum import Enum
from typing import Iterable, List, Optional, Union
from pydantic import BaseModel, ValidationError
from typing_extensions import Literal
from clients.ollm.models.chat_completion_response_format import (
ChatCompletionResponseFormat,
)
from clients.ollm.models.chat_message import ChatMessage
from clients.ollm.models.create_chat_completion_request import (
CreateChatCompletionRequest,
)
from octoai.client import Client
from octoai.errors import OctoAIValidationError
TEXT_DEFAULT_ENDPOINT = "https://text.octoai.run/v1/chat/completions"
TEXT_SECURELINK_ENDPOINT = "https://text.securelink.octo.ai/v1/chat/completions"
[docs]class TextModel(str, Enum):
"""List of available text models."""
LLAMA_2_13B_CHAT = "llama-2-13b-chat"
LLAMA_2_70B_CHAT = "llama-2-70b-chat"
CODELLAMA_7B_INSTRUCT = "codellama-7b-instruct"
CODELLAMA_13B_INSTRUCT = "codellama-13b-instruct"
CODELLAMA_34B_INSTRUCT = "codellama-34b-instruct"
CODELLAMA_70B_INSTRUCT = "codellama-70b-instruct"
MISTRAL_7B_INSTRUCT = "mistral-7b-instruct"
MIXTRAL_8X7B_INSTRUCT = "mixtral-8x7b-instruct"
[docs] def to_name(self):
"""Return the name of the model."""
if self == self.LLAMA_2_13B_CHAT:
return "llama-2-13b-chat"
elif self == self.LLAMA_2_70B_CHAT:
return "llama-2-70b-chat"
elif self == self.CODELLAMA_7B_INSTRUCT:
return "codellama-7b-instruct"
elif self == self.CODELLAMA_13B_INSTRUCT:
return "codellama-13b-instruct"
elif self == self.CODELLAMA_34B_INSTRUCT:
return "codellama-34b-instruct"
elif self == self.CODELLAMA_70B_INSTRUCT:
return "codellama-70b-instruct"
elif self == self.MISTRAL_7B_INSTRUCT:
return "mistral-7b-instruct"
elif self == self.MIXTRAL_8X7B_INSTRUCT:
return "mixtral-8x7b-instruct"
[docs]def get_model_list() -> List[str]:
"""Return a list of available text models."""
return [model.value for model in TextModel]
[docs]class ChoiceDelta(BaseModel):
"""Contents for streaming text completion responses."""
content: Optional[str] = None
role: Optional[Literal["system", "user", "assistant", "tool"]] = None
[docs]class Choice(BaseModel):
"""A single choice in a text completion response."""
index: int
message: ChatMessage = None
delta: ChoiceDelta = None
finish_reason: Optional[
Literal["stop", "length", "tool_calls", "content_filter", "function_call"]
] = None
[docs]class CompletionUsage(BaseModel):
"""Usage statistics for a text completion response."""
completion_tokens: int
prompt_tokens: int
total_tokens: int
[docs]class ChatCompletion(BaseModel):
"""A text completion response."""
id: str
choices: List[Choice]
created: int
model: str
object: Optional[Literal["chat.completion", "chat.completion.chunk"]] = None
system_fingerprint: Optional[str] = None
usage: Optional[CompletionUsage] = None
[docs]class Completions:
"""Text completions API."""
client: Client
endpoint: str = TEXT_DEFAULT_ENDPOINT
def __init__(self, client: Client) -> None:
self.client = client
if self.client.secure_link:
self.endpoint = TEXT_SECURELINK_ENDPOINT
[docs] def create(
self,
*,
messages: List[ChatMessage],
model: Union[str, TextModel],
frequency_penalty: Optional[float] = 0.0,
max_tokens: Optional[int] = None,
presence_penalty: Optional[float] = 0.0,
response_format: Optional[ChatCompletionResponseFormat] = None,
stop: Optional[str] = None,
stream: Optional[bool] = False,
temperature: Optional[float] = 1.0,
top_p: Optional[float] = 1.0,
) -> Union[ChatCompletion, Iterable[ChatCompletion]]:
"""
Create a chat completion with a text generation model.
:param messages: Required. A list of messages to use as context for the
completion.
:param model: Required. The model to use for the completion. Supported models
are listed in the `octoai.chat.TextModel` enum.
:param frequency_penalty: Positive values make it less likely that the model
repeats tokens several times in the completion. Valid values are between
-2.0 and 2.0.
:param max_tokens: The maximum number of tokens to generate.
:param presence_penalty: Positive values make it less likely that the model
repeats tokens in the completion. Valid values are between -2.0 and 2.0.
:param response_format: An object specifying the format that the model must
output.
:param stop: A list of sequences where the model stops generating tokens.
:param stream: Whether to return a generator that yields partial message
deltas as they become available, instead of waiting to return the entire
response.
:param temperature: Sampling temperature. A value between 0 and 2. Higher values
make the model more creative by sampling less likely tokens.
:param top_p: The cumulative probability of the most likely tokens to use. Use
`temperature` or `top_p` but not both.
"""
request = CreateChatCompletionRequest(
messages=messages,
model=model.value if isinstance(model, TextModel) else model,
frequency_penalty=frequency_penalty,
function_call=None,
functions=None,
logit_bias=None,
max_tokens=max_tokens,
n=1,
presence_penalty=presence_penalty,
response_format=response_format,
stop=stop,
stream=stream,
temperature=temperature,
top_p=top_p,
user=None,
)
inputs = request.to_dict()
if stream:
return self.client.infer_stream(
self.endpoint, inputs, map_fn=lambda resp: ChatCompletion(**resp)
) # type: ignore
resp = self.client.infer(self.endpoint, inputs)
try:
return ChatCompletion(**resp)
except ValidationError as e:
raise OctoAIValidationError(
"Unable to validate response from server.", caused_by=e
)
[docs]class Chat:
"""Chat API for text generation models."""
completions: Completions
def __init__(self, client: Client):
self.completions = Completions(client)