Skip to content

Commit 77e0580

Browse files
[HotFix] backward compabilities for 6.0 test tool (#722)
* Remove mix mode simulation for gfx90a * re-enable get heuristic for grouped gemm
1 parent 1549b02 commit 77e0580

File tree

30 files changed

+8169
-280
lines changed

30 files changed

+8169
-280
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@ Full documentation for hipBLASLt is available at [rocm.docs.amd.com/projects/hip
1212
* Added `GemmTuning` extension parameter to set split-k by user
1313
* Support for mix precision datatype: fp16/fp8 in with fp16 out
1414

15+
### Deprecations
16+
17+
* algoGetHeuristic() ext API for GroupGemm will be deprecated in a future release of hipBLASLt
18+
1519
## hipBLASLt 0.6.0
1620

1721
### Additions

clients/gtest/matmul_gtest.yaml

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -682,7 +682,7 @@ Tests:
682682
alpha: 1
683683
beta: [ 0.0, 2.0 ]
684684
unit_check: 1
685-
algo_method: [2]
685+
algo_method: [0, 2]
686686

687687
- name: matmul_groupedgemm_zero_n
688688
category: pre_checkin
@@ -987,24 +987,6 @@ Tests:
987987
unit_check: 1
988988
gpu_arch: '94?'
989989

990-
- name: matmul_gemm_mix_precisions_gfx90a_sim
991-
category: pre_checkin
992-
function:
993-
matmul: *real_mix_precisions
994-
matrix_size:
995-
M: [127, 129]
996-
N: [127, 129]
997-
K: [127, 129]
998-
transA_transB: *transA_transB_range
999-
alpha: 1
1000-
beta: [ 0, 2 ]
1001-
use_ext: [1]
1002-
use_ext_setproblem: [1]
1003-
bias_vector: 1
1004-
bias_type: [f32_r]
1005-
unit_check: 1
1006-
gpu_arch: '90a'
1007-
1008990
- name: matmul_gemm_mix_precisions2
1009991
category: pre_checkin
1010992
function:

library/include/hipblaslt-ext.hpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -374,8 +374,6 @@ namespace hipblaslt_ext
374374

375375
hipblasLtHandle_t m_handle;
376376
std::shared_ptr<void> m_data;
377-
struct ConversionHelper;
378-
std::unique_ptr<ConversionHelper> m_conversion_helper;
379377
};
380378

381379
/*! \ingroup types_module

library/src/amd_detail/hipblaslt-ext.cpp

Lines changed: 0 additions & 256 deletions
Original file line numberDiff line numberDiff line change
@@ -102,196 +102,6 @@ namespace hipblaslt_ext
102102
return m_workspace_bytes;
103103
}
104104

105-
struct GemmInstance::ConversionHelper
106-
{
107-
using Conversions = std::tuple<HipBufferPtr, //src
108-
HipBufferPtr, //dst
109-
hipDataType, //srcType
110-
hipDataType, //dstType
111-
std::size_t, //numElements
112-
HipBufferPtr>; //scale
113-
std::vector<std::vector<Conversions>> m_auxiliary_conversion_buffers;
114-
ConversionHelper(const std::vector<GemmProblemType>& problemTypes,
115-
GemmInputs& inputs,
116-
hipDataType conversionDType,
117-
int64_t batchSize,
118-
int64_t strideA,
119-
int64_t strideB,
120-
int64_t strideC,
121-
int64_t strideD)
122-
{
123-
const auto numGemms = problemTypes.size();
124-
if(m_auxiliary_conversion_buffers.size() != numGemms)
125-
{
126-
m_auxiliary_conversion_buffers.resize(numGemms);
127-
}
128-
129-
for(std::size_t j = 0; j < m_auxiliary_conversion_buffers.size(); ++j)
130-
{
131-
const std::vector<std::int64_t> sizes{strideA, strideB, strideC};
132-
const std::vector<void*> gemmInputs{inputs.a, inputs.b, inputs.c};
133-
const std::vector<void*> scales{inputs.scaleA, inputs.scaleB, inputs.scaleC};
134-
auto& conversions = m_auxiliary_conversion_buffers.at(j);
135-
auto& problem = problemTypes.at(j);
136-
const std::vector<hipDataType> dtypes{
137-
problem.type_a, problem.type_b, problem.type_c};
138-
139-
//a, b and c
140-
for(std::size_t i = 0; i < sizes.size(); ++i)
141-
{
142-
auto dtype = dtypes.at(i);
143-
const auto numElements = sizes.at(i);
144-
145-
if(dtype == HIP_R_8F_E4M3_FNUZ || dtype == HIP_R_8F_E5M2_FNUZ)
146-
{
147-
const auto numBytes = numElements * 2;
148-
conversions.emplace_back(
149-
std::make_tuple(std::move(HipBufferPtr(gemmInputs.at(i), NullDeleter)),
150-
std::move(makeHipBuffer(numBytes)),
151-
dtype,
152-
conversionDType,
153-
numElements,
154-
std::move(HipBufferPtr(scales.at(i), NullDeleter))));
155-
}
156-
else
157-
{
158-
conversions.emplace_back(
159-
std::make_tuple(std::move(HipBufferPtr(gemmInputs.at(i), NullDeleter)),
160-
std::move(makeHipBuffer(0)),
161-
dtype,
162-
conversionDType,
163-
numElements,
164-
std::move(HipBufferPtr(scales.at(i), NullDeleter))));
165-
}
166-
}
167-
168-
//for d
169-
auto output = inputs.d;
170-
const auto numElements = strideD * batchSize;
171-
172-
if(problem.type_d == HIP_R_8F_E4M3_FNUZ || problem.type_d == HIP_R_8F_E5M2_FNUZ)
173-
{
174-
auto numBytes = numElements * 2;
175-
conversions.emplace_back(
176-
std::make_tuple(std::move(makeHipBuffer(numBytes)),
177-
std::move(HipBufferPtr(output, NullDeleter)),
178-
conversionDType,
179-
problem.type_d,
180-
numElements,
181-
std::move(HipBufferPtr(inputs.scaleD, NullDeleter))));
182-
}
183-
else
184-
{
185-
conversions.emplace_back(
186-
std::make_tuple(std::move(makeHipBuffer(0)),
187-
std::move(HipBufferPtr(output, NullDeleter)),
188-
conversionDType,
189-
problem.type_d,
190-
numElements,
191-
std::move(HipBufferPtr(inputs.scaleD, NullDeleter))));
192-
}
193-
}
194-
}
195-
196-
~ConversionHelper() = default;
197-
ConversionHelper(ConversionHelper&& rhs) noexcept = default;
198-
//force move
199-
ConversionHelper(const ConversionHelper& rhs) = delete;
200-
ConversionHelper& operator=(const ConversionHelper&) = delete;
201-
ConversionHelper& operator=(ConversionHelper&&) = default;
202-
203-
void convertInputs(hipStream_t stream)
204-
{
205-
if(m_auxiliary_conversion_buffers.size())
206-
{
207-
for(auto& conversions : m_auxiliary_conversion_buffers)
208-
{
209-
for(size_t i = 0; i < 3; ++i)
210-
{
211-
auto& conversion = conversions.at(i);
212-
auto& dst = std::get<1>(conversion);
213-
auto& src = std::get<0>(conversion);
214-
215-
if(src && dst)
216-
{
217-
auto srcType = std::get<2>(conversion);
218-
auto dstType = std::get<3>(conversion);
219-
const auto numElements = std::get<4>(conversion);
220-
auto& scale = std::get<5>(conversion);
221-
constexpr auto numWorkitemsPerWg = 256;
222-
const auto numWg = (numElements / numWorkitemsPerWg)
223-
+ !!(numElements % numWorkitemsPerWg);
224-
225-
if(srcType == HIP_R_8F_E4M3_FNUZ)
226-
{
227-
datatypeConversion<hipblaslt_f8_fnuz, hipblasLtHalf>
228-
<<<numWg, numWorkitemsPerWg, 0, stream>>>(
229-
(const hipblaslt_f8_fnuz*)src.get(),
230-
(hipblasLtHalf*)dst.get(),
231-
(const float*)scale.get(),
232-
numElements);
233-
}
234-
else if(srcType == HIP_R_8F_E5M2_FNUZ)
235-
{
236-
datatypeConversion<hipblaslt_bf8_fnuz, hipblasLtHalf>
237-
<<<numWg, numWorkitemsPerWg, 0, stream>>>(
238-
(const hipblaslt_bf8_fnuz*)src.get(),
239-
(hipblasLtHalf*)dst.get(),
240-
(const float*)scale.get(),
241-
numElements);
242-
}
243-
}
244-
}
245-
}
246-
}
247-
}
248-
249-
void convertOutputs(hipStream_t stream)
250-
{
251-
if(m_auxiliary_conversion_buffers.size())
252-
{
253-
for(auto& conversions : m_auxiliary_conversion_buffers)
254-
{
255-
if(conversions.size() > 3)
256-
{
257-
auto& conversion = conversions.at(3);
258-
auto& src = std::get<0>(conversion);
259-
auto& dst = std::get<1>(conversion);
260-
auto srcType = std::get<2>(conversion);
261-
auto dstType = std::get<3>(conversion);
262-
const auto numElements = std::get<4>(conversion);
263-
auto& scale = std::get<5>(conversion);
264-
constexpr auto numWorkitemsPerWg = 256;
265-
const auto numWg = (numElements / numWorkitemsPerWg)
266-
+ !!(numElements % numWorkitemsPerWg);
267-
//indicates d needs datatype conversion
268-
if(src && dst)
269-
{
270-
if(dstType == HIP_R_8F_E4M3_FNUZ)
271-
{
272-
datatypeConversion<hipblasLtHalf, hipblaslt_f8_fnuz>
273-
<<<numWg, numWorkitemsPerWg, 0, stream>>>(
274-
(const hipblasLtHalf*)src.get(),
275-
(hipblaslt_f8_fnuz*)dst.get(),
276-
(const float*)scale.get(),
277-
numElements);
278-
}
279-
else if(dstType == HIP_R_8F_E5M2_FNUZ)
280-
{
281-
datatypeConversion<hipblasLtHalf, hipblaslt_bf8_fnuz>
282-
<<<numWg, numWorkitemsPerWg, 0, stream>>>(
283-
(const hipblasLtHalf*)src.get(),
284-
(hipblaslt_bf8_fnuz*)dst.get(),
285-
(const float*)scale.get(),
286-
numElements);
287-
}
288-
}
289-
}
290-
}
291-
}
292-
}
293-
};
294-
295105
GemmInstance::GemmInstance(hipblasLtHandle_t handle, GemmType type)
296106
: m_gemm_type(type)
297107
, m_handle(handle)
@@ -423,20 +233,10 @@ namespace hipblaslt_ext
423233
if(m_gemm_count == 0)
424234
return HIPBLAS_STATUS_INVALID_VALUE;
425235

426-
if(m_conversion_helper)
427-
{
428-
m_conversion_helper->convertInputs(stream);
429-
}
430-
431236
auto gemmType = static_cast<rocblaslt::RocGemmType>(m_gemm_type);
432237
auto status = RocBlasLtStatusToHIPStatus(
433238
rocblaslt_run_cpp((rocblaslt_handle)m_handle, gemmType, m_data, stream));
434239

435-
if(m_conversion_helper)
436-
{
437-
m_conversion_helper->convertOutputs(stream);
438-
}
439-
440240
return status;
441241
}
442242
catch(...)
@@ -543,64 +343,8 @@ namespace hipblaslt_ext
543343
GemmInputs& inputs,
544344
GemmProblemType& problemtype)
545345
{
546-
constexpr auto conversionDType = HIP_R_16F;
547-
auto needConversion = [&problemtype]() -> bool {
548-
using std::begin;
549-
using std::end;
550-
const auto types = {problemtype.type_a, problemtype.type_b, problemtype.type_c};
551-
auto mixedPrecision
552-
= end(types) != std::adjacent_find(begin(types), end(types), std::not_equal_to<>());
553-
return mixedPrecision && !currentArchSupportsFp8();
554-
}();
555-
556-
if(needConversion)
557-
{
558-
m_conversion_helper = std::make_unique<ConversionHelper>(m_problem_types,
559-
inputs,
560-
conversionDType,
561-
batch_count,
562-
strideA,
563-
strideB,
564-
strideC,
565-
strideD);
566-
}
567-
568-
//Shallow copy
569346
GemmInputs gemmInputs = inputs;
570347
GemmProblemType gemmProblemType = problemtype;
571-
auto& problem = m_problem_types.at(0);
572-
573-
if(needConversion)
574-
{
575-
if(auto& a
576-
= std::get<1>(m_conversion_helper->m_auxiliary_conversion_buffers.at(0).at(0)))
577-
{
578-
gemmInputs.a = a.get();
579-
gemmProblemType.type_a = conversionDType;
580-
}
581-
582-
if(auto& b
583-
= std::get<1>(m_conversion_helper->m_auxiliary_conversion_buffers.at(0).at(1)))
584-
{
585-
gemmInputs.b = b.get();
586-
gemmProblemType.type_b = conversionDType;
587-
}
588-
589-
if(auto& c
590-
= std::get<1>(m_conversion_helper->m_auxiliary_conversion_buffers.at(0).at(2)))
591-
{
592-
gemmInputs.c = c.get();
593-
gemmProblemType.type_c = conversionDType;
594-
}
595-
596-
if(auto& d
597-
= std::get<0>(m_conversion_helper->m_auxiliary_conversion_buffers.at(0).at(3)))
598-
{
599-
gemmInputs.d = d.get();
600-
gemmProblemType.type_d = conversionDType;
601-
}
602-
}
603-
604348
auto rocepilogue = reinterpret_cast<rocblaslt::RocGemmEpilogue*>(&epilogue);
605349
auto rocepinputs = reinterpret_cast<rocblaslt::RocGemmInputs*>(&gemmInputs);
606350
auto rocproblemtype = reinterpret_cast<rocblaslt::RocGemmProblemType*>(&gemmProblemType);

0 commit comments

Comments
 (0)