mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-07-06 04:17:48 +08:00
Compare commits
2 Commits
comfyanony
...
mem_fix_at
| Author | SHA1 | Date | |
|---|---|---|---|
| a78019266f | |||
| f5c4bb1f02 |
@ -4,12 +4,12 @@ early_access: false
|
||||
tone_instructions: "Only comment on issues introduced by this PR's changes. Do not flag pre-existing problems in moved, re-indented, or reformatted code."
|
||||
|
||||
reviews:
|
||||
profile: "assertive"
|
||||
request_changes_workflow: true
|
||||
profile: "chill"
|
||||
request_changes_workflow: false
|
||||
high_level_summary: false
|
||||
poem: false
|
||||
review_status: false
|
||||
review_details: true
|
||||
review_details: false
|
||||
commit_status: true
|
||||
collapse_walkthrough: true
|
||||
changed_files_summary: false
|
||||
@ -39,14 +39,6 @@ reviews:
|
||||
- path: "**"
|
||||
instructions: |
|
||||
IMPORTANT: Only comment on issues directly introduced by this PR's code changes.
|
||||
Treat AGENTS.md as mandatory repository policy, not optional style guidance.
|
||||
Flag PR changes that violate AGENTS.md even when the code is otherwise functional.
|
||||
In particular, enforce architecture boundaries, dtype/device/memory rules,
|
||||
interface contracts, import style, no unnecessary try/except blocks, no inline
|
||||
imports, no outbound internet paths in core ComfyUI, and narrow scoped fixes.
|
||||
Prefer direct findings over suggestions when a rule is violated. Only ignore
|
||||
AGENTS.md when it clearly conflicts with a newer explicit maintainer instruction
|
||||
in the PR.
|
||||
Do NOT flag pre-existing issues in code that was merely moved, re-indented,
|
||||
de-indented, or reformatted without logic changes. If code appears in the diff
|
||||
only due to whitespace or structural reformatting (e.g., removing a `with:` block),
|
||||
@ -131,10 +123,5 @@ chat:
|
||||
|
||||
knowledge_base:
|
||||
opt_out: false
|
||||
code_guidelines:
|
||||
enabled: true
|
||||
filePatterns:
|
||||
- files: "AGENTS.md"
|
||||
applyTo: "**"
|
||||
learnings:
|
||||
scope: "auto"
|
||||
|
||||
@ -543,24 +543,18 @@ class SDTokenizer:
|
||||
def _try_get_embedding(self, embedding_name:str):
|
||||
'''
|
||||
Takes a potential embedding name and tries to retrieve it.
|
||||
Returns a Tuple consisting of the embedding, the cleaned embedding name, and any leftover string, embedding can be None.
|
||||
Returns a Tuple consisting of the embedding and any leftover string, embedding can be None.
|
||||
'''
|
||||
split_embed = embedding_name.split()
|
||||
embedding_name = split_embed[0]
|
||||
leftover = ' '.join(split_embed[1:])
|
||||
|
||||
match = re.search(r'[<\[]', embedding_name)
|
||||
if match is not None:
|
||||
leftover = embedding_name[match.start():] + (" " + leftover if leftover else "")
|
||||
embedding_name = embedding_name[:match.start()]
|
||||
|
||||
embed = load_embed(embedding_name, self.embedding_directory, self.embedding_size, self.embedding_key)
|
||||
if embed is None:
|
||||
stripped = embedding_name.strip(',')
|
||||
if len(stripped) < len(embedding_name):
|
||||
embed = load_embed(stripped, self.embedding_directory, self.embedding_size, self.embedding_key)
|
||||
return (embed, embedding_name, "{} {}".format(embedding_name[len(stripped):], leftover))
|
||||
return (embed, embedding_name, leftover)
|
||||
return (embed, "{} {}".format(embedding_name[len(stripped):], leftover))
|
||||
return (embed, leftover)
|
||||
|
||||
def pad_tokens(self, tokens, amount):
|
||||
if self.pad_left:
|
||||
@ -591,7 +585,7 @@ class SDTokenizer:
|
||||
tokens = []
|
||||
for weighted_segment, weight in parsed_weights:
|
||||
to_tokenize = unescape_important(weighted_segment)
|
||||
split = re.split(r'(?<=\s){}'.format(re.escape(self.embedding_identifier)), to_tokenize)
|
||||
split = re.split(' {0}|\n{0}'.format(self.embedding_identifier), to_tokenize)
|
||||
to_tokenize = [split[0]]
|
||||
for i in range(1, len(split)):
|
||||
to_tokenize.append("{}{}".format(self.embedding_identifier, split[i]))
|
||||
@ -601,7 +595,7 @@ class SDTokenizer:
|
||||
# if we find an embedding, deal with the embedding
|
||||
if word.startswith(self.embedding_identifier) and self.embedding_directory is not None:
|
||||
embedding_name = word[len(self.embedding_identifier):].strip('\n')
|
||||
embed, embedding_name, leftover = self._try_get_embedding(embedding_name)
|
||||
embed, leftover = self._try_get_embedding(embedding_name)
|
||||
if embed is None:
|
||||
logging.warning(f"warning, embedding:{embedding_name} does not exist, ignoring")
|
||||
else:
|
||||
|
||||
@ -937,41 +937,22 @@ class BaseGenerate:
|
||||
return torch.argmax(logits, dim=-1, keepdim=True)
|
||||
|
||||
# Sampling mode
|
||||
if len(token_history) > 0 and (repetition_penalty != 1.0 or (presence_penalty is not None and presence_penalty != 0.0)):
|
||||
token_ids = torch.tensor(list(set(token_history)), device=logits.device)
|
||||
token_logits = logits[:, token_ids]
|
||||
if repetition_penalty != 1.0:
|
||||
token_logits = torch.where(token_logits < 0, token_logits * repetition_penalty, token_logits / repetition_penalty)
|
||||
if presence_penalty is not None and presence_penalty != 0.0:
|
||||
token_logits = token_logits - presence_penalty
|
||||
logits[:, token_ids] = token_logits
|
||||
if repetition_penalty != 1.0:
|
||||
for i in range(logits.shape[0]):
|
||||
for token_id in set(token_history):
|
||||
logits[i, token_id] *= repetition_penalty if logits[i, token_id] < 0 else 1/repetition_penalty
|
||||
|
||||
if presence_penalty is not None and presence_penalty != 0.0:
|
||||
for i in range(logits.shape[0]):
|
||||
for token_id in set(token_history):
|
||||
logits[i, token_id] -= presence_penalty
|
||||
|
||||
if temperature != 1.0:
|
||||
logits = logits / temperature
|
||||
|
||||
if top_k > 0:
|
||||
top_k = min(top_k, logits.shape[-1])
|
||||
logits, top_indices = torch.topk(logits, top_k)
|
||||
|
||||
if min_p > 0.0:
|
||||
probs_before_filter = torch.nn.functional.softmax(logits, dim=-1)
|
||||
top_probs, _ = probs_before_filter.max(dim=-1, keepdim=True)
|
||||
min_threshold = min_p * top_probs
|
||||
indices_to_remove = probs_before_filter < min_threshold
|
||||
logits[indices_to_remove] = torch.finfo(logits.dtype).min
|
||||
|
||||
if top_p < 1.0:
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
||||
cumulative_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
|
||||
sorted_indices_to_remove = cumulative_probs > top_p
|
||||
sorted_indices_to_remove[..., 0] = False
|
||||
indices_to_remove = torch.zeros_like(logits, dtype=torch.bool)
|
||||
indices_to_remove.scatter_(1, sorted_indices, sorted_indices_to_remove)
|
||||
logits[indices_to_remove] = torch.finfo(logits.dtype).min
|
||||
|
||||
probs = torch.nn.functional.softmax(logits, dim=-1)
|
||||
next_token = torch.multinomial(probs, num_samples=1, generator=generator)
|
||||
return top_indices.gather(1, next_token)
|
||||
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
||||
logits[indices_to_remove] = torch.finfo(logits.dtype).min
|
||||
|
||||
if min_p > 0.0:
|
||||
probs_before_filter = torch.nn.functional.softmax(logits, dim=-1)
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
import bisect
|
||||
import gc
|
||||
import itertools
|
||||
import psutil
|
||||
import time
|
||||
@ -528,6 +529,38 @@ class RAMPressureCache(LRUCache):
|
||||
if psutil.virtual_memory().available >= target:
|
||||
return
|
||||
|
||||
def remove_cache_key(key):
|
||||
del self.cache[key]
|
||||
self.used_generation.pop(key, None)
|
||||
self.timestamps.pop(key, None)
|
||||
self.children.pop(key, None)
|
||||
|
||||
def has_old_model_patcher(outputs):
|
||||
if outputs is None:
|
||||
return False
|
||||
for output in outputs:
|
||||
if isinstance(output, (list, tuple)):
|
||||
if has_old_model_patcher(output):
|
||||
return True
|
||||
elif isinstance(output, ModelPatcher):
|
||||
return True
|
||||
return False
|
||||
|
||||
old_modelpatcher_keys = []
|
||||
for key, cache_entry in self.cache.items():
|
||||
if self.used_generation[key] == self.generation:
|
||||
continue
|
||||
if has_old_model_patcher(cache_entry.outputs):
|
||||
old_modelpatcher_keys.append(key)
|
||||
|
||||
for key in old_modelpatcher_keys:
|
||||
remove_cache_key(key)
|
||||
|
||||
if old_modelpatcher_keys:
|
||||
gc.collect()
|
||||
if psutil.virtual_memory().available >= target:
|
||||
return
|
||||
|
||||
clean_list = []
|
||||
|
||||
for key, cache_entry in self.cache.items():
|
||||
@ -545,19 +578,17 @@ class RAMPressureCache(LRUCache):
|
||||
scan_list_for_ram_usage(output)
|
||||
elif isinstance(output, torch.Tensor) and output.device.type == 'cpu':
|
||||
ram_usage += output.numel() * output.element_size()
|
||||
elif isinstance(output, ModelPatcher) and self.used_generation[key] != self.generation:
|
||||
#old ModelPatchers are the first to go
|
||||
ram_usage = 1e30
|
||||
scan_list_for_ram_usage(cache_entry.outputs)
|
||||
|
||||
oom_score *= ram_usage
|
||||
#In the case where we have no information on the node ram usage at all,
|
||||
#break OOM score ties on the last touch timestamp (pure LRU)
|
||||
bisect.insort(clean_list, (oom_score, self.timestamps[key], key))
|
||||
bisect.insort(clean_list, (oom_score, self.timestamps[key], ram_usage, key))
|
||||
|
||||
while psutil.virtual_memory().available < target and clean_list:
|
||||
_, _, key = clean_list.pop()
|
||||
del self.cache[key]
|
||||
self.used_generation.pop(key, None)
|
||||
self.timestamps.pop(key, None)
|
||||
self.children.pop(key, None)
|
||||
to_free = target - psutil.virtual_memory().available
|
||||
while to_free > 0 and clean_list:
|
||||
_, _, ram_usage, key = clean_list.pop()
|
||||
remove_cache_key(key)
|
||||
to_free -= ram_usage
|
||||
|
||||
gc.collect()
|
||||
|
||||
@ -16,30 +16,23 @@ class ColorToRGBInt(io.ComfyNode):
|
||||
],
|
||||
outputs=[
|
||||
io.Int.Output(display_name="rgb_int"),
|
||||
io.Color.Output(display_name="hex"),
|
||||
io.Float.Output(display_name="alpha"),
|
||||
io.Color.Output(display_name="hex")
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, color: str) -> io.NodeOutput:
|
||||
# expect format #RRGGBB or #RRGGBBAA
|
||||
if len(color) not in (7, 9) or color[0] != "#":
|
||||
raise ValueError("Color must be in format #RRGGBB or #RRGGBBAA")
|
||||
# expect format #RRGGBB
|
||||
if len(color) != 7 or color[0] != "#":
|
||||
raise ValueError("Color must be in format #RRGGBB")
|
||||
try:
|
||||
int(color[1:], 16)
|
||||
except ValueError:
|
||||
raise ValueError("Color must be in format #RRGGBB or #RRGGBBAA") from None
|
||||
|
||||
alpha = 1.0
|
||||
if len(color) == 9:
|
||||
alpha = int(color[7:9], 16) / 255.0
|
||||
color = color[:7]
|
||||
|
||||
raise ValueError("Color must be in format #RRGGBB") from None
|
||||
r, g, b = hex_to_rgb(color)
|
||||
|
||||
rgb_int = r * 256 * 256 + g * 256 + b
|
||||
return io.NodeOutput(rgb_int, color, alpha)
|
||||
return io.NodeOutput(rgb_int, color)
|
||||
|
||||
|
||||
class ColorExtension(ComfyExtension):
|
||||
|
||||
2
main.py
2
main.py
@ -314,7 +314,7 @@ def prompt_worker(q, server_instance):
|
||||
cache_ram = 0
|
||||
cache_ram_inactive = 0
|
||||
if not args.cache_classic and not args.cache_none and args.cache_lru <= 0:
|
||||
cache_ram = min(10.0, max(2.0, comfy.model_management.total_ram * 0.10 / 1024.0))
|
||||
cache_ram = min(10.0, max(1.5, comfy.model_management.total_ram * 0.05 / 1024.0))
|
||||
cache_ram_inactive = min(96.0, comfy.model_management.total_ram / 1024.0)
|
||||
if len(args.cache_ram) > 0:
|
||||
cache_ram = args.cache_ram[0]
|
||||
|
||||
Reference in New Issue
Block a user