Skip to content

Commit 6b23d99

Browse files
committed
feat: 对接 ollama
1 parent 41058be commit 6b23d99

File tree

11 files changed

+201
-45
lines changed

11 files changed

+201
-45
lines changed

.gitignore

+2-1
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,5 @@ data/
1111
.DS_Store
1212
*.pt
1313
*.pth
14-
*.log
14+
*.log
15+
*.lock

README.md

+26
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,32 @@
22

33
This project aims using AI to simplify the file management on NAS.
44

5+
## TODO List
6+
7+
- File System Base:
8+
- [x] File Browser
9+
- [ ] File Manual Tag
10+
- [ ] FileIndex create and update at realtime
11+
- [ ] FileSystem event watch
12+
- [ ] File encryption at write
13+
- [ ] Multi NAS Support
14+
- Image Files:
15+
- [X] Image Browser
16+
- [X] Image Snapshot
17+
- [X] Image Caption Using Local Vision Model
18+
- [X] Image Auto Tag
19+
- [ ] Image caption and tag using LLM
20+
- [ ] Image search by tag and caption
21+
- [ ] Image search by similar
22+
- [ ] RAW Image Support
23+
- Video Files:
24+
- [ ] Video Player
25+
- [ ] Video Caption
26+
- Document Files:
27+
- [ ] Support Document preview and edit
28+
- [ ] Using RAG to build knowledge base
29+
- [ ] Document search using vec
30+
531
### extra things when use this project
632

733
decord install failed when install LAVIS:

file-server-dl/config.py

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from infra.ollama import Config as OllamaConfig
2+
import yaml
3+
4+
class Config:
5+
def __init__(self):
6+
self.ollama = OllamaConfig()
7+
self.nas_root_path = ""
8+
def from_yaml_file(self, file_path: str):
9+
with open(file_path, 'r') as f:
10+
config = yaml.safe_load(f)
11+
self.ollama = OllamaConfig()
12+
self.ollama.enabled = config['ollama']['enabled']
13+
self.ollama.model = config['ollama']['model']
14+
self.ollama.host = config['ollama']['host']
15+
self.ollama.port = config['ollama']['port']
16+
self.nas_root_path = config['nas_root_path']
17+
return self

file-server-dl/config.yaml

+7-1
Original file line numberDiff line numberDiff line change
@@ -1 +1,7 @@
1-
nas_root_path: /Users/weibo/code/myself/anfm/tests
1+
nas_root_path: /Users/weibo/code/myself/anfm/tests
2+
ollama:
3+
enabled: true
4+
host: localhost
5+
port: 11434
6+
model: llama3.2-vision
7+

file-server-dl/infra/ollama.py

+31
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import requests
2+
import json
3+
4+
class Config:
5+
enabled: bool
6+
model: str
7+
host: str
8+
port: int
9+
10+
class OllamaClient:
11+
def __init__(self, config: Config):
12+
self.config = config
13+
14+
def request_ollama_generate(self, body: str, model: str = None, image: list[str] = None, format: dict = None) -> dict:
15+
'''
16+
请求ollama生成文本,返回生成的文本
17+
'''
18+
req_body = {
19+
'prompt': body,
20+
'stream': False
21+
}
22+
if model:
23+
req_body['model'] = model
24+
else:
25+
req_body['model'] = self.config.model
26+
if image:
27+
req_body['image'] = image
28+
if format:
29+
req_body['format'] = format
30+
resp = requests.post(f'http://{self.config.host}:{self.config.port}/api/generate', json=req_body)
31+
return json.loads(resp.json()['response'])

file-server-dl/main.py

+6-7
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,15 @@
11
from magika import Magika
22
from flask import Flask, request
33
from service.file_understanding import FileUnderstanding
4+
from config import Config
45
import logging
56
import yaml
67

