Compare commits

..

7 Commits

7 changed files with 82 additions and 67 deletions

View File

@ -4,12 +4,12 @@ early_access: false
tone_instructions: "Only comment on issues introduced by this PR's changes. Do not flag pre-existing problems in moved, re-indented, or reformatted code."
reviews:
profile: "chill"
request_changes_workflow: false
profile: "assertive"
request_changes_workflow: true
high_level_summary: false
poem: false
review_status: false
review_details: false
review_details: true
commit_status: true
collapse_walkthrough: true
changed_files_summary: false
@ -39,6 +39,14 @@ reviews:
- path: "**"
instructions: |
IMPORTANT: Only comment on issues directly introduced by this PR's code changes.
Treat AGENTS.md as mandatory repository policy, not optional style guidance.
Flag PR changes that violate AGENTS.md even when the code is otherwise functional.
In particular, enforce architecture boundaries, dtype/device/memory rules,
interface contracts, import style, no unnecessary try/except blocks, no inline
imports, no outbound internet paths in core ComfyUI, and narrow scoped fixes.
Prefer direct findings over suggestions when a rule is violated. Only ignore
AGENTS.md when it clearly conflicts with a newer explicit maintainer instruction
in the PR.
Do NOT flag pre-existing issues in code that was merely moved, re-indented,
de-indented, or reformatted without logic changes. If code appears in the diff
only due to whitespace or structural reformatting (e.g., removing a `with:` block),
@ -123,5 +131,10 @@ chat:
knowledge_base:
opt_out: false
code_guidelines:
enabled: true
filePatterns:
- files: "AGENTS.md"
applyTo: "**"
learnings:
scope: "auto"

1
CLAUDE.md Symbolic link
View File

@ -0,0 +1 @@
AGENTS.md

View File

@ -543,18 +543,24 @@ class SDTokenizer:
def _try_get_embedding(self, embedding_name:str):
'''
Takes a potential embedding name and tries to retrieve it.
Returns a Tuple consisting of the embedding and any leftover string, embedding can be None.
Returns a Tuple consisting of the embedding, the cleaned embedding name, and any leftover string, embedding can be None.
'''
split_embed = embedding_name.split()
embedding_name = split_embed[0]
leftover = ' '.join(split_embed[1:])
match = re.search(r'[<\[]', embedding_name)
if match is not None:
leftover = embedding_name[match.start():] + (" " + leftover if leftover else "")
embedding_name = embedding_name[:match.start()]
embed = load_embed(embedding_name, self.embedding_directory, self.embedding_size, self.embedding_key)
if embed is None:
stripped = embedding_name.strip(',')
if len(stripped) < len(embedding_name):
embed = load_embed(stripped, self.embedding_directory, self.embedding_size, self.embedding_key)
return (embed, "{} {}".format(embedding_name[len(stripped):], leftover))
return (embed, leftover)
return (embed, embedding_name, "{} {}".format(embedding_name[len(stripped):], leftover))
return (embed, embedding_name, leftover)
def pad_tokens(self, tokens, amount):
if self.pad_left:
@ -585,7 +591,7 @@ class SDTokenizer:
tokens = []
for weighted_segment, weight in parsed_weights:
to_tokenize = unescape_important(weighted_segment)
split = re.split(' {0}|\n{0}'.format(self.embedding_identifier), to_tokenize)
split = re.split(r'(?<=\s){}'.format(re.escape(self.embedding_identifier)), to_tokenize)
to_tokenize = [split[0]]
for i in range(1, len(split)):
to_tokenize.append("{}{}".format(self.embedding_identifier, split[i]))
@ -595,7 +601,7 @@ class SDTokenizer:
# if we find an embedding, deal with the embedding
if word.startswith(self.embedding_identifier) and self.embedding_directory is not None:
embedding_name = word[len(self.embedding_identifier):].strip('\n')
embed, leftover = self._try_get_embedding(embedding_name)
embed, embedding_name, leftover = self._try_get_embedding(embedding_name)
if embed is None:
logging.warning(f"warning, embedding:{embedding_name} does not exist, ignoring")
else:

View File

@ -937,22 +937,41 @@ class BaseGenerate:
return torch.argmax(logits, dim=-1, keepdim=True)
# Sampling mode
if repetition_penalty != 1.0:
for i in range(logits.shape[0]):
for token_id in set(token_history):
logits[i, token_id] *= repetition_penalty if logits[i, token_id] < 0 else 1/repetition_penalty
if presence_penalty is not None and presence_penalty != 0.0:
for i in range(logits.shape[0]):
for token_id in set(token_history):
logits[i, token_id] -= presence_penalty
if len(token_history) > 0 and (repetition_penalty != 1.0 or (presence_penalty is not None and presence_penalty != 0.0)):
token_ids = torch.tensor(list(set(token_history)), device=logits.device)
token_logits = logits[:, token_ids]
if repetition_penalty != 1.0:
token_logits = torch.where(token_logits < 0, token_logits * repetition_penalty, token_logits / repetition_penalty)
if presence_penalty is not None and presence_penalty != 0.0:
token_logits = token_logits - presence_penalty
logits[:, token_ids] = token_logits
if temperature != 1.0:
logits = logits / temperature
if top_k > 0:
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = torch.finfo(logits.dtype).min
top_k = min(top_k, logits.shape[-1])
logits, top_indices = torch.topk(logits, top_k)
if min_p > 0.0:
probs_before_filter = torch.nn.functional.softmax(logits, dim=-1)
top_probs, _ = probs_before_filter.max(dim=-1, keepdim=True)
min_threshold = min_p * top_probs
indices_to_remove = probs_before_filter < min_threshold
logits[indices_to_remove] = torch.finfo(logits.dtype).min
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 0] = False
indices_to_remove = torch.zeros_like(logits, dtype=torch.bool)
indices_to_remove.scatter_(1, sorted_indices, sorted_indices_to_remove)
logits[indices_to_remove] = torch.finfo(logits.dtype).min
probs = torch.nn.functional.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1, generator=generator)
return top_indices.gather(1, next_token)
if min_p > 0.0:
probs_before_filter = torch.nn.functional.softmax(logits, dim=-1)

