LLVM 22.0.0git
DXILShaderFlags.cpp
Go to the documentation of this file.
1//===- DXILShaderFlags.cpp - DXIL Shader Flags helper objects -------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8///
9/// \file This file contains helper objects and APIs for working with DXIL
10/// Shader Flags.
11///
12//===----------------------------------------------------------------------===//
13
14#include "DXILShaderFlags.h"
15#include "DirectX.h"
20#include "llvm/IR/Attributes.h"
22#include "llvm/IR/Instruction.h"
25#include "llvm/IR/Intrinsics.h"
26#include "llvm/IR/IntrinsicsDirectX.h"
27#include "llvm/IR/Module.h"
31
32using namespace llvm;
33using namespace llvm::dxil;
34
35static bool hasUAVsAtEveryStage(const DXILResourceMap &DRM,
36 const ModuleMetadataInfo &MMDI) {
37 if (DRM.uavs().empty())
38 return false;
39
40 switch (MMDI.ShaderProfile) {
41 default:
42 return false;
43 case Triple::EnvironmentType::Compute:
44 case Triple::EnvironmentType::Pixel:
45 return false;
46 case Triple::EnvironmentType::Vertex:
47 case Triple::EnvironmentType::Geometry:
48 case Triple::EnvironmentType::Hull:
49 case Triple::EnvironmentType::Domain:
50 return true;
51 case Triple::EnvironmentType::Library:
52 case Triple::EnvironmentType::RayGeneration:
53 case Triple::EnvironmentType::Intersection:
54 case Triple::EnvironmentType::AnyHit:
55 case Triple::EnvironmentType::ClosestHit:
56 case Triple::EnvironmentType::Miss:
57 case Triple::EnvironmentType::Callable:
58 case Triple::EnvironmentType::Mesh:
59 case Triple::EnvironmentType::Amplification:
60 return MMDI.ValidatorVersion < VersionTuple(1, 8);
61 }
62}
63
64static bool checkWaveOps(Intrinsic::ID IID) {
65 // Currently unsupported intrinsics
66 // case Intrinsic::dx_wave_getlanecount:
67 // case Intrinsic::dx_wave_allequal:
68 // case Intrinsic::dx_wave_ballot:
69 // case Intrinsic::dx_wave_readfirst:
70 // case Intrinsic::dx_wave_reduce.and:
71 // case Intrinsic::dx_wave_reduce.or:
72 // case Intrinsic::dx_wave_reduce.xor:
73 // case Intrinsic::dx_wave_prefixop:
74 // case Intrinsic::dx_quad.readat:
75 // case Intrinsic::dx_quad.readacrossx:
76 // case Intrinsic::dx_quad.readacrossy:
77 // case Intrinsic::dx_quad.readacrossdiagonal:
78 // case Intrinsic::dx_wave_prefixballot:
79 // case Intrinsic::dx_wave_match:
80 // case Intrinsic::dx_wavemulti.*:
81 // case Intrinsic::dx_wavemulti.ballot:
82 // case Intrinsic::dx_quad.vote:
83 switch (IID) {
84 default:
85 return false;
86 case Intrinsic::dx_wave_is_first_lane:
87 case Intrinsic::dx_wave_getlaneindex:
88 case Intrinsic::dx_wave_any:
89 case Intrinsic::dx_wave_all:
90 case Intrinsic::dx_wave_readlane:
91 case Intrinsic::dx_wave_active_countbits:
92 // Wave Active Op Variants
93 case Intrinsic::dx_wave_reduce_sum:
94 case Intrinsic::dx_wave_reduce_usum:
95 case Intrinsic::dx_wave_reduce_max:
96 case Intrinsic::dx_wave_reduce_umax:
97 return true;
98 }
99}
100
101/// Update the shader flags mask based on the given instruction.
102/// \param CSF Shader flags mask to update.
103/// \param I Instruction to check.
104void ModuleShaderFlags::updateFunctionFlags(ComputedShaderFlags &CSF,
105 const Instruction &I,
107 const ModuleMetadataInfo &MMDI) {
108 if (!CSF.Doubles)
109 CSF.Doubles = I.getType()->getScalarType()->isDoubleTy();
110
111 if (!CSF.Doubles) {
112 for (const Value *Op : I.operands()) {
113 if (Op->getType()->getScalarType()->isDoubleTy()) {
114 CSF.Doubles = true;
115 break;
116 }
117 }
118 }
119
120 if (CSF.Doubles) {
121 switch (I.getOpcode()) {
122 case Instruction::FDiv:
123 case Instruction::UIToFP:
124 case Instruction::SIToFP:
125 case Instruction::FPToUI:
126 case Instruction::FPToSI:
127 CSF.DX11_1_DoubleExtensions = true;
128 break;
129 }
130 }
131
132 if (!CSF.LowPrecisionPresent)
133 CSF.LowPrecisionPresent = I.getType()->getScalarType()->isIntegerTy(16) ||
134 I.getType()->getScalarType()->isHalfTy();
135
136 if (!CSF.LowPrecisionPresent) {
137 for (const Value *Op : I.operands()) {
138 if (Op->getType()->getScalarType()->isIntegerTy(16) ||
139 Op->getType()->getScalarType()->isHalfTy()) {
140 CSF.LowPrecisionPresent = true;
141 break;
142 }
143 }
144 }
145
146 if (CSF.LowPrecisionPresent) {
147 if (CSF.NativeLowPrecisionMode)
148 CSF.NativeLowPrecision = true;
149 else
150 CSF.MinimumPrecision = true;
151 }
152
153 if (!CSF.Int64Ops)
154 CSF.Int64Ops = I.getType()->getScalarType()->isIntegerTy(64);
155
156 if (!CSF.Int64Ops && !isa<LifetimeIntrinsic>(&I)) {
157 for (const Value *Op : I.operands()) {
158 if (Op->getType()->getScalarType()->isIntegerTy(64)) {
159 CSF.Int64Ops = true;
160 break;
161 }
162 }
163 }
164
165 if (auto *II = dyn_cast<IntrinsicInst>(&I)) {
166 switch (II->getIntrinsicID()) {
167 default:
168 break;
169 case Intrinsic::dx_resource_handlefrombinding: {
170 dxil::ResourceTypeInfo &RTI = DRTM[cast<TargetExtType>(II->getType())];
171
172 // Set ResMayNotAlias if DXIL validator version >= 1.8 and the function
173 // uses UAVs
174 if (!CSF.ResMayNotAlias && CanSetResMayNotAlias &&
175 MMDI.ValidatorVersion >= VersionTuple(1, 8) && RTI.isUAV())
176 CSF.ResMayNotAlias = true;
177
178 switch (RTI.getResourceKind()) {
181 CSF.EnableRawAndStructuredBuffers = true;
182 break;
183 default:
184 break;
185 }
186 break;
187 }
188 case Intrinsic::dx_resource_load_typedbuffer: {
190 DRTM[cast<TargetExtType>(II->getArgOperand(0)->getType())];
191 if (RTI.isTyped())
192 CSF.TypedUAVLoadAdditionalFormats |= RTI.getTyped().ElementCount > 1;
193 break;
194 }
195 }
196 }
197 // Handle call instructions
198 if (auto *CI = dyn_cast<CallInst>(&I)) {
199 const Function *CF = CI->getCalledFunction();
200 // Merge-in shader flags mask of the called function in the current module
201 if (FunctionFlags.contains(CF))
202 CSF.merge(FunctionFlags[CF]);
203
204 // TODO: Set DX11_1_DoubleExtensions if I is a call to DXIL intrinsic
205 // DXIL::Opcode::Fma https://github.com/llvm/llvm-project/issues/114554
206
207 CSF.WaveOps |= checkWaveOps(CI->getIntrinsicID());
208 }
209}
210
211/// Set shader flags that apply to all functions within the module
213ModuleShaderFlags::gatherGlobalModuleFlags(const Module &M,
214 const DXILResourceMap &DRM,
215 const ModuleMetadataInfo &MMDI) {
216
218
219 // Set DisableOptimizations flag based on the presence of OptimizeNone
220 // attribute of entry functions.
221 if (MMDI.EntryPropertyVec.size() > 0) {
222 CSF.DisableOptimizations = MMDI.EntryPropertyVec[0].Entry->hasFnAttribute(
223 llvm::Attribute::OptimizeNone);
224 // Ensure all entry functions have the same optimization attribute
225 for (const auto &EntryFunProps : MMDI.EntryPropertyVec)
226 if (CSF.DisableOptimizations !=
227 EntryFunProps.Entry->hasFnAttribute(llvm::Attribute::OptimizeNone))
228 EntryFunProps.Entry->getContext().diagnose(DiagnosticInfoUnsupported(
229 *(EntryFunProps.Entry), "Inconsistent optnone attribute "));
230 }
231
232 CSF.UAVsAtEveryStage = hasUAVsAtEveryStage(DRM, MMDI);
233
234 // Set the Max64UAVs flag if the number of UAVs is > 8
235 uint32_t NumUAVs = 0;
236 for (auto &UAV : DRM.uavs())
237 if (MMDI.ValidatorVersion < VersionTuple(1, 6))
238 NumUAVs++;
239 else // MMDI.ValidatorVersion >= VersionTuple(1, 6)
240 NumUAVs += UAV.getBinding().Size;
241 if (NumUAVs > 8)
242 CSF.Max64UAVs = true;
243
244 // Set the module flag that enables native low-precision execution mode.
245 // NativeLowPrecisionMode can only be set when the command line option
246 // -enable-16bit-types is provided. This is indicated by the dx.nativelowprec
247 // module flag being set
248 // This flag is needed even if the module does not use 16-bit types because a
249 // corresponding debug module may include 16-bit types, and tools that use the
250 // debug module may expect it to have the same flags as the original
251 if (auto *NativeLowPrec = mdconst::extract_or_null<ConstantInt>(
252 M.getModuleFlag("dx.nativelowprec")))
253 if (MMDI.ShaderModelVersion >= VersionTuple(6, 2))
254 CSF.NativeLowPrecisionMode = NativeLowPrec->getValue().getBoolValue();
255
256 // Set ResMayNotAlias to true if DXIL validator version < 1.8 and there
257 // are UAVs present globally.
258 if (CanSetResMayNotAlias && MMDI.ValidatorVersion < VersionTuple(1, 8))
259 CSF.ResMayNotAlias = !DRM.uavs().empty();
260
261 return CSF;
262}
263
264/// Construct ModuleShaderFlags for module Module M
266 const DXILResourceMap &DRM,
267 const ModuleMetadataInfo &MMDI) {
268
269 CanSetResMayNotAlias = MMDI.DXILVersion >= VersionTuple(1, 7);
270 // The command line option -res-may-alias will set the dx.resmayalias module
271 // flag to 1, thereby disabling the ability to set the ResMayNotAlias flag
272 if (auto *ResMayAlias = mdconst::extract_or_null<ConstantInt>(
273 M.getModuleFlag("dx.resmayalias")))
274 if (ResMayAlias->getValue().getBoolValue())
275 CanSetResMayNotAlias = false;
276
277 ComputedShaderFlags GlobalSFMask = gatherGlobalModuleFlags(M, DRM, MMDI);
278
279 CallGraph CG(M);
280
281 // Compute Shader Flags Mask for all functions using post-order visit of SCC
282 // of the call graph.
283 for (scc_iterator<CallGraph *> SCCI = scc_begin(&CG); !SCCI.isAtEnd();
284 ++SCCI) {
285 const std::vector<CallGraphNode *> &CurSCC = *SCCI;
286
287 // Union of shader masks of all functions in CurSCC
289 // List of functions in CurSCC that are neither external nor declarations
290 // and hence whose flags are collected
291 SmallVector<Function *> CurSCCFuncs;
292 for (CallGraphNode *CGN : CurSCC) {
293 Function *F = CGN->getFunction();
294 if (!F)
295 continue;
296
297 if (F->isDeclaration()) {
298 assert(!F->getName().starts_with("dx.op.") &&
299 "DXIL Shader Flag analysis should not be run post-lowering.");
300 continue;
301 }
302
303 ComputedShaderFlags CSF = GlobalSFMask;
304 for (const auto &BB : *F)
305 for (const auto &I : BB)
306 updateFunctionFlags(CSF, I, DRTM, MMDI);
307 // Update combined shader flags mask for all functions in this SCC
308 SCCSF.merge(CSF);
309
310 CurSCCFuncs.push_back(F);
311 }
312
313 // Update combined shader flags mask for all functions of the module
314 CombinedSFMask.merge(SCCSF);
315
316 // Shader flags mask of each of the functions in an SCC of the call graph is
317 // the union of all functions in the SCC. Update shader flags masks of
318 // functions in CurSCC accordingly. This is trivially true if SCC contains
319 // one function.
320 for (Function *F : CurSCCFuncs)
321 // Merge SCCSF with that of F
322 FunctionFlags[F].merge(SCCSF);
323 }
324}
325
327 uint64_t FlagVal = (uint64_t) * this;
328 OS << formatv("; Shader Flags Value: {0:x8}\n;\n", FlagVal);
329 if (FlagVal == 0)
330 return;
331 OS << "; Note: shader requires additional functionality:\n";
332#define SHADER_FEATURE_FLAG(FeatureBit, DxilModuleNum, FlagName, Str) \
333 if (FlagName) \
334 (OS << ";").indent(7) << Str << "\n";
335#include "llvm/BinaryFormat/DXContainerConstants.def"
336 OS << "; Note: extra DXIL module flags:\n";
337#define DXIL_MODULE_FLAG(DxilModuleBit, FlagName, Str) \
338 if (FlagName) \
339 (OS << ";").indent(7) << Str << "\n";
340#include "llvm/BinaryFormat/DXContainerConstants.def"
341 OS << ";\n";
342}
343
344/// Return the shader flags mask of the specified function Func.
347 auto Iter = FunctionFlags.find(Func);
348 assert((Iter != FunctionFlags.end() && Iter->first == Func) &&
349 "Get Shader Flags : No Shader Flags Mask exists for function");
350 return Iter->second;
351}
352
353//===----------------------------------------------------------------------===//
354// ShaderFlagsAnalysis and ShaderFlagsAnalysisPrinterPass
355
356// Provide an explicit template instantiation for the static ID.
357AnalysisKey ShaderFlagsAnalysis::Key;
358
364
366 MSFI.initialize(M, DRTM, DRM, MMDI);
367
368 return MSFI;
369}
370
373 const ModuleShaderFlags &FlagsInfo = AM.getResult<ShaderFlagsAnalysis>(M);
374 // Print description of combined shader flags for all module functions
375 OS << "; Combined Shader Flags for Module\n";
376 FlagsInfo.getCombinedFlags().print(OS);
377 // Print shader flags mask for each of the module functions
378 OS << "; Shader Flags for Module Functions\n";
379 for (const auto &F : M.getFunctionList()) {
380 if (F.isDeclaration())
381 continue;
382 const ComputedShaderFlags &SFMask = FlagsInfo.getFunctionFlags(&F);
383 OS << formatv("; Function {0} : {1:x8}\n;\n", F.getName(),
384 (uint64_t)(SFMask));
385 }
386
387 return PreservedAnalyses::all();
388}
389
390//===----------------------------------------------------------------------===//
391// ShaderFlagsAnalysis and ShaderFlagsAnalysisPrinterPass
392
394 DXILResourceTypeMap &DRTM =
395 getAnalysis<DXILResourceTypeWrapperPass>().getResourceTypeMap();
396 DXILResourceMap &DRM =
397 getAnalysis<DXILResourceWrapperPass>().getResourceMap();
398 const ModuleMetadataInfo MMDI =
399 getAnalysis<DXILMetadataAnalysisWrapperPass>().getModuleMetadata();
400
401 MSFI.initialize(M, DRTM, DRM, MMDI);
402 return false;
403}
404
406 AU.setPreservesAll();
410}
411
413
415 "DXIL Shader Flag Analysis", true, true)
419 "DXIL Shader Flag Analysis", true, true)
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
This file contains the simple types necessary to represent the attributes associated with functions a...
block Block Frequency Analysis
This file provides interfaces used to build and manipulate a call graph, which is a very useful tool ...
static bool hasUAVsAtEveryStage(const DXILResourceMap &DRM, const ModuleMetadataInfo &MMDI)
static bool checkWaveOps(Intrinsic::ID IID)
Module.h This file contains the declarations for the Module class.
#define F(x, y, z)
Definition: MD5.cpp:55
#define I(x, y, z)
Definition: MD5.cpp:58
module summary analysis
uint64_t IntrinsicInst * II
#define INITIALIZE_PASS_DEPENDENCY(depName)
Definition: PassSupport.h:42
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:44
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:39
This builds on the llvm/ADT/GraphTraits.h file to find the strongly connected components (SCCs) of a ...
raw_pwrite_stream & OS
This file defines the SmallVector class.
A container for analyses that lazily runs them and caches their results.
Definition: PassManager.h:255
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
Definition: PassManager.h:412
Represent the analysis usage information of a pass.
AnalysisUsage & addRequired()
void setPreservesAll()
Set by analyses that do not transform their input at all.
AnalysisUsage & addRequiredTransitive()
A node in the call graph for a module.
Definition: CallGraph.h:162
The basic data container for the call graph of a Module of IR.
Definition: CallGraph.h:72
This class represents an Operation in the Expression.
iterator_range< iterator > uavs()
Definition: DXILResource.h:528
Diagnostic information for unsupported feature in backend.
A Module instance is used to store all the information related to an LLVM module.
Definition: Module.h:67
A set of analyses that are preserved following a run of a transformation pass.
Definition: Analysis.h:112
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition: Analysis.h:118
void push_back(const T &Elt)
Definition: SmallVector.h:414
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1197
LLVM Value Representation.
Definition: Value.h:75
Represents a version number in the form major[.minor[.subminor[.build]]].
Definition: VersionTuple.h:30
LLVM_ABI bool isUAV() const
LLVM_ABI bool isTyped() const
LLVM_ABI TypedInfo getTyped() const
dxil::ResourceKind getResourceKind() const
Definition: DXILResource.h:325
PreservedAnalyses run(Module &M, ModuleAnalysisManager &AM)
Wrapper pass for the legacy pass manager.
void getAnalysisUsage(AnalysisUsage &AU) const override
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
bool runOnModule(Module &M) override
runOnModule - Virtual method overriden by subclasses to process the module being operated on.
ModuleShaderFlags run(Module &M, ModuleAnalysisManager &AM)
This class implements an extremely fast bulk output stream that can only output to a stream.
Definition: raw_ostream.h:53
Enumerate the SCCs of a directed graph in reverse topological order of the SCC DAG.
Definition: SCCIterator.h:49
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
scc_iterator< T > scc_begin(const T &G)
Construct the begin iterator for a deduced graph type T.
Definition: SCCIterator.h:233
auto formatv(bool Validate, const char *Fmt, Ts &&...Vals)
A special type used by analysis passes to provide an address that identifies that particular analysis...
Definition: Analysis.h:29
void merge(const ComputedShaderFlags CSF)
void print(raw_ostream &OS=dbgs()) const
Triple::EnvironmentType ShaderProfile
SmallVector< EntryProperties > EntryPropertyVec
const ComputedShaderFlags & getFunctionFlags(const Function *) const
Return the shader flags mask of the specified function Func.
void initialize(Module &, DXILResourceTypeMap &DRTM, const DXILResourceMap &DRM, const ModuleMetadataInfo &MMDI)
Construct ModuleShaderFlags for module Module M.
const ComputedShaderFlags & getCombinedFlags() const