feat(api-nodes): add price extractor feature; small fixes to Kling & Pika nodes (#10284)

This commit is contained in:
Alexander Piskun
2025-10-11 02:21:40 +03:00
committed by GitHub
parent aa895db7e8
commit 14d642acd6
3 changed files with 33 additions and 17 deletions

View File

@@ -782,9 +782,11 @@ class PollingOperation(Generic[T, R]):
poll_endpoint: ApiEndpoint[EmptyRequest, R],
completed_statuses: list[str],
failed_statuses: list[str],
*,
status_extractor: Callable[[R], Optional[str]],
progress_extractor: Callable[[R], Optional[float]] | None = None,
result_url_extractor: Callable[[R], Optional[str]] | None = None,
price_extractor: Callable[[R], Optional[float]] | None = None,
request: Optional[T] = None,
api_base: str | None = None,
auth_token: Optional[str] = None,
@@ -815,10 +817,12 @@ class PollingOperation(Generic[T, R]):
self.status_extractor = status_extractor or (lambda x: getattr(x, "status", None))
self.progress_extractor = progress_extractor
self.result_url_extractor = result_url_extractor
self.price_extractor = price_extractor
self.node_id = node_id
self.completed_statuses = completed_statuses
self.failed_statuses = failed_statuses
self.final_response: Optional[R] = None
self.extracted_price: Optional[float] = None
async def execute(self, client: Optional[ApiClient] = None) -> R:
owns_client = client is None
@@ -840,6 +844,8 @@ class PollingOperation(Generic[T, R]):
def _display_text_on_node(self, text: str):
if not self.node_id:
return
if self.extracted_price is not None:
text = f"Price: {self.extracted_price}$\n{text}"
PromptServer.instance.send_progress_text(text, self.node_id)
def _display_time_progress_on_node(self, time_completed: int | float):
@@ -877,9 +883,7 @@ class PollingOperation(Generic[T, R]):
try:
logging.debug("[DEBUG] Polling attempt #%s", poll_count)
request_dict = (
None if self.request is None else self.request.model_dump(exclude_none=True)
)
request_dict = None if self.request is None else self.request.model_dump(exclude_none=True)
if poll_count == 1:
logging.debug(
@@ -912,6 +916,11 @@ class PollingOperation(Generic[T, R]):
if new_progress is not None:
progress.update_absolute(new_progress, total=PROGRESS_BAR_MAX)
if self.price_extractor:
price = self.price_extractor(response_obj)
if price is not None:
self.extracted_price = price
if status == TaskStatus.COMPLETED:
message = "Task completed successfully"
if self.result_url_extractor: