Skip to content
Toggle navigation
Toggle navigation
This project
Loading...
Sign in
周伟奇
/
test_on_pytorch
Go to a project
Toggle navigation
Toggle navigation pinning
Projects
Groups
Snippets
Help
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Network
Create a new issue
Builds
Commits
Issue Boards
Files
Commits
Network
Compare
Branches
Tags
f3fbac1b
authored
2022-12-14 16:47:11 +0800
by
周伟奇
Browse Files
Options
Browse Files
Tag
Download
Email Patches
Plain Diff
modify vit cls_token
1 parent
0a93c10d
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
16 additions
and
15 deletions
data/create_dataset.py
model/vit.py
solver/vit_solver.py
data/create_dataset.py
View file @
f3fbac1
...
...
@@ -55,10 +55,10 @@ def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save
skip_list: list 跳过的图片列表
save_dir: str 数据集保存目录
"""
#
if os.path.exists(save_dir):
#
return
#
else:
#
os.makedirs(save_dir, exist_ok=True)
if
os
.
path
.
exists
(
save_dir
):
return
else
:
os
.
makedirs
(
save_dir
,
exist_ok
=
True
)
top_text_count
=
len
(
top_text_list
)
for
img_name
in
sorted
(
os
.
listdir
(
img_dir
)):
...
...
@@ -238,11 +238,11 @@ if __name__ == '__main__':
'CH-B102708352-2.jpg'
,
]
#
build_dataset(train_image_path, train_go_path, train_label_path, filter_from_top_text_list, skip_list_train, train_dataset_dir)
build_dataset
(
train_image_path
,
train_go_path
,
train_label_path
,
filter_from_top_text_list
,
skip_list_train
,
train_dataset_dir
)
# build_dataset(valid_image_path, valid_go_path, valid_label_path, filter_from_top_text_list, skip_list_valid, valid_dataset_dir)
#
build_anno_file(train_dataset_dir, train_anno_file_path)
build_anno_file
(
train_dataset_dir
,
train_anno_file_path
)
# build_anno_file(valid_dataset_dir, valid_anno_file_path)
...
...
model/vit.py
View file @
f3fbac1
...
...
@@ -206,7 +206,8 @@ class VisionTransformer(nn.Module):
super
(
VisionTransformer
,
self
)
.
__init__
()
self
.
num_classes
=
num_classes
self
.
num_features
=
self
.
embed_dim
=
embed_dim
# num_features for consistency with other models
self
.
num_tokens
=
2
if
distilled
else
1
# self.num_tokens = 2 if distilled else 1
self
.
num_tokens
=
0
norm_layer
=
norm_layer
or
partial
(
nn
.
LayerNorm
,
eps
=
1e-6
)
act_layer
=
act_layer
or
nn
.
GELU
...
...
@@ -260,17 +261,17 @@ class VisionTransformer(nn.Module):
# [1, 1, 768] -> [B, 1, 768]
# [B, 28+1, 8]
cls_token
=
self
.
cls_token
.
expand
(
x
.
shape
[
0
],
-
1
,
-
1
)
if
self
.
dist_token
is
None
:
x
=
torch
.
cat
((
cls_token
,
x
),
dim
=
1
)
# [B, 197, 768]
else
:
x
=
torch
.
cat
((
cls_token
,
self
.
dist_token
.
expand
(
x
.
shape
[
0
],
-
1
,
-
1
),
x
),
dim
=
1
)
#
cls_token = self.cls_token.expand(x.shape[0], -1, -1)
#
if self.dist_token is None:
#
x = torch.cat((cls_token, x), dim=1) # [B, 197, 768]
#
else:
#
x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
x
=
self
.
pos_drop
(
x
+
self
.
pos_embed
)
x
=
self
.
blocks
(
x
)
x
=
self
.
norm
(
x
)
if
self
.
dist_token
is
None
:
return
self
.
pre_logits
(
x
[:,
0
])
return
self
.
pre_logits
(
x
[:,
-
1
])
else
:
return
x
[:,
0
],
x
[:,
1
]
...
...
solver/vit_solver.py
View file @
f3fbac1
...
...
@@ -56,7 +56,7 @@ class VITSolver(object):
for
batch
,
(
X
,
y
)
in
enumerate
(
self
.
train_loader
):
X
,
y
=
X
.
to
(
self
.
device
),
y
.
to
(
self
.
device
)
pred
=
self
.
model
(
X
)
pred
=
torch
.
nn
.
Softmax
(
dim
=
1
)(
self
.
model
(
X
)
)
correct
+=
self
.
evaluate
(
pred
,
y
)
...
...
@@ -85,7 +85,7 @@ class VITSolver(object):
for
X
,
y
in
self
.
val_loader
:
X
,
y
=
X
.
to
(
self
.
device
),
y
.
to
(
self
.
device
)
pred
=
self
.
model
(
X
)
pred
=
torch
.
nn
.
Softmax
(
dim
=
1
)(
self
.
model
(
X
)
)
correct
+=
self
.
evaluate
(
pred
,
y
)
...
...
Write
Preview
Styling with
Markdown
is supported
Attach a file
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to post a comment