Compare commits

..

1 Commits

2 changed files with 22 additions and 26 deletions

View File

@ -543,18 +543,24 @@ class SDTokenizer:
def _try_get_embedding(self, embedding_name:str):
'''
Takes a potential embedding name and tries to retrieve it.
Returns a Tuple consisting of the embedding and any leftover string, embedding can be None.
Returns a Tuple consisting of the embedding, the cleaned embedding name, and any leftover string, embedding can be None.
'''
split_embed = embedding_name.split()
embedding_name = split_embed[0]
leftover = ' '.join(split_embed[1:])
match = re.search(r'[<\[]', embedding_name)
if match is not None:
leftover = embedding_name[match.start():] + (" " + leftover if leftover else "")
embedding_name = embedding_name[:match.start()]
embed = load_embed(embedding_name, self.embedding_directory, self.embedding_size, self.embedding_key)
if embed is None:
stripped = embedding_name.strip(',')
if len(stripped) < len(embedding_name):
embed = load_embed(stripped, self.embedding_directory, self.embedding_size, self.embedding_key)
return (embed, "{} {}".format(embedding_name[len(stripped):], leftover))
return (embed, leftover)
return (embed, embedding_name, "{} {}".format(embedding_name[len(stripped):], leftover))
return (embed, embedding_name, leftover)
def pad_tokens(self, tokens, amount):
if self.pad_left:
@ -585,7 +591,7 @@ class SDTokenizer:
tokens = []
for weighted_segment, weight in parsed_weights:
to_tokenize = unescape_important(weighted_segment)
split = re.split(' {0}|\n{0}'.format(self.embedding_identifier), to_tokenize)
split = re.split(r'(?<=\s){}'.format(re.escape(self.embedding_identifier)), to_tokenize)
to_tokenize = [split[0]]
for i in range(1, len(split)):
to_tokenize.append("{}{}".format(self.embedding_identifier, split[i]))
@ -595,7 +601,7 @@ class SDTokenizer:
# if we find an embedding, deal with the embedding
if word.startswith(self.embedding_identifier) and self.embedding_directory is not None:
embedding_name = word[len(self.embedding_identifier):].strip('\n')
embed, leftover = self._try_get_embedding(embedding_name)
embed, embedding_name, leftover = self._try_get_embedding(embedding_name)
if embed is None:
logging.warning(f"warning, embedding:{embedding_name} does not exist, ignoring")
else:

View File

@ -844,18 +844,15 @@ class ImageMergeTileList(IO.ComfyNode):
# Format specifications
# ---------------------------------------------------------------------------
# Maps (file_format, bit_depth, num_channels) -> (quantization scale, numpy dtype,
# av frame pix_fmt, stream pix_fmt). Keeps the encode path declarative instead of branchy.
# Maps (file_format, bit_depth, has_alpha) -> (numpy dtype scale, av pixel format,
# stream pix_fmt). Keeps the encode path declarative instead of branchy.
_FORMAT_SPECS = {
("png", "8-bit", 1): {"scale": 255.0, "dtype": np.uint8, "frame_fmt": "gray", "stream_fmt": "gray"},
("png", "8-bit", 3): {"scale": 255.0, "dtype": np.uint8, "frame_fmt": "rgb24", "stream_fmt": "rgb24"},
("png", "8-bit", 4): {"scale": 255.0, "dtype": np.uint8, "frame_fmt": "rgba", "stream_fmt": "rgba"},
("png", "16-bit", 1): {"scale": 65535.0, "dtype": np.uint16, "frame_fmt": "gray16le", "stream_fmt": "gray16be"},
("png", "16-bit", 3): {"scale": 65535.0, "dtype": np.uint16, "frame_fmt": "rgb48le", "stream_fmt": "rgb48be"},
("png", "16-bit", 4): {"scale": 65535.0, "dtype": np.uint16, "frame_fmt": "rgba64le", "stream_fmt": "rgba64be"},
("exr", "32-bit float", 1): {"scale": 1.0, "dtype": np.float32, "frame_fmt": "grayf32le", "stream_fmt": "grayf32le"},
("exr", "32-bit float", 3): {"scale": 1.0, "dtype": np.float32, "frame_fmt": "gbrpf32le", "stream_fmt": "gbrpf32le"},
("exr", "32-bit float", 4): {"scale": 1.0, "dtype": np.float32, "frame_fmt": "gbrapf32le", "stream_fmt": "gbrapf32le"},
("png", "8-bit", False): {"scale": 255.0, "dtype": np.uint8, "frame_fmt": "rgb24", "stream_fmt": "rgb24"},
("png", "8-bit", True): {"scale": 255.0, "dtype": np.uint8, "frame_fmt": "rgba", "stream_fmt": "rgba"},
("png", "16-bit", False): {"scale": 65535.0, "dtype": np.uint16, "frame_fmt": "rgb48le", "stream_fmt": "rgb48be"},
("png", "16-bit", True): {"scale": 65535.0, "dtype": np.uint16, "frame_fmt": "rgba64le", "stream_fmt": "rgba64be"},
("exr", "32-bit float", False): {"scale": 1.0, "dtype": np.float32, "frame_fmt": "gbrpf32le", "stream_fmt": "gbrpf32le"},
("exr", "32-bit float", True): {"scale": 1.0, "dtype": np.float32, "frame_fmt": "gbrapf32le", "stream_fmt": "gbrapf32le"},
}
@ -1090,8 +1087,7 @@ def _encode_image(
bit_depth: str,
colorspace: str,
) -> bytes:
"""Encode a single HxWxC (or channel-less HxW grayscale) tensor to PNG or
EXR bytes in memory. Grayscale is written as single-channel PNG / Y-only EXR.
"""Encode a single HxWxC tensor to PNG or EXR bytes in memory.
For EXR the input is interpreted according to `colorspace` and converted
to scene-linear (EXR's convention) before writing:
@ -1105,16 +1101,10 @@ def _encode_image(
For PNG, colorspace selection does not modify pixels — PNG is delivered
sRGB-encoded and there is no PNG path for wide-gamut HDR in this node.
"""
if img_tensor.ndim == 2:
img_tensor = img_tensor.unsqueeze(-1) # Some nodes emit grayscale as (H, W) with no channel dim, mask-style.
height, width, num_channels = img_tensor.shape
has_alpha = num_channels == 4
spec = _FORMAT_SPECS.get((file_format, bit_depth, num_channels))
if spec is None:
raise ValueError(
f"No {file_format}/{bit_depth} encoder for {num_channels}-channel images: "
"supported channel counts are 1 (grayscale), 3 (RGB) and 4 (RGBA)."
)
spec = _FORMAT_SPECS[(file_format, bit_depth, has_alpha)]
if spec["dtype"] == np.float32:
# EXR path: preserve full range, no clamp.