Skip to content

Commit c129682

Browse files
committed
Fix issues in jitlayers and aotcompile
1 parent a223989 commit c129682

File tree

2 files changed

+82
-20
lines changed

2 files changed

+82
-20
lines changed

src/aotcompile.cpp

+30-10
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ static void makeSafeName(GlobalObject &G)
228228
G.setName(StringRef(SafeName.data(), SafeName.size()));
229229
}
230230

231-
static void jl_ci_cache_lookup(const jl_cgparams_t &cgparams, jl_method_instance_t *mi, size_t world, jl_code_instance_t **ci_out, jl_code_info_t **src_out)
231+
static void jl_ci_cache_lookup(const jl_cgparams_t &cgparams, jl_method_instance_t *mi JL_REQUIRE_PIN, size_t world, jl_code_instance_t **ci_out, jl_code_info_t **src_out)
232232
{
233233
++CICacheLookups;
234234
jl_value_t *ci = cgparams.lookup(mi, world, world);
@@ -273,6 +273,7 @@ void replaceUsesWithLoad(Function &F, function_ref<GlobalVariable *(Instruction
273273
extern "C" JL_DLLEXPORT
274274
void *jl_create_native_impl(jl_array_t *methods, LLVMOrcThreadSafeModuleRef llvmmod, const jl_cgparams_t *cgparams, int _policy, int _imaging_mode, int _external_linkage, size_t _world)
275275
{
276+
PTR_PIN(methods);
276277
++CreateNativeCalls;
277278
CreateNativeMax.updateMax(jl_array_len(methods));
278279
if (cgparams == NULL)
@@ -320,12 +321,16 @@ void *jl_create_native_impl(jl_array_t *methods, LLVMOrcThreadSafeModuleRef llvm
320321
// each item in this list is either a MethodInstance indicating something
321322
// to compile, or an svec(rettype, sig) describing a C-callable alias to create.
322323
jl_value_t *item = jl_array_ptr_ref(methods, i);
324+
PTR_PIN(item);
323325
if (jl_is_simplevector(item)) {
324326
if (worlds == 1)
325327
jl_compile_extern_c(wrap(&clone), &params, NULL, jl_svecref(item, 0), jl_svecref(item, 1));
328+
PTR_UNPIN(item);
326329
continue;
327330
}
331+
PTR_UNPIN(item);
328332
mi = (jl_method_instance_t*)item;
333+
PTR_PIN(mi);
329334
src = NULL;
330335
// if this method is generally visible to the current compilation world,
331336
// and this is either the primary world, or not applicable in the primary world
@@ -337,20 +342,24 @@ void *jl_create_native_impl(jl_array_t *methods, LLVMOrcThreadSafeModuleRef llvm
337342
if (src && !emitted.count(codeinst)) {
338343
// now add it to our compilation results
339344
JL_GC_PROMISE_ROOTED(codeinst->rettype);
345+
PTR_PIN(codeinst->rettype);
340346
orc::ThreadSafeModule result_m = jl_create_ts_module(name_from_method_instance(codeinst->def),
341347
params.tsctx, params.imaging,
342348
clone.getModuleUnlocked()->getDataLayout(),
343349
Triple(clone.getModuleUnlocked()->getTargetTriple()));
344350
jl_llvm_functions_t decls = jl_emit_code(result_m, mi, src, codeinst->rettype, params);
351+
PTR_UNPIN(codeinst->rettype);
345352
if (result_m)
346353
emitted[codeinst] = {std::move(result_m), std::move(decls)};
347354
}
348355
}
356+
PTR_UNPIN(mi);
349357
}
350358

351359
// finally, make sure all referenced methods also get compiled or fixed up
352360
jl_compile_workqueue(emitted, *clone.getModuleUnlocked(), params, policy);
353361
}
362+
PTR_UNPIN(methods);
354363
JL_UNLOCK(&jl_codegen_lock); // Might GC
355364
JL_GC_POP();
356365

@@ -1048,31 +1057,38 @@ void jl_get_llvmf_defn_impl(jl_llvmf_dump_t* dump, jl_method_instance_t *mi, siz
10481057
return;
10491058
}
10501059

1060+
jl_method_t *method = mi->def.method;
1061+
PTR_PIN(mi);
1062+
PTR_PIN(method);
10511063
// get the source code for this function
10521064
jl_value_t *jlrettype = (jl_value_t*)jl_any_type;
10531065
jl_code_info_t *src = NULL;
10541066
JL_GC_PUSH2(&src, &jlrettype);
1055-
if (jl_is_method(mi->def.method) && mi->def.method->source != NULL && jl_ir_flag_inferred((jl_array_t*)mi->def.method->source)) {
1056-
src = (jl_code_info_t*)mi->def.method->source;
1067+
if (jl_is_method(method) && method->source != NULL && jl_ir_flag_inferred((jl_array_t*)method->source)) {
1068+
src = (jl_code_info_t*)method->source;
10571069
if (src && !jl_is_code_info(src))
1058-
src = jl_uncompress_ir(mi->def.method, NULL, (jl_array_t*)src);
1070+
src = jl_uncompress_ir(method, NULL, (jl_array_t*)src);
10591071
} else {
10601072
jl_value_t *ci = jl_rettype_inferred(mi, world, world);
10611073
if (ci != jl_nothing) {
10621074
jl_code_instance_t *codeinst = (jl_code_instance_t*)ci;
10631075
src = (jl_code_info_t*)jl_atomic_load_relaxed(&codeinst->inferred);
1064-
if ((jl_value_t*)src != jl_nothing && !jl_is_code_info(src) && jl_is_method(mi->def.method))
1065-
src = jl_uncompress_ir(mi->def.method, codeinst, (jl_array_t*)src);
1076+
if ((jl_value_t*)src != jl_nothing && !jl_is_code_info(src) && jl_is_method(method)) {
1077+
PTR_PIN(codeinst);
1078+
src = jl_uncompress_ir(method, codeinst, (jl_array_t*)src);
1079+
PTR_UNPIN(codeinst);
1080+
}
10661081
jlrettype = codeinst->rettype;
1082+
10671083
}
10681084
if (!src || (jl_value_t*)src == jl_nothing) {
10691085
src = jl_type_infer(mi, world, 0);
10701086
if (src)
10711087
jlrettype = src->rettype;
1072-
else if (jl_is_method(mi->def.method)) {
1073-
src = mi->def.method->generator ? jl_code_for_staged(mi) : (jl_code_info_t*)mi->def.method->source;
1074-
if (src && !jl_is_code_info(src) && jl_is_method(mi->def.method))
1075-
src = jl_uncompress_ir(mi->def.method, NULL, (jl_array_t*)src);
1088+
else if (jl_is_method(method)) {
1089+
src = method->generator ? jl_code_for_staged(mi) : (jl_code_info_t*)method->source;
1090+
if (src && !jl_is_code_info(src) && jl_is_method(method))
1091+
src = jl_uncompress_ir(method, NULL, (jl_array_t*)src);
10761092
}
10771093
// TODO: use mi->uninferred
10781094
}
@@ -1132,10 +1148,14 @@ void jl_get_llvmf_defn_impl(jl_llvmf_dump_t* dump, jl_method_instance_t *mi, siz
11321148
if (F) {
11331149
dump->TSM = wrap(new orc::ThreadSafeModule(std::move(m)));
11341150
dump->F = wrap(F);
1151+
PTR_UNPIN(mi);
1152+
PTR_UNPIN(method);
11351153
return;
11361154
}
11371155
}
11381156

11391157
const char *mname = name_from_method_instance(mi);
1158+
PTR_UNPIN(mi);
1159+
PTR_UNPIN(method);
11401160
jl_errorf("unable to compile source for function %s", mname);
11411161
}

src/jitlayers.cpp

+52-10
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,8 @@ static jl_callptr_t _jl_compile_codeinst(
180180
size_t world,
181181
orc::ThreadSafeContext context)
182182
{
183+
PTR_PIN(codeinst);
184+
PTR_PIN(src);
183185
// caller must hold codegen_lock
184186
// and have disabled finalizers
185187
uint64_t start_time = 0;
@@ -246,6 +248,7 @@ static jl_callptr_t _jl_compile_codeinst(
246248
IndirectCodeinsts += emitted.size() - 1;
247249
}
248250
JL_TIMING(LLVM_MODULE_FINISH);
251+
PTR_UNPIN(src);
249252

250253
for (auto &def : emitted) {
251254
jl_code_instance_t *this_code = def.first;
@@ -307,6 +310,7 @@ static jl_callptr_t _jl_compile_codeinst(
307310
jl_printf(stream, "\"\n");
308311
}
309312
}
313+
PTR_UNPIN(codeinst);
310314
return fptr;
311315
}
312316

@@ -316,6 +320,8 @@ const char *jl_generate_ccallable(LLVMOrcThreadSafeModuleRef llvmmod, void *sysi
316320
extern "C" JL_DLLEXPORT
317321
int jl_compile_extern_c_impl(LLVMOrcThreadSafeModuleRef llvmmod, void *p, void *sysimg, jl_value_t *declrt, jl_value_t *sigt)
318322
{
323+
PTR_PIN(declrt);
324+
PTR_PIN(sigt);
319325
auto ct = jl_current_task;
320326
ct->reentrant_timing++;
321327
uint64_t compiler_start_time = 0;
@@ -339,6 +345,8 @@ int jl_compile_extern_c_impl(LLVMOrcThreadSafeModuleRef llvmmod, void *p, void *
339345
pparams = &params;
340346
assert(pparams->tsctx.getContext() == into->getContext().getContext());
341347
const char *name = jl_generate_ccallable(wrap(into), sysimg, declrt, sigt, *pparams);
348+
PTR_UNPIN(declrt);
349+
PTR_UNPIN(sigt);
342350
bool success = true;
343351
if (!sysimg) {
344352
if (jl_ExecutionEngine->getGlobalValueAddress(name)) {
@@ -368,13 +376,18 @@ int jl_compile_extern_c_impl(LLVMOrcThreadSafeModuleRef llvmmod, void *p, void *
368376
extern "C" JL_DLLEXPORT
369377
void jl_extern_c_impl(jl_value_t *declrt, jl_tupletype_t *sigt)
370378
{
379+
PTR_PIN(declrt);
380+
PTR_PIN(sigt);
381+
jl_svec_t *params = ((jl_datatype_t*)(sigt))->parameters;
382+
PTR_PIN(params);
371383
// validate arguments. try to do as many checks as possible here to avoid
372384
// throwing errors later during codegen.
373385
JL_TYPECHK(@ccallable, type, declrt);
374386
if (!jl_is_tuple_type(sigt))
375387
jl_type_error("@ccallable", (jl_value_t*)jl_anytuple_type_type, (jl_value_t*)sigt);
376388
// check that f is a guaranteed singleton type
377389
jl_datatype_t *ft = (jl_datatype_t*)jl_tparam0(sigt);
390+
PTR_PIN(ft);
378391
if (!jl_is_datatype(ft) || ft->instance == NULL)
379392
jl_error("@ccallable: function object must be a singleton");
380393

@@ -385,12 +398,13 @@ void jl_extern_c_impl(jl_value_t *declrt, jl_tupletype_t *sigt)
385398
jl_error("@ccallable: return type doesn't correspond to a C type");
386399

387400
// validate method signature
388-
size_t i, nargs = jl_nparams(sigt);
401+
size_t i, nargs = jl_svec_len(params);
389402
for (i = 1; i < nargs; i++) {
390-
jl_value_t *ati = jl_tparam(sigt, i);
403+
jl_value_t *ati = jl_svecref(params, i);
391404
if (!jl_is_concrete_type(ati) || jl_is_kind(ati) || !jl_type_mappable_to_c(ati))
392405
jl_error("@ccallable: argument types must be concrete");
393406
}
407+
PTR_UNPIN(params);
394408

395409
// save a record of this so that the alias is generated when we write an object file
396410
jl_method_t *meth = (jl_method_t*)jl_methtable_lookup(ft->name->mt, (jl_value_t*)sigt, jl_atomic_load_acquire(&jl_world_counter));
@@ -403,6 +417,9 @@ void jl_extern_c_impl(jl_value_t *declrt, jl_tupletype_t *sigt)
403417

404418
// create the alias in the current runtime environment
405419
int success = jl_compile_extern_c(NULL, NULL, NULL, declrt, (jl_value_t*)sigt);
420+
PTR_UNPIN(declrt);
421+
PTR_UNPIN(sigt);
422+
PTR_UNPIN(ft);
406423
if (!success)
407424
jl_error("@ccallable was already defined for this method name");
408425
}
@@ -411,6 +428,9 @@ void jl_extern_c_impl(jl_value_t *declrt, jl_tupletype_t *sigt)
411428
extern "C" JL_DLLEXPORT
412429
jl_code_instance_t *jl_generate_fptr_impl(jl_method_instance_t *mi JL_PROPAGATES_ROOT, size_t world)
413430
{
431+
PTR_PIN(mi);
432+
jl_method_t * method = mi->def.method;
433+
PTR_PIN(method);
414434
auto ct = jl_current_task;
415435
ct->reentrant_timing++;
416436
uint64_t compiler_start_time = 0;
@@ -425,19 +445,20 @@ jl_code_instance_t *jl_generate_fptr_impl(jl_method_instance_t *mi JL_PROPAGATES
425445
jl_value_t *ci = jl_rettype_inferred(mi, world, world);
426446
jl_code_instance_t *codeinst = (ci == jl_nothing ? NULL : (jl_code_instance_t*)ci);
427447
if (codeinst) {
448+
PTR_PIN(codeinst);
428449
src = (jl_code_info_t*)jl_atomic_load_relaxed(&codeinst->inferred);
429450
if ((jl_value_t*)src == jl_nothing)
430451
src = NULL;
431-
else if (jl_is_method(mi->def.method))
432-
src = jl_uncompress_ir(mi->def.method, codeinst, (jl_array_t*)src);
452+
else if (jl_is_method(method))
453+
src = jl_uncompress_ir(method, codeinst, (jl_array_t*)src);
433454
}
434455
else {
435456
// identify whether this is an invalidated method that is being recompiled
436457
is_recompile = jl_atomic_load_relaxed(&mi->cache) != NULL;
437458
}
438-
if (src == NULL && jl_is_method(mi->def.method) &&
439-
jl_symbol_name(mi->def.method->name)[0] != '@') {
440-
if (mi->def.method->source != jl_nothing) {
459+
if (src == NULL && jl_is_method(method) &&
460+
jl_symbol_name(method->name)[0] != '@') {
461+
if (method->source != jl_nothing) {
441462
// If the caller didn't provide the source and IR is available,
442463
// see if it is inferred, or try to infer it for ourself.
443464
// (but don't bother with typeinf on macros or toplevel thunks)
@@ -446,22 +467,28 @@ jl_code_instance_t *jl_generate_fptr_impl(jl_method_instance_t *mi JL_PROPAGATES
446467
}
447468
jl_code_instance_t *compiled = jl_method_compiled(mi, world);
448469
if (compiled) {
470+
if (codeinst) PTR_UNPIN(codeinst);
449471
codeinst = compiled;
472+
PTR_PIN(codeinst);
450473
}
451474
else if (src && jl_is_code_info(src)) {
452475
if (!codeinst) {
453476
codeinst = jl_get_method_inferred(mi, src->rettype, src->min_world, src->max_world);
477+
PTR_PIN(codeinst);
454478
if (src->inferred) {
455479
jl_value_t *null = nullptr;
456480
jl_atomic_cmpswap_relaxed(&codeinst->inferred, &null, jl_nothing);
457481
}
458482
}
459483
++SpecFPtrCount;
460484
_jl_compile_codeinst(codeinst, src, world, *jl_ExecutionEngine->getContext());
461-
if (jl_atomic_load_relaxed(&codeinst->invoke) == NULL)
485+
if (jl_atomic_load_relaxed(&codeinst->invoke) == NULL) {
486+
if (codeinst) PTR_UNPIN(codeinst);
462487
codeinst = NULL;
488+
}
463489
}
464490
else {
491+
if (codeinst) PTR_UNPIN(codeinst);
465492
codeinst = NULL;
466493
}
467494
JL_UNLOCK(&jl_codegen_lock);
@@ -473,6 +500,9 @@ jl_code_instance_t *jl_generate_fptr_impl(jl_method_instance_t *mi JL_PROPAGATES
473500
jl_atomic_fetch_add_relaxed(&jl_cumulative_compile_time, t_comp);
474501
}
475502
JL_GC_POP();
503+
PTR_UNPIN(mi);
504+
PTR_UNPIN(method);
505+
if(codeinst) PTR_UNPIN(codeinst);
476506
return codeinst;
477507
}
478508

@@ -482,6 +512,7 @@ void jl_generate_fptr_for_unspecialized_impl(jl_code_instance_t *unspec)
482512
if (jl_atomic_load_relaxed(&unspec->invoke) != NULL) {
483513
return;
484514
}
515+
PTR_PIN(unspec);
485516
auto ct = jl_current_task;
486517
ct->reentrant_timing++;
487518
uint64_t compiler_start_time = 0;
@@ -494,6 +525,7 @@ void jl_generate_fptr_for_unspecialized_impl(jl_code_instance_t *unspec)
494525
JL_GC_PUSH1(&src);
495526
jl_method_t *def = unspec->def->def.method;
496527
if (jl_is_method(def)) {
528+
PTR_PIN(def);
497529
src = (jl_code_info_t*)def->source;
498530
if (src == NULL) {
499531
// TODO: this is wrong
@@ -503,6 +535,7 @@ void jl_generate_fptr_for_unspecialized_impl(jl_code_instance_t *unspec)
503535
}
504536
if (src && (jl_value_t*)src != jl_nothing)
505537
src = jl_uncompress_ir(def, NULL, (jl_array_t*)src);
538+
PTR_UNPIN(def);
506539
}
507540
else {
508541
src = (jl_code_info_t*)unspec->def->uninferred;
@@ -515,6 +548,7 @@ void jl_generate_fptr_for_unspecialized_impl(jl_code_instance_t *unspec)
515548
jl_atomic_cmpswap(&unspec->invoke, &null, jl_fptr_interpret_call_addr);
516549
JL_GC_POP();
517550
}
551+
PTR_UNPIN(unspec);
518552
JL_UNLOCK(&jl_codegen_lock); // Might GC
519553
if (!--ct->reentrant_timing && measure_compile_time_enabled) {
520554
auto end = jl_hrtime();
@@ -528,12 +562,15 @@ extern "C" JL_DLLEXPORT
528562
jl_value_t *jl_dump_method_asm_impl(jl_method_instance_t *mi, size_t world,
529563
char raw_mc, char getwrapper, const char* asm_variant, const char *debuginfo, char binary)
530564
{
565+
PTR_PIN(mi);
531566
// printing via disassembly
532567
jl_code_instance_t *codeinst = jl_generate_fptr(mi, world);
533568
if (codeinst) {
534569
uintptr_t fptr = (uintptr_t)jl_atomic_load_acquire(&codeinst->invoke);
535-
if (getwrapper)
570+
if (getwrapper) {
571+
PTR_UNPIN(mi);
536572
return jl_dump_fptr_asm(fptr, raw_mc, asm_variant, debuginfo, binary);
573+
}
537574
uintptr_t specfptr = (uintptr_t)jl_atomic_load_relaxed(&codeinst->specptr.fptr);
538575
if (fptr == (uintptr_t)jl_fptr_const_return_addr && specfptr == 0) {
539576
// normally we prevent native code from being generated for these functions,
@@ -545,6 +582,7 @@ jl_value_t *jl_dump_method_asm_impl(jl_method_instance_t *mi, size_t world,
545582
uint8_t measure_compile_time_enabled = jl_atomic_load_relaxed(&jl_measure_compile_time_enabled);
546583
if (measure_compile_time_enabled)
547584
compiler_start_time = jl_hrtime();
585+
PTR_PIN(codeinst);
548586
JL_LOCK(&jl_codegen_lock); // also disables finalizers, to prevent any unexpected recursion
549587
specfptr = (uintptr_t)jl_atomic_load_relaxed(&codeinst->specptr.fptr);
550588
if (specfptr == 0) {
@@ -569,19 +607,23 @@ jl_value_t *jl_dump_method_asm_impl(jl_method_instance_t *mi, size_t world,
569607
}
570608
JL_GC_POP();
571609
}
610+
PTR_UNPIN(codeinst);
572611
JL_UNLOCK(&jl_codegen_lock);
573612
if (!--ct->reentrant_timing && measure_compile_time_enabled) {
574613
auto end = jl_hrtime();
575614
jl_atomic_fetch_add_relaxed(&jl_cumulative_compile_time, end - compiler_start_time);
576615
}
577616
}
578-
if (specfptr != 0)
617+
if (specfptr != 0) {
618+
PTR_UNPIN(mi);
579619
return jl_dump_fptr_asm(specfptr, raw_mc, asm_variant, debuginfo, binary);
620+
}
580621
}
581622

582623
// whatever, that didn't work - use the assembler output instead
583624
jl_llvmf_dump_t llvmf_dump;
584625
jl_get_llvmf_defn(&llvmf_dump, mi, world, getwrapper, true, jl_default_cgparams);
626+
PTR_UNPIN(mi);
585627
if (!llvmf_dump.F)
586628
return jl_an_empty_string;
587629
return jl_dump_function_asm(&llvmf_dump, raw_mc, asm_variant, debuginfo, binary);

0 commit comments

Comments
 (0)