# import torch
import torch
# load model
model = torch.hub.load('datvuthanh/hybridnets', 'hybridnets', pretrained=True)
#inference
img = torch.randn(1, 3, 640, 384)
features, regression, classification, anchors, segmentation = model(img)
파이토치(PyTorch) 소개
PyTorch(파이토치)는 FAIR(Facebook AI Research)에서 만든 연구용 프로토타입부터 상용 제품까지 빠르게 만들 수 있는 오픈 소스 머신러닝 프레임워크입니다.
PyTorch는 사용자 친화적인 프론트엔드(front-end)와 분산 학습, 다양한 도구와 라이브러리를 통해 빠르고 유연한 실험 및 효과적인 상용화를 가능하게 합니다.
PyTorch에 대한 더 자세한 소개는 공식 홈페이지에서 확인하실 수 있습니다.
파이토치 한국 사용자 모임 소개
파이토치 한국 사용자 모임은 2018년 중순 학습 목적으로 PyTorch 튜토리얼 문서를 한국어로 번역하면서 시작하였습니다.
PyTorch를 학습하고 사용하는 한국 사용자들이 시작한 사용자 커뮤니티로, 한국어를 사용하시는 많은 분들께 PyTorch를 소개하고 함께 배우며 성장하는 것을 목표로 합니다.
PyTorch를 사용하며 얻은 유용한 정보를 공유하고 싶으시거나 다른 사용자와 소통하고 싶으시다면 커뮤니티 공간에 방문해주세요!
import torch
class MyModule(torch.nn.Module):
def __init__(self, N, M):
super(MyModule, self).__init__()
self.weight = torch.nn.Parameter(torch.rand(N, M))
def forward(self, input):
if input.sum() > 0:
output = self.weight.mv(input)
else:
output = self.weight + input
return output
my_script_module = torch.jit.script(MyModule(3, 4))
my_script_module.save("my_script_module.pt")