Skip to content

Commit

Permalink
[bugfix] enable faster rcnn and sd model with oneflow backend
Browse files Browse the repository at this point in the history
  • Loading branch information
crazy-JiangDongHua committed Mar 5, 2024
1 parent cb03b91 commit ca33f43
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,13 @@ def _format_full_class_name(self, obj: Union[str, type, FunctionType]):

elif isinstance(obj, FunctionType):
module = inspect.getmodule(obj)
obj = f"{module.__name__}.{obj.__qualname__}"
if (
module.__name__ == "torch.nn.functional"
and obj.__qualname__ == "boolean_dispatch.<locals>.fn"
):
obj = f"{module.__name__}.{obj.__name__}"
else:
obj = f"{module.__name__}.{obj.__qualname__}"

assert isinstance(obj, str), f"obj must be str, but got {type(obj)}"

Expand Down
9 changes: 7 additions & 2 deletions python/oneflow/framework/infer_compiler/with_fx_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ def fx_node_tranform(gm):
# Align this with env setting in `with_oneflow_compile`.
# Otherwise, infererence using PyTorch with OneFlow backend on
# multiple input shapes may crash
os.environ.setdefault("ONEFLOW_RUN_GRAPH_BY_VM", "1")
os.environ.setdefault("ONEFLOW_GRAPH_DELAY_VARIABLE_OP_EXECUTION", "1")
os.environ.setdefault("ONEFLOW_MLIR_CSE", "1")
os.environ.setdefault("ONEFLOW_MLIR_ENABLE_INFERENCE_OPTIMIZATION", "1")
Expand All @@ -63,16 +62,22 @@ def fx_node_tranform(gm):
os.environ.setdefault("ONEFLOW_MLIR_GROUP_MATMUL_QUANT", "1")

class OfGraph(flow.nn.Graph):
@flow.nn.Graph.with_dynamic_input_shape()
def __init__(self):
super().__init__()
self.fx_md = of_gm
self.config.enable_cudnn_conv_heuristic_search_algo(False)
self.config.allow_fuse_add_to_output(True)

def build(self, *args, **kwargs):
return self.fx_md(*args, **kwargs)
if self.fx_md.training:
return self.fx_md(*args, **kwargs)
with flow.no_grad():
return self.fx_md(*args, **kwargs)

of_g = OfGraph()
of_g._dynamic_input_graph_cache.set_cache_size(9)
of_g._dynamic_input_graph_cache.enable_shared(True)
oneflow_fn = lambda *args, **kwargs: of_g(*args, **kwargs)

return oneflow_fn
Expand Down
20 changes: 17 additions & 3 deletions python/oneflow/framework/infer_compiler/with_oneflow_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,22 @@ def input_fn(value):
)
else:
output = transformed_fn(*args, **kwargs)
if isinstance(output, tuple):
return tuple(flow.utils.tensor.to_torch(i) for i in output)
return flow.utils.tensor.to_torch(output)

def output_fn(value):
if isinstance(value, flow.Tensor):
return flow.utils.tensor.to_torch(value)
else:
return value

if isinstance(output, (tuple, list, flow._oneflow_internal.TensorTuple)):
return tuple(output_fn(i) for i in output)
elif isinstance(output, dict):
return {k: output_fn(v) for (k, v) in output.items()}
elif isinstance(output, flow.Tensor):
return output_fn(output)
else:
raise NotImplementedError(
f"How to handle {type(output)} output type is not implemented"
)

return wrapped_forward

0 comments on commit ca33f43

Please sign in to comment.