@@ -102,196 +102,6 @@ namespace hipblaslt_ext
102102 return m_workspace_bytes;
103103 }
104104
105- struct GemmInstance ::ConversionHelper
106- {
107- using Conversions = std::tuple<HipBufferPtr, // src
108- HipBufferPtr, // dst
109- hipDataType, // srcType
110- hipDataType, // dstType
111- std::size_t , // numElements
112- HipBufferPtr>; // scale
113- std::vector<std::vector<Conversions>> m_auxiliary_conversion_buffers;
114- ConversionHelper (const std::vector<GemmProblemType>& problemTypes,
115- GemmInputs& inputs,
116- hipDataType conversionDType,
117- int64_t batchSize,
118- int64_t strideA,
119- int64_t strideB,
120- int64_t strideC,
121- int64_t strideD)
122- {
123- const auto numGemms = problemTypes.size ();
124- if (m_auxiliary_conversion_buffers.size () != numGemms)
125- {
126- m_auxiliary_conversion_buffers.resize (numGemms);
127- }
128-
129- for (std::size_t j = 0 ; j < m_auxiliary_conversion_buffers.size (); ++j)
130- {
131- const std::vector<std::int64_t > sizes{strideA, strideB, strideC};
132- const std::vector<void *> gemmInputs{inputs.a , inputs.b , inputs.c };
133- const std::vector<void *> scales{inputs.scaleA , inputs.scaleB , inputs.scaleC };
134- auto & conversions = m_auxiliary_conversion_buffers.at (j);
135- auto & problem = problemTypes.at (j);
136- const std::vector<hipDataType> dtypes{
137- problem.type_a , problem.type_b , problem.type_c };
138-
139- // a, b and c
140- for (std::size_t i = 0 ; i < sizes.size (); ++i)
141- {
142- auto dtype = dtypes.at (i);
143- const auto numElements = sizes.at (i);
144-
145- if (dtype == HIP_R_8F_E4M3_FNUZ || dtype == HIP_R_8F_E5M2_FNUZ)
146- {
147- const auto numBytes = numElements * 2 ;
148- conversions.emplace_back (
149- std::make_tuple (std::move (HipBufferPtr (gemmInputs.at (i), NullDeleter)),
150- std::move (makeHipBuffer (numBytes)),
151- dtype,
152- conversionDType,
153- numElements,
154- std::move (HipBufferPtr (scales.at (i), NullDeleter))));
155- }
156- else
157- {
158- conversions.emplace_back (
159- std::make_tuple (std::move (HipBufferPtr (gemmInputs.at (i), NullDeleter)),
160- std::move (makeHipBuffer (0 )),
161- dtype,
162- conversionDType,
163- numElements,
164- std::move (HipBufferPtr (scales.at (i), NullDeleter))));
165- }
166- }
167-
168- // for d
169- auto output = inputs.d ;
170- const auto numElements = strideD * batchSize;
171-
172- if (problem.type_d == HIP_R_8F_E4M3_FNUZ || problem.type_d == HIP_R_8F_E5M2_FNUZ)
173- {
174- auto numBytes = numElements * 2 ;
175- conversions.emplace_back (
176- std::make_tuple (std::move (makeHipBuffer (numBytes)),
177- std::move (HipBufferPtr (output, NullDeleter)),
178- conversionDType,
179- problem.type_d ,
180- numElements,
181- std::move (HipBufferPtr (inputs.scaleD , NullDeleter))));
182- }
183- else
184- {
185- conversions.emplace_back (
186- std::make_tuple (std::move (makeHipBuffer (0 )),
187- std::move (HipBufferPtr (output, NullDeleter)),
188- conversionDType,
189- problem.type_d ,
190- numElements,
191- std::move (HipBufferPtr (inputs.scaleD , NullDeleter))));
192- }
193- }
194- }
195-
196- ~ConversionHelper () = default ;
197- ConversionHelper (ConversionHelper&& rhs) noexcept = default ;
198- // force move
199- ConversionHelper (const ConversionHelper& rhs) = delete ;
200- ConversionHelper& operator =(const ConversionHelper&) = delete ;
201- ConversionHelper& operator =(ConversionHelper&&) = default ;
202-
203- void convertInputs (hipStream_t stream)
204- {
205- if (m_auxiliary_conversion_buffers.size ())
206- {
207- for (auto & conversions : m_auxiliary_conversion_buffers)
208- {
209- for (size_t i = 0 ; i < 3 ; ++i)
210- {
211- auto & conversion = conversions.at (i);
212- auto & dst = std::get<1 >(conversion);
213- auto & src = std::get<0 >(conversion);
214-
215- if (src && dst)
216- {
217- auto srcType = std::get<2 >(conversion);
218- auto dstType = std::get<3 >(conversion);
219- const auto numElements = std::get<4 >(conversion);
220- auto & scale = std::get<5 >(conversion);
221- constexpr auto numWorkitemsPerWg = 256 ;
222- const auto numWg = (numElements / numWorkitemsPerWg)
223- + !!(numElements % numWorkitemsPerWg);
224-
225- if (srcType == HIP_R_8F_E4M3_FNUZ)
226- {
227- datatypeConversion<hipblaslt_f8_fnuz, hipblasLtHalf>
228- <<<numWg, numWorkitemsPerWg, 0 , stream>>>(
229- (const hipblaslt_f8_fnuz*)src.get (),
230- (hipblasLtHalf*)dst.get (),
231- (const float *)scale.get (),
232- numElements);
233- }
234- else if (srcType == HIP_R_8F_E5M2_FNUZ)
235- {
236- datatypeConversion<hipblaslt_bf8_fnuz, hipblasLtHalf>
237- <<<numWg, numWorkitemsPerWg, 0 , stream>>>(
238- (const hipblaslt_bf8_fnuz*)src.get (),
239- (hipblasLtHalf*)dst.get (),
240- (const float *)scale.get (),
241- numElements);
242- }
243- }
244- }
245- }
246- }
247- }
248-
249- void convertOutputs (hipStream_t stream)
250- {
251- if (m_auxiliary_conversion_buffers.size ())
252- {
253- for (auto & conversions : m_auxiliary_conversion_buffers)
254- {
255- if (conversions.size () > 3 )
256- {
257- auto & conversion = conversions.at (3 );
258- auto & src = std::get<0 >(conversion);
259- auto & dst = std::get<1 >(conversion);
260- auto srcType = std::get<2 >(conversion);
261- auto dstType = std::get<3 >(conversion);
262- const auto numElements = std::get<4 >(conversion);
263- auto & scale = std::get<5 >(conversion);
264- constexpr auto numWorkitemsPerWg = 256 ;
265- const auto numWg = (numElements / numWorkitemsPerWg)
266- + !!(numElements % numWorkitemsPerWg);
267- // indicates d needs datatype conversion
268- if (src && dst)
269- {
270- if (dstType == HIP_R_8F_E4M3_FNUZ)
271- {
272- datatypeConversion<hipblasLtHalf, hipblaslt_f8_fnuz>
273- <<<numWg, numWorkitemsPerWg, 0 , stream>>>(
274- (const hipblasLtHalf*)src.get (),
275- (hipblaslt_f8_fnuz*)dst.get (),
276- (const float *)scale.get (),
277- numElements);
278- }
279- else if (dstType == HIP_R_8F_E5M2_FNUZ)
280- {
281- datatypeConversion<hipblasLtHalf, hipblaslt_bf8_fnuz>
282- <<<numWg, numWorkitemsPerWg, 0 , stream>>>(
283- (const hipblasLtHalf*)src.get (),
284- (hipblaslt_bf8_fnuz*)dst.get (),
285- (const float *)scale.get (),
286- numElements);
287- }
288- }
289- }
290- }
291- }
292- }
293- };
294-
295105 GemmInstance::GemmInstance (hipblasLtHandle_t handle, GemmType type)
296106 : m_gemm_type(type)
297107 , m_handle(handle)
@@ -423,20 +233,10 @@ namespace hipblaslt_ext
423233 if (m_gemm_count == 0 )
424234 return HIPBLAS_STATUS_INVALID_VALUE;
425235
426- if (m_conversion_helper)
427- {
428- m_conversion_helper->convertInputs (stream);
429- }
430-
431236 auto gemmType = static_cast <rocblaslt::RocGemmType>(m_gemm_type);
432237 auto status = RocBlasLtStatusToHIPStatus (
433238 rocblaslt_run_cpp ((rocblaslt_handle)m_handle, gemmType, m_data, stream));
434239
435- if (m_conversion_helper)
436- {
437- m_conversion_helper->convertOutputs (stream);
438- }
439-
440240 return status;
441241 }
442242 catch (...)
@@ -543,64 +343,8 @@ namespace hipblaslt_ext
543343 GemmInputs& inputs,
544344 GemmProblemType& problemtype)
545345 {
546- constexpr auto conversionDType = HIP_R_16F;
547- auto needConversion = [&problemtype]() -> bool {
548- using std::begin;
549- using std::end;
550- const auto types = {problemtype.type_a , problemtype.type_b , problemtype.type_c };
551- auto mixedPrecision
552- = end (types) != std::adjacent_find (begin (types), end (types), std::not_equal_to<>());
553- return mixedPrecision && !currentArchSupportsFp8 ();
554- }();
555-
556- if (needConversion)
557- {
558- m_conversion_helper = std::make_unique<ConversionHelper>(m_problem_types,
559- inputs,
560- conversionDType,
561- batch_count,
562- strideA,
563- strideB,
564- strideC,
565- strideD);
566- }
567-
568- // Shallow copy
569346 GemmInputs gemmInputs = inputs;
570347 GemmProblemType gemmProblemType = problemtype;
571- auto & problem = m_problem_types.at (0 );
572-
573- if (needConversion)
574- {
575- if (auto & a
576- = std::get<1 >(m_conversion_helper->m_auxiliary_conversion_buffers .at (0 ).at (0 )))
577- {
578- gemmInputs.a = a.get ();
579- gemmProblemType.type_a = conversionDType;
580- }
581-
582- if (auto & b
583- = std::get<1 >(m_conversion_helper->m_auxiliary_conversion_buffers .at (0 ).at (1 )))
584- {
585- gemmInputs.b = b.get ();
586- gemmProblemType.type_b = conversionDType;
587- }
588-
589- if (auto & c
590- = std::get<1 >(m_conversion_helper->m_auxiliary_conversion_buffers .at (0 ).at (2 )))
591- {
592- gemmInputs.c = c.get ();
593- gemmProblemType.type_c = conversionDType;
594- }
595-
596- if (auto & d
597- = std::get<0 >(m_conversion_helper->m_auxiliary_conversion_buffers .at (0 ).at (3 )))
598- {
599- gemmInputs.d = d.get ();
600- gemmProblemType.type_d = conversionDType;
601- }
602- }
603-
604348 auto rocepilogue = reinterpret_cast <rocblaslt::RocGemmEpilogue*>(&epilogue);
605349 auto rocepinputs = reinterpret_cast <rocblaslt::RocGemmInputs*>(&gemmInputs);
606350 auto rocproblemtype = reinterpret_cast <rocblaslt::RocGemmProblemType*>(&gemmProblemType);
0 commit comments