"""OctoAI Image Generation."""
from __future__ import (
annotations, # required to allow 3.7+ python use type | syntax introduced in 3.10
)
from enum import Enum
from typing import Dict, List
from clients.image_gen.models import ImageEncoding, Scheduler, SDXLStyles
from clients.image_gen.models.pre_defined_styles import PreDefinedStyles
from octoai.client import Client
from octoai.clients.asset_orch import Asset
from octoai.errors import OctoAIValidationError
from octoai.types import Image
SDXL_ALLOWABLE_WIDTH_TO_HEIGHT = {
1024: {1024, 768},
896: {1152, 896},
832: {1216, 512},
768: {1344, 1024, 576, 512},
704: {1216, 384},
640: {1536, 768, 640},
576: {1024, 448, 768},
512: {832, 768, 704, 512},
448: {576},
384: {704},
1536: {640},
1344: {768},
1216: {832, 704},
1152: {896},
}
SD_ALLOWABLE_WIDTH_TO_HEIGHT = {
576: {1024, 768},
512: {512, 704, 768},
640: {512, 768},
768: {512},
1024: {576},
}
SSD_ALLOWABLE_WIDTH_TO_HEIGHT = {
640: {1536},
768: {1344},
832: {1216},
896: {1152},
1024: {1024},
1152: {896},
1216: {832},
1344: {768},
1536: {640},
}
MAX_CFG_SCALE = 50
MAX_NUM_IMAGES = 16
MAX_STEPS = 200
[docs]class Engine(str, Enum):
"""
SDXL: Stable Diffusion XL
SD: Stable Diffusion
"""
SDXL = "sdxl"
SD = "sd"
SSD = "ssd"
CONTROLNET_SDXL = "controlnet-sdxl"
CONTROLNET_SD = "controlnet-sd"
[docs] def to_name(self):
if self == self.SDXL:
return "Stable Diffusion XL"
elif self == self.SD:
return "Stable Diffusion"
elif self == self.SSD:
return "Stable Diffusion SSD"
elif self == self.CONTROLNET_SDXL:
return "ControlNet Stable Diffusion XL"
elif self == self.CONTROLNET_SD:
return "ControlNet Stable Diffusion 1.5"
[docs]class ImageGenerateResponse:
"""
Image generation response.
Contains a list of images as well as a counter of those filtered for safety.
"""
def __init__(self, images: List[Image], removed_for_safety: int):
self._images = images
self._removed_for_safety = removed_for_safety
@property
def images(self) -> List[Image]:
"""Return list of :class:`Image` generated from request."""
return self._images
@property
def removed_for_safety(self) -> int:
"""Return int representing number of images removed for safety."""
return self._removed_for_safety
[docs]class ImageGenerator(Client):
"""Client for image generation."""
def __init__(
self,
api_endpoint: str | None = None,
*args,
**kwargs,
):
if not api_endpoint:
api_endpoint = "https://image.octoai.run/"
if not api_endpoint.endswith("/"):
api_endpoint += "/"
self.api_endpoint = api_endpoint
super(ImageGenerator, self).__init__(*args, **kwargs)
if self._httpx_client.headers.get("Authorization") is None:
msg = (
"Authorization is required. Please set an `OCTOAI_TOKEN` "
"environment variable, or pass your token to the client using "
"`client = ImageGenerator(token='your-octoai-api-token')`"
)
raise OctoAIValidationError(msg)
# Raises OctoAIValidationError on failure to validate variable.
# Does not validate strings currently, though should once API is stable.
# TODO: standardize error strings (once more input types are known).
def _validate_inputs(
self,
engine: Engine | str,
cfg_scale: float | None,
height: int | None,
high_noise_frac: float | None,
num_images: int | None,
seed: int | None,
steps: int | None,
strength: float | None,
width: int | None,
image_encoding: ImageEncoding | str | None,
sampler: Scheduler | str | None,
prompt_2: str | None,
negative_prompt_2: str | None,
use_refiner: bool | None,
init_image: str | None = None,
controlnet: str | None = None,
controlnet_image: str | None = None,
controlnet_conditioning_scale: float | None = None,
style_preset: str | None = None,
):
"""Validate inputs."""
engines = [e.value for e in Engine]
if engine not in engines:
raise OctoAIValidationError(
f"engine set to {engine}. Must be one of: {', '.join(engines)}."
)
# Check only compatible with engine attributes being used
if high_noise_frac and engine == "sd":
self._input_not_match_engine(
"high_noise_frac", high_noise_frac, engine, "sdxl"
)
if prompt_2 and engine == "sd":
self._input_not_match_engine("prompt_2", prompt_2, engine, "sdxl")
if negative_prompt_2 and engine == "sd":
self._input_not_match_engine(
"negative_prompt_2", negative_prompt_2, engine, "sdxl"
)
if use_refiner is not None and engine == "sd":
self._input_not_match_engine("use_refiner", use_refiner, engine, "sdxl")
# Check number values in range
if cfg_scale is not None and ((0 > cfg_scale) or (MAX_CFG_SCALE < cfg_scale)):
raise OctoAIValidationError(
f"cfg_scale set to: {cfg_scale}. Allowable range is > 0 and <= {MAX_CFG_SCALE}."
)
if high_noise_frac is not None and (0 > high_noise_frac or high_noise_frac > 1):
raise OctoAIValidationError(
f"high_noise_frac set to: {high_noise_frac}. Allowable range is "
f">= 0 and <= 1."
)
if num_images is not None and (0 >= num_images or num_images > MAX_NUM_IMAGES):
raise OctoAIValidationError(
f"num_images set to: {num_images}. Allowable range is > 0 and <= {MAX_NUM_IMAGES}."
)
if isinstance(seed, list):
for each in seed:
if each is not None and (0 > each or each >= 2**32):
raise OctoAIValidationError(
f"seed({seed}) contains {each}. Allowable range is >= 0 and "
f"< 2**32."
)
if type(seed) == int and (0 > seed or seed >= 2**32):
raise OctoAIValidationError(
f"seed set to: {seed}. Allowable range is >= 0 and < 2**32."
)
if steps is not None and (0 >= steps or steps > MAX_STEPS):
raise OctoAIValidationError(
f"steps set to: {steps}. Allowable range is > 0 and <= {MAX_STEPS}."
)
if strength is not None and (0 > strength or strength > 1):
raise OctoAIValidationError(
f"strength set to: {strength}. Allowable range is >= 0 and <= 1."
)
if (
controlnet_conditioning_scale is not None
and controlnet_conditioning_scale < 0
):
raise OctoAIValidationError(
f"controlnet_conditional_scale set to: {controlnet_conditioning_scale}."
" Allowable range is >= 0."
)
# More verifying required value
if strength is not None and init_image is None:
raise OctoAIValidationError(
f"init_image required for img2img generation. "
f"strength({strength}) cannot be set if "
f"init_image is None."
)
if style_preset is not None and engine == "sd":
self._input_not_match_engine("style_preset", style_preset, engine, "sdxl")
if style_preset is not None and engine == "sdxl":
try:
style_preset = PreDefinedStyles(style_preset)
except ValueError:
msg = (
f"style_preset({style_preset}) is not valid. "
f"Valid options include {[e.value for e in SDXLStyles]}."
)
raise OctoAIValidationError(msg)
# Validate width and height to engine
if (height is None) ^ (width is None):
raise OctoAIValidationError(
f"if height({height}) or width({width}) is set "
f"to None, both must be None."
)
engine = Engine(engine)
if height is not None and width is not None:
self._validate_height_and_width_to_engine(engine, height, width)
if init_image is not None:
if not isinstance(init_image, Image):
init_image = Image(init_image)
if not init_image.is_valid():
msg = (
"init_image is not a valid image. May either use the "
"octoai.types Image class or a base64 string."
)
raise OctoAIValidationError(msg)
if controlnet_image is not None:
if not isinstance(controlnet_image, Image):
controlnet_image = Image(controlnet_image)
if not controlnet_image.is_valid():
msg = (
"controlnet_image is not a valid image. May either use the "
"octoai.types Image class or a base64 string."
)
raise OctoAIValidationError(msg)
if engine in [
Engine.CONTROLNET_SDXL,
Engine.CONTROLNET_SDXL.value,
Engine.CONTROLNET_SD,
Engine.CONTROLNET_SD.value,
]:
if controlnet_image is None:
raise OctoAIValidationError(
f"controlnet_image is required for engine {engine}."
)
if controlnet is None:
raise OctoAIValidationError(
f"controlnet is required for engine {engine}."
)
# Server will return a 500 error if incorrect height and width are entered
@staticmethod
def _validate_height_and_width_to_engine(engine: Engine, height: int, width: int):
width_to_height_by_engine = {
Engine.SDXL.value: SDXL_ALLOWABLE_WIDTH_TO_HEIGHT,
Engine.SD.value: SD_ALLOWABLE_WIDTH_TO_HEIGHT,
Engine.SSD.value: SSD_ALLOWABLE_WIDTH_TO_HEIGHT,
Engine.CONTROLNET_SDXL.value: SDXL_ALLOWABLE_WIDTH_TO_HEIGHT,
Engine.CONTROLNET_SD.value: SD_ALLOWABLE_WIDTH_TO_HEIGHT,
}
# Set to correct const allowable values
width_to_height = width_to_height_by_engine.get(engine)
if width_to_height.get(width) is None:
raise OctoAIValidationError(
f"width ({width}): height ({height}) "
f"values must match {engine.to_name()} allowable values or both be "
f"None. Valid values for width are {list(width_to_height.keys())}."
)
if height not in width_to_height.get(width):
raise OctoAIValidationError(
f"width ({width}): height ({height}) "
f"values must match {engine.to_name()} allowable values or both be "
f"None. Valid width:height values are: {width_to_height}."
)
@staticmethod
def _input_not_match_engine(name, value, engine, ok_engine):
raise OctoAIValidationError(
f"{name}({value}) is set but engine is set to "
f"{engine}. {name} can only be used with {ok_engine}."
)
# May be worthwhile to limit steps to 50, but at this time, 100 matches the UI
# num_images limit of 10 temp set by Itay, but part of larger limit
# discussion.
[docs] def generate(
self,
engine: Engine | str,
prompt: str,
prompt_2: str | None = None, # SDXL only
negative_prompt: str | None = None,
negative_prompt_2: str | None = None, # SDXL only
checkpoint: str | Asset | None = None,
vae: str | Asset | None = None,
textual_inversions: Dict[str | Asset, str] | None = None,
loras: Dict[str | Asset, float] | None = None,
sampler: str | Scheduler | None = None, # Server default DDIM
height: int | None = None, # Different defaults for sdxl and sd
width: int | None = None,
cfg_scale: float | None = 12.0,
steps: int | None = 30,
num_images: int | None = 1,
seed: int | List[int] | None = None,
init_image: str | Image | None = None, # b64, img2img only
controlnet: str | None = None, # controlnet-sdxl|sd only
controlnet_image: str | Image | None = None, # b64, controlnet-sdxl|sd only
controlnet_conditioning_scale: float
| None = None, # 1.0 server default, controlnet-sdxl|sd only
strength: float | None = None, # 0.8 server default, img2img only
style_preset: str | SDXLStyles | None = None,
use_refiner: bool | None = None, # True default, SDXL only
high_noise_frac: float | None = None, # 0.8 server default, SDXL only
enable_safety: bool | None = True,
image_encoding: ImageEncoding | str = None,
) -> ImageGenerateResponse:
"""
Generate a list of images based on request.
:param engine: Required. "sdxl" for Stable Diffusion XL; "sd" for Stable
Diffusion 1.5; "ssd" for Stable Diffusion SSD; "controlnet-sdxl" for
ControlNet Stable Diffusion XL; "controlnet-sd" for ControlNet Stable
Diffusion 1.5.
:param prompt: Required. Describes the image to generate.
ex. "An octopus playing chess, masterpiece, photorealistic"
:param prompt_2: High level description of the image to generate,
defaults to None.
:param negative_prompt: Description of image traits to avoid, defaults to None.
ex. "Fingers, distortions"
:param negative_prompt_2: High level description of things to avoid during
generation, defaults to None.
ex. "Unusual proportions and distorted faces"
:param checkpoint: Which checkpoint to use for inferences, defaults to None.
:param vae: Custom VAE to be used during image generation, defaults to None.
:param textual_inversions: A dictionary of textual inversion updates,
defaults to None
ex. {'name': 'trigger_word'}
:param loras: A dictionary of LoRAs updates to apply and their weight, can also
be used with Assets created in the SDK directly, defaults to None.
ex. {'crayon-style': 0.3, my_created_asset: 0.1}
:param sampler: :class:`Scheduler` to use when generating image,
defaults to None.
:param height: Height of image to generate, defaults to None.
:param width: Width of image to generate, defaults to None.
:param cfg_scale: How closely to adhere to prompt description, defaults to 12.0.
Must be >= 0 and <= 50.
:param steps: How many steps of diffusion to run, defaults to 30.
May be > 0 and <= 100.
:param num_images: How many images to generate, defaults to 1.
May be > 0 and <= 4.
:param seed: Fixed random seed, useful when attempting to generate a
specific image, defaults to None.
May be >= 0 < 2**32.
:param init_image: Starting image for img2img mode, defaults to None.
Requires a b64 string image or :class:`Image`.
:param controlnet: String matching id of controlnet to use for controlnet engine
inferences, defaults to None. Required for using controlnet engines.
:param controlnet_image: Starting image for controlnet-sdxl mode, defaults to None.
Requires a b64 string image or :class:`Image`.
:param controlnet_conditioning_scale: How strong the effect of the controlnet
should be, defaults to 1.0.
:param strength: How much creative to be in img2img mode, defaults to 0.8.
May be >= 0 and <= 1. Must have an `init_image`.
:param style_preset: Used to guide the output image towards a particular style,
only usable with SDXL,defaults to None.
ex. "low-poly"
:param use_refiner: Whether to apply the sdxl refiner, defaults to True.
:param high_noise_frac: Which fraction of steps to perform with the base model,
defaults to 0.8.
May be >= 0 and <= 1.
:param enable_safety: Whether to use safety checking on generated outputs or
not, defaults to True.
:param image_encoding: Choose returned :class:`ImageEncoding` type,
defaults to :class:`ImageEncoding.JPEG`.
:return: :class:`GenerateImagesResponse` object including properties for a list
of images as well as a counter of total images returned below the
`num_images` value due to being removed for safety.
"""
self._validate_inputs(
engine,
cfg_scale,
height,
high_noise_frac,
num_images,
seed,
steps,
strength,
width,
image_encoding,
sampler,
prompt_2,
negative_prompt_2,
use_refiner,
init_image,
controlnet,
controlnet_image,
controlnet_conditioning_scale,
style_preset,
)
if isinstance(init_image, Image):
init_image = init_image.to_base64()
if isinstance(controlnet_image, Image):
controlnet_image = controlnet_image.to_base64()
inputs = self._process_local_vars_to_inputs_dict(locals())
images = []
endpoint = self.api_endpoint + "generate/" + engine
output = self.infer(endpoint, inputs)
removed_for_safety = 0
for image_b64 in output.get("images"):
if image_b64.get("removed_for_safety"):
removed_for_safety += 1
else:
image_b64_str = image_b64.get("image_b64")
image = Image(image_b64_str)
images.append(image)
return ImageGenerateResponse(
images=images, removed_for_safety=removed_for_safety
)
# If the key is of type str, it does nothing, otherwise it returns a dict where
# objects with id fields have the id field used as the key instead.
# Examples are Assets as keys.
@staticmethod
def _replace_object_keys_with_ids(obj_dict: dict) -> dict:
result = {}
for key, value in obj_dict.items():
if isinstance(key, Asset):
key = key.id
elif type(key) != str:
msg = (
f"key({key}) is invalid. `loras` and `textual_inversions` "
f"require keys in dictionary to either be a str "
f"or :class:`Asset`."
)
raise OctoAIValidationError(msg)
result[key] = value
return result
# Purges irrelevant locals from inputs dict and converts Asset type to ids
def _process_local_vars_to_inputs_dict(self, inputs):
inputs.pop("self")
for key in list(inputs):
if inputs[key] is None:
inputs.pop(key)
inputs.pop("engine")
if "loras" in inputs:
inputs["loras"] = self._replace_object_keys_with_ids(inputs["loras"])
if "textual_inversions" in inputs:
inputs["textual_inversions"] = self._replace_object_keys_with_ids(
inputs["textual_inversions"]
)
if isinstance(inputs.get("vae"), Asset):
inputs["vae"] = inputs["vae"].id
if isinstance(inputs.get("checkpoint"), Asset):
inputs["checkpoint"] = inputs["checkpoint"].id
return inputs