Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for dynamic control flow of auto placement. #476

Open
wants to merge 1 commit into
base: dev_mock_pipeline_inference
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 15 additions & 9 deletions projects/mock_transformers/dist_infer_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,12 +107,18 @@ def __init__(self, *args, **kwargs):
)

# generate id
for i in range(100):
with global_mode(True, **placement_sbp_dict):
model = init_env.compile_auto_placement(
model,
input_ids
)
generated_ids = model.generate(input_ids, max_length=30)
out_put_ids = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
print(out_put_ids)
# for i in range(100):

# generated_ids = model.generate(input_ids, max_length=30)
# raise KeyError
with global_mode(True, **placement_sbp_dict):
compiled_model = init_env.compile_auto_placement(
model,
input_ids=input_ids,
)
# print(model.code) # use this to print the compiled module code
generated_ids = compiled_model.run(input_ids)
print(generated_ids)
# generated_ids = model(input_ids)
# out_put_ids = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
# print(generated_ids)
263 changes: 250 additions & 13 deletions projects/mock_transformers/init_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@


import copy # noqa
import sys
sys.path.append("../")
sys.path.append("./")
sys.path.append("./libai/")
import onefx as fx # noqa
from typing import List, Dict, Any # noqa
from oneflow import Tensor, nn # noqa
Expand Down Expand Up @@ -212,10 +216,235 @@ def auto_set_pipeline_stage_id(model, pipeline_parallel_size=1):

# ---------------def fx for auto changing placement ----------------------

import inspect
import math
from typing import Tuple, Dict, Optional, Any, Callable, Union
from copy import deepcopy
import traceback
import builtins

_customized_not_wrapped_oneflow_functions = [
flow.ones_like,
flow.zeros_like,
flow.randn,
flow.randn_like,
flow.randint, flow.randint_like,
flow.device
]

class CustomiziedTracer(fx.Tracer):
def __init__(self, autowrap_modules = (math, ), autowrap_functions: Tuple[Callable, ...] = (), param_shapes_constant: bool = False,
not_wrapped_oneflow_functions=_customized_not_wrapped_oneflow_functions, input_args=None) -> None:
super().__init__(autowrap_modules, autowrap_functions, param_shapes_constant, not_wrapped_oneflow_functions)
self.registered_values = {}
self.args_iter = iter(input_args)

def to_bool(self, obj: fx.Proxy) -> bool: #override
if obj.node.name in self.registered_values:
return self.registered_values[obj.node.name]
return super().to_bool(obj)

def create_proxy(self, kind: str, target, args: Tuple[Any, ...], kwargs: Dict[str, Any],
name: Optional[str] = None, type_expr : Optional[Any] = None,
proxy_factory_fn: Callable[[fx.Node], fx.Proxy] = None): #override
arg_values = []
for i, arg in enumerate(args):
# if isinstance(arg, tuple) and isinstance(arg[0], fx.Proxy) and callable(target)
# and list(inspect.signature(target).parameters.keys())[0].startswith('*'):
# arg = arg[0]
if isinstance(arg, tuple):
has_proxy = [1 if isinstance(item, fx.Proxy) else 0 for item in arg]
has_proxy = sum(has_proxy)
current_arg_value = []
if has_proxy > 0:
for proxy in arg:
if not isinstance(proxy, fx.Proxy):
current_arg_value.append(proxy)
continue
if not proxy.node.name in self.registered_values:
raise ValueError(f"{arg.node.name} cannot be found.")
else:
current_arg_value.append(self.registered_values[proxy.node.name])
arg_values.append(tuple(current_arg_value))
continue
if not isinstance(arg, fx.Proxy):
arg_values.append(arg)
continue
if not arg.node.name in self.registered_values:
raise ValueError(f"{arg.node.name} cannot be found.")
else:
arg_values.append(self.registered_values[arg.node.name])