7-
logging.basicConfig(level=logging.DEBUG)
8-
global config
9-
10-
# read config from config.yaml file then parse into config object
11-
with open("config.yaml", "r") as file:
12-
config = yaml.safe_load(file)
13-
8+
logging.basicConfig(level=logging.INFO)
9+
config = Config().from_yaml_file("config.yaml")
10+
print("===============")
11+
print(config.__dict__)
12+
print(config.ollama.__dict__)
1413
fileUnderstanding = FileUnderstanding(config=config)
1514
logging.info("File Understanding Service Started")
1615

file-server-dl/service/file_understanding.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,21 @@
22
from pathlib import Path
33
from models.file import FileUnderstandingResult
44
from service.image_understanding import ImageUnderstanding
5+
from config import Config
56
import logging
67

8+
79
class FileUnderstanding:
8-
def __init__(self, config: any):
10+
def __init__(self, config: Config):
911
self.magika = Magika()
10-
self.image_understanding = ImageUnderstanding()
12+
self.image_understanding = ImageUnderstanding(config)
1113
self.config = config
1214

1315
def understand(self, path: str) -> FileUnderstandingResult:
14-
result = self.magika.identify_path(Path(self.config['nas_root_path'] + path))
16+
result = self.magika.identify_path(Path(self.config.nas_root_path + path))
1517
file_understanding = FileUnderstandingResult(result.output.ct_label, result.output.group, result.output.description)
1618
logging.info("File Understanding: %s", file_understanding)
1719
if file_understanding.group == 'image':
18-
file_understanding.set_ext(self.image_understanding.understand( self.config['nas_root_path'] + path))
20+
file_understanding.set_ext(self.image_understanding.understand( self.config.nas_root_path + path))
1921
logging.info("File Understanding Result: %s", file_understanding)
2022
return file_understanding

file-server-dl/service/image_understanding.py

+77-6
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import logging
22
import torch
3+
import base64
4+
from io import BytesIO
35
import cn_clip.clip as clip
46
from cn_clip.clip import load_from_name
57
from lavis.models import load_model_and_preprocess
@@ -9,24 +11,29 @@
911
import numpy as np
1012
from transformers import AutoModel, AutoImageProcessor
1113
from infra.milvus import conn as milvus_conn
14+
from infra.ollama import OllamaClient, Config as OllamaConfig
15+
from config import Config
1216

1317
class ImageUnderstanding:
14-
def __init__(self):
18+
def __init__(self, config: Config):
1519
# check if torch_directml is available
1620
if importutil.find_spec("torch_directml") is None:
1721
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1822
else:
1923
import torch_directml
2024
self.device = torch_directml.device()
25+
self.config = config
2126
self.clip_model = None
2227
self.clip_preprocess = None
2328
self.text_labels = None
2429
self.text_feature = None
2530
self.caption_model = None
2631
self.caption_vis_processors = None
2732
self.milvus_conn = milvus_conn
28-
self.__init_clip_model__()
29-
self.__init_caption_model__()
33+
if not config.ollama.enabled:
34+
# if ollama is not enabled, init local clip and caption model
35+
self.__init_clip_model__()
36+
self.__init_caption_model__()
3037
self.__init_embedding_model__()
3138

3239
def __init_clip_model__(self):
@@ -88,20 +95,84 @@ def image_embedding(self, path: str) -> np.ndarray:
8895
inputs = self.embedding_processor(image, return_tensors="pt").to(self.device)
8996
outputs = self.embedding_model(**inputs)
9097
embedding = outputs.pooler_output.cpu().detach().numpy().flatten()
98+
return embedding
9199

