LLVM 22.0.0git
IR2Vec.cpp
Go to the documentation of this file.
1//===- IR2Vec.cpp - Implementation of IR2Vec -----------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM
4// Exceptions. See the LICENSE file for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8///
9/// \file
10/// This file implements the IR2Vec algorithm.
11///
12//===----------------------------------------------------------------------===//
13
15
17#include "llvm/ADT/Sequence.h"
19#include "llvm/ADT/Statistic.h"
20#include "llvm/IR/CFG.h"
21#include "llvm/IR/Module.h"
22#include "llvm/IR/PassManager.h"
23#include "llvm/Support/Debug.h"
24#include "llvm/Support/Errc.h"
25#include "llvm/Support/Error.h"
27#include "llvm/Support/Format.h"
29
30using namespace llvm;
31using namespace ir2vec;
32
33#define DEBUG_TYPE "ir2vec"
34
35STATISTIC(VocabMissCounter,
36 "Number of lookups to entities not present in the vocabulary");
37
38namespace llvm {
39namespace ir2vec {
41
42// FIXME: Use a default vocab when not specified
44 VocabFile("ir2vec-vocab-path", cl::Optional,
45 cl::desc("Path to the vocabulary file for IR2Vec"), cl::init(""),
47cl::opt<float> OpcWeight("ir2vec-opc-weight", cl::Optional, cl::init(1.0),
48 cl::desc("Weight for opcode embeddings"),
50cl::opt<float> TypeWeight("ir2vec-type-weight", cl::Optional, cl::init(0.5),
51 cl::desc("Weight for type embeddings"),
53cl::opt<float> ArgWeight("ir2vec-arg-weight", cl::Optional, cl::init(0.2),
54 cl::desc("Weight for argument embeddings"),
57 "ir2vec-kind", cl::Optional,
59 "Generate symbolic embeddings"),
61 "Generate flow-aware embeddings")),
62 cl::init(IR2VecKind::Symbolic), cl::desc("IR2Vec embedding kind"),
64
65} // namespace ir2vec
66} // namespace llvm
67
69
70// ==----------------------------------------------------------------------===//
71// Local helper functions
72//===----------------------------------------------------------------------===//
73namespace llvm::json {
74inline bool fromJSON(const llvm::json::Value &E, Embedding &Out,
76 std::vector<double> TempOut;
77 if (!llvm::json::fromJSON(E, TempOut, P))
78 return false;
79 Out = Embedding(std::move(TempOut));
80 return true;
81}
82} // namespace llvm::json
83
84// ==----------------------------------------------------------------------===//
85// Embedding
86//===----------------------------------------------------------------------===//
88 assert(this->size() == RHS.size() && "Vectors must have the same dimension");
89 std::transform(this->begin(), this->end(), RHS.begin(), this->begin(),
90 std::plus<double>());
91 return *this;
92}
93
95 Embedding Result(*this);
96 Result += RHS;
97 return Result;
98}
99
101 assert(this->size() == RHS.size() && "Vectors must have the same dimension");
102 std::transform(this->begin(), this->end(), RHS.begin(), this->begin(),
103 std::minus<double>());
104 return *this;
105}
106
108 Embedding Result(*this);
109 Result -= RHS;
110 return Result;
111}
112
114 std::transform(this->begin(), this->end(), this->begin(),
115 [Factor](double Elem) { return Elem * Factor; });
116 return *this;
117}
118
119Embedding Embedding::operator*(double Factor) const {
120 Embedding Result(*this);
121 Result *= Factor;
122 return Result;
123}
124
125Embedding &Embedding::scaleAndAdd(const Embedding &Src, float Factor) {
126 assert(this->size() == Src.size() && "Vectors must have the same dimension");
127 for (size_t Itr = 0; Itr < this->size(); ++Itr)
128 (*this)[Itr] += Src[Itr] * Factor;
129 return *this;
130}
131
133 double Tolerance) const {
134 assert(this->size() == RHS.size() && "Vectors must have the same dimension");
135 for (size_t Itr = 0; Itr < this->size(); ++Itr)
136 if (std::abs((*this)[Itr] - RHS[Itr]) > Tolerance) {
137 LLVM_DEBUG(errs() << "Embedding mismatch at index " << Itr << ": "
138 << (*this)[Itr] << " vs " << RHS[Itr]
139 << "; Tolerance: " << Tolerance << "\n");
140 return false;
141 }
142 return true;
143}
144
146 OS << " [";
147 for (const auto &Elem : Data)
148 OS << " " << format("%.2f", Elem) << " ";
149 OS << "]\n";
150}
151
152// ==----------------------------------------------------------------------===//
153// Embedder and its subclasses
154//===----------------------------------------------------------------------===//
155
156std::unique_ptr<Embedder> Embedder::create(IR2VecKind Mode, const Function &F,
157 const Vocabulary &Vocab) {
158 switch (Mode) {
160 return std::make_unique<SymbolicEmbedder>(F, Vocab);
162 return std::make_unique<FlowAwareEmbedder>(F, Vocab);
163 }
164 return nullptr;
165}
166
168 Embedding FuncVector(Dimension, 0.0);
169
170 if (F.isDeclaration())
171 return FuncVector;
172
173 // Consider only the basic blocks that are reachable from entry
174 for (const BasicBlock *BB : depth_first(&F))
175 FuncVector += computeEmbeddings(*BB);
176 return FuncVector;
177}
178
180 Embedding BBVector(Dimension, 0);
181
182 // We consider only the non-debug and non-pseudo instructions
183 for (const auto &I : BB.instructionsWithoutDebug())
184 BBVector += computeEmbeddings(I);
185 return BBVector;
186}
187
189 // Currently, we always (re)compute the embeddings for symbolic embedder.
190 // This is cheaper than caching the vectors.
191 Embedding ArgEmb(Dimension, 0);
192 for (const auto &Op : I.operands())
193 ArgEmb += Vocab[*Op];
194 auto InstVector =
195 Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb;
196 if (const auto *IC = dyn_cast<CmpInst>(&I))
197 InstVector += Vocab[IC->getPredicate()];
198 return InstVector;
199}
200
202 // If we have already computed the embedding for this instruction, return it
203 auto It = InstVecMap.find(&I);
204 if (It != InstVecMap.end())
205 return It->second;
206
207 // TODO: Handle call instructions differently.
208 // For now, we treat them like other instructions
209 Embedding ArgEmb(Dimension, 0);
210 for (const auto &Op : I.operands()) {
211 // If the operand is defined elsewhere, we use its embedding
212 if (const auto *DefInst = dyn_cast<Instruction>(Op)) {
213 auto DefIt = InstVecMap.find(DefInst);
214 // Fixme (#159171): Ideally we should never miss an instruction
215 // embedding here.
216 // But when we have cyclic dependencies (e.g., phi
217 // nodes), we might miss the embedding. In such cases, we fall back to
218 // using the vocabulary embedding. This can be fixed by iterating to a
219 // fixed-point, or by using a simple solver for the set of simultaneous
220 // equations.
221 // Another case when we might miss an instruction embedding is when
222 // the operand instruction is in a different basic block that has not
223 // been processed yet. This can be fixed by processing the basic blocks
224 // in a topological order.
225 if (DefIt != InstVecMap.end())
226 ArgEmb += DefIt->second;
227 else
228 ArgEmb += Vocab[*Op];
229 }
230 // If the operand is not defined by an instruction, we use the
231 // vocabulary
232 else {
233 LLVM_DEBUG(errs() << "Using embedding from vocabulary for operand: "
234 << *Op << "=" << Vocab[*Op][0] << "\n");
235 ArgEmb += Vocab[*Op];
236 }
237 }
238 // Create the instruction vector by combining opcode, type, and arguments
239 // embeddings
240 auto InstVector =
241 Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb;
242 if (const auto *IC = dyn_cast<CmpInst>(&I))
243 InstVector += Vocab[IC->getPredicate()];
244 InstVecMap[&I] = InstVector;
245 return InstVector;
246}
247
248// ==----------------------------------------------------------------------===//
249// VocabStorage
250//===----------------------------------------------------------------------===//
251
252VocabStorage::VocabStorage(std::vector<std::vector<Embedding>> &&SectionData)
253 : Sections(std::move(SectionData)), TotalSize([&] {
254 assert(!Sections.empty() && "Vocabulary has no sections");
255 // Compute total size across all sections
256 size_t Size = 0;
257 for (const auto &Section : Sections) {
258 assert(!Section.empty() && "Vocabulary section is empty");
259 Size += Section.size();
260 }
261 return Size;
262 }()),
263 Dimension([&] {
264 // Get dimension from the first embedding in the first section - all
265 // embeddings must have the same dimension
266 assert(!Sections.empty() && "Vocabulary has no sections");
267 assert(!Sections[0].empty() && "First section of vocabulary is empty");
268 unsigned ExpectedDim = static_cast<unsigned>(Sections[0][0].size());
269
270 // Verify that all embeddings across all sections have the same
271 // dimension
272 [[maybe_unused]] auto allSameDim =
273 [ExpectedDim](const std::vector<Embedding> &Section) {
274 return std::all_of(Section.begin(), Section.end(),
275 [ExpectedDim](const Embedding &Emb) {
276 return Emb.size() == ExpectedDim;
277 });
278 };
279 assert(std::all_of(Sections.begin(), Sections.end(), allSameDim) &&
280 "All embeddings must have the same dimension");
281
282 return ExpectedDim;
283 }()) {}
284
286 assert(SectionId < Storage->Sections.size() && "Invalid section ID");
287 assert(LocalIndex < Storage->Sections[SectionId].size() &&
288 "Local index out of range");
289 return Storage->Sections[SectionId][LocalIndex];
290}
291
293 ++LocalIndex;
294 // Check if we need to move to the next section
295 if (SectionId < Storage->getNumSections() &&
296 LocalIndex >= Storage->Sections[SectionId].size()) {
297 assert(LocalIndex == Storage->Sections[SectionId].size() &&
298 "Local index should be at the end of the current section");
299 LocalIndex = 0;
300 ++SectionId;
301 }
302 return *this;
303}
304
306 const const_iterator &Other) const {
307 return Storage == Other.Storage && SectionId == Other.SectionId &&
308 LocalIndex == Other.LocalIndex;
309}
310
312 const const_iterator &Other) const {
313 return !(*this == Other);
314}
315
317 const json::Value &ParsedVocabValue,
318 VocabMap &TargetVocab, unsigned &Dim) {
319 json::Path::Root Path("");
320 const json::Object *RootObj = ParsedVocabValue.getAsObject();
321 if (!RootObj)
323 "JSON root is not an object");
324
325 const json::Value *SectionValue = RootObj->get(Key);
326 if (!SectionValue)
328 "Missing '" + std::string(Key) +
329 "' section in vocabulary file");
330 if (!json::fromJSON(*SectionValue, TargetVocab, Path))
332 "Unable to parse '" + std::string(Key) +
333 "' section from vocabulary");
334
335 Dim = TargetVocab.begin()->second.size();
336 if (Dim == 0)
338 "Dimension of '" + std::string(Key) +
339 "' section of the vocabulary is zero");
340
341 if (!std::all_of(TargetVocab.begin(), TargetVocab.end(),
342 [Dim](const std::pair<StringRef, Embedding> &Entry) {
343 return Entry.second.size() == Dim;
344 }))
345 return createStringError(
347 "All vectors in the '" + std::string(Key) +
348 "' section of the vocabulary are not of the same dimension");
349
350 return Error::success();
351}
352
353// ==----------------------------------------------------------------------===//
354// Vocabulary
355//===----------------------------------------------------------------------===//
356
358 assert(Opcode >= 1 && Opcode <= MaxOpcodes && "Invalid opcode");
359#define HANDLE_INST(NUM, OPCODE, CLASS) \
360 if (Opcode == NUM) { \
361 return #OPCODE; \
362 }
363#include "llvm/IR/Instruction.def"
364#undef HANDLE_INST
365 return "UnknownOpcode";
366}
367
368// Helper function to classify an operand into OperandKind
378
379unsigned Vocabulary::getPredicateLocalIndex(CmpInst::Predicate P) {
382 else
385}
386
387CmpInst::Predicate Vocabulary::getPredicateFromLocalIndex(unsigned LocalIndex) {
388 unsigned fcmpRange =
390 if (LocalIndex < fcmpRange)
392 LocalIndex);
393 else
395 LocalIndex - fcmpRange);
396}
397
399 static SmallString<16> PredNameBuffer;
401 PredNameBuffer = "FCMP_";
402 else
403 PredNameBuffer = "ICMP_";
404 PredNameBuffer += CmpInst::getPredicateName(Pred);
405 return PredNameBuffer;
406}
407
409 assert(Pos < NumCanonicalEntries && "Position out of bounds in vocabulary");
410 // Opcode
411 if (Pos < MaxOpcodes)
412 return getVocabKeyForOpcode(Pos + 1);
413 // Type
414 if (Pos < OperandBaseOffset)
415 return getVocabKeyForCanonicalTypeID(
416 static_cast<CanonicalTypeID>(Pos - MaxOpcodes));
417 // Operand
418 if (Pos < PredicateBaseOffset)
420 static_cast<OperandKind>(Pos - OperandBaseOffset));
421 // Predicates
422 return getVocabKeyForPredicate(getPredicate(Pos - PredicateBaseOffset));
423}
424
425// For now, assume vocabulary is stable unless explicitly invalidated.
427 ModuleAnalysisManager::Invalidator &Inv) const {
428 auto PAC = PA.getChecker<IR2VecVocabAnalysis>();
429 return !(PAC.preservedWhenStateless());
430}
431
433 float DummyVal = 0.1f;
434
435 // Create sections for opcodes, types, operands, and predicates
436 // Order must match Vocabulary::Section enum
437 std::vector<std::vector<Embedding>> Sections;
438 Sections.reserve(4);
439
440 // Opcodes section
441 std::vector<Embedding> OpcodeSec;
442 OpcodeSec.reserve(MaxOpcodes);
443 for (unsigned I = 0; I < MaxOpcodes; ++I) {
444 OpcodeSec.emplace_back(Dim, DummyVal);
445 DummyVal += 0.1f;
446 }
447 Sections.push_back(std::move(OpcodeSec));
448
449 // Types section
450 std::vector<Embedding> TypeSec;
451 TypeSec.reserve(MaxCanonicalTypeIDs);
452 for (unsigned I = 0; I < MaxCanonicalTypeIDs; ++I) {
453 TypeSec.emplace_back(Dim, DummyVal);
454 DummyVal += 0.1f;
455 }
456 Sections.push_back(std::move(TypeSec));
457
458 // Operands section
459 std::vector<Embedding> OperandSec;
460 OperandSec.reserve(MaxOperandKinds);
461 for (unsigned I = 0; I < MaxOperandKinds; ++I) {
462 OperandSec.emplace_back(Dim, DummyVal);
463 DummyVal += 0.1f;
464 }
465 Sections.push_back(std::move(OperandSec));
466
467 // Predicates section
468 std::vector<Embedding> PredicateSec;
469 PredicateSec.reserve(MaxPredicateKinds);
470 for (unsigned I = 0; I < MaxPredicateKinds; ++I) {
471 PredicateSec.emplace_back(Dim, DummyVal);
472 DummyVal += 0.1f;
473 }
474 Sections.push_back(std::move(PredicateSec));
475
476 return VocabStorage(std::move(Sections));
477}
478
479// ==----------------------------------------------------------------------===//
480// IR2VecVocabAnalysis
481//===----------------------------------------------------------------------===//
482
483// FIXME: Make this optional. We can avoid file reads
484// by auto-generating a default vocabulary during the build time.
485Error IR2VecVocabAnalysis::readVocabulary(VocabMap &OpcVocab,
486 VocabMap &TypeVocab,
487 VocabMap &ArgVocab) {
488 auto BufOrError = MemoryBuffer::getFileOrSTDIN(VocabFile, /*IsText=*/true);
489 if (!BufOrError)
490 return createFileError(VocabFile, BufOrError.getError());
491
492 auto Content = BufOrError.get()->getBuffer();
493
494 Expected<json::Value> ParsedVocabValue = json::parse(Content);
495 if (!ParsedVocabValue)
496 return ParsedVocabValue.takeError();
497
498 unsigned OpcodeDim = 0, TypeDim = 0, ArgDim = 0;
499 if (auto Err = VocabStorage::parseVocabSection("Opcodes", *ParsedVocabValue,
500 OpcVocab, OpcodeDim))
501 return Err;
502
503 if (auto Err = VocabStorage::parseVocabSection("Types", *ParsedVocabValue,
504 TypeVocab, TypeDim))
505 return Err;
506
507 if (auto Err = VocabStorage::parseVocabSection("Arguments", *ParsedVocabValue,
508 ArgVocab, ArgDim))
509 return Err;
510
511 if (!(OpcodeDim == TypeDim && TypeDim == ArgDim))
513 "Vocabulary sections have different dimensions");
514
515 return Error::success();
516}
517
518void IR2VecVocabAnalysis::generateVocabStorage(VocabMap &OpcVocab,
519 VocabMap &TypeVocab,
520 VocabMap &ArgVocab) {
521
522 // Helper for handling missing entities in the vocabulary.
523 // Currently, we use a zero vector. In the future, we will throw an error to
524 // ensure that *all* known entities are present in the vocabulary.
525 auto handleMissingEntity = [](const std::string &Val) {
526 LLVM_DEBUG(errs() << Val
527 << " is not in vocabulary, using zero vector; This "
528 "would result in an error in future.\n");
529 ++VocabMissCounter;
530 };
531
532 unsigned Dim = OpcVocab.begin()->second.size();
533 assert(Dim > 0 && "Vocabulary dimension must be greater than zero");
534
535 // Handle Opcodes
536 std::vector<Embedding> NumericOpcodeEmbeddings(Vocabulary::MaxOpcodes,
537 Embedding(Dim));
538 for (unsigned Opcode : seq(0u, Vocabulary::MaxOpcodes)) {
539 StringRef VocabKey = Vocabulary::getVocabKeyForOpcode(Opcode + 1);
540 auto It = OpcVocab.find(VocabKey.str());
541 if (It != OpcVocab.end())
542 NumericOpcodeEmbeddings[Opcode] = It->second;
543 else
544 handleMissingEntity(VocabKey.str());
545 }
546
547 // Handle Types - only canonical types are present in vocabulary
548 std::vector<Embedding> NumericTypeEmbeddings(Vocabulary::MaxCanonicalTypeIDs,
549 Embedding(Dim));
550 for (unsigned CTypeID : seq(0u, Vocabulary::MaxCanonicalTypeIDs)) {
551 StringRef VocabKey = Vocabulary::getVocabKeyForCanonicalTypeID(
552 static_cast<Vocabulary::CanonicalTypeID>(CTypeID));
553 if (auto It = TypeVocab.find(VocabKey.str()); It != TypeVocab.end()) {
554 NumericTypeEmbeddings[CTypeID] = It->second;
555 continue;
556 }
557 handleMissingEntity(VocabKey.str());
558 }
559
560 // Handle Arguments/Operands
561 std::vector<Embedding> NumericArgEmbeddings(Vocabulary::MaxOperandKinds,
562 Embedding(Dim));
563 for (unsigned OpKind : seq(0u, Vocabulary::MaxOperandKinds)) {
565 StringRef VocabKey = Vocabulary::getVocabKeyForOperandKind(Kind);
566 auto It = ArgVocab.find(VocabKey.str());
567 if (It != ArgVocab.end()) {
568 NumericArgEmbeddings[OpKind] = It->second;
569 continue;
570 }
571 handleMissingEntity(VocabKey.str());
572 }
573
574 // Handle Predicates: part of Operands section. We look up predicate keys
575 // in ArgVocab.
576 std::vector<Embedding> NumericPredEmbeddings(Vocabulary::MaxPredicateKinds,
577 Embedding(Dim, 0));
578 for (unsigned PK : seq(0u, Vocabulary::MaxPredicateKinds)) {
579 StringRef VocabKey =
580 Vocabulary::getVocabKeyForPredicate(Vocabulary::getPredicate(PK));
581 auto It = ArgVocab.find(VocabKey.str());
582 if (It != ArgVocab.end()) {
583 NumericPredEmbeddings[PK] = It->second;
584 continue;
585 }
586 handleMissingEntity(VocabKey.str());
587 }
588
589 // Create section-based storage instead of flat vocabulary
590 // Order must match Vocabulary::Section enum
591 std::vector<std::vector<Embedding>> Sections(4);
592 Sections[static_cast<unsigned>(Vocabulary::Section::Opcodes)] =
593 std::move(NumericOpcodeEmbeddings); // Section::Opcodes
594 Sections[static_cast<unsigned>(Vocabulary::Section::CanonicalTypes)] =
595 std::move(NumericTypeEmbeddings); // Section::CanonicalTypes
596 Sections[static_cast<unsigned>(Vocabulary::Section::Operands)] =
597 std::move(NumericArgEmbeddings); // Section::Operands
598 Sections[static_cast<unsigned>(Vocabulary::Section::Predicates)] =
599 std::move(NumericPredEmbeddings); // Section::Predicates
600
601 // Create VocabStorage from organized sections
602 Vocab.emplace(std::move(Sections));
603}
604
605void IR2VecVocabAnalysis::emitError(Error Err, LLVMContext &Ctx) {
606 handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EI) {
607 Ctx.emitError("Error reading vocabulary: " + EI.message());
608 });
609}
610
613 auto Ctx = &M.getContext();
614 // If vocabulary is already populated by the constructor, use it.
615 if (Vocab.has_value())
616 return Vocabulary(std::move(Vocab.value()));
617
618 // Otherwise, try to read from the vocabulary file.
619 if (VocabFile.empty()) {
620 // FIXME: Use default vocabulary
621 Ctx->emitError("IR2Vec vocabulary file path not specified; You may need to "
622 "set it using --ir2vec-vocab-path");
623 return Vocabulary(); // Return invalid result
624 }
625
626 VocabMap OpcVocab, TypeVocab, ArgVocab;
627 if (auto Err = readVocabulary(OpcVocab, TypeVocab, ArgVocab)) {
628 emitError(std::move(Err), *Ctx);
629 return Vocabulary();
630 }
631
632 // Scale the vocabulary sections based on the provided weights
633 auto scaleVocabSection = [](VocabMap &Vocab, double Weight) {
634 for (auto &Entry : Vocab)
635 Entry.second *= Weight;
636 };
637 scaleVocabSection(OpcVocab, OpcWeight);
638 scaleVocabSection(TypeVocab, TypeWeight);
639 scaleVocabSection(ArgVocab, ArgWeight);
640
641 // Generate the numeric lookup vocabulary
642 generateVocabStorage(OpcVocab, TypeVocab, ArgVocab);
643
644 return Vocabulary(std::move(Vocab.value()));
645}
646
647// ==----------------------------------------------------------------------===//
648// Printer Passes
649//===----------------------------------------------------------------------===//
650
653 auto &Vocabulary = MAM.getResult<IR2VecVocabAnalysis>(M);
654 assert(Vocabulary.isValid() && "IR2Vec Vocabulary is invalid");
655
656 for (Function &F : M) {
658 if (!Emb) {
659 OS << "Error creating IR2Vec embeddings \n";
660 continue;
661 }
662
663 OS << "IR2Vec embeddings for function " << F.getName() << ":\n";
664 OS << "Function vector: ";
665 Emb->getFunctionVector().print(OS);
666
667 OS << "Basic block vectors:\n";
668 for (const BasicBlock &BB : F) {
669 OS << "Basic block: " << BB.getName() << ":\n";
670 Emb->getBBVector(BB).print(OS);
671 }
672
673 OS << "Instruction vectors:\n";
674 for (const BasicBlock &BB : F) {
675 for (const Instruction &I : BB) {
676 OS << "Instruction: ";
677 I.print(OS);
678 Emb->getInstVector(I).print(OS);
679 }
680 }
681 }
682 return PreservedAnalyses::all();
683}
684
687 auto &IR2VecVocabulary = MAM.getResult<IR2VecVocabAnalysis>(M);
688 assert(IR2VecVocabulary.isValid() && "IR2Vec Vocabulary is invalid");
689
690 // Print each entry
691 unsigned Pos = 0;
692 for (const auto &Entry : IR2VecVocabulary) {
693 OS << "Key: " << IR2VecVocabulary.getStringKey(Pos++) << ": ";
694 Entry.print(OS);
695 }
696 return PreservedAnalyses::all();
697}
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
#define clEnumValN(ENUMVAL, FLAGNAME, DESC)
This file builds on the ADT/GraphTraits.h file to build generic depth first graph iterator.
This file defines the IR2Vec vocabulary analysis(IR2VecVocabAnalysis), the core ir2vec::Embedder inte...
This file provides various utilities for inspecting and working with the control flow graph in LLVM I...
Module.h This file contains the declarations for the Module class.
This header defines various interfaces for pass management in LLVM.
#define F(x, y, z)
Definition MD5.cpp:55
#define I(x, y, z)
Definition MD5.cpp:58
#define P(N)
ModuleAnalysisManager MAM
Provides some synthesis utilities to produce sequences of values.
This file defines the SmallVector class.
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
Definition Statistic.h:171
#define LLVM_DEBUG(...)
Definition Debug.h:114
LLVM Basic Block Representation.
Definition BasicBlock.h:62
LLVM_ABI iterator_range< filter_iterator< BasicBlock::const_iterator, std::function< bool(const Instruction &)> > > instructionsWithoutDebug(bool SkipPseudoOp=true) const
Return a const iterator range over the instructions in the block, skipping any debug instructions.
Predicate
This enumeration lists the possible predicates for CmpInst subclasses.
Definition InstrTypes.h:676
static LLVM_ABI StringRef getPredicateName(Predicate P)
iterator find(const_arg_type_t< KeyT > Val)
Definition DenseMap.h:167
iterator end()
Definition DenseMap.h:81
virtual std::string message() const
Return the error message as a string.
Definition Error.h:52
Lightweight error class with error context and mandatory checking.
Definition Error.h:159
static ErrorSuccess success()
Create a success value.
Definition Error.h:336
Tagged union holding either a T or a Error.
Definition Error.h:485
Error takeError()
Take ownership of the stored error.
Definition Error.h:612
LLVM_ABI PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM)
Definition IR2Vec.cpp:651
This analysis provides the vocabulary for IR2Vec.
Definition IR2Vec.h:614
ir2vec::Vocabulary Result
Definition IR2Vec.h:629
LLVM_ABI Result run(Module &M, ModuleAnalysisManager &MAM)
Definition IR2Vec.cpp:612
static LLVM_ABI AnalysisKey Key
Definition IR2Vec.h:625
LLVM_ABI PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM)
Definition IR2Vec.cpp:685
LLVM_ABI void emitError(const Instruction *I, const Twine &ErrorStr)
emitError - Emit an error message to the currently installed error handler with optional location inf...
static ErrorOr< std::unique_ptr< MemoryBuffer > > getFileOrSTDIN(const Twine &Filename, bool IsText=false, bool RequiresNullTerminator=true, std::optional< Align > Alignment=std::nullopt)
Open the specified file as a MemoryBuffer, or open stdin if the Filename is "-".
A Module instance is used to store all the information related to an LLVM module.
Definition Module.h:67
A set of analyses that are preserved following a run of a transformation pass.
Definition Analysis.h:112
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition Analysis.h:118
PreservedAnalysisChecker getChecker() const
Build a checker for this PreservedAnalyses and the specified analysis type.
Definition Analysis.h:275
SmallString - A SmallString is just a SmallVector with methods and accessors that make it work better...
Definition SmallString.h:26
StringRef - Represent a constant reference to a string, i.e.
Definition StringRef.h:55
std::string str() const
str - Get the contents as an std::string.
Definition StringRef.h:225
LLVM Value Representation.
Definition Value.h:75
static LLVM_ABI std::unique_ptr< Embedder > create(IR2VecKind Mode, const Function &F, const Vocabulary &Vocab)
Factory method to create an Embedder object.
Definition IR2Vec.cpp:156
const Vocabulary & Vocab
Definition IR2Vec.h:527
const unsigned Dimension
Dimension of the vector representation; captured from the input vocabulary.
Definition IR2Vec.h:530
Embedding computeEmbeddings() const
Function to compute embeddings.
Definition IR2Vec.cpp:167
const Function & F
Definition IR2Vec.h:526
Iterator support for section-based access.
Definition IR2Vec.h:196
const_iterator(const VocabStorage *Storage, unsigned SectionId, size_t LocalIndex)
Definition IR2Vec.h:202
LLVM_ABI bool operator!=(const const_iterator &Other) const
Definition IR2Vec.cpp:311
LLVM_ABI const_iterator & operator++()
Definition IR2Vec.cpp:292
LLVM_ABI const Embedding & operator*() const
Definition IR2Vec.cpp:285
LLVM_ABI bool operator==(const const_iterator &Other) const
Definition IR2Vec.cpp:305
Generic storage class for section-based vocabularies.
Definition IR2Vec.h:151
static Error parseVocabSection(StringRef Key, const json::Value &ParsedVocabValue, VocabMap &TargetVocab, unsigned &Dim)
Parse a vocabulary section from JSON and populate the target vocabulary map.
Definition IR2Vec.cpp:316
unsigned getNumSections() const
Get number of sections.
Definition IR2Vec.h:179
VocabStorage()
Default constructor creates empty storage (invalid state)
Definition IR2Vec.h:164
size_t size() const
Get total number of entries across all sections.
Definition IR2Vec.h:176
std::map< std::string, Embedding > VocabMap
Definition IR2Vec.h:217
Class for storing and accessing the IR2Vec vocabulary.
Definition IR2Vec.h:242
static LLVM_ABI StringRef getVocabKeyForOperandKind(OperandKind Kind)
Function to get vocabulary key for a given OperandKind.
Definition IR2Vec.h:352
LLVM_ABI bool invalidate(Module &M, const PreservedAnalyses &PA, ModuleAnalysisManager::Invalidator &Inv) const
Definition IR2Vec.cpp:426
static LLVM_ABI OperandKind getOperandKind(const Value *Op)
Function to classify an operand into OperandKind.
Definition IR2Vec.cpp:369
friend class llvm::IR2VecVocabAnalysis
Definition IR2Vec.h:243
static LLVM_ABI StringRef getStringKey(unsigned Pos)
Returns the string key for a given index position in the vocabulary.
Definition IR2Vec.cpp:408
static constexpr unsigned MaxCanonicalTypeIDs
Definition IR2Vec.h:312
static constexpr unsigned MaxOperandKinds
Definition IR2Vec.h:314
OperandKind
Operand kinds supported by IR2Vec Vocabulary.
Definition IR2Vec.h:298
static LLVM_ABI StringRef getVocabKeyForPredicate(CmpInst::Predicate P)
Function to get vocabulary key for a given predicate.
Definition IR2Vec.cpp:398
static LLVM_ABI StringRef getVocabKeyForOpcode(unsigned Opcode)
Function to get vocabulary key for a given Opcode.
Definition IR2Vec.cpp:357
LLVM_ABI bool isValid() const
Definition IR2Vec.h:330
static LLVM_ABI VocabStorage createDummyVocabForTest(unsigned Dim=1)
Create a dummy vocabulary for testing purposes.
Definition IR2Vec.cpp:432
static constexpr unsigned MaxPredicateKinds
Definition IR2Vec.h:318
CanonicalTypeID
Canonical type IDs supported by IR2Vec Vocabulary.
Definition IR2Vec.h:281
An Object is a JSON object, which maps strings to heterogenous JSON values.
Definition JSON.h:98
LLVM_ABI Value * get(StringRef K)
Definition JSON.cpp:30
The root is the trivial Path to the root value.
Definition JSON.h:713
A "cursor" marking a position within a Value.
Definition JSON.h:666
A Value is an JSON value of unknown type.
Definition JSON.h:290
const json::Object * getAsObject() const
Definition JSON.h:464
This class implements an extremely fast bulk output stream that can only output to a stream.
Definition raw_ostream.h:53
ValuesClass values(OptsTy... Options)
Helper to build a ValuesClass by forwarding a variable number of arguments as an initializer list to ...
initializer< Ty > init(const Ty &Val)
static cl::opt< std::string > VocabFile("ir2vec-vocab-path", cl::Optional, cl::desc("Path to the vocabulary file for IR2Vec"), cl::init(""), cl::cat(IR2VecCategory))
LLVM_ABI cl::opt< float > ArgWeight
LLVM_ABI cl::opt< float > OpcWeight
LLVM_ABI cl::opt< float > TypeWeight
LLVM_ABI cl::opt< IR2VecKind > IR2VecEmbeddingKind
llvm::cl::OptionCategory IR2VecCategory
LLVM_ABI llvm::Expected< Value > parse(llvm::StringRef JSON)
Parses the provided JSON source, or returns a ParseError.
Definition JSON.cpp:675
bool fromJSON(const Value &E, std::string &Out, Path P)
Definition JSON.h:742
ir2vec::Embedding Embedding
Definition MIR2Vec.h:59
This is an optimization pass for GlobalISel generic memory operations.
Error createFileError(const Twine &F, Error E)
Concatenate a source file path and/or name with an Error.
Definition Error.h:1399
decltype(auto) dyn_cast(const From &Val)
dyn_cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:644
void handleAllErrors(Error E, HandlerTs &&... Handlers)
Behaves the same as handleErrors, except that by contract all errors must be handled by the given han...
Definition Error.h:990
Error createStringError(std::error_code EC, char const *Fmt, const Ts &... Vals)
Create formatted StringError object.
Definition Error.h:1305
@ illegal_byte_sequence
Definition Errc.h:52
@ invalid_argument
Definition Errc.h:56
IR2VecKind
IR2Vec computes two kinds of embeddings: Symbolic and Flow-aware.
Definition IR2Vec.h:71
bool isa(const From &Val)
isa<X> - Return true if the parameter to the template is an instance of one of the template type argu...
Definition Casting.h:548
format_object< Ts... > format(const char *Fmt, const Ts &... Vals)
These are helper functions used to produce formatted output.
Definition Format.h:118
LLVM_ATTRIBUTE_VISIBILITY_DEFAULT AnalysisKey InnerAnalysisManagerProxy< AnalysisManagerT, IRUnitT, ExtraArgTs... >::Key
LLVM_ABI raw_fd_ostream & errs()
This returns a reference to a raw_ostream for standard error.
@ Other
Any other memory.
Definition ModRef.h:68
DWARFExpression::Operation Op
OutputIt move(R &&Range, OutputIt Out)
Provide wrappers to std::move which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1869
iterator_range< df_iterator< T > > depth_first(const T &G)
auto seq(T Begin, T End)
Iterate over an integral type from Begin up to - but not including - End.
Definition Sequence.h:305
AnalysisManager< Module > ModuleAnalysisManager
Convenience typedef for the Module analysis manager.
Definition MIRParser.h:39
Implement std::hash so that hash_code can be used in STL containers.
Definition BitVector.h:867
A special type used by analysis passes to provide an address that identifies that particular analysis...
Definition Analysis.h:29
Embedding is a datatype that wraps std::vector<double>.
Definition IR2Vec.h:87
LLVM_ABI bool approximatelyEquals(const Embedding &RHS, double Tolerance=1e-4) const
Returns true if the embedding is approximately equal to the RHS embedding within the specified tolera...
Definition IR2Vec.cpp:132
LLVM_ABI Embedding & operator+=(const Embedding &RHS)
Arithmetic operators.
Definition IR2Vec.cpp:87
LLVM_ABI Embedding operator-(const Embedding &RHS) const
Definition IR2Vec.cpp:107
LLVM_ABI Embedding & operator-=(const Embedding &RHS)
Definition IR2Vec.cpp:100
LLVM_ABI Embedding operator*(double Factor) const
Definition IR2Vec.cpp:119
size_t size() const
Definition IR2Vec.h:100
LLVM_ABI Embedding & operator*=(double Factor)
Definition IR2Vec.cpp:113
LLVM_ABI Embedding operator+(const Embedding &RHS) const
Definition IR2Vec.cpp:94
LLVM_ABI Embedding & scaleAndAdd(const Embedding &Src, float Factor)
Adds Src Embedding scaled by Factor with the called Embedding.
Definition IR2Vec.cpp:125
LLVM_ABI void print(raw_ostream &OS) const
Definition IR2Vec.cpp:145