LLVM 22.0.0git
SPIRVLegalizerInfo.cpp
Go to the documentation of this file.
1//===- SPIRVLegalizerInfo.cpp --- SPIR-V Legalization Rules ------*- C++ -*-==//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the targeting of the Machinelegalizer class for SPIR-V.
10//
11//===----------------------------------------------------------------------===//
12
13#include "SPIRVLegalizerInfo.h"
14#include "SPIRV.h"
15#include "SPIRVGlobalRegistry.h"
16#include "SPIRVSubtarget.h"
22
23using namespace llvm;
24using namespace llvm::LegalizeActions;
25using namespace llvm::LegalityPredicates;
26
27LegalityPredicate typeOfExtendedScalars(unsigned TypeIdx, bool IsExtendedInts) {
28 return [IsExtendedInts, TypeIdx](const LegalityQuery &Query) {
29 const LLT Ty = Query.Types[TypeIdx];
30 return IsExtendedInts && Ty.isValid() && Ty.isScalar();
31 };
32}
33
35 using namespace TargetOpcode;
36
37 this->ST = &ST;
38 GR = ST.getSPIRVGlobalRegistry();
39
40 const LLT s1 = LLT::scalar(1);
41 const LLT s8 = LLT::scalar(8);
42 const LLT s16 = LLT::scalar(16);
43 const LLT s32 = LLT::scalar(32);
44 const LLT s64 = LLT::scalar(64);
45
46 const LLT v16s64 = LLT::fixed_vector(16, 64);
47 const LLT v16s32 = LLT::fixed_vector(16, 32);
48 const LLT v16s16 = LLT::fixed_vector(16, 16);
49 const LLT v16s8 = LLT::fixed_vector(16, 8);
50 const LLT v16s1 = LLT::fixed_vector(16, 1);
51
52 const LLT v8s64 = LLT::fixed_vector(8, 64);
53 const LLT v8s32 = LLT::fixed_vector(8, 32);
54 const LLT v8s16 = LLT::fixed_vector(8, 16);
55 const LLT v8s8 = LLT::fixed_vector(8, 8);
56 const LLT v8s1 = LLT::fixed_vector(8, 1);
57
58 const LLT v4s64 = LLT::fixed_vector(4, 64);
59 const LLT v4s32 = LLT::fixed_vector(4, 32);
60 const LLT v4s16 = LLT::fixed_vector(4, 16);
61 const LLT v4s8 = LLT::fixed_vector(4, 8);
62 const LLT v4s1 = LLT::fixed_vector(4, 1);
63
64 const LLT v3s64 = LLT::fixed_vector(3, 64);
65 const LLT v3s32 = LLT::fixed_vector(3, 32);
66 const LLT v3s16 = LLT::fixed_vector(3, 16);
67 const LLT v3s8 = LLT::fixed_vector(3, 8);
68 const LLT v3s1 = LLT::fixed_vector(3, 1);
69
70 const LLT v2s64 = LLT::fixed_vector(2, 64);
71 const LLT v2s32 = LLT::fixed_vector(2, 32);
72 const LLT v2s16 = LLT::fixed_vector(2, 16);
73 const LLT v2s8 = LLT::fixed_vector(2, 8);
74 const LLT v2s1 = LLT::fixed_vector(2, 1);
75
76 const unsigned PSize = ST.getPointerSize();
77 const LLT p0 = LLT::pointer(0, PSize); // Function
78 const LLT p1 = LLT::pointer(1, PSize); // CrossWorkgroup
79 const LLT p2 = LLT::pointer(2, PSize); // UniformConstant
80 const LLT p3 = LLT::pointer(3, PSize); // Workgroup
81 const LLT p4 = LLT::pointer(4, PSize); // Generic
82 const LLT p5 =
83 LLT::pointer(5, PSize); // Input, SPV_INTEL_usm_storage_classes (Device)
84 const LLT p6 = LLT::pointer(6, PSize); // SPV_INTEL_usm_storage_classes (Host)
85 const LLT p7 = LLT::pointer(7, PSize); // Input
86 const LLT p8 = LLT::pointer(8, PSize); // Output
87 const LLT p10 = LLT::pointer(10, PSize); // Private
88 const LLT p11 = LLT::pointer(11, PSize); // StorageBuffer
89 const LLT p12 = LLT::pointer(12, PSize); // Uniform
90
91 // TODO: remove copy-pasting here by using concatenation in some way.
92 auto allPtrsScalarsAndVectors = {
93 p0, p1, p2, p3, p4, p5, p6, p7, p8,
94 p10, p11, p12, s1, s8, s16, s32, s64, v2s1,
95 v2s8, v2s16, v2s32, v2s64, v3s1, v3s8, v3s16, v3s32, v3s64,
96 v4s1, v4s8, v4s16, v4s32, v4s64, v8s1, v8s8, v8s16, v8s32,
97 v8s64, v16s1, v16s8, v16s16, v16s32, v16s64};
98
99 auto allVectors = {v2s1, v2s8, v2s16, v2s32, v2s64, v3s1, v3s8,
100 v3s16, v3s32, v3s64, v4s1, v4s8, v4s16, v4s32,
101 v4s64, v8s1, v8s8, v8s16, v8s32, v8s64, v16s1,
102 v16s8, v16s16, v16s32, v16s64};
103
104 auto allScalarsAndVectors = {
105 s1, s8, s16, s32, s64, v2s1, v2s8, v2s16, v2s32, v2s64,
106 v3s1, v3s8, v3s16, v3s32, v3s64, v4s1, v4s8, v4s16, v4s32, v4s64,
107 v8s1, v8s8, v8s16, v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64};
108
109 auto allIntScalarsAndVectors = {s8, s16, s32, s64, v2s8, v2s16,
110 v2s32, v2s64, v3s8, v3s16, v3s32, v3s64,
111 v4s8, v4s16, v4s32, v4s64, v8s8, v8s16,
112 v8s32, v8s64, v16s8, v16s16, v16s32, v16s64};
113
114 auto allBoolScalarsAndVectors = {s1, v2s1, v3s1, v4s1, v8s1, v16s1};
115
116 auto allIntScalars = {s8, s16, s32, s64};
117
118 auto allFloatScalars = {s16, s32, s64};
119
120 auto allFloatScalarsAndVectors = {
121 s16, s32, s64, v2s16, v2s32, v2s64, v3s16, v3s32, v3s64,
122 v4s16, v4s32, v4s64, v8s16, v8s32, v8s64, v16s16, v16s32, v16s64};
123
124 auto allFloatAndIntScalarsAndPtrs = {s8, s16, s32, s64, p0, p1, p2, p3,
125 p4, p5, p6, p7, p8, p10, p11, p12};
126
127 auto allPtrs = {p0, p1, p2, p3, p4, p5, p6, p7, p8, p10, p11, p12};
128
129 bool IsExtendedInts =
130 ST.canUseExtension(
131 SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers) ||
132 ST.canUseExtension(SPIRV::Extension::SPV_KHR_bit_instructions) ||
133 ST.canUseExtension(SPIRV::Extension::SPV_INTEL_int4);
134 auto extendedScalarsAndVectors =
135 [IsExtendedInts](const LegalityQuery &Query) {
136 const LLT Ty = Query.Types[0];
137 return IsExtendedInts && Ty.isValid() && !Ty.isPointerOrPointerVector();
138 };
139 auto extendedScalarsAndVectorsProduct = [IsExtendedInts](
140 const LegalityQuery &Query) {
141 const LLT Ty1 = Query.Types[0], Ty2 = Query.Types[1];
142 return IsExtendedInts && Ty1.isValid() && Ty2.isValid() &&
143 !Ty1.isPointerOrPointerVector() && !Ty2.isPointerOrPointerVector();
144 };
145 auto extendedPtrsScalarsAndVectors =
146 [IsExtendedInts](const LegalityQuery &Query) {
147 const LLT Ty = Query.Types[0];
148 return IsExtendedInts && Ty.isValid();
149 };
150
153
155
156 // TODO: add proper rules for vectors legalization.
158 {G_BUILD_VECTOR, G_SHUFFLE_VECTOR, G_SPLAT_VECTOR})
159 .alwaysLegal();
160
161 // Vector Reduction Operations
163 {G_VECREDUCE_SMIN, G_VECREDUCE_SMAX, G_VECREDUCE_UMIN, G_VECREDUCE_UMAX,
164 G_VECREDUCE_ADD, G_VECREDUCE_MUL, G_VECREDUCE_FMUL, G_VECREDUCE_FMIN,
165 G_VECREDUCE_FMAX, G_VECREDUCE_FMINIMUM, G_VECREDUCE_FMAXIMUM,
166 G_VECREDUCE_OR, G_VECREDUCE_AND, G_VECREDUCE_XOR})
167 .legalFor(allVectors)
168 .scalarize(1)
169 .lower();
170
171 getActionDefinitionsBuilder({G_VECREDUCE_SEQ_FADD, G_VECREDUCE_SEQ_FMUL})
172 .scalarize(2)
173 .lower();
174
175 // Merge/Unmerge
176 // TODO: add proper legalization rules.
177 getActionDefinitionsBuilder(G_UNMERGE_VALUES).alwaysLegal();
178
179 getActionDefinitionsBuilder({G_MEMCPY, G_MEMMOVE})
180 .legalIf(all(typeInSet(0, allPtrs), typeInSet(1, allPtrs)));
181
183 all(typeInSet(0, allPtrs), typeInSet(1, allIntScalars)));
184
185 getActionDefinitionsBuilder(G_ADDRSPACE_CAST)
186 .legalForCartesianProduct(allPtrs, allPtrs);
187
188 getActionDefinitionsBuilder({G_LOAD, G_STORE}).legalIf(typeInSet(1, allPtrs));
189
190 getActionDefinitionsBuilder({G_SMIN, G_SMAX, G_UMIN, G_UMAX, G_ABS,
191 G_BITREVERSE, G_SADDSAT, G_UADDSAT, G_SSUBSAT,
192 G_USUBSAT, G_SCMP, G_UCMP})
193 .legalFor(allIntScalarsAndVectors)
194 .legalIf(extendedScalarsAndVectors);
195
196 getActionDefinitionsBuilder({G_FMA, G_STRICT_FMA})
197 .legalFor(allFloatScalarsAndVectors);
198
199 getActionDefinitionsBuilder(G_STRICT_FLDEXP)
200 .legalForCartesianProduct(allFloatScalarsAndVectors, allIntScalars);
201
202 getActionDefinitionsBuilder({G_FPTOSI, G_FPTOUI})
203 .legalForCartesianProduct(allIntScalarsAndVectors,
204 allFloatScalarsAndVectors);
205
206 getActionDefinitionsBuilder({G_FPTOSI_SAT, G_FPTOUI_SAT})
207 .legalForCartesianProduct(allIntScalarsAndVectors,
208 allFloatScalarsAndVectors);
209
210 getActionDefinitionsBuilder({G_SITOFP, G_UITOFP})
211 .legalForCartesianProduct(allFloatScalarsAndVectors,
212 allScalarsAndVectors);
213
215 .legalForCartesianProduct(allIntScalarsAndVectors)
216 .legalIf(extendedScalarsAndVectorsProduct);
217
218 // Extensions.
219 getActionDefinitionsBuilder({G_TRUNC, G_ZEXT, G_SEXT, G_ANYEXT})
220 .legalForCartesianProduct(allScalarsAndVectors)
221 .legalIf(extendedScalarsAndVectorsProduct);
222
224 .legalFor(allPtrsScalarsAndVectors)
225 .legalIf(extendedPtrsScalarsAndVectors);
226
228 all(typeInSet(0, allPtrsScalarsAndVectors),
229 typeInSet(1, allPtrsScalarsAndVectors)));
230
231 getActionDefinitionsBuilder({G_IMPLICIT_DEF, G_FREEZE}).alwaysLegal();
232
233 getActionDefinitionsBuilder({G_STACKSAVE, G_STACKRESTORE}).alwaysLegal();
234
236 .legalForCartesianProduct(allPtrs, allIntScalars)
237 .legalIf(
238 all(typeInSet(0, allPtrs), typeOfExtendedScalars(1, IsExtendedInts)));
240 .legalForCartesianProduct(allIntScalars, allPtrs)
241 .legalIf(
242 all(typeOfExtendedScalars(0, IsExtendedInts), typeInSet(1, allPtrs)));
244 .legalForCartesianProduct(allPtrs, allIntScalars)
245 .legalIf(
246 all(typeInSet(0, allPtrs), typeOfExtendedScalars(1, IsExtendedInts)));
247
248 // ST.canDirectlyComparePointers() for pointer args is supported in
249 // legalizeCustom().
251 all(typeInSet(0, allBoolScalarsAndVectors),
252 typeInSet(1, allPtrsScalarsAndVectors)));
253
255 all(typeInSet(0, allBoolScalarsAndVectors),
256 typeInSet(1, allFloatScalarsAndVectors)));
257
258 getActionDefinitionsBuilder({G_ATOMICRMW_OR, G_ATOMICRMW_ADD, G_ATOMICRMW_AND,
259 G_ATOMICRMW_MAX, G_ATOMICRMW_MIN,
260 G_ATOMICRMW_SUB, G_ATOMICRMW_XOR,
261 G_ATOMICRMW_UMAX, G_ATOMICRMW_UMIN})
262 .legalForCartesianProduct(allIntScalars, allPtrs);
263
265 {G_ATOMICRMW_FADD, G_ATOMICRMW_FSUB, G_ATOMICRMW_FMIN, G_ATOMICRMW_FMAX})
266 .legalForCartesianProduct(allFloatScalars, allPtrs);
267
268 getActionDefinitionsBuilder(G_ATOMICRMW_XCHG)
269 .legalForCartesianProduct(allFloatAndIntScalarsAndPtrs, allPtrs);
270
271 getActionDefinitionsBuilder(G_ATOMIC_CMPXCHG_WITH_SUCCESS).lower();
272 // TODO: add proper legalization rules.
273 getActionDefinitionsBuilder(G_ATOMIC_CMPXCHG).alwaysLegal();
274
276 {G_UADDO, G_SADDO, G_USUBO, G_SSUBO, G_UMULO, G_SMULO})
277 .alwaysLegal();
278
279 // FP conversions.
280 getActionDefinitionsBuilder({G_FPTRUNC, G_FPEXT})
281 .legalForCartesianProduct(allFloatScalarsAndVectors);
282
283 // Pointer-handling.
284 getActionDefinitionsBuilder(G_FRAME_INDEX).legalFor({p0});
285
286 // Control-flow. In some cases (e.g. constants) s1 may be promoted to s32.
287 getActionDefinitionsBuilder(G_BRCOND).legalFor({s1, s32});
288
289 // TODO: Review the target OpenCL and GLSL Extended Instruction Set specs to
290 // tighten these requirements. Many of these math functions are only legal on
291 // specific bitwidths, so they are not selectable for
292 // allFloatScalarsAndVectors.
293 getActionDefinitionsBuilder({G_STRICT_FSQRT,
294 G_FPOW,
295 G_FEXP,
296 G_FEXP2,
297 G_FLOG,
298 G_FLOG2,
299 G_FLOG10,
300 G_FABS,
301 G_FMINNUM,
302 G_FMAXNUM,
303 G_FCEIL,
304 G_FCOS,
305 G_FSIN,
306 G_FTAN,
307 G_FACOS,
308 G_FASIN,
309 G_FATAN,
310 G_FATAN2,
311 G_FCOSH,
312 G_FSINH,
313 G_FTANH,
314 G_FSQRT,
315 G_FFLOOR,
316 G_FRINT,
317 G_FNEARBYINT,
318 G_INTRINSIC_ROUND,
319 G_INTRINSIC_TRUNC,
320 G_FMINIMUM,
321 G_FMAXIMUM,
322 G_INTRINSIC_ROUNDEVEN})
323 .legalFor(allFloatScalarsAndVectors);
324
325 getActionDefinitionsBuilder(G_FCOPYSIGN)
326 .legalForCartesianProduct(allFloatScalarsAndVectors,
327 allFloatScalarsAndVectors);
328
330 allFloatScalarsAndVectors, allIntScalarsAndVectors);
331
332 if (ST.canUseExtInstSet(SPIRV::InstructionSet::OpenCL_std)) {
334 {G_CTTZ, G_CTTZ_ZERO_UNDEF, G_CTLZ, G_CTLZ_ZERO_UNDEF})
335 .legalForCartesianProduct(allIntScalarsAndVectors,
336 allIntScalarsAndVectors);
337
338 // Struct return types become a single scalar, so cannot easily legalize.
339 getActionDefinitionsBuilder({G_SMULH, G_UMULH}).alwaysLegal();
340 }
341
342 getActionDefinitionsBuilder(G_IS_FPCLASS).custom();
343
345 verify(*ST.getInstrInfo());
346}
347
348static Register convertPtrToInt(Register Reg, LLT ConvTy, SPIRVType *SpvType,
349 LegalizerHelper &Helper,
352 Register ConvReg = MRI.createGenericVirtualRegister(ConvTy);
353 MRI.setRegClass(ConvReg, GR->getRegClass(SpvType));
354 GR->assignSPIRVTypeToVReg(SpvType, ConvReg, Helper.MIRBuilder.getMF());
355 Helper.MIRBuilder.buildInstr(TargetOpcode::G_PTRTOINT)
356 .addDef(ConvReg)
357 .addUse(Reg);
358 return ConvReg;
359}
360
363 LostDebugLocObserver &LocObserver) const {
364 MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
365 switch (MI.getOpcode()) {
366 default:
367 // TODO: implement legalization for other opcodes.
368 return true;
369 case TargetOpcode::G_IS_FPCLASS:
370 return legalizeIsFPClass(Helper, MI, LocObserver);
371 case TargetOpcode::G_ICMP: {
372 assert(GR->getSPIRVTypeForVReg(MI.getOperand(0).getReg()));
373 auto &Op0 = MI.getOperand(2);
374 auto &Op1 = MI.getOperand(3);
375 Register Reg0 = Op0.getReg();
376 Register Reg1 = Op1.getReg();
378 static_cast<CmpInst::Predicate>(MI.getOperand(1).getPredicate());
379 if ((!ST->canDirectlyComparePointers() ||
381 MRI.getType(Reg0).isPointer() && MRI.getType(Reg1).isPointer()) {
382 LLT ConvT = LLT::scalar(ST->getPointerSize());
383 Type *LLVMTy = IntegerType::get(MI.getMF()->getFunction().getContext(),
384 ST->getPointerSize());
385 SPIRVType *SpirvTy = GR->getOrCreateSPIRVType(
386 LLVMTy, Helper.MIRBuilder, SPIRV::AccessQualifier::ReadWrite, true);
387 Op0.setReg(convertPtrToInt(Reg0, ConvT, SpirvTy, Helper, MRI, GR));
388 Op1.setReg(convertPtrToInt(Reg1, ConvT, SpirvTy, Helper, MRI, GR));
389 }
390 return true;
391 }
392 }
393}
394
395// Note this code was copied from LegalizerHelper::lowerISFPCLASS and adjusted
396// to ensure that all instructions created during the lowering have SPIR-V types
397// assigned to them.
398bool SPIRVLegalizerInfo::legalizeIsFPClass(
400 LostDebugLocObserver &LocObserver) const {
401 auto [DstReg, DstTy, SrcReg, SrcTy] = MI.getFirst2RegLLTs();
402 FPClassTest Mask = static_cast<FPClassTest>(MI.getOperand(2).getImm());
403
404 auto &MIRBuilder = Helper.MIRBuilder;
405 auto &MF = MIRBuilder.getMF();
406 MachineRegisterInfo &MRI = MF.getRegInfo();
407
408 Type *LLVMDstTy =
409 IntegerType::get(MIRBuilder.getContext(), DstTy.getScalarSizeInBits());
410 if (DstTy.isVector())
411 LLVMDstTy = VectorType::get(LLVMDstTy, DstTy.getElementCount());
412 SPIRVType *SPIRVDstTy = GR->getOrCreateSPIRVType(
413 LLVMDstTy, MIRBuilder, SPIRV::AccessQualifier::ReadWrite,
414 /*EmitIR*/ true);
415
416 unsigned BitSize = SrcTy.getScalarSizeInBits();
417 const fltSemantics &Semantics = getFltSemanticForLLT(SrcTy.getScalarType());
418
419 LLT IntTy = LLT::scalar(BitSize);
420 Type *LLVMIntTy = IntegerType::get(MIRBuilder.getContext(), BitSize);
421 if (SrcTy.isVector()) {
422 IntTy = LLT::vector(SrcTy.getElementCount(), IntTy);
423 LLVMIntTy = VectorType::get(LLVMIntTy, SrcTy.getElementCount());
424 }
425 SPIRVType *SPIRVIntTy = GR->getOrCreateSPIRVType(
426 LLVMIntTy, MIRBuilder, SPIRV::AccessQualifier::ReadWrite,
427 /*EmitIR*/ true);
428
429 // Clang doesn't support capture of structured bindings:
430 LLT DstTyCopy = DstTy;
431 const auto assignSPIRVTy = [&](MachineInstrBuilder &&MI) {
432 // Assign this MI's (assumed only) destination to one of the two types we
433 // expect: either the G_IS_FPCLASS's destination type, or the integer type
434 // bitcast from the source type.
435 LLT MITy = MRI.getType(MI.getReg(0));
436 assert((MITy == IntTy || MITy == DstTyCopy) &&
437 "Unexpected LLT type while lowering G_IS_FPCLASS");
438 auto *SPVTy = MITy == IntTy ? SPIRVIntTy : SPIRVDstTy;
439 GR->assignSPIRVTypeToVReg(SPVTy, MI.getReg(0), MF);
440 return MI;
441 };
442
443 // Helper to build and assign a constant in one go
444 const auto buildSPIRVConstant = [&](LLT Ty, auto &&C) -> MachineInstrBuilder {
445 if (!Ty.isFixedVector())
446 return assignSPIRVTy(MIRBuilder.buildConstant(Ty, C));
447 auto ScalarC = MIRBuilder.buildConstant(Ty.getScalarType(), C);
448 assert((Ty == IntTy || Ty == DstTyCopy) &&
449 "Unexpected LLT type while lowering constant for G_IS_FPCLASS");
450 SPIRVType *VecEltTy = GR->getOrCreateSPIRVType(
451 (Ty == IntTy ? LLVMIntTy : LLVMDstTy)->getScalarType(), MIRBuilder,
452 SPIRV::AccessQualifier::ReadWrite,
453 /*EmitIR*/ true);
454 GR->assignSPIRVTypeToVReg(VecEltTy, ScalarC.getReg(0), MF);
455 return assignSPIRVTy(MIRBuilder.buildSplatBuildVector(Ty, ScalarC));
456 };
457
458 if (Mask == fcNone) {
459 MIRBuilder.buildCopy(DstReg, buildSPIRVConstant(DstTy, 0));
460 MI.eraseFromParent();
461 return true;
462 }
463 if (Mask == fcAllFlags) {
464 MIRBuilder.buildCopy(DstReg, buildSPIRVConstant(DstTy, 1));
465 MI.eraseFromParent();
466 return true;
467 }
468
469 // Note that rather than creating a COPY here (between a floating-point and
470 // integer type of the same size) we create a SPIR-V bitcast immediately. We
471 // can't create a G_BITCAST because the LLTs are the same, and we can't seem
472 // to correctly lower COPYs to SPIR-V bitcasts at this moment.
473 Register ResVReg = MRI.createGenericVirtualRegister(IntTy);
474 MRI.setRegClass(ResVReg, GR->getRegClass(SPIRVIntTy));
475 GR->assignSPIRVTypeToVReg(SPIRVIntTy, ResVReg, Helper.MIRBuilder.getMF());
476 auto AsInt = MIRBuilder.buildInstr(SPIRV::OpBitcast)
477 .addDef(ResVReg)
478 .addUse(GR->getSPIRVTypeID(SPIRVIntTy))
479 .addUse(SrcReg);
480 AsInt = assignSPIRVTy(std::move(AsInt));
481
482 // Various masks.
483 APInt SignBit = APInt::getSignMask(BitSize);
484 APInt ValueMask = APInt::getSignedMaxValue(BitSize); // All bits but sign.
485 APInt Inf = APFloat::getInf(Semantics).bitcastToAPInt(); // Exp and int bit.
486 APInt ExpMask = Inf;
487 APInt AllOneMantissa = APFloat::getLargest(Semantics).bitcastToAPInt() & ~Inf;
488 APInt QNaNBitMask =
489 APInt::getOneBitSet(BitSize, AllOneMantissa.getActiveBits() - 1);
490 APInt InversionMask = APInt::getAllOnes(DstTy.getScalarSizeInBits());
491
492 auto SignBitC = buildSPIRVConstant(IntTy, SignBit);
493 auto ValueMaskC = buildSPIRVConstant(IntTy, ValueMask);
494 auto InfC = buildSPIRVConstant(IntTy, Inf);
495 auto ExpMaskC = buildSPIRVConstant(IntTy, ExpMask);
496 auto ZeroC = buildSPIRVConstant(IntTy, 0);
497
498 auto Abs = assignSPIRVTy(MIRBuilder.buildAnd(IntTy, AsInt, ValueMaskC));
499 auto Sign = assignSPIRVTy(
500 MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_NE, DstTy, AsInt, Abs));
501
502 auto Res = buildSPIRVConstant(DstTy, 0);
503
504 const auto appendToRes = [&](MachineInstrBuilder &&ToAppend) {
505 Res = assignSPIRVTy(
506 MIRBuilder.buildOr(DstTyCopy, Res, assignSPIRVTy(std::move(ToAppend))));
507 };
508
509 // Tests that involve more than one class should be processed first.
510 if ((Mask & fcFinite) == fcFinite) {
511 // finite(V) ==> abs(V) u< exp_mask
512 appendToRes(MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_ULT, DstTy, Abs,
513 ExpMaskC));
514 Mask &= ~fcFinite;
515 } else if ((Mask & fcFinite) == fcPosFinite) {
516 // finite(V) && V > 0 ==> V u< exp_mask
517 appendToRes(MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_ULT, DstTy, AsInt,
518 ExpMaskC));
519 Mask &= ~fcPosFinite;
520 } else if ((Mask & fcFinite) == fcNegFinite) {
521 // finite(V) && V < 0 ==> abs(V) u< exp_mask && signbit == 1
522 auto Cmp = assignSPIRVTy(MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_ULT,
523 DstTy, Abs, ExpMaskC));
524 appendToRes(MIRBuilder.buildAnd(DstTy, Cmp, Sign));
525 Mask &= ~fcNegFinite;
526 }
527
528 if (FPClassTest PartialCheck = Mask & (fcZero | fcSubnormal)) {
529 // fcZero | fcSubnormal => test all exponent bits are 0
530 // TODO: Handle sign bit specific cases
531 // TODO: Handle inverted case
532 if (PartialCheck == (fcZero | fcSubnormal)) {
533 auto ExpBits = assignSPIRVTy(MIRBuilder.buildAnd(IntTy, AsInt, ExpMaskC));
534 appendToRes(MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_EQ, DstTy,
535 ExpBits, ZeroC));
536 Mask &= ~PartialCheck;
537 }
538 }
539
540 // Check for individual classes.
541 if (FPClassTest PartialCheck = Mask & fcZero) {
542 if (PartialCheck == fcPosZero)
543 appendToRes(MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_EQ, DstTy,
544 AsInt, ZeroC));
545 else if (PartialCheck == fcZero)
546 appendToRes(
547 MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_EQ, DstTy, Abs, ZeroC));
548 else // fcNegZero
549 appendToRes(MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_EQ, DstTy,
550 AsInt, SignBitC));
551 }
552
553 if (FPClassTest PartialCheck = Mask & fcSubnormal) {
554 // issubnormal(V) ==> unsigned(abs(V) - 1) u< (all mantissa bits set)
555 // issubnormal(V) && V>0 ==> unsigned(V - 1) u< (all mantissa bits set)
556 auto V = (PartialCheck == fcPosSubnormal) ? AsInt : Abs;
557 auto OneC = buildSPIRVConstant(IntTy, 1);
558 auto VMinusOne = MIRBuilder.buildSub(IntTy, V, OneC);
559 auto SubnormalRes = assignSPIRVTy(
560 MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_ULT, DstTy, VMinusOne,
561 buildSPIRVConstant(IntTy, AllOneMantissa)));
562 if (PartialCheck == fcNegSubnormal)
563 SubnormalRes = MIRBuilder.buildAnd(DstTy, SubnormalRes, Sign);
564 appendToRes(std::move(SubnormalRes));
565 }
566
567 if (FPClassTest PartialCheck = Mask & fcInf) {
568 if (PartialCheck == fcPosInf)
569 appendToRes(MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_EQ, DstTy,
570 AsInt, InfC));
571 else if (PartialCheck == fcInf)
572 appendToRes(
573 MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_EQ, DstTy, Abs, InfC));
574 else { // fcNegInf
575 APInt NegInf = APFloat::getInf(Semantics, true).bitcastToAPInt();
576 auto NegInfC = buildSPIRVConstant(IntTy, NegInf);
577 appendToRes(MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_EQ, DstTy,
578 AsInt, NegInfC));
579 }
580 }
581
582 if (FPClassTest PartialCheck = Mask & fcNan) {
583 auto InfWithQnanBitC = buildSPIRVConstant(IntTy, Inf | QNaNBitMask);
584 if (PartialCheck == fcNan) {
585 // isnan(V) ==> abs(V) u> int(inf)
586 appendToRes(
587 MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_UGT, DstTy, Abs, InfC));
588 } else if (PartialCheck == fcQNan) {
589 // isquiet(V) ==> abs(V) u>= (unsigned(Inf) | quiet_bit)
590 appendToRes(MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_UGE, DstTy, Abs,
591 InfWithQnanBitC));
592 } else { // fcSNan
593 // issignaling(V) ==> abs(V) u> unsigned(Inf) &&
594 // abs(V) u< (unsigned(Inf) | quiet_bit)
595 auto IsNan = assignSPIRVTy(
596 MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_UGT, DstTy, Abs, InfC));
597 auto IsNotQnan = assignSPIRVTy(MIRBuilder.buildICmp(
598 CmpInst::Predicate::ICMP_ULT, DstTy, Abs, InfWithQnanBitC));
599 appendToRes(MIRBuilder.buildAnd(DstTy, IsNan, IsNotQnan));
600 }
601 }
602
603 if (FPClassTest PartialCheck = Mask & fcNormal) {
604 // isnormal(V) ==> (0 u< exp u< max_exp) ==> (unsigned(exp-1) u<
605 // (max_exp-1))
606 APInt ExpLSB = ExpMask & ~(ExpMask.shl(1));
607 auto ExpMinusOne = assignSPIRVTy(
608 MIRBuilder.buildSub(IntTy, Abs, buildSPIRVConstant(IntTy, ExpLSB)));
609 APInt MaxExpMinusOne = ExpMask - ExpLSB;
610 auto NormalRes = assignSPIRVTy(
611 MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_ULT, DstTy, ExpMinusOne,
612 buildSPIRVConstant(IntTy, MaxExpMinusOne)));
613 if (PartialCheck == fcNegNormal)
614 NormalRes = MIRBuilder.buildAnd(DstTy, NormalRes, Sign);
615 else if (PartialCheck == fcPosNormal) {
616 auto PosSign = assignSPIRVTy(MIRBuilder.buildXor(
617 DstTy, Sign, buildSPIRVConstant(DstTy, InversionMask)));
618 NormalRes = MIRBuilder.buildAnd(DstTy, NormalRes, PosSign);
619 }
620 appendToRes(std::move(NormalRes));
621 }
622
623 MIRBuilder.buildCopy(DstReg, Res);
624 MI.eraseFromParent();
625 return true;
626}
unsigned const MachineRegisterInfo * MRI
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
static void scalarize(Instruction *I, SmallVectorImpl< Instruction * > &Replace)
Definition: ExpandFp.cpp:577
IRTranslator LLVM IR MI
This file declares the MachineIRBuilder class.
ppc ctr loops verify
const SmallVectorImpl< MachineOperand > & Cond
static Register convertPtrToInt(Register Reg, LLT ConvTy, SPIRVType *SpvType, LegalizerHelper &Helper, MachineRegisterInfo &MRI, SPIRVGlobalRegistry *GR)
LegalityPredicate typeOfExtendedScalars(unsigned TypeIdx, bool IsExtendedInts)
APInt bitcastToAPInt() const
Definition: APFloat.h:1353
static APFloat getLargest(const fltSemantics &Sem, bool Negative=false)
Returns the largest finite number in the given semantics.
Definition: APFloat.h:1138
static APFloat getInf(const fltSemantics &Sem, bool Negative=false)
Factory for Positive and Negative Infinity.
Definition: APFloat.h:1098
Class for arbitrary precision integers.
Definition: APInt.h:78
static APInt getAllOnes(unsigned numBits)
Return an APInt of a specified width with all bits set.
Definition: APInt.h:234
static APInt getSignMask(unsigned BitWidth)
Get the SignMask for a specific bit width.
Definition: APInt.h:229
unsigned getActiveBits() const
Compute the number of active bits in the value.
Definition: APInt.h:1512
static APInt getSignedMaxValue(unsigned numBits)
Gets maximum signed value of APInt for a specific bit width.
Definition: APInt.h:209
APInt shl(unsigned shiftAmt) const
Left-shift function.
Definition: APInt.h:873
static APInt getOneBitSet(unsigned numBits, unsigned BitNo)
Return an APInt with exactly one bit set in the result.
Definition: APInt.h:239
Predicate
This enumeration lists the possible predicates for CmpInst subclasses.
Definition: InstrTypes.h:678
@ ICMP_UGE
unsigned greater or equal
Definition: InstrTypes.h:702
@ ICMP_UGT
unsigned greater than
Definition: InstrTypes.h:701
@ ICMP_ULT
unsigned less than
Definition: InstrTypes.h:703
@ ICMP_EQ
equal
Definition: InstrTypes.h:699
@ ICMP_NE
not equal
Definition: InstrTypes.h:700
static LLVM_ABI IntegerType * get(LLVMContext &C, unsigned NumBits)
This static method is the primary way of constructing an IntegerType.
Definition: Type.cpp:319
constexpr bool isScalar() const
Definition: LowLevelType.h:147
static constexpr LLT vector(ElementCount EC, unsigned ScalarSizeInBits)
Get a low-level vector of some number of elements and element width.
Definition: LowLevelType.h:65
static constexpr LLT scalar(unsigned SizeInBits)
Get a low-level scalar or aggregate "bag of bits".
Definition: LowLevelType.h:43
constexpr bool isValid() const
Definition: LowLevelType.h:146
static constexpr LLT pointer(unsigned AddressSpace, unsigned SizeInBits)
Get a low-level pointer in the given address space.
Definition: LowLevelType.h:58
static constexpr LLT fixed_vector(unsigned NumElements, unsigned ScalarSizeInBits)
Get a low-level fixed-width vector of some number of elements and element width.
Definition: LowLevelType.h:101
constexpr bool isPointerOrPointerVector() const
Definition: LowLevelType.h:154
constexpr bool isFixedVector() const
Returns true if the LLT is a fixed vector.
Definition: LowLevelType.h:178
constexpr LLT getScalarType() const
Definition: LowLevelType.h:206
LLVM_ABI void computeTables()
Compute any ancillary tables needed to quickly decide how an operation should be handled.
LegalizeRuleSet & legalFor(std::initializer_list< LLT > Types)
The instruction is legal when type index 0 is any type in the given list.
LegalizeRuleSet & lower()
The instruction is lowered.
LegalizeRuleSet & custom()
Unconditionally custom lower.
LegalizeRuleSet & alwaysLegal()
LegalizeRuleSet & customIf(LegalityPredicate Predicate)
LegalizeRuleSet & scalarize(unsigned TypeIdx)
LegalizeRuleSet & legalForCartesianProduct(std::initializer_list< LLT > Types)
The instruction is legal when type indexes 0 and 1 are both in the given list.
LegalizeRuleSet & legalIf(LegalityPredicate Predicate)
The instruction is legal if predicate is true.
MachineIRBuilder & MIRBuilder
Expose MIRBuilder so clients can set their own RecordInsertInstruction functions.
LegalizeRuleSet & getActionDefinitionsBuilder(unsigned Opcode)
Get the action definition builder for the given opcode.
const LegacyLegalizerInfo & getLegacyLegalizerInfo() const
MachineInstrBuilder buildInstr(unsigned Opcode)
Build and insert <empty> = Opcode <empty>.
MachineFunction & getMF()
Getter for the function we currently build.
const MachineInstrBuilder & addUse(Register RegNo, unsigned Flags=0, unsigned SubReg=0) const
Add a virtual register use operand.
const MachineInstrBuilder & addDef(Register RegNo, unsigned Flags=0, unsigned SubReg=0) const
Add a virtual register definition operand.
Representation of each machine instruction.
Definition: MachineInstr.h:72
MachineRegisterInfo - Keep track of information for virtual and physical registers,...
Wrapper class representing virtual and physical registers.
Definition: Register.h:19
SPIRVType * getSPIRVTypeForVReg(Register VReg, const MachineFunction *MF=nullptr) const
void assignSPIRVTypeToVReg(SPIRVType *Type, Register VReg, const MachineFunction &MF)
SPIRVType * getOrCreateSPIRVType(const Type *Type, MachineInstr &I, SPIRV::AccessQualifier::AccessQualifier AQ, bool EmitIR)
Register getSPIRVTypeID(const SPIRVType *SpirvType) const
const TargetRegisterClass * getRegClass(SPIRVType *SpvType) const
SPIRVLegalizerInfo(const SPIRVSubtarget &ST)
bool legalizeCustom(LegalizerHelper &Helper, MachineInstr &MI, LostDebugLocObserver &LocObserver) const override
Called for instructions with the Custom LegalizationAction.
unsigned getPointerSize() const
bool canDirectlyComparePointers() const
The instances of the Type class are immutable: once they are created, they are never changed.
Definition: Type.h:45
static LLVM_ABI VectorType * get(Type *ElementType, ElementCount EC)
This static method is the primary way to construct an VectorType.
constexpr std::underlying_type_t< E > Mask()
Get a bitmask with 1s in all places up to the high-order bit of E's largest value.
Definition: BitmaskEnum.h:126
@ C
The default llvm calling convention, compatible with C.
Definition: CallingConv.h:34
LLVM_ABI LegalityPredicate typeInSet(unsigned TypeIdx, std::initializer_list< LLT > TypesInit)
True iff the given type index is one of the specified types.
Predicate all(Predicate P0, Predicate P1)
True iff P0 and P1 are true.
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
LLVM_ABI const llvm::fltSemantics & getFltSemanticForLLT(LLT Ty)
Get the appropriate floating point arithmetic semantic based on the bit size of the given scalar LLT.
FPClassTest
Floating-point class tests, supported by 'is_fpclass' intrinsic.
const std::set< unsigned > & getTypeFoldingSupportedOpcodes()
Definition: SPIRVUtils.cpp:922
std::function< bool(const LegalityQuery &)> LegalityPredicate
The LegalityQuery object bundles together all the information that's needed to decide whether a given...