Skip to content

Commit

Permalink
Slightly improve softmax's codegen formatting
Browse files Browse the repository at this point in the history
Summary:
Very minor cleanups I did while familiarizing myself with the code

Aside from whitespace changes I also removed a few unnecessary automatic variables

Differential Revision: D47732846

fbshipit-source-id: 24c939f0eb0cc74cb5f72444cc805b140f638735
  • Loading branch information
int3 authored and facebook-github-bot committed Jul 25, 2023
1 parent d50e946 commit 2efce5e
Showing 1 changed file with 15 additions and 15 deletions.
30 changes: 15 additions & 15 deletions python/aitemplate/backend/cuda/softmax/softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,6 @@
{{func_signature}}
{
{{shape_functions}}
size_t m0 = {{m}};
size_t n = {{K}};
size_t m = M;
bool success = true;
Expand All @@ -64,7 +62,7 @@
{% elif K > 3840 %}
// K/8 > 480
using vec8 = VecTFor<{{dtype}}>::vec8;
LaunchSoftmaxBlockAll<vec8, {{dtype}},{{K}}>(reinterpret_cast<const vec8*>(input), reinterpret_cast<vec8*>(output), M, stream, &success);
LaunchSoftmaxBlockAll<vec8, {{dtype}}, {{K}}>(reinterpret_cast<const vec8*>(input), reinterpret_cast<vec8*>(output), M, stream, &success);
{% endif %}
{% elif K % 4 == 0 %}
// K % 4 == 0: vector4 kernels
Expand All @@ -77,7 +75,7 @@
{% elif K > 1920 %}
// K/4 > 480
using vec4 = VecTFor<{{dtype}}>::vec4;
LaunchSoftmaxBlockAll<vec4,{{dtype}},{{K}}>(reinterpret_cast<const vec4*>(input), reinterpret_cast<vec4*>(output), M, stream, &success);
LaunchSoftmaxBlockAll<vec4, {{dtype}}, {{K}}>(reinterpret_cast<const vec4*>(input), reinterpret_cast<vec4*>(output), M, stream, &success);
{% endif %}
{% elif K % 2 == 0 %}
// K % 2 == 0: vector2 kernels
Expand All @@ -90,7 +88,7 @@
{% elif K > 1152 %}
// K/2 > 576
using vec2 = VecTFor<{{dtype}}>::vec2;
LaunchSoftmaxBlockAll<vec2,{{dtype}},{{K}}>(reinterpret_cast<const vec2*>(input), reinterpret_cast<vec2*>(output), M, stream, &success);
LaunchSoftmaxBlockAll<vec2, {{dtype}}, {{K}}>(reinterpret_cast<const vec2*>(input), reinterpret_cast<vec2*>(output), M, stream, &success);
{% endif %}
{% else %}
// odd K
Expand All @@ -102,12 +100,12 @@
LaunchSoftmaxK1Middle<{{dtype}}, {{K}}>(static_cast<const {{dtype}}*>(input), static_cast<{{dtype}}*>(output), M, stream);
{% elif K > 1408 %}
// K > 1408
LaunchSoftmaxBlockAll<{{dtype}},{{dtype}},{{K}}>( (const {{dtype}}*) input, ({{dtype}}*) output, m, stream, &success);
LaunchSoftmaxBlockAll<{{dtype}}, {{dtype}}, {{K}}>( (const {{dtype}}*) input, ({{dtype}}*) output, m, stream, &success);
{% endif %}
{% endif %}
if (!success) {
softmaxBlockNocache<{{dtype}}><<<m, 1024, 0, stream>>>(({{dtype}}*)input, ({{dtype}}*)output, m, n);
softmaxBlockNocache<{{dtype}}><<<m, 1024, 0, stream>>>(({{dtype}}*)input, ({{dtype}}*)output, m, {{K}});
}
}
"""
Expand All @@ -125,18 +123,19 @@
FUNC_SIGNATURE = jinja2.Template(
"""
void {{func_name}}(void* input,
void* output,
void* output,
{% for idx in range(input_ndim - 1) %}
int64_t* in_{{idx}},
int64_t* in_{{idx}},
{% endfor %}
cudaStream_t stream)
"""
cudaStream_t stream)
""",
trim_blocks=True,
)

FUNC_DECL = jinja2.Template(
"""
{{func_signature}};
"""
{{func_signature}};
""",
)

FUNC_CALL_TEMPLATE = jinja2.Template(
Expand All @@ -145,11 +144,12 @@
{{indent}} {{input}},
{{indent}} {{output}},
{% for name in input_dim_names[:-1] %}
{{indent}} &{{name}},
{{indent}} &{{name}},
{% endfor %}
{{indent}} stream
{{indent}});
"""
""",
trim_blocks=True,
)


Expand Down

0 comments on commit 2efce5e

Please sign in to comment.