|
| 1 | +import logging |
| 2 | + |
| 3 | +import torch |
| 4 | +import triton |
| 5 | +import triton.language as tl |
| 6 | + |
| 7 | +from ..utils import libentry |
| 8 | + |
| 9 | + |
| 10 | +@libentry() |
| 11 | +@triton.autotune( |
| 12 | + configs=[ |
| 13 | + triton.Config({"BLOCK_SIZE": k}, num_warps=w) |
| 14 | + for w in [4, 8, 16, 32] |
| 15 | + for k in [512, 1024, 2048, 4096] |
| 16 | + ], |
| 17 | + key=[ |
| 18 | + "max_tile_elems", |
| 19 | + ], |
| 20 | +) |
| 21 | +@triton.jit |
| 22 | +def vstack_kernel( |
| 23 | + itensor_ptr0, |
| 24 | + itensor_ptr1, |
| 25 | + itensor_ptr2, |
| 26 | + itensor_ptr3, |
| 27 | + output_ptr, |
| 28 | + local_row0, |
| 29 | + local_row1, |
| 30 | + local_row2, |
| 31 | + local_row3, |
| 32 | + exc_row_offset0, |
| 33 | + exc_row_offset1, |
| 34 | + exc_row_offset2, |
| 35 | + exc_row_offset3, |
| 36 | + total_row_offset, |
| 37 | + row_stride, |
| 38 | + max_tile_elems, |
| 39 | + BLOCK_SIZE: tl.constexpr, |
| 40 | +): |
| 41 | + pid_x = tl.program_id(axis=0) |
| 42 | + tensor_idx = tl.program_id(axis=1) |
| 43 | + col_idx = tl.arange(0, BLOCK_SIZE) |
| 44 | + |
| 45 | + intensor_ptr = tl.where(tensor_idx == 0, itensor_ptr0, itensor_ptr1) |
| 46 | + intensor_ptr = tl.where(tensor_idx == 2, itensor_ptr2, intensor_ptr) |
| 47 | + intensor_ptr = tl.where(tensor_idx == 3, itensor_ptr3, intensor_ptr) |
| 48 | + base_exc_row_idx = tl.where(tensor_idx == 0, exc_row_offset0, exc_row_offset1) |
| 49 | + base_exc_row_idx = tl.where(tensor_idx == 2, exc_row_offset2, base_exc_row_idx) |
| 50 | + base_exc_row_idx = tl.where(tensor_idx == 3, exc_row_offset3, base_exc_row_idx) |
| 51 | + local_row = tl.where(tensor_idx == 0, local_row0, local_row1) |
| 52 | + local_row = tl.where(tensor_idx == 2, local_row2, local_row) |
| 53 | + local_row = tl.where(tensor_idx == 3, local_row3, local_row) |
| 54 | + |
| 55 | + end_idx = local_row * row_stride.to(tl.int64) |
| 56 | + idx = (pid_x * BLOCK_SIZE + col_idx).to(tl.int64) |
| 57 | + offset_mask = idx < end_idx |
| 58 | + in_offset = intensor_ptr + idx |
| 59 | + row_stride_offset = (total_row_offset + base_exc_row_idx) * row_stride.to(tl.int64) |
| 60 | + out_offset = output_ptr + row_stride_offset + idx |
| 61 | + out = tl.load(in_offset, mask=offset_mask) |
| 62 | + tl.store(out_offset, out, mask=offset_mask) |
| 63 | + |
| 64 | + |
| 65 | +def vstack(tensors: list[torch.Tensor]): |
| 66 | + logging.debug("GEMS VSTACK") |
| 67 | + |
| 68 | + tensors = torch.atleast_2d(tensors) |
| 69 | + num_tensors = len(tensors) |
| 70 | + assert num_tensors > 0 |
| 71 | + |
| 72 | + # Ensure all tensors are on the same device and have the same dtype |
| 73 | + device = tensors[0].device |
| 74 | + dtype = tensors[0].dtype |
| 75 | + for tensor in tensors: |
| 76 | + assert ( |
| 77 | + tensor.device == device |
| 78 | + and tensor.dtype == dtype |
| 79 | + and tensors[0].shape[1:] == tensor.shape[1:] |
| 80 | + ) |
| 81 | + |
| 82 | + c_tensors = [t.contiguous() for t in tensors] |
| 83 | + # Calculate the output shape |
| 84 | + total_rows = sum(tensor.shape[0] for tensor in c_tensors) |
| 85 | + output_shape = list(c_tensors[0].shape) |
| 86 | + output_shape[0] = total_rows |
| 87 | + output = torch.empty(output_shape, device=device, dtype=dtype) |
| 88 | + row_stride = c_tensors[0].stride(0) |
| 89 | + |
| 90 | + outer_iters = triton.cdiv(num_tensors, 4) |
| 91 | + total_row_offset = 0 |
| 92 | + for i in range(outer_iters): |
| 93 | + max_rows = 1 |
| 94 | + itensors = [] |
| 95 | + exclusive_row = [] |
| 96 | + local_row = [] |
| 97 | + array_row_offset = 0 |
| 98 | + scheduled_num_tensors = 0 |
| 99 | + for j in range(4): |
| 100 | + tensor_idx = i * 4 + j |
| 101 | + if tensor_idx < num_tensors: |
| 102 | + scheduled_num_tensors += 1 |
| 103 | + itensors.append(c_tensors[tensor_idx]) |
| 104 | + local_row.append(c_tensors[tensor_idx].shape[0]) |
| 105 | + exclusive_row.append(array_row_offset) |
| 106 | + array_row_offset += c_tensors[tensor_idx].shape[0] |
| 107 | + max_rows = max(max_rows, c_tensors[tensor_idx].shape[0]) |
| 108 | + else: |
| 109 | + empty_tensor = torch.empty( |
| 110 | + 0, dtype=c_tensors[0].dtype, device=c_tensors[0].device |
| 111 | + ) |
| 112 | + itensors.append(empty_tensor) |
| 113 | + local_row.append(local_row[-1]) |
| 114 | + exclusive_row.append(exclusive_row[-1]) |
| 115 | + max_tile_elems = max_rows * row_stride |
| 116 | + grid = lambda META: ( |
| 117 | + triton.cdiv(max_tile_elems, META["BLOCK_SIZE"]), |
| 118 | + scheduled_num_tensors, |
| 119 | + ) |
| 120 | + # Launch the kernel |
| 121 | + with torch.cuda.device(c_tensors[0].device): |
| 122 | + vstack_kernel[grid]( |
| 123 | + itensors[0], |
| 124 | + itensors[1], |
| 125 | + itensors[2], |
| 126 | + itensors[3], |
| 127 | + output, |
| 128 | + local_row[0], |
| 129 | + local_row[1], |
| 130 | + local_row[2], |
| 131 | + local_row[3], |
| 132 | + exclusive_row[0], |
| 133 | + exclusive_row[1], |
| 134 | + exclusive_row[2], |
| 135 | + exclusive_row[3], |
| 136 | + total_row_offset, |
| 137 | + row_stride, |
| 138 | + max_tile_elems, |
| 139 | + ) |
| 140 | + total_row_offset += array_row_offset |
| 141 | + return output |
0 commit comments