diff --git a/.changeset/twenty-waves-reply.md b/.changeset/twenty-waves-reply.md new file mode 100644 index 0000000000000..d3d2cd56c4385 --- /dev/null +++ b/.changeset/twenty-waves-reply.md @@ -0,0 +1,5 @@ +--- +"gradio": patch +--- + +fix:Fix typing for components in `gr.Interface` and docstring in `image.py` diff --git a/gradio/components/image.py b/gradio/components/image.py index 28c90d1a5362c..da1a0534c784b 100644 --- a/gradio/components/image.py +++ b/gradio/components/image.py @@ -117,7 +117,7 @@ def __init__( every: Continously calls `value` to recalculate it if `value` is a function (has no effect otherwise). Can provide a Timer whose tick resets `value`, or a float that provides the regular interval for the reset Timer. inputs: Components that are used as inputs to calculate `value` if `value` is a function (has no effect otherwise). `value` is recalculated any time the inputs change. show_label: if True, will display label. - show_download_button: If True, will display button to download image. + show_download_button: If True, will display button to download image. Only applies if interactive is False (e.g. if the component is used as an output). container: If True, will place the component in a container - providing some extra padding around the border. scale: relative size compared to adjacent Components. For example if Components A and B are in a Row, and A has scale=2, and B has scale=1, A will be twice as wide as B. Should be an integer. scale applies in Rows, and to top-level Components in Blocks where fill_height=True. min_width: minimum pixel width, will wrap if not sufficient screen space to satisfy this value. If a certain scale value results in this Component being narrower than min_width, the min_width parameter will be respected first. diff --git a/gradio/interface.py b/gradio/interface.py index 677dfac307531..fe10d63fc662f 100644 --- a/gradio/interface.py +++ b/gradio/interface.py @@ -222,28 +222,43 @@ def __init__( get_component_instance(i, unrender=True) for i in additional_inputs ] - if not isinstance(inputs, (str, list, Component)): + if not isinstance(inputs, (Sequence, Component)): raise TypeError( f"inputs must be a string, list, or Component, not {inputs}" ) - if not isinstance(outputs, (str, list, Component)): + if not isinstance(outputs, (Sequence, Component)): raise TypeError( f"outputs must be a string, list, or Component, not {outputs}" ) - if not isinstance(inputs, list): + if isinstance(inputs, (str, Component)): inputs = [inputs] - if not isinstance(outputs, list): + if isinstance(outputs, (str, Component)): outputs = [outputs] self.cache_examples = cache_examples self.cache_mode: Literal["eager", "lazy"] | None = cache_mode + self.main_input_components = [ + get_component_instance(i, unrender=True) for i in inputs + ] + self.input_components = ( + self.main_input_components + self.additional_input_components + ) + self.output_components = [ + get_component_instance(o, unrender=True) + for o in outputs # type: ignore + ] + state_input_indexes = [ - idx for idx, i in enumerate(inputs) if i == "state" or isinstance(i, State) + idx + for idx, i in enumerate(self.input_components) + if i == "state" or isinstance(i, State) ] state_output_indexes = [ - idx for idx, o in enumerate(outputs) if o == "state" or isinstance(o, State) + idx + for idx, o in enumerate(self.output_components) + if o == "state" or isinstance(o, State) ] if len(state_input_indexes) == 0 and len(state_output_indexes) == 0: @@ -255,14 +270,14 @@ def __init__( else: state_input_index = state_input_indexes[0] state_output_index = state_output_indexes[0] - if inputs[state_input_index] == "state": + if self.input_components[state_input_index] == "state": default = utils.get_default_args(fn)[state_input_index] state_variable = State(value=default) else: - state_variable = inputs[state_input_index] + state_variable = self.input_components[state_input_index] - inputs[state_input_index] = state_variable - outputs[state_output_index] = state_variable + self.input_components[state_input_index] = state_variable + self.output_components[state_output_index] = state_variable if cache_examples: warnings.warn( @@ -271,10 +286,6 @@ def __init__( ) self.cache_examples = False - self.main_input_components = [ - get_component_instance(i, unrender=True) for i in inputs - ] - if additional_inputs_accordion is None: self.additional_inputs_accordion_params = { "label": "Additional Inputs", @@ -294,13 +305,6 @@ def __init__( raise ValueError( f"The `additional_inputs_accordion` parameter must be a string or gr.Accordion, not {type(additional_inputs_accordion)}" ) - self.input_components = ( - self.main_input_components + self.additional_input_components - ) - self.output_components = [ - get_component_instance(o, unrender=True) - for o in outputs # type: ignore - ] for component in self.input_components + self.output_components: if not (isinstance(component, Component)): diff --git a/test/test_interfaces.py b/test/test_interfaces.py index e304598e239b5..06cb1626ef873 100644 --- a/test/test_interfaces.py +++ b/test/test_interfaces.py @@ -69,7 +69,7 @@ def test(parameter_name1, parameter_name2): t = Textbox() i = Image() - Interface(test, [t, i], "text") + Interface(test, (t, i), "text") assert t.label == "parameter_name1" assert i.label == "parameter_name2"