LLVM 22.0.0git
SPIRVLegalizePointerCast.cpp
Go to the documentation of this file.
1//===-- SPIRVLegalizePointerCast.cpp ----------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// The LLVM IR has multiple legal patterns we cannot lower to Logical SPIR-V.
10// This pass modifies such loads to have an IR we can directly lower to valid
11// logical SPIR-V.
12// OpenCL can avoid this because they rely on ptrcast, which is not supported
13// by logical SPIR-V.
14//
15// This pass relies on the assign_ptr_type intrinsic to deduce the type of the
16// pointed values, must replace all occurences of `ptrcast`. This is why
17// unhandled cases are reported as unreachable: we MUST cover all cases.
18//
19// 1. Loading the first element of an array
20//
21// %array = [10 x i32]
22// %value = load i32, ptr %array
23//
24// LLVM can skip the GEP instruction, and only request loading the first 4
25// bytes. In logical SPIR-V, we need an OpAccessChain to access the first
26// element. This pass will add a getelementptr instruction before the load.
27//
28//
29// 2. Implicit downcast from load
30//
31// %1 = getelementptr <4 x i32>, ptr %vec4, i64 0
32// %2 = load <3 x i32>, ptr %1
33//
34// The pointer in the GEP instruction is only used for offset computations,
35// but it doesn't NEED to match the pointed type. OpAccessChain however
36// requires this. Also, LLVM loads define the bitwidth of the load, not the
37// pointer. In this example, we can guess %vec4 is a vec4 thanks to the GEP
38// instruction basetype, but we only want to load the first 3 elements, hence
39// do a partial load. In logical SPIR-V, this is not legal. What we must do
40// is load the full vector (basetype), extract 3 elements, and recombine them
41// to form a 3-element vector.
42//
43//===----------------------------------------------------------------------===//
44
45#include "SPIRV.h"
46#include "SPIRVSubtarget.h"
47#include "SPIRVTargetMachine.h"
48#include "SPIRVUtils.h"
49#include "llvm/IR/IRBuilder.h"
51#include "llvm/IR/Intrinsics.h"
52#include "llvm/IR/IntrinsicsSPIRV.h"
55
56using namespace llvm;
57
58namespace {
59class SPIRVLegalizePointerCast : public FunctionPass {
60
61 // Builds the `spv_assign_type` assigning |Ty| to |Value| at the current
62 // builder position.
63 void buildAssignType(IRBuilder<> &B, Type *Ty, Value *Arg) {
64 Value *OfType = PoisonValue::get(Ty);
65 CallInst *AssignCI = buildIntrWithMD(Intrinsic::spv_assign_type,
66 {Arg->getType()}, OfType, Arg, {}, B);
67 GR->addAssignPtrTypeInstr(Arg, AssignCI);
68 }
69
70 // Loads parts of the vector of type |SourceType| from the pointer |Source|
71 // and create a new vector of type |TargetType|. |TargetType| must be a vector
72 // type, and element types of |TargetType| and |SourceType| must match.
73 // Returns the loaded value.
74 Value *loadVectorFromVector(IRBuilder<> &B, FixedVectorType *SourceType,
75 FixedVectorType *TargetType, Value *Source) {
76 LoadInst *NewLoad = B.CreateLoad(SourceType, Source);
77 buildAssignType(B, SourceType, NewLoad);
78 Value *AssignValue = NewLoad;
79 if (TargetType->getElementType() != SourceType->getElementType()) {
80 const DataLayout &DL = B.GetInsertBlock()->getModule()->getDataLayout();
81 [[maybe_unused]] TypeSize TargetTypeSize =
82 DL.getTypeSizeInBits(TargetType);
83 [[maybe_unused]] TypeSize SourceTypeSize =
84 DL.getTypeSizeInBits(SourceType);
85 assert(TargetTypeSize == SourceTypeSize);
86 AssignValue = B.CreateIntrinsic(Intrinsic::spv_bitcast,
87 {TargetType, SourceType}, {NewLoad});
88 buildAssignType(B, TargetType, AssignValue);
89 return AssignValue;
90 }
91
92 assert(TargetType->getNumElements() < SourceType->getNumElements());
93 SmallVector<int> Mask(/* Size= */ TargetType->getNumElements());
94 for (unsigned I = 0; I < TargetType->getNumElements(); ++I)
95 Mask[I] = I;
96 Value *Output = B.CreateShuffleVector(AssignValue, AssignValue, Mask);
97 buildAssignType(B, TargetType, Output);
98 return Output;
99 }
100
101 // Loads the first value in an aggregate pointed by |Source| of containing
102 // elements of type |ElementType|. Load flags will be copied from |BadLoad|,
103 // which should be the load being legalized. Returns the loaded value.
104 Value *loadFirstValueFromAggregate(IRBuilder<> &B, Type *ElementType,
105 Value *Source, LoadInst *BadLoad) {
107 BadLoad->getPointerOperandType()};
108 SmallVector<Value *, 3> Args{/* isInBounds= */ B.getInt1(false), Source,
109 B.getInt32(0), B.getInt32(0)};
110 auto *GEP = B.CreateIntrinsic(Intrinsic::spv_gep, {Types}, {Args});
111 GR->buildAssignPtr(B, ElementType, GEP);
112
113 LoadInst *LI = B.CreateLoad(ElementType, GEP);
114 LI->setAlignment(BadLoad->getAlign());
115 buildAssignType(B, ElementType, LI);
116 return LI;
117 }
118
119 // Replaces the load instruction to get rid of the ptrcast used as source
120 // operand.
121 void transformLoad(IRBuilder<> &B, LoadInst *LI, Value *CastedOperand,
122 Value *OriginalOperand) {
123 Type *FromTy = GR->findDeducedElementType(OriginalOperand);
124 Type *ToTy = GR->findDeducedElementType(CastedOperand);
125 Value *Output = nullptr;
126
127 auto *SAT = dyn_cast<ArrayType>(FromTy);
128 auto *SVT = dyn_cast<FixedVectorType>(FromTy);
129 auto *SST = dyn_cast<StructType>(FromTy);
130 auto *DVT = dyn_cast<FixedVectorType>(ToTy);
131
132 B.SetInsertPoint(LI);
133
134 // Destination is the element type of Source, and source is an array ->
135 // Loading 1st element.
136 // - float a = array[0];
137 if (SAT && SAT->getElementType() == ToTy)
138 Output = loadFirstValueFromAggregate(B, SAT->getElementType(),
139 OriginalOperand, LI);
140 // Destination is the element type of Source, and source is a vector ->
141 // Vector to scalar.
142 // - float a = vector.x;
143 else if (!DVT && SVT && SVT->getElementType() == ToTy) {
144 Output = loadFirstValueFromAggregate(B, SVT->getElementType(),
145 OriginalOperand, LI);
146 }
147 // Destination is a smaller vector than source or different vector type.
148 // - float3 v3 = vector4;
149 // - float4 v2 = int4;
150 else if (SVT && DVT)
151 Output = loadVectorFromVector(B, SVT, DVT, OriginalOperand);
152 // Destination is the scalar type stored at the start of an aggregate.
153 // - struct S { float m };
154 // - float v = s.m;
155 else if (SST && SST->getTypeAtIndex(0u) == ToTy)
156 Output = loadFirstValueFromAggregate(B, ToTy, OriginalOperand, LI);
157 else
158 llvm_unreachable("Unimplemented implicit down-cast from load.");
159
160 GR->replaceAllUsesWith(LI, Output, /* DeleteOld= */ true);
161 DeadInstructions.push_back(LI);
162 }
163
164 // Creates an spv_insertelt instruction (equivalent to llvm's insertelement).
165 Value *makeInsertElement(IRBuilder<> &B, Value *Vector, Value *Element,
166 unsigned Index) {
167 Type *Int32Ty = Type::getInt32Ty(B.getContext());
168 SmallVector<Type *, 4> Types = {Vector->getType(), Vector->getType(),
169 Element->getType(), Int32Ty};
170 SmallVector<Value *> Args = {Vector, Element, B.getInt32(Index)};
171 Instruction *NewI =
172 B.CreateIntrinsic(Intrinsic::spv_insertelt, {Types}, {Args});
173 buildAssignType(B, Vector->getType(), NewI);
174 return NewI;
175 }
176
177 // Creates an spv_extractelt instruction (equivalent to llvm's
178 // extractelement).
179 Value *makeExtractElement(IRBuilder<> &B, Type *ElementType, Value *Vector,
180 unsigned Index) {
181 Type *Int32Ty = Type::getInt32Ty(B.getContext());
183 SmallVector<Value *> Args = {Vector, B.getInt32(Index)};
184 Instruction *NewI =
185 B.CreateIntrinsic(Intrinsic::spv_extractelt, {Types}, {Args});
186 buildAssignType(B, ElementType, NewI);
187 return NewI;
188 }
189
190 // Stores the given Src vector operand into the Dst vector, adjusting the size
191 // if required.
192 Value *storeVectorFromVector(IRBuilder<> &B, Value *Src, Value *Dst,
193 Align Alignment) {
194 FixedVectorType *SrcType = cast<FixedVectorType>(Src->getType());
195 FixedVectorType *DstType =
196 cast<FixedVectorType>(GR->findDeducedElementType(Dst));
197 auto dstNumElements = DstType->getNumElements();
198 auto srcNumElements = SrcType->getNumElements();
199
200 // if the element type differs, it is a bitcast.
201 if (DstType->getElementType() != SrcType->getElementType()) {
202 // Support bitcast between vectors of different sizes only if
203 // the total bitwidth is the same.
204 [[maybe_unused]] auto dstBitWidth =
205 DstType->getElementType()->getScalarSizeInBits() * dstNumElements;
206 [[maybe_unused]] auto srcBitWidth =
207 SrcType->getElementType()->getScalarSizeInBits() * srcNumElements;
208 assert(dstBitWidth == srcBitWidth &&
209 "Unsupported bitcast between vectors of different sizes.");
210
211 Src =
212 B.CreateIntrinsic(Intrinsic::spv_bitcast, {DstType, SrcType}, {Src});
213 buildAssignType(B, DstType, Src);
214 SrcType = DstType;
215
216 StoreInst *SI = B.CreateStore(Src, Dst);
217 SI->setAlignment(Alignment);
218 return SI;
219 }
220
221 assert(DstType->getNumElements() >= SrcType->getNumElements());
222 LoadInst *LI = B.CreateLoad(DstType, Dst);
223 LI->setAlignment(Alignment);
224 Value *OldValues = LI;
225 buildAssignType(B, OldValues->getType(), OldValues);
226 Value *NewValues = Src;
227
228 for (unsigned I = 0; I < SrcType->getNumElements(); ++I) {
229 Value *Element =
230 makeExtractElement(B, SrcType->getElementType(), NewValues, I);
231 OldValues = makeInsertElement(B, OldValues, Element, I);
232 }
233
234 StoreInst *SI = B.CreateStore(OldValues, Dst);
235 SI->setAlignment(Alignment);
236 return SI;
237 }
238
239 void buildGEPIndexChain(IRBuilder<> &B, Type *Search, Type *Aggregate,
240 SmallVectorImpl<Value *> &Indices) {
241 Indices.push_back(B.getInt32(0));
242
243 if (Search == Aggregate)
244 return;
245
246 if (auto *ST = dyn_cast<StructType>(Aggregate))
247 buildGEPIndexChain(B, Search, ST->getTypeAtIndex(0u), Indices);
248 else if (auto *AT = dyn_cast<ArrayType>(Aggregate))
249 buildGEPIndexChain(B, Search, AT->getElementType(), Indices);
250 else if (auto *VT = dyn_cast<FixedVectorType>(Aggregate))
251 buildGEPIndexChain(B, Search, VT->getElementType(), Indices);
252 else
253 llvm_unreachable("Bad access chain?");
254 }
255
256 // Stores the given Src value into the first entry of the Dst aggregate.
257 Value *storeToFirstValueAggregate(IRBuilder<> &B, Value *Src, Value *Dst,
258 Type *DstPointeeType, Align Alignment) {
259 SmallVector<Type *, 2> Types = {Dst->getType(), Dst->getType()};
260 SmallVector<Value *, 3> Args{/* isInBounds= */ B.getInt1(true), Dst};
261 buildGEPIndexChain(B, Src->getType(), DstPointeeType, Args);
262 auto *GEP = B.CreateIntrinsic(Intrinsic::spv_gep, {Types}, {Args});
263 GR->buildAssignPtr(B, Src->getType(), GEP);
264 StoreInst *SI = B.CreateStore(Src, GEP);
265 SI->setAlignment(Alignment);
266 return SI;
267 }
268
269 bool isTypeFirstElementAggregate(Type *Search, Type *Aggregate) {
270 if (Search == Aggregate)
271 return true;
272 if (auto *ST = dyn_cast<StructType>(Aggregate))
273 return isTypeFirstElementAggregate(Search, ST->getTypeAtIndex(0u));
274 if (auto *VT = dyn_cast<FixedVectorType>(Aggregate))
275 return isTypeFirstElementAggregate(Search, VT->getElementType());
276 if (auto *AT = dyn_cast<ArrayType>(Aggregate))
277 return isTypeFirstElementAggregate(Search, AT->getElementType());
278 return false;
279 }
280
281 // Transforms a store instruction (or SPV intrinsic) using a ptrcast as
282 // operand into a valid logical SPIR-V store with no ptrcast.
283 void transformStore(IRBuilder<> &B, Instruction *BadStore, Value *Src,
284 Value *Dst, Align Alignment) {
285 Type *ToTy = GR->findDeducedElementType(Dst);
286 Type *FromTy = Src->getType();
287
288 auto *S_VT = dyn_cast<FixedVectorType>(FromTy);
289 auto *D_ST = dyn_cast<StructType>(ToTy);
290 auto *D_VT = dyn_cast<FixedVectorType>(ToTy);
291
292 B.SetInsertPoint(BadStore);
293 if (D_ST && isTypeFirstElementAggregate(FromTy, D_ST))
294 storeToFirstValueAggregate(B, Src, Dst, D_ST, Alignment);
295 else if (D_VT && S_VT)
296 storeVectorFromVector(B, Src, Dst, Alignment);
297 else if (D_VT && !S_VT && FromTy == D_VT->getElementType())
298 storeToFirstValueAggregate(B, Src, Dst, D_VT, Alignment);
299 else
300 llvm_unreachable("Unsupported ptrcast use in store. Please fix.");
301
302 DeadInstructions.push_back(BadStore);
303 }
304
305 void legalizePointerCast(IntrinsicInst *II) {
306 Value *CastedOperand = II;
307 Value *OriginalOperand = II->getOperand(0);
308
309 IRBuilder<> B(II->getContext());
310 std::vector<Value *> Users;
311 for (Use &U : II->uses())
312 Users.push_back(U.getUser());
313
314 for (Value *User : Users) {
315 if (LoadInst *LI = dyn_cast<LoadInst>(User)) {
316 transformLoad(B, LI, CastedOperand, OriginalOperand);
317 continue;
318 }
319
320 if (StoreInst *SI = dyn_cast<StoreInst>(User)) {
321 transformStore(B, SI, SI->getValueOperand(), OriginalOperand,
322 SI->getAlign());
323 continue;
324 }
325
326 if (IntrinsicInst *Intrin = dyn_cast<IntrinsicInst>(User)) {
327 if (Intrin->getIntrinsicID() == Intrinsic::spv_assign_ptr_type) {
328 DeadInstructions.push_back(Intrin);
329 continue;
330 }
331
332 if (Intrin->getIntrinsicID() == Intrinsic::spv_gep) {
333 GR->replaceAllUsesWith(CastedOperand, OriginalOperand,
334 /* DeleteOld= */ false);
335 continue;
336 }
337
338 if (Intrin->getIntrinsicID() == Intrinsic::spv_store) {
339 Align Alignment;
340 if (ConstantInt *C = dyn_cast<ConstantInt>(Intrin->getOperand(3)))
341 Alignment = Align(C->getZExtValue());
342 transformStore(B, Intrin, Intrin->getArgOperand(0), OriginalOperand,
343 Alignment);
344 continue;
345 }
346 }
347
348 llvm_unreachable("Unsupported ptrcast user. Please fix.");
349 }
350
351 DeadInstructions.push_back(II);
352 }
353
354public:
355 SPIRVLegalizePointerCast(SPIRVTargetMachine *TM) : FunctionPass(ID), TM(TM) {}
356
357 bool runOnFunction(Function &F) override {
358 const SPIRVSubtarget &ST = TM->getSubtarget<SPIRVSubtarget>(F);
359 GR = ST.getSPIRVGlobalRegistry();
360 DeadInstructions.clear();
361
362 std::vector<IntrinsicInst *> WorkList;
363 for (auto &BB : F) {
364 for (auto &I : BB) {
365 auto *II = dyn_cast<IntrinsicInst>(&I);
366 if (II && II->getIntrinsicID() == Intrinsic::spv_ptrcast)
367 WorkList.push_back(II);
368 }
369 }
370
371 for (IntrinsicInst *II : WorkList)
372 legalizePointerCast(II);
373
374 for (Instruction *I : DeadInstructions)
375 I->eraseFromParent();
376
377 return DeadInstructions.size() != 0;
378 }
379
380private:
381 SPIRVTargetMachine *TM = nullptr;
382 SPIRVGlobalRegistry *GR = nullptr;
383 std::vector<Instruction *> DeadInstructions;
384
385public:
386 static char ID;
387};
388} // namespace
389
390char SPIRVLegalizePointerCast::ID = 0;
391INITIALIZE_PASS(SPIRVLegalizePointerCast, "spirv-legalize-bitcast",
392 "SPIRV legalize bitcast pass", false, false)
393
395 return new SPIRVLegalizePointerCast(TM);
396}
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
static bool runOnFunction(Function &F, bool PostInlining)
Hexagon Common GEP
iv Induction Variable Users
Definition IVUsers.cpp:48
#define F(x, y, z)
Definition MD5.cpp:55
#define I(x, y, z)
Definition MD5.cpp:58
uint64_t IntrinsicInst * II
#define INITIALIZE_PASS(passName, arg, name, cfg, analysis)
Definition PassSupport.h:56
unsigned getNumElements() const
FunctionPass class - This class is used to implement most global optimizations.
Definition Pass.h:314
void setAlignment(Align Align)
Type * getPointerOperandType() const
Align getAlign() const
Return the alignment of the access that is being performed.
static LLVM_ABI PoisonValue * get(Type *T)
Static factory methods - Return an 'poison' object of the specified type.
void push_back(const T &Elt)
LLVM_ABI unsigned getScalarSizeInBits() const LLVM_READONLY
If this is a vector type, return the getPrimitiveSizeInBits value for the element type.
Definition Type.cpp:231
Type * getType() const
All values are typed, get the type of this value.
Definition Value.h:256
Type * getElementType() const
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
constexpr char Align[]
Key for Kernel::Arg::Metadata::mAlign.
constexpr char Args[]
Key for Kernel::Metadata::mArgs.
constexpr std::underlying_type_t< E > Mask()
Get a bitmask with 1s in all places up to the high-order bit of E's largest value.
@ C
The default llvm calling convention, compatible with C.
Definition CallingConv.h:34
ElementType
The element type of an SRV or UAV resource.
Definition DXILABI.h:60
friend class Instruction
Iterator for Instructions in a `BasicBlock.
Definition BasicBlock.h:73
This is an optimization pass for GlobalISel generic memory operations.
FunctionAddr VTableAddr Value
Definition InstrProf.h:137
decltype(auto) dyn_cast(const From &Val)
dyn_cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:643
FunctionAddr VTableAddr uintptr_t uintptr_t Int32Ty
Definition InstrProf.h:296
CallInst * buildIntrWithMD(Intrinsic::ID IntrID, ArrayRef< Type * > Types, Value *Arg, Value *Arg2, ArrayRef< Constant * > Imms, IRBuilder<> &B)
class LLVM_GSL_OWNER SmallVector
Forward declaration of SmallVector so that calculateSmallVectorDefaultInlinedElements can reference s...
IRBuilder(LLVMContext &, FolderTy, InserterTy, MDNode *, ArrayRef< OperandBundleDef >) -> IRBuilder< FolderTy, InserterTy >
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:559
FunctionPass * createSPIRVLegalizePointerCastPass(SPIRVTargetMachine *TM)