Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor softmax templates to use outer dims #844

Closed
wants to merge 2 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 26 additions & 27 deletions python/aitemplate/backend/cuda/softmax/softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,6 @@
{{func_signature}}
{
{{shape_functions}}
size_t m0 = {{m}};
size_t n = {{K}};
size_t m = M;
bool success = true;

Expand All @@ -64,7 +62,7 @@
{% elif K > 3840 %}
// K/8 > 480
using vec8 = VecTFor<{{dtype}}>::vec8;
LaunchSoftmaxBlockAll<vec8, {{dtype}},{{K}}>(reinterpret_cast<const vec8*>(input), reinterpret_cast<vec8*>(output), M, stream, &success);
LaunchSoftmaxBlockAll<vec8, {{dtype}}, {{K}}>(reinterpret_cast<const vec8*>(input), reinterpret_cast<vec8*>(output), M, stream, &success);
{% endif %}
{% elif K % 4 == 0 %}
// K % 4 == 0: vector4 kernels
Expand All @@ -77,7 +75,7 @@
{% elif K > 1920 %}
// K/4 > 480
using vec4 = VecTFor<{{dtype}}>::vec4;
LaunchSoftmaxBlockAll<vec4,{{dtype}},{{K}}>(reinterpret_cast<const vec4*>(input), reinterpret_cast<vec4*>(output), M, stream, &success);
LaunchSoftmaxBlockAll<vec4, {{dtype}}, {{K}}>(reinterpret_cast<const vec4*>(input), reinterpret_cast<vec4*>(output), M, stream, &success);
{% endif %}
{% elif K % 2 == 0 %}
// K % 2 == 0: vector2 kernels
Expand All @@ -90,7 +88,7 @@
{% elif K > 1152 %}
// K/2 > 576
using vec2 = VecTFor<{{dtype}}>::vec2;
LaunchSoftmaxBlockAll<vec2,{{dtype}},{{K}}>(reinterpret_cast<const vec2*>(input), reinterpret_cast<vec2*>(output), M, stream, &success);
LaunchSoftmaxBlockAll<vec2, {{dtype}}, {{K}}>(reinterpret_cast<const vec2*>(input), reinterpret_cast<vec2*>(output), M, stream, &success);
{% endif %}
{% else %}
// odd K
Expand All @@ -102,12 +100,12 @@
LaunchSoftmaxK1Middle<{{dtype}}, {{K}}>(static_cast<const {{dtype}}*>(input), static_cast<{{dtype}}*>(output), M, stream);
{% elif K > 1408 %}
// K > 1408
LaunchSoftmaxBlockAll<{{dtype}},{{dtype}},{{K}}>( (const {{dtype}}*) input, ({{dtype}}*) output, m, stream, &success);
LaunchSoftmaxBlockAll<{{dtype}}, {{dtype}}, {{K}}>( (const {{dtype}}*) input, ({{dtype}}*) output, m, stream, &success);
{% endif %}
{% endif %}

if (!success) {
softmaxBlockNocache<{{dtype}}><<<m, 1024, 0, stream>>>(({{dtype}}*)input, ({{dtype}}*)output, m, n);
softmaxBlockNocache<{{dtype}}><<<m, 1024, 0, stream>>>(({{dtype}}*)input, ({{dtype}}*)output, m, {{K}});
}
}
"""
Expand All @@ -116,7 +114,7 @@
SHAPE_FUNCTIONS = jinja2.Template(
"""
int64_t M = 1;
{% for idx in range(input_ndim - 1) %}
{% for idx in range(reduction_dim) %}
M *= *in_{{idx}};
{% endfor %}
"""
Expand All @@ -125,39 +123,40 @@
FUNC_SIGNATURE = jinja2.Template(
"""
void {{func_name}}(void* input,
void* output,
{% for idx in range(input_ndim - 1) %}
int64_t* in_{{idx}},
void* output,
{% for idx in range(reduction_dim) %}
int64_t* in_{{idx}},
{% endfor %}
cudaStream_t stream)
"""
cudaStream_t stream)
""",
trim_blocks=True,
)

FUNC_DECL = jinja2.Template(
"""
{{func_signature}};
"""
{{func_signature}};
""",
)

FUNC_CALL_TEMPLATE = jinja2.Template(
"""
{{indent}}{{func_name}}(
{{indent}} {{input}},
{{indent}} {{output}},
{% for name in input_dim_names[:-1] %}
{{indent}} &{{name}},
{% for name in outer_dim_names %}
{{indent}} &{{name}},
{% endfor %}
{{indent}} stream
{{indent}});
"""
""",
trim_blocks=True,
)


def get_func_signature(func_attrs: Dict[str, Any]) -> str:
input_ndim = func_attrs["inputs"][0]._rank()
return FUNC_SIGNATURE.render(
func_name=func_attrs["name"],
input_ndim=input_ndim,
reduction_dim=func_attrs["dim"],
).strip()


Expand All @@ -180,7 +179,6 @@ def find_tile_size(k: int) -> int:
def softmax_gen_function(func_attrs: Dict[str, Any]) -> str:
dim = func_attrs["dim"]
shapes = func_attrs["inputs"][0]._attrs["shape"]
rank = len(shapes)

assert isinstance(
shapes[dim], IntImm
Expand All @@ -197,7 +195,7 @@ def softmax_gen_function(func_attrs: Dict[str, Any]) -> str:
os.path.dirname(__file__), "softmax.cuh"
),
func_signature=get_func_signature(func_attrs),
shape_functions=SHAPE_FUNCTIONS.render(input_ndim=rank),
shape_functions=SHAPE_FUNCTIONS.render(reduction_dim=dim),
dtype=elem_input_type,
K=k,
m=find_tile_size(k),
Expand All @@ -217,17 +215,18 @@ def softmax_gen_function_call(func_attrs, indent=" "):
input_name = func_attrs["inputs"][0]._attrs["name"]
output_name = func_attrs["outputs"][0]._attrs["name"]

shapes = func_attrs["inputs"][0]._attrs["shape"]
shape = func_attrs["inputs"][0]._attrs["shape"]
assert (
len(shapes) >= 2
), f"Softmax only supports input with rank >= 2, current rank: {len(shapes)}"
len(shape) >= 2
), f"Softmax only supports input with rank >= 2, current rank: {len(shape)}"

input_dim_names = [shape._attrs["name"] for shape in shapes]
reduction_dim = func_attrs["dim"]
outer_dim_names = [dim._attrs["name"] for dim in shape[:reduction_dim]]

return FUNC_CALL_TEMPLATE.render(
func_name=func_attrs["name"],
input=input_name,
output=output_name,
input_dim_names=input_dim_names,
outer_dim_names=outer_dim_names,
indent=indent,
)
Loading