92-
def understand(self, path: str) -> ImageUnderstandingResult:
100+
def image_understand_with_local_model(self, path: str):
93101
labels = self.label_image(path)
94102
logging.info("Image Labels: %s", labels)
95103
caption = self.caption_image(path)
96104
logging.info("Image Caption: %s", caption)
105+
return labels, caption
106+
107+
def image_understand_with_ollama(self, path: str):
108+
image = Image.open(path)
109+
# resize long side to 1024
110+
width, height = image.size
111+
if width > height:
112+
if width > 1024:
113+
height = int(1024 * height / width)
114+
width = 1024
115+
if width <= height:
116+
if height > 1024:
117+
width = int(1024 * width / height)
118+
height = 1024
119+
image = image.resize((width, height))
120+
# 写入到byte数组
121+
buffered = BytesIO()
122+
image.save(buffered, format="JPEG")
123+
bts = buffered.getvalue()
124+
b64str = base64.b64encode(bts).decode("utf-8")
125+
prompt = '''You are an experienced art critic and photographer who specializes in evaluating works of art using simple and beautiful language.
126+
Now, please use a short paragraph to describe the content of the picture you saw, and use this paragraph as the 'caption' in your answer.
127+
After that, you are asked to give 3-5 words that summarize the image in a high level and are used to label the image you saw, these words will be used as 'tags' in your answer.
128+
Finally, you will need to rate the image from four perspectives: 'Composition', 'Light and Shadow', 'Color' and 'Idea of the Work'. You need to rate the image from four perspectives: 'composition', 'light and shadow', 'color' and 'ideas', and give a final score of 0-10 on a scale of 0.1. The four scores and the overall rating will be used together as the 'scores' in your answer, and you will also be given a reason for why you gave the scores you gave from the four perspectives mentioned above, which will be used as the 'reason' for your answer.
129+
Your answer needs to use the json format as a return, if you are not sure what you are seeing, please just return the empty Json object, e.g. '{}'. The content in your answer MUST be in 'Chinese', including 'caption', 'tags' and 'reason', any Non-Chinese answer will be considered as an invalid answer.
130+
Your answer should not contain any subjective personal pronouns, e.g. 'I', 'we' etc. When you think you need to use them, please use words such as 'audience', 'others' etc. instead.'''
131+
format = {
132+
"type": "object",
133+
"properties": {
134+
"caption": {
135+
"type": "string"
136+
},
137+
"tags": {
138+
"type": "array",
139+
"items": {
140+
"type": "string"
141+
}
142+
},
143+
"score": {
144+
"type": "array",
145+
"items": {
146+
"type": "number",
147+
"minimum": 0,
148+
"maximum": 10
149+
}
150+
},
151+
"reason": {
152+
"type": "string"
153+
}
154+
},
155+
"required": ["caption", "tags", "score", "reason"]
156+
}
157+
result = OllamaClient(config=self.config.ollama).request_ollama_generate(body=prompt, image=[b64str], format=format)
158+
# 由ollama输出的标签没有置信度
159+
return [ImageLabel(x, 0.0) for x in result['tags']], result['caption']
160+
161+
def understand(self, path: str) -> ImageUnderstandingResult:
162+
if self.config.ollama.enabled:
163+
# using ollama
164+
labels, caption = self.image_understand_with_ollama(path)
165+
else:
166+
# using local model
167+
labels, caption = self.image_understand_with_local_model(path)
97168
embedding = self.image_embedding(path)
98169
self.milvus_conn.insert(embedding, path)
99-
logging.info("Image Embedding: %s", embedding)
100170
return ImageUnderstandingResult(labels, caption)
101171

102172
def image_similarity(self, path: str) -> list[dict]:
103173
embedding = self.image_embedding(path)
104174
records = self.milvus_conn.search_by_vec(embedding)
105175
results = []
106176
for record in records:
107-
results.append({"path": record.id, "score": record.distance})
177+
results.append({"path": record.id, "score": record.distance})
178+
return results

file-server/internal/domain/file/tree.go

+1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ func init() {
1111
loadTree()
1212
}
1313

