LLVM 22.0.0git
IR2Vec.cpp
Go to the documentation of this file.
1//===- IR2Vec.cpp - Implementation of IR2Vec -----------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM
4// Exceptions. See the LICENSE file for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8///
9/// \file
10/// This file implements the IR2Vec algorithm.
11///
12//===----------------------------------------------------------------------===//
13
15
17#include "llvm/ADT/Sequence.h"
19#include "llvm/ADT/Statistic.h"
20#include "llvm/IR/CFG.h"
21#include "llvm/IR/Module.h"
22#include "llvm/IR/PassManager.h"
23#include "llvm/Support/Debug.h"
24#include "llvm/Support/Errc.h"
25#include "llvm/Support/Error.h"
27#include "llvm/Support/Format.h"
29
30using namespace llvm;
31using namespace ir2vec;
32
33#define DEBUG_TYPE "ir2vec"
34
35STATISTIC(VocabMissCounter,
36 "Number of lookups to entities not present in the vocabulary");
37
38namespace llvm {
39namespace ir2vec {
41
42// FIXME: Use a default vocab when not specified
44 VocabFile("ir2vec-vocab-path", cl::Optional,
45 cl::desc("Path to the vocabulary file for IR2Vec"), cl::init(""),
47cl::opt<float> OpcWeight("ir2vec-opc-weight", cl::Optional, cl::init(1.0),
48 cl::desc("Weight for opcode embeddings"),
50cl::opt<float> TypeWeight("ir2vec-type-weight", cl::Optional, cl::init(0.5),
51 cl::desc("Weight for type embeddings"),
53cl::opt<float> ArgWeight("ir2vec-arg-weight", cl::Optional, cl::init(0.2),
54 cl::desc("Weight for argument embeddings"),
57 "ir2vec-kind", cl::Optional,
59 "Generate symbolic embeddings"),
61 "Generate flow-aware embeddings")),
62 cl::init(IR2VecKind::Symbolic), cl::desc("IR2Vec embedding kind"),
64
65} // namespace ir2vec
66} // namespace llvm
67
69
70// ==----------------------------------------------------------------------===//
71// Local helper functions
72//===----------------------------------------------------------------------===//
73namespace llvm::json {
74inline bool fromJSON(const llvm::json::Value &E, Embedding &Out,
76 std::vector<double> TempOut;
77 if (!llvm::json::fromJSON(E, TempOut, P))
78 return false;
79 Out = Embedding(std::move(TempOut));
80 return true;
81}
82} // namespace llvm::json
83
84// ==----------------------------------------------------------------------===//
85// Embedding
86//===----------------------------------------------------------------------===//
88 assert(this->size() == RHS.size() && "Vectors must have the same dimension");
89 std::transform(this->begin(), this->end(), RHS.begin(), this->begin(),
90 std::plus<double>());
91 return *this;
92}
93
95 Embedding Result(*this);
96 Result += RHS;
97 return Result;
98}
99
101 assert(this->size() == RHS.size() && "Vectors must have the same dimension");
102 std::transform(this->begin(), this->end(), RHS.begin(), this->begin(),
103 std::minus<double>());
104 return *this;
105}
106
108 Embedding Result(*this);
109 Result -= RHS;
110 return Result;
111}
112
114 std::transform(this->begin(), this->end(), this->begin(),
115 [Factor](double Elem) { return Elem * Factor; });
116 return *this;
117}
118
119Embedding Embedding::operator*(double Factor) const {
120 Embedding Result(*this);
121 Result *= Factor;
122 return Result;
123}
124
125Embedding &Embedding::scaleAndAdd(const Embedding &Src, float Factor) {
126 assert(this->size() == Src.size() && "Vectors must have the same dimension");
127 for (size_t Itr = 0; Itr < this->size(); ++Itr)
128 (*this)[Itr] += Src[Itr] * Factor;
129 return *this;
130}
131
133 double Tolerance) const {
134 assert(this->size() == RHS.size() && "Vectors must have the same dimension");
135 for (size_t Itr = 0; Itr < this->size(); ++Itr)
136 if (std::abs((*this)[Itr] - RHS[Itr]) > Tolerance) {
137 LLVM_DEBUG(errs() << "Embedding mismatch at index " << Itr << ": "
138 << (*this)[Itr] << " vs " << RHS[Itr]
139 << "; Tolerance: " << Tolerance << "\n");
140 return false;
141 }
142 return true;
143}
144
146 OS << " [";
147 for (const auto &Elem : Data)
148 OS << " " << format("%.2f", Elem) << " ";
149 OS << "]\n";
150}
151
152// ==----------------------------------------------------------------------===//
153// Embedder and its subclasses
154//===----------------------------------------------------------------------===//
155
160
161std::unique_ptr<Embedder> Embedder::create(IR2VecKind Mode, const Function &F,
162 const Vocabulary &Vocab) {
163 switch (Mode) {
165 return std::make_unique<SymbolicEmbedder>(F, Vocab);
167 return std::make_unique<FlowAwareEmbedder>(F, Vocab);
168 }
169 return nullptr;
170}
171
173 if (InstVecMap.empty())
175 return InstVecMap;
176}
177
179 if (BBVecMap.empty())
181 return BBVecMap;
182}
183
185 auto It = BBVecMap.find(&BB);
186 if (It != BBVecMap.end())
187 return It->second;
189 return BBVecMap[&BB];
190}
191
193 // Currently, we always (re)compute the embeddings for the function.
194 // This is cheaper than caching the vector.
196 return FuncVector;
197}
198
200 if (F.isDeclaration())
201 return;
202
204
205 // Consider only the basic blocks that are reachable from entry
206 for (const BasicBlock *BB : depth_first(&F)) {
208 FuncVector += BBVecMap[BB];
209 }
210}
211
213 Embedding BBVector(Dimension, 0);
214
215 // We consider only the non-debug and non-pseudo instructions
216 for (const auto &I : BB.instructionsWithoutDebug()) {
217 Embedding ArgEmb(Dimension, 0);
218 for (const auto &Op : I.operands())
219 ArgEmb += Vocab[*Op];
220 auto InstVector =
221 Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb;
222 if (const auto *IC = dyn_cast<CmpInst>(&I))
223 InstVector += Vocab[IC->getPredicate()];
224 InstVecMap[&I] = InstVector;
225 BBVector += InstVector;
226 }
227 BBVecMap[&BB] = BBVector;
228}
229
231 Embedding BBVector(Dimension, 0);
232
233 // We consider only the non-debug and non-pseudo instructions
234 for (const auto &I : BB.instructionsWithoutDebug()) {
235 // TODO: Handle call instructions differently.
236 // For now, we treat them like other instructions
237 Embedding ArgEmb(Dimension, 0);
238 for (const auto &Op : I.operands()) {
239 // If the operand is defined elsewhere, we use its embedding
240 if (const auto *DefInst = dyn_cast<Instruction>(Op)) {
241 auto DefIt = InstVecMap.find(DefInst);
242 // Fixme (#159171): Ideally we should never miss an instruction
243 // embedding here.
244 // But when we have cyclic dependencies (e.g., phi
245 // nodes), we might miss the embedding. In such cases, we fall back to
246 // using the vocabulary embedding. This can be fixed by iterating to a
247 // fixed-point, or by using a simple solver for the set of simultaneous
248 // equations.
249 // Another case when we might miss an instruction embedding is when
250 // the operand instruction is in a different basic block that has not
251 // been processed yet. This can be fixed by processing the basic blocks
252 // in a topological order.
253 if (DefIt != InstVecMap.end())
254 ArgEmb += DefIt->second;
255 else
256 ArgEmb += Vocab[*Op];
257 }
258 // If the operand is not defined by an instruction, we use the vocabulary
259 else {
260 LLVM_DEBUG(errs() << "Using embedding from vocabulary for operand: "
261 << *Op << "=" << Vocab[*Op][0] << "\n");
262 ArgEmb += Vocab[*Op];
263 }
264 }
265 // Create the instruction vector by combining opcode, type, and arguments
266 // embeddings
267 auto InstVector =
268 Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb;
269 // Add compare predicate embedding as an additional operand if applicable
270 if (const auto *IC = dyn_cast<CmpInst>(&I))
271 InstVector += Vocab[IC->getPredicate()];
272 InstVecMap[&I] = InstVector;
273 BBVector += InstVector;
274 }
275 BBVecMap[&BB] = BBVector;
276}
277
278// ==----------------------------------------------------------------------===//
279// VocabStorage
280//===----------------------------------------------------------------------===//
281
282VocabStorage::VocabStorage(std::vector<std::vector<Embedding>> &&SectionData)
283 : Sections(std::move(SectionData)), TotalSize([&] {
284 assert(!Sections.empty() && "Vocabulary has no sections");
285 // Compute total size across all sections
286 size_t Size = 0;
287 for (const auto &Section : Sections) {
288 assert(!Section.empty() && "Vocabulary section is empty");
289 Size += Section.size();
290 }
291 return Size;
292 }()),
293 Dimension([&] {
294 // Get dimension from the first embedding in the first section - all
295 // embeddings must have the same dimension
296 assert(!Sections.empty() && "Vocabulary has no sections");
297 assert(!Sections[0].empty() && "First section of vocabulary is empty");
298 unsigned ExpectedDim = static_cast<unsigned>(Sections[0][0].size());
299
300 // Verify that all embeddings across all sections have the same
301 // dimension
302 [[maybe_unused]] auto allSameDim =
303 [ExpectedDim](const std::vector<Embedding> &Section) {
304 return std::all_of(Section.begin(), Section.end(),
305 [ExpectedDim](const Embedding &Emb) {
306 return Emb.size() == ExpectedDim;
307 });
308 };
309 assert(std::all_of(Sections.begin(), Sections.end(), allSameDim) &&
310 "All embeddings must have the same dimension");
311
312 return ExpectedDim;
313 }()) {}
314
316 assert(SectionId < Storage->Sections.size() && "Invalid section ID");
317 assert(LocalIndex < Storage->Sections[SectionId].size() &&
318 "Local index out of range");
319 return Storage->Sections[SectionId][LocalIndex];
320}
321
323 ++LocalIndex;
324 // Check if we need to move to the next section
325 if (SectionId < Storage->getNumSections() &&
326 LocalIndex >= Storage->Sections[SectionId].size()) {
327 assert(LocalIndex == Storage->Sections[SectionId].size() &&
328 "Local index should be at the end of the current section");
329 LocalIndex = 0;
330 ++SectionId;
331 }
332 return *this;
333}
334
336 const const_iterator &Other) const {
337 return Storage == Other.Storage && SectionId == Other.SectionId &&
338 LocalIndex == Other.LocalIndex;
339}
340
342 const const_iterator &Other) const {
343 return !(*this == Other);
344}
345
347 const json::Value &ParsedVocabValue,
348 VocabMap &TargetVocab, unsigned &Dim) {
349 json::Path::Root Path("");
350 const json::Object *RootObj = ParsedVocabValue.getAsObject();
351 if (!RootObj)
353 "JSON root is not an object");
354
355 const json::Value *SectionValue = RootObj->get(Key);
356 if (!SectionValue)
358 "Missing '" + std::string(Key) +
359 "' section in vocabulary file");
360 if (!json::fromJSON(*SectionValue, TargetVocab, Path))
362 "Unable to parse '" + std::string(Key) +
363 "' section from vocabulary");
364
365 Dim = TargetVocab.begin()->second.size();
366 if (Dim == 0)
368 "Dimension of '" + std::string(Key) +
369 "' section of the vocabulary is zero");
370
371 if (!std::all_of(TargetVocab.begin(), TargetVocab.end(),
372 [Dim](const std::pair<StringRef, Embedding> &Entry) {
373 return Entry.second.size() == Dim;
374 }))
375 return createStringError(
377 "All vectors in the '" + std::string(Key) +
378 "' section of the vocabulary are not of the same dimension");
379
380 return Error::success();
381}
382
383// ==----------------------------------------------------------------------===//
384// Vocabulary
385//===----------------------------------------------------------------------===//
386
388 assert(Opcode >= 1 && Opcode <= MaxOpcodes && "Invalid opcode");
389#define HANDLE_INST(NUM, OPCODE, CLASS) \
390 if (Opcode == NUM) { \
391 return #OPCODE; \
392 }
393#include "llvm/IR/Instruction.def"
394#undef HANDLE_INST
395 return "UnknownOpcode";
396}
397
398// Helper function to classify an operand into OperandKind
408
409unsigned Vocabulary::getPredicateLocalIndex(CmpInst::Predicate P) {
412 else
415}
416
417CmpInst::Predicate Vocabulary::getPredicateFromLocalIndex(unsigned LocalIndex) {
418 unsigned fcmpRange =
420 if (LocalIndex < fcmpRange)
422 LocalIndex);
423 else
425 LocalIndex - fcmpRange);
426}
427
429 static SmallString<16> PredNameBuffer;
431 PredNameBuffer = "FCMP_";
432 else
433 PredNameBuffer = "ICMP_";
434 PredNameBuffer += CmpInst::getPredicateName(Pred);
435 return PredNameBuffer;
436}
437
439 assert(Pos < NumCanonicalEntries && "Position out of bounds in vocabulary");
440 // Opcode
441 if (Pos < MaxOpcodes)
442 return getVocabKeyForOpcode(Pos + 1);
443 // Type
444 if (Pos < OperandBaseOffset)
445 return getVocabKeyForCanonicalTypeID(
446 static_cast<CanonicalTypeID>(Pos - MaxOpcodes));
447 // Operand
448 if (Pos < PredicateBaseOffset)
450 static_cast<OperandKind>(Pos - OperandBaseOffset));
451 // Predicates
452 return getVocabKeyForPredicate(getPredicate(Pos - PredicateBaseOffset));
453}
454
455// For now, assume vocabulary is stable unless explicitly invalidated.
457 ModuleAnalysisManager::Invalidator &Inv) const {
458 auto PAC = PA.getChecker<IR2VecVocabAnalysis>();
459 return !(PAC.preservedWhenStateless());
460}
461
463 float DummyVal = 0.1f;
464
465 // Create sections for opcodes, types, operands, and predicates
466 // Order must match Vocabulary::Section enum
467 std::vector<std::vector<Embedding>> Sections;
468 Sections.reserve(4);
469
470 // Opcodes section
471 std::vector<Embedding> OpcodeSec;
472 OpcodeSec.reserve(MaxOpcodes);
473 for (unsigned I = 0; I < MaxOpcodes; ++I) {
474 OpcodeSec.emplace_back(Dim, DummyVal);
475 DummyVal += 0.1f;
476 }
477 Sections.push_back(std::move(OpcodeSec));
478
479 // Types section
480 std::vector<Embedding> TypeSec;
481 TypeSec.reserve(MaxCanonicalTypeIDs);
482 for (unsigned I = 0; I < MaxCanonicalTypeIDs; ++I) {
483 TypeSec.emplace_back(Dim, DummyVal);
484 DummyVal += 0.1f;
485 }
486 Sections.push_back(std::move(TypeSec));
487
488 // Operands section
489 std::vector<Embedding> OperandSec;
490 OperandSec.reserve(MaxOperandKinds);
491 for (unsigned I = 0; I < MaxOperandKinds; ++I) {
492 OperandSec.emplace_back(Dim, DummyVal);
493 DummyVal += 0.1f;
494 }
495 Sections.push_back(std::move(OperandSec));
496
497 // Predicates section
498 std::vector<Embedding> PredicateSec;
499 PredicateSec.reserve(MaxPredicateKinds);
500 for (unsigned I = 0; I < MaxPredicateKinds; ++I) {
501 PredicateSec.emplace_back(Dim, DummyVal);
502 DummyVal += 0.1f;
503 }
504 Sections.push_back(std::move(PredicateSec));
505
506 return VocabStorage(std::move(Sections));
507}
508
509// ==----------------------------------------------------------------------===//
510// IR2VecVocabAnalysis
511//===----------------------------------------------------------------------===//
512
513// FIXME: Make this optional. We can avoid file reads
514// by auto-generating a default vocabulary during the build time.
515Error IR2VecVocabAnalysis::readVocabulary(VocabMap &OpcVocab,
516 VocabMap &TypeVocab,
517 VocabMap &ArgVocab) {
518 auto BufOrError = MemoryBuffer::getFileOrSTDIN(VocabFile, /*IsText=*/true);
519 if (!BufOrError)
520 return createFileError(VocabFile, BufOrError.getError());
521
522 auto Content = BufOrError.get()->getBuffer();
523
524 Expected<json::Value> ParsedVocabValue = json::parse(Content);
525 if (!ParsedVocabValue)
526 return ParsedVocabValue.takeError();
527
528 unsigned OpcodeDim = 0, TypeDim = 0, ArgDim = 0;
529 if (auto Err = VocabStorage::parseVocabSection("Opcodes", *ParsedVocabValue,
530 OpcVocab, OpcodeDim))
531 return Err;
532
533 if (auto Err = VocabStorage::parseVocabSection("Types", *ParsedVocabValue,
534 TypeVocab, TypeDim))
535 return Err;
536
537 if (auto Err = VocabStorage::parseVocabSection("Arguments", *ParsedVocabValue,
538 ArgVocab, ArgDim))
539 return Err;
540
541 if (!(OpcodeDim == TypeDim && TypeDim == ArgDim))
543 "Vocabulary sections have different dimensions");
544
545 return Error::success();
546}
547
548void IR2VecVocabAnalysis::generateVocabStorage(VocabMap &OpcVocab,
549 VocabMap &TypeVocab,
550 VocabMap &ArgVocab) {
551
552 // Helper for handling missing entities in the vocabulary.
553 // Currently, we use a zero vector. In the future, we will throw an error to
554 // ensure that *all* known entities are present in the vocabulary.
555 auto handleMissingEntity = [](const std::string &Val) {
556 LLVM_DEBUG(errs() << Val
557 << " is not in vocabulary, using zero vector; This "
558 "would result in an error in future.\n");
559 ++VocabMissCounter;
560 };
561
562 unsigned Dim = OpcVocab.begin()->second.size();
563 assert(Dim > 0 && "Vocabulary dimension must be greater than zero");
564
565 // Handle Opcodes
566 std::vector<Embedding> NumericOpcodeEmbeddings(Vocabulary::MaxOpcodes,
567 Embedding(Dim));
568 for (unsigned Opcode : seq(0u, Vocabulary::MaxOpcodes)) {
569 StringRef VocabKey = Vocabulary::getVocabKeyForOpcode(Opcode + 1);
570 auto It = OpcVocab.find(VocabKey.str());
571 if (It != OpcVocab.end())
572 NumericOpcodeEmbeddings[Opcode] = It->second;
573 else
574 handleMissingEntity(VocabKey.str());
575 }
576
577 // Handle Types - only canonical types are present in vocabulary
578 std::vector<Embedding> NumericTypeEmbeddings(Vocabulary::MaxCanonicalTypeIDs,
579 Embedding(Dim));
580 for (unsigned CTypeID : seq(0u, Vocabulary::MaxCanonicalTypeIDs)) {
581 StringRef VocabKey = Vocabulary::getVocabKeyForCanonicalTypeID(
582 static_cast<Vocabulary::CanonicalTypeID>(CTypeID));
583 if (auto It = TypeVocab.find(VocabKey.str()); It != TypeVocab.end()) {
584 NumericTypeEmbeddings[CTypeID] = It->second;
585 continue;
586 }
587 handleMissingEntity(VocabKey.str());
588 }
589
590 // Handle Arguments/Operands
591 std::vector<Embedding> NumericArgEmbeddings(Vocabulary::MaxOperandKinds,
592 Embedding(Dim));
593 for (unsigned OpKind : seq(0u, Vocabulary::MaxOperandKinds)) {
595 StringRef VocabKey = Vocabulary::getVocabKeyForOperandKind(Kind);
596 auto It = ArgVocab.find(VocabKey.str());
597 if (It != ArgVocab.end()) {
598 NumericArgEmbeddings[OpKind] = It->second;
599 continue;
600 }
601 handleMissingEntity(VocabKey.str());
602 }
603
604 // Handle Predicates: part of Operands section. We look up predicate keys
605 // in ArgVocab.
606 std::vector<Embedding> NumericPredEmbeddings(Vocabulary::MaxPredicateKinds,
607 Embedding(Dim, 0));
608 for (unsigned PK : seq(0u, Vocabulary::MaxPredicateKinds)) {
609 StringRef VocabKey =
610 Vocabulary::getVocabKeyForPredicate(Vocabulary::getPredicate(PK));
611 auto It = ArgVocab.find(VocabKey.str());
612 if (It != ArgVocab.end()) {
613 NumericPredEmbeddings[PK] = It->second;
614 continue;
615 }
616 handleMissingEntity(VocabKey.str());
617 }
618
619 // Create section-based storage instead of flat vocabulary
620 // Order must match Vocabulary::Section enum
621 std::vector<std::vector<Embedding>> Sections(4);
622 Sections[static_cast<unsigned>(Vocabulary::Section::Opcodes)] =
623 std::move(NumericOpcodeEmbeddings); // Section::Opcodes
624 Sections[static_cast<unsigned>(Vocabulary::Section::CanonicalTypes)] =
625 std::move(NumericTypeEmbeddings); // Section::CanonicalTypes
626 Sections[static_cast<unsigned>(Vocabulary::Section::Operands)] =
627 std::move(NumericArgEmbeddings); // Section::Operands
628 Sections[static_cast<unsigned>(Vocabulary::Section::Predicates)] =
629 std::move(NumericPredEmbeddings); // Section::Predicates
630
631 // Create VocabStorage from organized sections
632 Vocab.emplace(std::move(Sections));
633}
634
635void IR2VecVocabAnalysis::emitError(Error Err, LLVMContext &Ctx) {
636 handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EI) {
637 Ctx.emitError("Error reading vocabulary: " + EI.message());
638 });
639}
640
643 auto Ctx = &M.getContext();
644 // If vocabulary is already populated by the constructor, use it.
645 if (Vocab.has_value())
646 return Vocabulary(std::move(Vocab.value()));
647
648 // Otherwise, try to read from the vocabulary file.
649 if (VocabFile.empty()) {
650 // FIXME: Use default vocabulary
651 Ctx->emitError("IR2Vec vocabulary file path not specified; You may need to "
652 "set it using --ir2vec-vocab-path");
653 return Vocabulary(); // Return invalid result
654 }
655
656 VocabMap OpcVocab, TypeVocab, ArgVocab;
657 if (auto Err = readVocabulary(OpcVocab, TypeVocab, ArgVocab)) {
658 emitError(std::move(Err), *Ctx);
659 return Vocabulary();
660 }
661
662 // Scale the vocabulary sections based on the provided weights
663 auto scaleVocabSection = [](VocabMap &Vocab, double Weight) {
664 for (auto &Entry : Vocab)
665 Entry.second *= Weight;
666 };
667 scaleVocabSection(OpcVocab, OpcWeight);
668 scaleVocabSection(TypeVocab, TypeWeight);
669 scaleVocabSection(ArgVocab, ArgWeight);
670
671 // Generate the numeric lookup vocabulary
672 generateVocabStorage(OpcVocab, TypeVocab, ArgVocab);
673
674 return Vocabulary(std::move(Vocab.value()));
675}
676
677// ==----------------------------------------------------------------------===//
678// Printer Passes
679//===----------------------------------------------------------------------===//
680
683 auto &Vocabulary = MAM.getResult<IR2VecVocabAnalysis>(M);
684 assert(Vocabulary.isValid() && "IR2Vec Vocabulary is invalid");
685
686 for (Function &F : M) {
688 if (!Emb) {
689 OS << "Error creating IR2Vec embeddings \n";
690 continue;
691 }
692
693 OS << "IR2Vec embeddings for function " << F.getName() << ":\n";
694 OS << "Function vector: ";
695 Emb->getFunctionVector().print(OS);
696
697 OS << "Basic block vectors:\n";
698 const auto &BBMap = Emb->getBBVecMap();
699 for (const BasicBlock &BB : F) {
700 auto It = BBMap.find(&BB);
701 if (It != BBMap.end()) {
702 OS << "Basic block: " << BB.getName() << ":\n";
703 It->second.print(OS);
704 }
705 }
706
707 OS << "Instruction vectors:\n";
708 const auto &InstMap = Emb->getInstVecMap();
709 for (const BasicBlock &BB : F) {
710 for (const Instruction &I : BB) {
711 auto It = InstMap.find(&I);
712 if (It != InstMap.end()) {
713 OS << "Instruction: ";
714 I.print(OS);
715 It->second.print(OS);
716 }
717 }
718 }
719 }
720 return PreservedAnalyses::all();
721}
722
725 auto &IR2VecVocabulary = MAM.getResult<IR2VecVocabAnalysis>(M);
726 assert(IR2VecVocabulary.isValid() && "IR2Vec Vocabulary is invalid");
727
728 // Print each entry
729 unsigned Pos = 0;
730 for (const auto &Entry : IR2VecVocabulary) {
731 OS << "Key: " << IR2VecVocabulary.getStringKey(Pos++) << ": ";
732 Entry.print(OS);
733 }
734 return PreservedAnalyses::all();
735}
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
#define clEnumValN(ENUMVAL, FLAGNAME, DESC)
This file builds on the ADT/GraphTraits.h file to build generic depth first graph iterator.
This file defines the IR2Vec vocabulary analysis(IR2VecVocabAnalysis), the core ir2vec::Embedder inte...
This file provides various utilities for inspecting and working with the control flow graph in LLVM I...
Module.h This file contains the declarations for the Module class.
This header defines various interfaces for pass management in LLVM.
#define F(x, y, z)
Definition MD5.cpp:55
#define I(x, y, z)
Definition MD5.cpp:58
#define P(N)
ModuleAnalysisManager MAM
Provides some synthesis utilities to produce sequences of values.
This file defines the SmallVector class.
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
Definition Statistic.h:171
#define LLVM_DEBUG(...)
Definition Debug.h:114
LLVM Basic Block Representation.
Definition BasicBlock.h:62
LLVM_ABI iterator_range< filter_iterator< BasicBlock::const_iterator, std::function< bool(const Instruction &)> > > instructionsWithoutDebug(bool SkipPseudoOp=true) const
Return a const iterator range over the instructions in the block, skipping any debug instructions.
Predicate
This enumeration lists the possible predicates for CmpInst subclasses.
Definition InstrTypes.h:676
static LLVM_ABI StringRef getPredicateName(Predicate P)
iterator find(const_arg_type_t< KeyT > Val)
Definition DenseMap.h:167
iterator end()
Definition DenseMap.h:81
virtual std::string message() const
Return the error message as a string.
Definition Error.h:52
Lightweight error class with error context and mandatory checking.
Definition Error.h:159
static ErrorSuccess success()
Create a success value.
Definition Error.h:336
Tagged union holding either a T or a Error.
Definition Error.h:485
Error takeError()
Take ownership of the stored error.
Definition Error.h:612
LLVM_ABI PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM)
Definition IR2Vec.cpp:681
This analysis provides the vocabulary for IR2Vec.
Definition IR2Vec.h:607
ir2vec::Vocabulary Result
Definition IR2Vec.h:622
LLVM_ABI Result run(Module &M, ModuleAnalysisManager &MAM)
Definition IR2Vec.cpp:642
static LLVM_ABI AnalysisKey Key
Definition IR2Vec.h:618
LLVM_ABI PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM)
Definition IR2Vec.cpp:723
LLVM_ABI void emitError(const Instruction *I, const Twine &ErrorStr)
emitError - Emit an error message to the currently installed error handler with optional location inf...
static ErrorOr< std::unique_ptr< MemoryBuffer > > getFileOrSTDIN(const Twine &Filename, bool IsText=false, bool RequiresNullTerminator=true, std::optional< Align > Alignment=std::nullopt)
Open the specified file as a MemoryBuffer, or open stdin if the Filename is "-".
A Module instance is used to store all the information related to an LLVM module.
Definition Module.h:67
A set of analyses that are preserved following a run of a transformation pass.
Definition Analysis.h:112
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition Analysis.h:118
PreservedAnalysisChecker getChecker() const
Build a checker for this PreservedAnalyses and the specified analysis type.
Definition Analysis.h:275
SmallString - A SmallString is just a SmallVector with methods and accessors that make it work better...
Definition SmallString.h:26
StringRef - Represent a constant reference to a string, i.e.
Definition StringRef.h:55
std::string str() const
str - Get the contents as an std::string.
Definition StringRef.h:225
LLVM Value Representation.
Definition Value.h:75
LLVM_ABI StringRef getName() const
Return a constant reference to the value's name.
Definition Value.cpp:322
LLVM_ABI const Embedding & getBBVector(const BasicBlock &BB) const
Returns the embedding for a given basic block in the function F if it has been computed.
Definition IR2Vec.cpp:184
static LLVM_ABI std::unique_ptr< Embedder > create(IR2VecKind Mode, const Function &F, const Vocabulary &Vocab)
Factory method to create an Embedder object.
Definition IR2Vec.cpp:161
BBEmbeddingsMap BBVecMap
Definition IR2Vec.h:539
LLVM_ABI const BBEmbeddingsMap & getBBVecMap() const
Returns a map containing basic block and the corresponding embeddings for the function F if it has be...
Definition IR2Vec.cpp:178
const Vocabulary & Vocab
Definition IR2Vec.h:527
void computeEmbeddings() const
Function to compute embeddings.
Definition IR2Vec.cpp:199
const float TypeWeight
Definition IR2Vec.h:534
LLVM_ABI const InstEmbeddingsMap & getInstVecMap() const
Returns a map containing instructions and the corresponding embeddings for the function F if it has b...
Definition IR2Vec.cpp:172
const float OpcWeight
Weights for different entities (like opcode, arguments, types) in the IR instructions to generate the...
Definition IR2Vec.h:534
const unsigned Dimension
Dimension of the vector representation; captured from the input vocabulary.
Definition IR2Vec.h:530
LLVM_ABI Embedder(const Function &F, const Vocabulary &Vocab)
Definition IR2Vec.cpp:156
const float ArgWeight
Definition IR2Vec.h:534
Embedding FuncVector
Definition IR2Vec.h:538
LLVM_ABI const Embedding & getFunctionVector() const
Computes and returns the embedding for the current function.
Definition IR2Vec.cpp:192
InstEmbeddingsMap InstVecMap
Definition IR2Vec.h:540
const Function & F
Definition IR2Vec.h:526
Iterator support for section-based access.
Definition IR2Vec.h:196
const_iterator(const VocabStorage *Storage, unsigned SectionId, size_t LocalIndex)
Definition IR2Vec.h:202
LLVM_ABI bool operator!=(const const_iterator &Other) const
Definition IR2Vec.cpp:341
LLVM_ABI const_iterator & operator++()
Definition IR2Vec.cpp:322
LLVM_ABI const Embedding & operator*() const
Definition IR2Vec.cpp:315
LLVM_ABI bool operator==(const const_iterator &Other) const
Definition IR2Vec.cpp:335
Generic storage class for section-based vocabularies.
Definition IR2Vec.h:151
static Error parseVocabSection(StringRef Key, const json::Value &ParsedVocabValue, VocabMap &TargetVocab, unsigned &Dim)
Parse a vocabulary section from JSON and populate the target vocabulary map.
Definition IR2Vec.cpp:346
unsigned getNumSections() const
Get number of sections.
Definition IR2Vec.h:179
VocabStorage()
Default constructor creates empty storage (invalid state)
Definition IR2Vec.h:164
size_t size() const
Get total number of entries across all sections.
Definition IR2Vec.h:176
std::map< std::string, Embedding > VocabMap
Definition IR2Vec.h:217
Class for storing and accessing the IR2Vec vocabulary.
Definition IR2Vec.h:242
static LLVM_ABI StringRef getVocabKeyForOperandKind(OperandKind Kind)
Function to get vocabulary key for a given OperandKind.
Definition IR2Vec.h:352
LLVM_ABI bool invalidate(Module &M, const PreservedAnalyses &PA, ModuleAnalysisManager::Invalidator &Inv) const
Definition IR2Vec.cpp:456
static LLVM_ABI OperandKind getOperandKind(const Value *Op)
Function to classify an operand into OperandKind.
Definition IR2Vec.cpp:399
friend class llvm::IR2VecVocabAnalysis
Definition IR2Vec.h:243
static LLVM_ABI StringRef getStringKey(unsigned Pos)
Returns the string key for a given index position in the vocabulary.
Definition IR2Vec.cpp:438
static constexpr unsigned MaxCanonicalTypeIDs
Definition IR2Vec.h:312
static constexpr unsigned MaxOperandKinds
Definition IR2Vec.h:314
OperandKind
Operand kinds supported by IR2Vec Vocabulary.
Definition IR2Vec.h:298
static LLVM_ABI StringRef getVocabKeyForPredicate(CmpInst::Predicate P)
Function to get vocabulary key for a given predicate.
Definition IR2Vec.cpp:428
static LLVM_ABI StringRef getVocabKeyForOpcode(unsigned Opcode)
Function to get vocabulary key for a given Opcode.
Definition IR2Vec.cpp:387
LLVM_ABI bool isValid() const
Definition IR2Vec.h:330
static LLVM_ABI VocabStorage createDummyVocabForTest(unsigned Dim=1)
Create a dummy vocabulary for testing purposes.
Definition IR2Vec.cpp:462
static constexpr unsigned MaxPredicateKinds
Definition IR2Vec.h:318
CanonicalTypeID
Canonical type IDs supported by IR2Vec Vocabulary.
Definition IR2Vec.h:281
An Object is a JSON object, which maps strings to heterogenous JSON values.
Definition JSON.h:98
LLVM_ABI Value * get(StringRef K)
Definition JSON.cpp:30
The root is the trivial Path to the root value.
Definition JSON.h:713
A "cursor" marking a position within a Value.
Definition JSON.h:666
A Value is an JSON value of unknown type.
Definition JSON.h:290
const json::Object * getAsObject() const
Definition JSON.h:464
This class implements an extremely fast bulk output stream that can only output to a stream.
Definition raw_ostream.h:53
ValuesClass values(OptsTy... Options)
Helper to build a ValuesClass by forwarding a variable number of arguments as an initializer list to ...
initializer< Ty > init(const Ty &Val)
DenseMap< const Instruction *, Embedding > InstEmbeddingsMap
Definition IR2Vec.h:145
static cl::opt< std::string > VocabFile("ir2vec-vocab-path", cl::Optional, cl::desc("Path to the vocabulary file for IR2Vec"), cl::init(""), cl::cat(IR2VecCategory))
LLVM_ABI cl::opt< float > ArgWeight
DenseMap< const BasicBlock *, Embedding > BBEmbeddingsMap
Definition IR2Vec.h:146
LLVM_ABI cl::opt< float > OpcWeight
LLVM_ABI cl::opt< float > TypeWeight
LLVM_ABI cl::opt< IR2VecKind > IR2VecEmbeddingKind
llvm::cl::OptionCategory IR2VecCategory
LLVM_ABI llvm::Expected< Value > parse(llvm::StringRef JSON)
Parses the provided JSON source, or returns a ParseError.
Definition JSON.cpp:675
bool fromJSON(const Value &E, std::string &Out, Path P)
Definition JSON.h:742
ir2vec::Embedding Embedding
Definition MIR2Vec.h:59
This is an optimization pass for GlobalISel generic memory operations.
Error createFileError(const Twine &F, Error E)
Concatenate a source file path and/or name with an Error.
Definition Error.h:1399
decltype(auto) dyn_cast(const From &Val)
dyn_cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:644
void handleAllErrors(Error E, HandlerTs &&... Handlers)
Behaves the same as handleErrors, except that by contract all errors must be handled by the given han...
Definition Error.h:990
Error createStringError(std::error_code EC, char const *Fmt, const Ts &... Vals)
Create formatted StringError object.
Definition Error.h:1305
@ illegal_byte_sequence
Definition Errc.h:52
@ invalid_argument
Definition Errc.h:56
IR2VecKind
IR2Vec computes two kinds of embeddings: Symbolic and Flow-aware.
Definition IR2Vec.h:71
bool isa(const From &Val)
isa<X> - Return true if the parameter to the template is an instance of one of the template type argu...
Definition Casting.h:548
format_object< Ts... > format(const char *Fmt, const Ts &... Vals)
These are helper functions used to produce formatted output.
Definition Format.h:118
LLVM_ATTRIBUTE_VISIBILITY_DEFAULT AnalysisKey InnerAnalysisManagerProxy< AnalysisManagerT, IRUnitT, ExtraArgTs... >::Key
LLVM_ABI raw_fd_ostream & errs()
This returns a reference to a raw_ostream for standard error.
@ Other
Any other memory.
Definition ModRef.h:68
DWARFExpression::Operation Op
OutputIt move(R &&Range, OutputIt Out)
Provide wrappers to std::move which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1869
iterator_range< df_iterator< T > > depth_first(const T &G)
auto seq(T Begin, T End)
Iterate over an integral type from Begin up to - but not including - End.
Definition Sequence.h:305
AnalysisManager< Module > ModuleAnalysisManager
Convenience typedef for the Module analysis manager.
Definition MIRParser.h:39
Implement std::hash so that hash_code can be used in STL containers.
Definition BitVector.h:867
A special type used by analysis passes to provide an address that identifies that particular analysis...
Definition Analysis.h:29
Embedding is a datatype that wraps std::vector<double>.
Definition IR2Vec.h:87
LLVM_ABI bool approximatelyEquals(const Embedding &RHS, double Tolerance=1e-4) const
Returns true if the embedding is approximately equal to the RHS embedding within the specified tolera...
Definition IR2Vec.cpp:132
LLVM_ABI Embedding & operator+=(const Embedding &RHS)
Arithmetic operators.
Definition IR2Vec.cpp:87
LLVM_ABI Embedding operator-(const Embedding &RHS) const
Definition IR2Vec.cpp:107
LLVM_ABI Embedding & operator-=(const Embedding &RHS)
Definition IR2Vec.cpp:100
LLVM_ABI Embedding operator*(double Factor) const
Definition IR2Vec.cpp:119
size_t size() const
Definition IR2Vec.h:100
LLVM_ABI Embedding & operator*=(double Factor)
Definition IR2Vec.cpp:113
LLVM_ABI Embedding operator+(const Embedding &RHS) const
Definition IR2Vec.cpp:94
LLVM_ABI Embedding & scaleAndAdd(const Embedding &Src, float Factor)
Adds Src Embedding scaled by Factor with the called Embedding.
Definition IR2Vec.cpp:125
LLVM_ABI void print(raw_ostream &OS) const
Definition IR2Vec.cpp:145