View File

@ -1,6 +1,5 @@
import asyncio
import bisect
import gc
import itertools
import psutil
import time
@ -529,38 +528,6 @@ class RAMPressureCache(LRUCache):
if psutil.virtual_memory().available >= target:
return
def remove_cache_key(key):
del self.cache[key]
self.used_generation.pop(key, None)
self.timestamps.pop(key, None)
self.children.pop(key, None)
def has_old_model_patcher(outputs):
if outputs is None:
return False
for output in outputs:
if isinstance(output, (list, tuple)):
if has_old_model_patcher(output):
return True
elif isinstance(output, ModelPatcher):
return True
return False
old_modelpatcher_keys = []
for key, cache_entry in self.cache.items():
if self.used_generation[key] == self.generation:
continue
if has_old_model_patcher(cache_entry.outputs):
old_modelpatcher_keys.append(key)
for key in old_modelpatcher_keys:
remove_cache_key(key)
if old_modelpatcher_keys:
gc.collect()
if psutil.virtual_memory().available >= target:
return
clean_list = []
for key, cache_entry in self.cache.items():
@ -578,17 +545,19 @@ class RAMPressureCache(LRUCache):
scan_list_for_ram_usage(output)
elif isinstance(output, torch.Tensor) and output.device.type == 'cpu':
ram_usage += output.numel() * output.element_size()
elif isinstance(output, ModelPatcher) and self.used_generation[key] != self.generation:
#old ModelPatchers are the first to go
ram_usage = 1e30
scan_list_for_ram_usage(cache_entry.outputs)
oom_score *= ram_usage
#In the case where we have no information on the node ram usage at all,
#break OOM score ties on the last touch timestamp (pure LRU)
bisect.insort(clean_list, (oom_score, self.timestamps[key], ram_usage, key))
bisect.insort(clean_list, (oom_score, self.timestamps[key], key))
to_free = target - psutil.virtual_memory().available
while to_free > 0 and clean_list:
_, _, ram_usage, key = clean_list.pop()
remove_cache_key(key)
to_free -= ram_usage
gc.collect()
while psutil.virtual_memory().available < target and clean_list:
_, _, key = clean_list.pop()
del self.cache[key]
self.used_generation.pop(key, None)
self.timestamps.pop(key, None)
self.children.pop(key, None)

View File

@ -16,23 +16,30 @@ class ColorToRGBInt(io.ComfyNode):
],
outputs=[
io.Int.Output(display_name="rgb_int"),
io.Color.Output(display_name="hex")
io.Color.Output(display_name="hex"),
io.Float.Output(display_name="alpha"),
],
)
@classmethod
def execute(cls, color: str) -> io.NodeOutput:
# expect format #RRGGBB
if len(color) != 7 or color[0] != "#":
raise ValueError("Color must be in format #RRGGBB")
# expect format #RRGGBB or #RRGGBBAA
if len(color) not in (7, 9) or color[0] != "#":
raise ValueError("Color must be in format #RRGGBB or #RRGGBBAA")
try:
int(color[1:], 16)
except ValueError:
raise ValueError("Color must be in format #RRGGBB") from None
raise ValueError("Color must be in format #RRGGBB or #RRGGBBAA") from None
alpha = 1.0
if len(color) == 9:
alpha = int(color[7:9], 16) / 255.0
color = color[:7]
r, g, b = hex_to_rgb(color)
rgb_int = r * 256 * 256 + g * 256 + b
return io.NodeOutput(rgb_int, color)
return io.NodeOutput(rgb_int, color, alpha)
class ColorExtension(ComfyExtension):

View File

@ -314,7 +314,7 @@ def prompt_worker(q, server_instance):
cache_ram = 0
cache_ram_inactive = 0
if not args.cache_classic and not args.cache_none and args.cache_lru <= 0:
cache_ram = min(10.0, max(1.5, comfy.model_management.total_ram * 0.05 / 1024.0))
cache_ram = min(10.0, max(2.0, comfy.model_management.total_ram * 0.10 / 1024.0))
cache_ram_inactive = min(96.0, comfy.model_management.total_ram / 1024.0)
if len(args.cache_ram) > 0:
cache_ram = args.cache_ram[0]