Skip to content

Commit 4db52ec

Browse files
committed
More SIM fixes
Signed-off-by: Yuanyuan Chen <[email protected]>
1 parent ae925d8 commit 4db52ec

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

56 files changed

+160
-128
lines changed

.github/scripts/trymerge.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1091,7 +1091,7 @@ def _comment_from_node(node: Any) -> GitHubComment:
10911091
editor = node["editor"]
10921092
return GitHubComment(
10931093
body_text=node["bodyText"],
1094-
created_at=node["createdAt"] if "createdAt" in node else "",
1094+
created_at=node.get("createdAt", ""),
10951095
author_login=node["author"]["login"],
10961096
author_association=node["authorAssociation"],
10971097
editor_login=editor["login"] if editor else None,

test/dynamo/test_guard_serialization.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,7 @@ def _tracefunc(self, frame, event, arg):
321321
def _test_serialization(self, guard_type, fn, *args, **kwargs):
322322
# kwargs might contain a callable that generates kwargs
323323
torch._dynamo.reset()
324-
kwarg_gen_fn = kwargs.get("_gen_fn", None)
324+
kwarg_gen_fn = kwargs.get("_gen_fn")
325325
if kwarg_gen_fn is not None:
326326
kwargs = kwarg_gen_fn()
327327

test/dynamo/test_recompile_ux.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -242,11 +242,12 @@ def f(x):
242242
opt_f(torch.randn(8 + i))
243243

244244
failure_str = "\n".join(failure_reasons)
245-
for line in """\
246-
tensor 'x' size mismatch at index 0. expected 11, actual 12
247-
tensor 'x' size mismatch at index 0. expected 10, actual 12
248-
tensor 'x' size mismatch at index 0. expected 9, actual 12
249-
tensor 'x' size mismatch at index 0. expected 8, actual 12""".split("\n"):
245+
for line in [
246+
"tensor 'x' size mismatch at index 0. expected 11, actual 12",
247+
"tensor 'x' size mismatch at index 0. expected 10, actual 12",
248+
"tensor 'x' size mismatch at index 0. expected 9, actual 12",
249+
"tensor 'x' size mismatch at index 0. expected 8, actual 12",
250+
]:
250251
self.assertIn(
251252
line,
252253
failure_str,
@@ -281,16 +282,13 @@ def filter_reasons():
281282
failure_reasons.clear()
282283
opt_f([7, 8])
283284

284-
for line in """\
285-
len(x) == 3""".split("\n"):
285+
for line in ["len(x) == 3"]:
286286
self.assertIn(line, filter_reasons())
287287

288288
failure_reasons.clear()
289289
opt_f([9])
290290

291-
for line in """\
292-
len(x) == 2
293-
len(x) == 3""".split("\n"):
291+
for line in ["len(x) == 2", "len(x) == 3"]:
294292
self.assertIn(line, filter_reasons())
295293

296294
@torch._dynamo.config.patch(recompile_limit=1)

test/inductor/test_cudagraph_trees.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ def tearDown(self):
177177

178178
def get_manager(self, device_index=None):
179179
return torch._inductor.cudagraph_trees.get_container(
180-
self.device_idx if not device_index else device_index
180+
device_index if device_index else self.device_idx
181181
).tree_manager
182182

183183
def get_roots(self):

test/inductor/test_group_batch_fusion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -686,7 +686,7 @@ def build_graph(self, desc):
686686
unsatisfied += 1
687687
assert unsatisfied <= len(desc) # cycle or bad input?
688688
name, v = desc.popleft()
689-
args = tuple(lookup.get(n, None) for n in v)
689+
args = tuple(lookup.get(n) for n in v)
690690
if None in args:
691691
desc.append((name, v))
692692
continue

test/profiler/test_memory_profiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -901,7 +901,7 @@ def _run_and_format_categories(self, fn, indent=12):
901901
ptr_pair_to_key[(t.impl_ptr, t.storage_data_ptr)] = key
902902

903903
def format_categories(ptr_pair: int):
904-
target_key = ptr_pair_to_key.get(ptr_pair, None)
904+
target_key = ptr_pair_to_key.get(ptr_pair)
905905
if target_key is None:
906906
return "???"
907907

test/test_cuda_nvml_based_avail.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def test_partial_uuid_resolver(self):
127127
_transform_uuid_to_ordinals(["GPU-e4", "GPU-9e8d35e3"], uuids), [2, 1]
128128
)
129129
self.assertEqual(
130-
_transform_uuid_to_ordinals("GPU-9e8d35e3,GPU-1,GPU-47".split(","), uuids),
130+
_transform_uuid_to_ordinals(["GPU-9e8d35e3", "GPU-1", "GPU-47"], uuids),
131131
[1, 7, 5],
132132
)
133133
# First invalid UUID aborts parsing

test/test_jit_string.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -241,17 +241,17 @@ def test_rpartition() -> tuple[tuple[str, str, str], tuple[str, str, str], tuple
241241
def test_split() -> tuple[list[str], list[str], list[str], list[str], list[str],
242242
list[str], list[str], list[str], list[str], list[str], list[str]]:
243243
return (
244-
"a a a a a".split(),
245-
"a a a a a".split(),
246-
" a a\ta \v a \v\f\n a \t ".split(),
247-
" a a a a a ".split(" "),
248-
"a a a a a ".split(" ", 10),
249-
"a a a a a ".split(" ", -1),
250-
"a a a a a ".split(" ", 3),
251-
" a a a a a ".split("*"),
252-
" a*a a*a a".split("*"),
253-
" a*a a*a a ".split("*", -1),
254-
" a*a a*a a ".split("a*", 10),
244+
["a", "a", "a", "a", "a"],
245+
["a", "a", "a", "a", "a"],
246+
["a", "a", "a", "a", "a"],
247+
["", "a", "a", "a", "a", "a", ""],
248+
["a", "a", "a", "a", "a", ""],
249+
["a", "a", "a", "a", "a", ""],
250+
["a", "a", "a", "a a "],
251+
[" a a a a a "],
252+
[" a", "a a", "a a"],
253+
[" a", "a a", "a a "],
254+
[" ", "a ", "a a "],
255255
)
256256
self.checkScript(test_split, ())
257257

@@ -266,15 +266,15 @@ def test_split_empty_separator():
266266
def test_rsplit() -> tuple[list[str], list[str], list[str], list[str], list[str],
267267
list[str], list[str], list[str], list[str]]:
268268
return (
269-
"a a a a a".rsplit(),
270-
" a a a a a ".rsplit(" "),
271-
"a a a a a ".rsplit(" ", 10),
272-
"a a a a a ".rsplit(" ", -1),
273-
"a a a a a ".rsplit(" ", 3),
274-
" a a a a a ".rsplit("*"),
275-
" a*a a*a a ".rsplit("*"),
276-
" a*a a*a a ".rsplit("*", -1),
277-
" a*a a*a a".rsplit("a*", 10),
269+
["a", "a", "a", "a", "a"],
270+
["", "a", "a", "a", "a", "a", ""],
271+
["a", "a", "a", "a", "a", ""],
272+
["a", "a", "a", "a", "a", ""],
273+
["a a a", "a", "a", ""],
274+
[" a a a a a "],
275+
[" a", "a a", "a a "],
276+
[" a", "a a", "a a "],
277+
[" ", "a ", "a a"],
278278
)
279279
self.checkScript(test_rsplit, ())
280280

test/test_scaled_matmul_cuda.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -956,7 +956,7 @@ def test_honor_sm_carveout(self) -> None:
956956
events = sorted(events, key=lambda x: x['ts'])
957957
# ROCm carveout is invisible except for kernels running slower on fewer CUs
958958
no_carveout, carveout_0, carveout, no_carveout_again = [float(evt.get("dur", "0.0")) for evt in events]
959-
if True or not (no_carveout < carveout and carveout_0 < carveout and no_carveout_again < carveout):
959+
if True or not (no_carveout < carveout and carveout_0 < carveout and no_carveout_again < carveout): # noqa: SIM222
960960
# something went wrong, print more info to help debug flaky test
961961
print("ROCm debug info for test_honor_sm_carveout")
962962
print("cu_count", cu_count)

test/test_sparse.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4913,7 +4913,7 @@ def generic_constructor(*args, **kwargs):
49134913
lambda i, v, sz: cnstr(i, v, sz, **kwargs_).to_dense(masked_grad=masked),
49144914
args_, masked=masked)
49154915
else:
4916-
if layout in {torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc} and 0:
4916+
if layout in {torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc} and 0: # noqa: SIM223
49174917
# TODO: remove this if-block after gh-107370 is resolved
49184918
continue
49194919
torch.autograd.gradcheck(

0 commit comments

Comments
 (0)