Skip to content

Commit ca33f43

Browse files
[bugfix] enable faster rcnn and sd model with oneflow backend
1 parent cb03b91 commit ca33f43

File tree

3 files changed

+31
-6
lines changed

3 files changed

+31
-6
lines changed

python/oneflow/framework/infer_compiler/import_tools/format_utils.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,13 @@ def _format_full_class_name(self, obj: Union[str, type, FunctionType]):
4040

4141
elif isinstance(obj, FunctionType):
4242
module = inspect.getmodule(obj)
43-
obj = f"{module.__name__}.{obj.__qualname__}"
43+
if (
44+
module.__name__ == "torch.nn.functional"
45+
and obj.__qualname__ == "boolean_dispatch.<locals>.fn"
46+
):
47+
obj = f"{module.__name__}.{obj.__name__}"
48+
else:
49+
obj = f"{module.__name__}.{obj.__qualname__}"
4450

4551
assert isinstance(obj, str), f"obj must be str, but got {type(obj)}"
4652

python/oneflow/framework/infer_compiler/with_fx_graph.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ def fx_node_tranform(gm):
3838
# Align this with env setting in `with_oneflow_compile`.
3939
# Otherwise, infererence using PyTorch with OneFlow backend on
4040
# multiple input shapes may crash
41-
os.environ.setdefault("ONEFLOW_RUN_GRAPH_BY_VM", "1")
4241
os.environ.setdefault("ONEFLOW_GRAPH_DELAY_VARIABLE_OP_EXECUTION", "1")
4342
os.environ.setdefault("ONEFLOW_MLIR_CSE", "1")
4443
os.environ.setdefault("ONEFLOW_MLIR_ENABLE_INFERENCE_OPTIMIZATION", "1")
@@ -63,16 +62,22 @@ def fx_node_tranform(gm):
6362
os.environ.setdefault("ONEFLOW_MLIR_GROUP_MATMUL_QUANT", "1")
6463

6564
class OfGraph(flow.nn.Graph):
65+
@flow.nn.Graph.with_dynamic_input_shape()
6666
def __init__(self):
6767
super().__init__()
6868
self.fx_md = of_gm
6969
self.config.enable_cudnn_conv_heuristic_search_algo(False)
7070
self.config.allow_fuse_add_to_output(True)
7171

7272
def build(self, *args, **kwargs):
73-
return self.fx_md(*args, **kwargs)
73+
if self.fx_md.training:
74+
return self.fx_md(*args, **kwargs)
75+
with flow.no_grad():
76+
return self.fx_md(*args, **kwargs)
7477

7578
of_g = OfGraph()
79+
of_g._dynamic_input_graph_cache.set_cache_size(9)
80+
of_g._dynamic_input_graph_cache.enable_shared(True)
7681
oneflow_fn = lambda *args, **kwargs: of_g(*args, **kwargs)
7782

7883
return oneflow_fn

python/oneflow/framework/infer_compiler/with_oneflow_backend.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,22 @@ def input_fn(value):
4646
)
4747
else:
4848
output = transformed_fn(*args, **kwargs)
49-
if isinstance(output, tuple):
50-
return tuple(flow.utils.tensor.to_torch(i) for i in output)
51-
return flow.utils.tensor.to_torch(output)
49+
50+
def output_fn(value):
51+
if isinstance(value, flow.Tensor):
52+
return flow.utils.tensor.to_torch(value)
53+
else:
54+
return value
55+
56+
if isinstance(output, (tuple, list, flow._oneflow_internal.TensorTuple)):
57+
return tuple(output_fn(i) for i in output)
58+
elif isinstance(output, dict):
59+
return {k: output_fn(v) for (k, v) in output.items()}
60+
elif isinstance(output, flow.Tensor):
61+
return output_fn(output)
62+
else:
63+
raise NotImplementedError(
64+
f"How to handle {type(output)} output type is not implemented"
65+
)
5266

5367
return wrapped_forward

0 commit comments

Comments
 (0)