kwarg_values = {}
for arg_name, arg in kwargs.items():
if isinstance(arg, tuple):
has_proxy = [1 if isinstance(item, fx.Proxy) else 0 for item in arg]
has_proxy = sum(has_proxy)
current_arg_value = []
if has_proxy > 0:
for proxy in arg:
if not isinstance(proxy, fx.Proxy):
current_arg_value.append(proxy)
continue
if not proxy.node.name in self.registered_values:
raise ValueError(f"{arg.node.name} cannot be found.")
else:
current_arg_value.append(self.registered_values[proxy.node.name])
kwarg_values[arg_name] = tuple(current_arg_value)
continue
if not isinstance(arg, fx.Proxy):
kwarg_values[arg_name] = arg
continue
if not arg.node.name in self.registered_values:
raise ValueError(f"{arg.node.name} cannot be found.")
else:
kwarg_values[arg_name] = self.registered_values[arg.node.name]

assert kind != "call_function" or callable(target)

with fx.fx_no_wrap_context(self):
if kind == "call_function":
result_value = target(*arg_values, **kwarg_values)
elif kind == "call_method":
self_obj, *args_tail = arg_values

# Execute the method and return the result
assert isinstance(target, str)
method = getattr(self_obj, target)
result_value = method(*args_tail, **kwarg_values)
elif kind == "call_module":
assert isinstance(target, str)
submod = self.fetch_attr(target)

result_value = submod(*arg_values, **kwarg_values)
elif kind == "placeholder":
assert isinstance(target, str)
if target.startswith('*'):
# For a starred parameter e.g. `*args`, retrieve all
# remaining values from the args list.
result_value = list(self.args_iter)
else:
try:
result_value = next(self.args_iter)
except StopIteration as si:
raise RuntimeError(f'Expected positional argument for parameter {target}, but one was not passed in!') from si
elif kind == "get_attr":
assert isinstance(target, str)
result_value = self.fetch_attr(target)
elif kind == "output":
result_value = arg_values[0]
elif kind == "root":
raise NotImplementedError
else:
raise NotImplementedError

if isinstance(result_value, fx.Proxy):
if result_value.node.name in self.registered_values:
result_value = self.registered_values[result_value.node.name]
else:
raise ValueError("Got a proxy object when running with original values.")

if not self.fx_no_wrap:
result_proxy = super().create_proxy(kind, target, args, kwargs, name, type_expr, proxy_factory_fn)
self.registered_values[result_proxy.node.name] = result_value
return result_proxy

def getattr(self, attr: str, attr_val: Any, parameter_proxy_cache: Dict[str, Any]):
if self.fx_no_wrap:
return attr_val
def maybe_get_proxy_for_attr(
attr_val, collection_to_search, parameter_proxy_cache
):
for n, p in collection_to_search:
if attr_val is p:
if n not in parameter_proxy_cache:
kwargs = {}
if (
"proxy_factory_fn"
in inspect.signature(self.create_proxy).parameters
):
kwargs["proxy_factory_fn"] = (
None
if not self.param_shapes_constant
else lambda node: fx.ParameterProxy(
self, node, n, attr_val
)
)
val_proxy = self.create_proxy("get_attr", n, (), {}, **kwargs) # type: ignore[arg-type]
parameter_proxy_cache[n] = val_proxy
return parameter_proxy_cache[n]
return None

if isinstance(attr_val, flow.nn.Parameter):
maybe_parameter_proxy = maybe_get_proxy_for_attr(
attr_val, self.root.named_parameters(), parameter_proxy_cache
)
if maybe_parameter_proxy is not None:
if not maybe_parameter_proxy.node.name in self.registered_values:
self.registered_values[maybe_parameter_proxy.node.name] = attr_val
return maybe_parameter_proxy

if self.proxy_buffer_attributes and isinstance(attr_val, flow.Tensor):
maybe_buffer_proxy = maybe_get_proxy_for_attr(
attr_val, self.root.named_buffers(), parameter_proxy_cache
)
if maybe_buffer_proxy is not None:
if not maybe_buffer_proxy.node.name in self.registered_values:
self.registered_values[maybe_buffer_proxy.node.name] = attr_val
return maybe_buffer_proxy

return attr_val

