feat: add curve inputs and raise uniform limit for GLSL shader node (#13158)
* feat: add curve inputs and raise uniform limit for GLSL shader node * allow arbitrary size for curve
This commit is contained in:
@@ -87,7 +87,9 @@ class SizeModeInput(TypedDict):
|
||||
|
||||
|
||||
MAX_IMAGES = 5 # u_image0-4
|
||||
MAX_UNIFORMS = 5 # u_float0-4, u_int0-4
|
||||
MAX_UNIFORMS = 20 # u_float0-19, u_int0-19
|
||||
MAX_BOOLS = 10 # u_bool0-9
|
||||
MAX_CURVES = 4 # u_curve0-3 (1D LUT textures)
|
||||
MAX_OUTPUTS = 4 # fragColor0-3 (MRT)
|
||||
|
||||
# Vertex shader using gl_VertexID trick - no VBO needed.
|
||||
@@ -497,6 +499,8 @@ def _render_shader_batch(
|
||||
image_batches: list[list[np.ndarray]],
|
||||
floats: list[float],
|
||||
ints: list[int],
|
||||
bools: list[bool] | None = None,
|
||||
curves: list[np.ndarray] | None = None,
|
||||
) -> list[list[np.ndarray]]:
|
||||
"""
|
||||
Render a fragment shader for multiple batches efficiently.
|
||||
@@ -511,6 +515,8 @@ def _render_shader_batch(
|
||||
image_batches: List of batches, each batch is a list of input images (H, W, C) float32 [0,1]
|
||||
floats: List of float uniforms
|
||||
ints: List of int uniforms
|
||||
bools: List of bool uniforms (passed as int 0/1 to GLSL bool uniforms)
|
||||
curves: List of 1D LUT arrays (float32) of arbitrary size for u_curve0-N
|
||||
|
||||
Returns:
|
||||
List of batch outputs, each is a list of output images (H, W, 4) float32 [0,1]
|
||||
@@ -533,11 +539,17 @@ def _render_shader_batch(
|
||||
# Detect multi-pass rendering
|
||||
num_passes = _detect_pass_count(fragment_code)
|
||||
|
||||
if bools is None:
|
||||
bools = []
|
||||
if curves is None:
|
||||
curves = []
|
||||
|
||||
# Track resources for cleanup
|
||||
program = None
|
||||
fbo = None
|
||||
output_textures = []
|
||||
input_textures = []
|
||||
curve_textures = []
|
||||
ping_pong_textures = []
|
||||
ping_pong_fbos = []
|
||||
|
||||
@@ -624,6 +636,28 @@ def _render_shader_batch(
|
||||
if loc >= 0:
|
||||
gl.glUniform1i(loc, v)
|
||||
|
||||
for i, v in enumerate(bools):
|
||||
loc = gl.glGetUniformLocation(program, f"u_bool{i}")
|
||||
if loc >= 0:
|
||||
gl.glUniform1i(loc, 1 if v else 0)
|
||||
|
||||
# Create 1D LUT textures for curves (bound after image texture units)
|
||||
for i, lut in enumerate(curves):
|
||||
tex = gl.glGenTextures(1)
|
||||
curve_textures.append(tex)
|
||||
unit = MAX_IMAGES + i
|
||||
gl.glActiveTexture(gl.GL_TEXTURE0 + unit)
|
||||
gl.glBindTexture(gl.GL_TEXTURE_2D, tex)
|
||||
gl.glTexImage2D(gl.GL_TEXTURE_2D, 0, gl.GL_R32F, len(lut), 1, 0, gl.GL_RED, gl.GL_FLOAT, lut)
|
||||
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MIN_FILTER, gl.GL_LINEAR)
|
||||
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MAG_FILTER, gl.GL_LINEAR)
|
||||
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_S, gl.GL_CLAMP_TO_EDGE)
|
||||
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_T, gl.GL_CLAMP_TO_EDGE)
|
||||
|
||||
loc = gl.glGetUniformLocation(program, f"u_curve{i}")
|
||||
if loc >= 0:
|
||||
gl.glUniform1i(loc, unit)
|
||||
|
||||
# Get u_pass uniform location for multi-pass
|
||||
pass_loc = gl.glGetUniformLocation(program, "u_pass")
|
||||
|
||||
@@ -718,6 +752,8 @@ def _render_shader_batch(
|
||||
|
||||
for tex in input_textures:
|
||||
gl.glDeleteTextures(int(tex))
|
||||
for tex in curve_textures:
|
||||
gl.glDeleteTextures(int(tex))
|
||||
for tex in output_textures:
|
||||
gl.glDeleteTextures(int(tex))
|
||||
for tex in ping_pong_textures:
|
||||
@@ -754,6 +790,20 @@ class GLSLShader(io.ComfyNode):
|
||||
max=MAX_UNIFORMS,
|
||||
)
|
||||
|
||||
bool_template = io.Autogrow.TemplatePrefix(
|
||||
io.Boolean.Input("bool", default=False),
|
||||
prefix="u_bool",
|
||||
min=0,
|
||||
max=MAX_BOOLS,
|
||||
)
|
||||
|
||||
curve_template = io.Autogrow.TemplatePrefix(
|
||||
io.Curve.Input("curve"),
|
||||
prefix="u_curve",
|
||||
min=0,
|
||||
max=MAX_CURVES,
|
||||
)
|
||||
|
||||
return io.Schema(
|
||||
node_id="GLSLShader",
|
||||
display_name="GLSL Shader",
|
||||
@@ -762,6 +812,7 @@ class GLSLShader(io.ComfyNode):
|
||||
"Apply GLSL ES fragment shaders to images. "
|
||||
"u_resolution (vec2) is always available."
|
||||
),
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.String.Input(
|
||||
"fragment_shader",
|
||||
@@ -796,6 +847,8 @@ class GLSLShader(io.ComfyNode):
|
||||
io.Autogrow.Input("images", template=image_template, tooltip=f"Images are available as u_image0-{MAX_IMAGES-1} (sampler2D) in the shader code"),
|
||||
io.Autogrow.Input("floats", template=float_template, tooltip=f"Floats are available as u_float0-{MAX_UNIFORMS-1} in the shader code"),
|
||||
io.Autogrow.Input("ints", template=int_template, tooltip=f"Ints are available as u_int0-{MAX_UNIFORMS-1} in the shader code"),
|
||||
io.Autogrow.Input("bools", template=bool_template, tooltip=f"Booleans are available as u_bool0-{MAX_BOOLS-1} (bool) in the shader code"),
|
||||
io.Autogrow.Input("curves", template=curve_template, tooltip=f"Curves are available as u_curve0-{MAX_CURVES-1} (sampler2D, 1D LUT) in the shader code. Sample with texture(u_curve0, vec2(x, 0.5)).r"),
|
||||
],
|
||||
outputs=[
|
||||
io.Image.Output(display_name="IMAGE0", tooltip="Available via layout(location = 0) out vec4 fragColor0 in the shader code"),
|
||||
@@ -813,13 +866,19 @@ class GLSLShader(io.ComfyNode):
|
||||
images: io.Autogrow.Type,
|
||||
floats: io.Autogrow.Type = None,
|
||||
ints: io.Autogrow.Type = None,
|
||||
bools: io.Autogrow.Type = None,
|
||||
curves: io.Autogrow.Type = None,
|
||||
**kwargs,
|
||||
) -> io.NodeOutput:
|
||||
|
||||
image_list = [v for v in images.values() if v is not None]
|
||||
float_list = (
|
||||
[v if v is not None else 0.0 for v in floats.values()] if floats else []
|
||||
)
|
||||
int_list = [v if v is not None else 0 for v in ints.values()] if ints else []
|
||||
bool_list = [v if v is not None else False for v in bools.values()] if bools else []
|
||||
|
||||
curve_luts = [v.to_lut().astype(np.float32) for v in curves.values() if v is not None] if curves else []
|
||||
|
||||
if not image_list:
|
||||
raise ValueError("At least one input image is required")
|
||||
@@ -846,6 +905,8 @@ class GLSLShader(io.ComfyNode):
|
||||
image_batches,
|
||||
float_list,
|
||||
int_list,
|
||||
bool_list,
|
||||
curve_luts,
|
||||
)
|
||||
|
||||
# Collect outputs into tensors
|
||||
|
||||
Reference in New Issue
Block a user