14+
// TODO: refactor into sqlite.
1415
var Root = &DirNode{
1516
Name: "/",
1617
Path: "/",

file-server/internal/tasks/file_process_task.go

+19-16
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ import (
1010
"fileserver/utils"
1111
"log"
1212
"strings"
13-
"sync"
1413
"time"
1514
)
1615

@@ -55,6 +54,7 @@ func (t *FileProcessTaskHandler) Start(ctx context.Context) error {
5554
case task := <-t.taskChan:
5655
_task := (task).(*entity.FileProcessTask)
5756
t.singleFileHandler(ctx, _task.File)
57+
log.Default().Printf("file process task: %s complete", _task.File)
5858
}
5959
}
6060
}
@@ -89,26 +89,29 @@ func (t *FileProcessTaskHandler) singleFileHandler(ctx context.Context, file str
8989
}
9090
}
9191

92-
wg := sync.WaitGroup{}
93-
wg.Add(1)
94-
go func() {
95-
defer wg.Done()
96-
result, err := dl.NewClient(t.config.DLConfiguration).Understanding(ctx, dl.UnderstandingRequest{
97-
Path: file,
98-
})
99-
if err != nil {
100-
log.Default().Printf("error getting file type: %v", err)
101-
return
102-
}
103-
_file.SetFileTypeFromUnderstanding(result)
104-
}()
92+
// wg := sync.WaitGroup{}
93+
// wg.Add(1)
94+
// go func() {
95+
// defer wg.Done()
96+
result, err := dl.NewClient(t.config.DLConfiguration).Understanding(ctx, dl.UnderstandingRequest{
97+
Path: file,
98+
})
99+
if err != nil {
100+
log.Default().Printf("error getting file type: %v", err)
101+
} else {
102+
log.Default().Printf("file understand result: %v", result)
103+
}
104+
_file.SetFileTypeFromUnderstanding(result)
105+
// }()
105106

106107
// insert into database
107-
wg.Wait()
108+
// wg.Wait()
108109
if _file.Group == "image" {
110+
log.Default().Printf("send image compression task: %s", file)
109111
bus.Send(&entity.ImageCompressionTask{File: _file})
110112
}
111-
err := t.repo.CreateOrUpdateFile(ctx, _file)
113+
log.Default().Printf("insert file %s", file)
114+
err = t.repo.CreateOrUpdateFile(ctx, _file)
112115
if err != nil {
113116
log.Default().Printf("error inserting file %s: %v", file, err)
114117
}

file-server/internal/tasks/task_bus.go

+9-10
Original file line numberDiff line numberDiff line change
@@ -9,33 +9,32 @@ var bus *TaskBus
99

1010
func init() {
1111
bus = NewTaskBus()
12-
go bus.TaskHandleLoop()
1312
log.Default().Println("task bus init")
1413
}
1514

1615
// TaskBus 任务总线, 用于任务之间的通信
1716
type TaskBus struct {
18-
bus chan server.ITask
17+
buses map[string]chan server.ITask
1918
handlers map[string]server.BackendTaskHandler
2019
}
2120

2221
func NewTaskBus() *TaskBus {
2322
return &TaskBus{
24-
bus: make(chan server.ITask),
23+
buses: make(map[string]chan server.ITask),
2524
handlers: make(map[string]server.BackendTaskHandler),
2625
}
2726
}
2827

2928
func (b *TaskBus) Send(task server.ITask) {
30-
b.bus <- task
29+
b.buses[task.GetTaskName()] <- task
3130
}
3231

3332
func (b *TaskBus) RegisterHandler(task server.BackendTaskHandler) {
3433
b.handlers[task.GetTaskName()] = task
35-
}
36-
37-
func (b *TaskBus) TaskHandleLoop() {
38-
for task := range b.bus {
39-
b.handlers[task.GetTaskName()].Append(task)
40-
}
34+
b.buses[task.GetTaskName()] = make(chan server.ITask, 100)
35+
go func() {
36+
for _task := range b.buses[task.GetTaskName()] {
37+
task.Append(_task)
38+
}
39+
}()
4140
}

0 commit comments

Comments
 (0)