def call_module(
self,
m: flow.nn.Module,
forward: Callable[..., Any],
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
) -> Any: # override
if self.fx_no_wrap:
return forward(*args, **kwargs)
else:
return super().call_module(m, forward, args, kwargs)

def trace(
self,
root: Union[flow.nn.Module, Callable[..., Any]],
concrete_args: Optional[Dict[str, Any]] = None,
) -> fx.Graph: # override
self.module = root
return super().trace(root, concrete_args)

def fetch_attr(self, target : str):
target_atoms = target.split('.')
attr_itr = self.module
for i, atom in enumerate(target_atoms):
if not hasattr(attr_itr, atom):
raise RuntimeError(f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}")
attr_itr = getattr(attr_itr, atom)
if not isinstance(attr_itr, fx.Proxy):
return attr_itr
if attr_itr.node.name in self.registered_values:
return self.registered_values[attr_itr.node.name]

raise ValueError(f"No attr <{target}> was found.")


def customized_symbolic_trace(
root: Union[flow.nn.Module, Callable[..., Any]],
concrete_args: Optional[Dict[str, Any]] = None,
input_args=None
) -> fx.GraphModule:
tracer = CustomiziedTracer(input_args=input_args)
graph = tracer.trace(root, concrete_args)
name = (
root.__class__.__name__ if isinstance(root, flow.nn.Module) else root.__name__
)
return fx.GraphModule(tracer.root, graph, name)

class AutoPlacementInterpreter(fx.Interpreter):
def __init__(self, mod : flow.nn.Module):
gm = fx.symbolic_trace(mod)
def __init__(self, mod : flow.nn.Module, concrete_args=None, input_args=None):
gm = customized_symbolic_trace(mod, concrete_args=concrete_args, input_args=input_args)
super().__init__(gm)

self.global_infos : Dict[int, Dict[int, Any]] = {}
Expand Down Expand Up @@ -258,12 +487,11 @@ def run_node(self, n : fx.Node) -> Any:
return return_val


def add_auto_placement(model: flow.nn.Module, global_info_dict: Dict[int, Dict[int, List[int]]]) -> flow.nn.Module:
def add_auto_placement(model: flow.nn.Module, global_info_dict: Dict[int, Dict[int, List[int]]], concrete_args=None, input_args=None) -> flow.nn.Module:
model = copy.deepcopy(model)
fx_model: fx.GraphModule = fx.symbolic_trace(model)
fx_model: fx.GraphModule = customized_symbolic_trace(model, concrete_args=concrete_args, input_args=input_args)

for node_id, node in enumerate(fx_model.graph.nodes):
print(node_id, " ", node.op)
if not node_id in global_info_dict:
continue

Expand All @@ -277,14 +505,23 @@ def add_auto_placement(model: flow.nn.Module, global_info_dict: Dict[int, Dict[i

fx_model.graph.lint()
fx_model.recompile()
return fx_model

def compile_auto_placement(model: flow.nn.Module, input_x: flow.Tensor):
assert input_x.is_global
interpret = AutoPlacementInterpreter(model)
interpret.run(input_x)
model = add_auto_placement(model, interpret.global_infos)
return model
return fx.Interpreter(fx_model)

fx.wrap(len)
def compile_auto_placement(model: flow.nn.Module, concrete_args=None, **kwargs):
with fx.global_wrap([dist.get_nd_sbp, dist.same_sbp], dist):
with fx.global_wrap([flow.finfo], flow):
if concrete_args is None:
all_args = inspect.signature(model.forward).parameters
concrete_args = {}
for arg_name, param in all_args.items():
if not arg_name in kwargs and param.default != inspect._empty:
concrete_args.update({arg_name:param.default})

interpret = AutoPlacementInterpreter(model, concrete_args=concrete_args, input_args=list(kwargs.values()) + list(concrete_args.values()))
interpret.run(*(kwargs.values()))
model = add_auto_placement(model, interpret.global_infos, concrete_args, input_args=list(kwargs.values()) + list(concrete_args.values()))
return model

# b = flow.ones(
# (2,2),
Expand Down