Revert "execution: fold in dependency aware caching / Fix --cache-none with l…" (#10422)
This reverts commit b1467da480.
This commit is contained in:
34
execution.py
34
execution.py
@@ -18,7 +18,7 @@ from comfy_execution.caching import (
|
||||
BasicCache,
|
||||
CacheKeySetID,
|
||||
CacheKeySetInputSignature,
|
||||
NullCache,
|
||||
DependencyAwareCache,
|
||||
HierarchicalCache,
|
||||
LRUCache,
|
||||
)
|
||||
@@ -91,13 +91,13 @@ class IsChangedCache:
|
||||
class CacheType(Enum):
|
||||
CLASSIC = 0
|
||||
LRU = 1
|
||||
NONE = 2
|
||||
DEPENDENCY_AWARE = 2
|
||||
|
||||
|
||||
class CacheSet:
|
||||
def __init__(self, cache_type=None, cache_size=None):
|
||||
if cache_type == CacheType.NONE:
|
||||
self.init_null_cache()
|
||||
if cache_type == CacheType.DEPENDENCY_AWARE:
|
||||
self.init_dependency_aware_cache()
|
||||
logging.info("Disabling intermediate node cache.")
|
||||
elif cache_type == CacheType.LRU:
|
||||
if cache_size is None:
|
||||
@@ -120,12 +120,11 @@ class CacheSet:
|
||||
self.ui = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
|
||||
self.objects = HierarchicalCache(CacheKeySetID)
|
||||
|
||||
def init_null_cache(self):
|
||||
self.outputs = NullCache()
|
||||
#The UI cache is expected to be iterable at the end of each workflow
|
||||
#so it must cache at least a full workflow. Use Heirachical
|
||||
self.ui = HierarchicalCache(CacheKeySetInputSignature)
|
||||
self.objects = NullCache()
|
||||
# only hold cached items while the decendents have not executed
|
||||
def init_dependency_aware_cache(self):
|
||||
self.outputs = DependencyAwareCache(CacheKeySetInputSignature)
|
||||
self.ui = DependencyAwareCache(CacheKeySetInputSignature)
|
||||
self.objects = DependencyAwareCache(CacheKeySetID)
|
||||
|
||||
def recursive_debug_dump(self):
|
||||
result = {
|
||||
@@ -136,7 +135,7 @@ class CacheSet:
|
||||
|
||||
SENSITIVE_EXTRA_DATA_KEYS = ("auth_token_comfy_org", "api_key_comfy_org")
|
||||
|
||||
def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=None, extra_data={}):
|
||||
def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, extra_data={}):
|
||||
is_v3 = issubclass(class_def, _ComfyNodeInternal)
|
||||
if is_v3:
|
||||
valid_inputs, schema = class_def.INPUT_TYPES(include_hidden=False, return_schema=True)
|
||||
@@ -154,10 +153,10 @@ def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=
|
||||
if is_link(input_data) and (not input_info or not input_info.get("rawLink", False)):
|
||||
input_unique_id = input_data[0]
|
||||
output_index = input_data[1]
|
||||
if execution_list is None:
|
||||
if outputs is None:
|
||||
mark_missing()
|
||||
continue # This might be a lazily-evaluated input
|
||||
cached_output = execution_list.get_output_cache(input_unique_id, unique_id)
|
||||
cached_output = outputs.get(input_unique_id)
|
||||
if cached_output is None:
|
||||
mark_missing()
|
||||
continue
|
||||
@@ -406,7 +405,6 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
||||
cached_output = caches.ui.get(unique_id) or {}
|
||||
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_output.get("output",None), "prompt_id": prompt_id }, server.client_id)
|
||||
get_progress_state().finish_progress(unique_id)
|
||||
execution_list.cache_update(unique_id, caches.outputs.get(unique_id))
|
||||
return (ExecutionResult.SUCCESS, None, None)
|
||||
|
||||
input_data_all = None
|
||||
@@ -436,7 +434,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
||||
for r in result:
|
||||
if is_link(r):
|
||||
source_node, source_output = r[0], r[1]
|
||||
node_output = execution_list.get_output_cache(source_node, unique_id)[source_output]
|
||||
node_output = caches.outputs.get(source_node)[source_output]
|
||||
for o in node_output:
|
||||
resolved_output.append(o)
|
||||
|
||||
@@ -448,7 +446,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
||||
has_subgraph = False
|
||||
else:
|
||||
get_progress_state().start_progress(unique_id)
|
||||
input_data_all, missing_keys, hidden_inputs = get_input_data(inputs, class_def, unique_id, execution_list, dynprompt, extra_data)
|
||||
input_data_all, missing_keys, hidden_inputs = get_input_data(inputs, class_def, unique_id, caches.outputs, dynprompt, extra_data)
|
||||
if server.client_id is not None:
|
||||
server.last_node_id = display_node_id
|
||||
server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id)
|
||||
@@ -551,15 +549,11 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
||||
subcache.clean_unused()
|
||||
for node_id in new_output_ids:
|
||||
execution_list.add_node(node_id)
|
||||
execution_list.cache_link(node_id, unique_id)
|
||||
for link in new_output_links:
|
||||
execution_list.add_strong_link(link[0], link[1], unique_id)
|
||||
pending_subgraph_results[unique_id] = cached_outputs
|
||||
return (ExecutionResult.PENDING, None, None)
|
||||
|
||||
caches.outputs.set(unique_id, output_data)
|
||||
execution_list.cache_update(unique_id, output_data)
|
||||
|
||||
except comfy.model_management.InterruptProcessingException as iex:
|
||||
logging.info("Processing interrupted")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user