Coverage for mongo/problem/problem.py: 89%

238 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-14 03:01 +0000

1import json 

2import enum 

3from datetime import datetime, timedelta 

4from typing import ( 

5 Any, 

6 BinaryIO, 

7 Dict, 

8 List, 

9 Optional, 

10) 

11from dataclasses import dataclass 

12from io import BytesIO 

13from ulid import ULID 

14from .. import engine 

15from ..base import MongoBase 

16from ..course import * 

17from ..utils import (RedisCache, doc_required, drop_none, MinioClient) 

18from ..user import User 

19from .exception import BadTestCase 

20from .test_case import ( 

21 SimpleIO, 

22 ContextIO, 

23 IncludeDirectory, 

24 TestCaseRule, 

25) 

26 

27__all__ = ('Problem', ) 

28 

29 

30@dataclass 

31class UploadInfo: 

32 urls: List[str] 

33 upload_id: str 

34 

35 

36class Problem(MongoBase, engine=engine.Problem): 

37 

38 class Permission(enum.IntFlag): 

39 VIEW = enum.auto() # user view permission 

40 ONLINE = enum.auto() # user can view problem or not 

41 MANAGE = enum.auto() # user manage problem permission 

42 

43 def detailed_info(self, *ks, **kns) -> Dict[str, Any]: 

44 ''' 

45 return detailed info about this problem. notice 

46 that the `input` and `output` of problem test 

47 case won't be sent to front end, need call other 

48 route to get this info. 

49 

50 Args: 

51 ks (*str): the field name you want to get 

52 kns (**[str, str]): 

53 specify the dict key you want to store 

54 the data get by field name 

55 Return: 

56 a dict contains problem's data 

57 ''' 

58 if not self: 

59 return {} 

60 # problem -> dict 

61 _ret = self.to_mongo() 

62 # preprocess fields 

63 # case zip can not be serialized 

64 if 'caseZip' in _ret['testCase']: 

65 del _ret['testCase']['caseZip'] 

66 # skip minio path 

67 if 'caseZipMinioPath' in _ret['testCase']: 

68 del _ret['testCase']['caseZipMinioPath'] 

69 # convert couse document to course name 

70 _ret['courses'] = [course.course_name for course in self.courses] 

71 ret = {} 

72 for k in ks: 

73 kns[k] = k 

74 for k, n in kns.items(): 

75 s_ns = n.split('__') 

76 # extract wanted value 

77 v = _ret[s_ns[0]] 

78 for s_n in s_ns[1:]: 

79 v = v[s_n] 

80 # extract wanted keys 

81 e = ret 

82 s_ks = k.split('__') 

83 for s_k in s_ks[:-1]: 

84 if s_k not in e: 

85 e[s_k] = {} 

86 e = e[s_k] 

87 e[s_ks[-1]] = v 

88 return ret 

89 

90 def allowed(self, language): 

91 if self.problem_type == 2: 

92 return True 

93 if language >= 3 or language < 0: 

94 return False 

95 return bool((1 << language) & self.allowed_language) 

96 

97 def submit_count(self, user: User) -> int: 

98 ''' 

99 Calculate how many submissions the user has submitted to this problem. 

100 ''' 

101 # reset quota if it's a new day 

102 if user.last_submit.date() != datetime.now().date(): 

103 user.update(problem_submission={}) 

104 return 0 

105 return user.problem_submission.get(str(self.problem_id), 0) 

106 

107 def running_homeworks(self) -> List: 

108 from ..homework import Homework 

109 now = datetime.now() 

110 return [Homework(hw.id) for hw in self.homeworks if now in hw.duration] 

111 

112 def is_valid_ip(self, ip: str): 

113 return all(hw.is_valid_ip(ip) for hw in self.running_homeworks()) 

114 

115 def get_submission_status(self) -> Dict[str, int]: 

116 pipeline = { 

117 "$group": { 

118 "_id": "$status", 

119 "count": { 

120 "$sum": 1 

121 }, 

122 } 

123 } 

124 cursor = engine.Submission.objects(problem=self.id).aggregate( 

125 [pipeline], ) 

126 return {item['_id']: item['count'] for item in cursor} 

127 

128 def get_ac_user_count(self) -> int: 

129 ac_users = engine.Submission.objects( 

130 problem=self.id, 

131 status=0, 

132 ).distinct('user') 

133 return len(ac_users) 

134 

135 def get_tried_user_count(self) -> int: 

136 tried_users = engine.Submission.objects( 

137 problem=self.id, ).distinct('user') 

138 return len(tried_users) 

139 

140 @doc_required('user', User) 

141 def high_score_key(self, user: User) -> str: 

142 return f'PROBLEM_{self.id}_{user.id}_HIGH_SCORE' 

143 

144 @doc_required('user', User) 

145 def get_high_score(self, user: User) -> int: 

146 ''' 

147 Get highest score for user of this problem. 

148 ''' 

149 cache = RedisCache() 

150 key = self.high_score_key(user=user) 

151 if (val := cache.get(key)) is not None: 

152 return int(val.decode()) 

153 # TODO: avoid calling mongoengine API directly 

154 submissions = engine.Submission.objects( 

155 user=user.id, 

156 problem=self.id, 

157 ).only('score').order_by('-score').limit(1) 

158 if submissions.count() == 0: 

159 high_score = 0 

160 else: 

161 # It might < 0 if there is only incomplete submission 

162 high_score = max(submissions[0].score, 0) 

163 cache.set(key, high_score, ex=600) 

164 return high_score 

165 

166 @doc_required('user', User) 

167 def own_permission(self, user: User) -> Permission: 

168 """ 

169 generate user permission capability 

170 """ 

171 

172 user_cap = self.Permission(0) 

173 for course in map(Course, self.courses): 

174 # inherit course permission 

175 if course.permission(user, Course.Permission.VIEW): 

176 user_cap |= self.Permission.VIEW 

177 

178 # online problem 

179 if self.problem_status == 0: 

180 check_public_problem = True 

181 for homework in course.homeworks: 

182 if self.problem_id in homework.problem_ids: 

183 check_public_problem = False 

184 # current time after homework then online problem 

185 if datetime.now() >= homework.duration.start: 

186 user_cap |= self.Permission.ONLINE 

187 

188 # problem does not belong to any homework 

189 if check_public_problem: 

190 user_cap |= self.Permission.ONLINE 

191 

192 # Admin, Teacher && is owner 

193 if user.role == 0 or self.owner == user.username: 

194 user_cap |= self.Permission.VIEW 

195 user_cap |= self.Permission.ONLINE 

196 user_cap |= self.Permission.MANAGE 

197 

198 return user_cap 

199 

200 def permission(self, user: User, req: Permission) -> bool: 

201 """ 

202 check whether user own `req` permission 

203 """ 

204 

205 return (self.own_permission(user=user) & req) == req 

206 

207 @classmethod 

208 def get_problem_list( 

209 cls, 

210 user, 

211 offset: int = 0, 

212 count: int = -1, 

213 problem_id: int = None, 

214 name: str = None, 

215 tags: list = None, 

216 course: str = None, 

217 ): 

218 ''' 

219 get a list of problems 

220 ''' 

221 if course is not None: 

222 course = Course(course) 

223 if not course: 

224 return [] 

225 course = course.obj 

226 # qurey args 

227 ks = drop_none({ 

228 'problem_id': problem_id, 

229 'problem_name': name, 

230 'courses': course, 

231 'tags__in': tags, 

232 }) 

233 problems = [ 

234 p for p in engine.Problem.objects(**ks).order_by('problemId') 

235 if cls(p).permission(user=user, req=cls.Permission.ONLINE) 

236 ] 

237 # truncate 

238 if offset < 0 or (offset >= len(problems) and len(problems)): 

239 raise IndexError 

240 right = len(problems) if count < 0 else offset + count 

241 right = min(len(problems), right) 

242 return problems[offset:right] 

243 

244 @classmethod 

245 def add( 

246 cls, 

247 user: User, 

248 courses: List[str], 

249 problem_name: str, 

250 status: Optional[int] = None, 

251 description: Optional[Dict[str, Any]] = None, 

252 tags: Optional[List[str]] = None, 

253 type: Optional[int] = None, 

254 test_case_info: Optional[Dict[str, Any]] = None, 

255 can_view_stdout: bool = False, 

256 allowed_language: Optional[int] = None, 

257 quota: Optional[int] = None, 

258 default_code: Optional[str] = None, 

259 ): 

260 if len(courses) == 0: 

261 raise ValueError('No course provided') 

262 course_objs = [] 

263 for course in map(Course, courses): 

264 if not course: 

265 raise engine.DoesNotExist 

266 course_objs.append(course.id) 

267 problem_args = drop_none({ 

268 'courses': course_objs, 

269 'problem_status': status, 

270 'problem_type': type, 

271 'problem_name': problem_name, 

272 'description': description, 

273 'owner': user.username, 

274 'tags': tags, 

275 'quota': quota, 

276 'default_code': default_code, 

277 }) 

278 problem = cls.engine(**problem_args).save() 

279 programming_problem_args = drop_none({ 

280 'test_case': 

281 test_case_info, 

282 'can_view_stdout': 

283 can_view_stdout, 

284 'allowed_language': 

285 allowed_language, 

286 }) 

287 if programming_problem_args and type != 2: 

288 problem.update(**programming_problem_args) 

289 return problem.problem_id 

290 

291 @classmethod 

292 def edit_problem( 

293 cls, 

294 user: User, 

295 problem_id: int, 

296 courses: List[str], 

297 status: int, 

298 problem_name: str, 

299 description: Dict[str, Any], 

300 tags: List[str], 

301 type, 

302 test_case_info: Optional[Dict[str, Any]] = None, 

303 allowed_language: int = 7, 

304 can_view_stdout: bool = False, 

305 quota: int = -1, 

306 default_code: str = '', 

307 ): 

308 if type != 2: 

309 score = sum(t['taskScore'] for t in test_case_info['tasks']) 

310 if score != 100: 

311 raise ValueError("Cases' scores should be 100 in total") 

312 problem = Problem(problem_id).obj 

313 course_objs = [] 

314 for name in courses: 

315 if not (course := Course(name)): 

316 raise engine.DoesNotExist 

317 course_objs.append(course.obj) 

318 problem.update( 

319 courses=course_objs, 

320 problem_status=status, 

321 problem_type=type, 

322 problem_name=problem_name, 

323 description=description, 

324 owner=user.username, 

325 tags=tags, 

326 quota=quota, 

327 default_code=default_code, 

328 ) 

329 if type != 2: 

330 # preprocess test case 

331 test_case = problem.test_case 

332 if test_case_info: 

333 test_case = engine.ProblemTestCase.from_json( 

334 json.dumps(test_case_info)) 

335 test_case.case_zip = problem.test_case.case_zip 

336 test_case.case_zip_minio_path = problem.test_case.case_zip_minio_path 

337 problem.update( 

338 allowed_language=allowed_language, 

339 can_view_stdout=can_view_stdout, 

340 test_case=test_case, 

341 ) 

342 

343 def update_test_case(self, test_case: BinaryIO): 

344 ''' 

345 edit problem's testcase 

346 

347 Args: 

348 test_case: testcase zip file 

349 Exceptions: 

350 zipfile.BadZipFile: if `test_case` is not a zip file 

351 ValueError: if test case is None or problem_id is invalid 

352 engine.DoesNotExist 

353 ''' 

354 self._validate_test_case(test_case) 

355 test_case.seek(0) 

356 self._save_test_case_zip(test_case) 

357 

358 def _save_test_case_zip(self, test_case: BinaryIO): 

359 ''' 

360 save test case zip file 

361 ''' 

362 minio_client = MinioClient() 

363 path = self._generate_test_case_obj_path() 

364 minio_client.client.put_object( 

365 minio_client.bucket, 

366 path, 

367 test_case, 

368 -1, 

369 part_size=5 * 1024 * 1024, 

370 content_type='application/zip', 

371 ) 

372 self.update(test_case__case_zip_minio_path=path) 

373 self.reload('test_case') 

374 

375 def _generate_test_case_obj_path(self): 

376 return f'problem-test-case/{ULID()}.zip' 

377 

378 def _validate_test_case(self, test_case: BinaryIO): 

379 ''' 

380 validate test case, raise BadTestCase if invalid 

381 ''' 

382 rules: List[TestCaseRule] = [ 

383 IncludeDirectory(self, 'include'), 

384 IncludeDirectory(self, 'share'), 

385 # for backward compatibility 

386 IncludeDirectory(self, 'chaos'), 

387 ] 

388 for rule in rules: 

389 rule.validate(test_case) 

390 

391 # Should only match one format 

392 rules = [ 

393 SimpleIO(self, ['include/', 'share/', 'chaos/']), 

394 ContextIO(self), 

395 ] 

396 excs = [] 

397 for rule in rules: 

398 try: 

399 rule.validate(test_case) 

400 except BadTestCase as e: 

401 excs.append(e) 

402 

403 if len(excs) == 0: 

404 raise BadTestCase('ambiguous test case format') 

405 elif len(excs) == 2: 

406 raise BadTestCase( 

407 f'invalid test case format\n\n{excs[0]}\n\n{excs[1]}') 

408 

409 @classmethod 

410 def copy_problem(cls, user, problem_id): 

411 problem = Problem(problem_id).obj 

412 engine.Problem( 

413 problem_status=problem.problem_status, 

414 problem_type=problem.problem_type, 

415 problem_name=problem.problem_name, 

416 description=problem.description, 

417 owner=user.username, 

418 tags=problem.tags, 

419 test_case=problem.test_case, 

420 ).save() 

421 

422 @doc_required('target', Course, src_none_allowed=True) 

423 def copy_to( 

424 self, 

425 user: User, 

426 target: Optional[Course] = None, 

427 **override, 

428 ) -> 'Problem': 

429 ''' 

430 Copy a problem to target course, hidden by default. 

431 

432 Args: 

433 user (User): The user who execute this action and will become 

434 the owner of copied problem. 

435 target (Optional[Course] = None): The course this problem will 

436 be copied to, default to the first of origial courses. 

437 override: Override field values passed to `Problem.add`. 

438 ''' 

439 target = self.courses[0] if target is None else target 

440 # Copied problem is hidden by default 

441 status = override.pop('status', Problem.engine.Visibility.HIDDEN) 

442 ks = dict( 

443 user=user, 

444 courses=[target.course_name], 

445 problem_name=self.problem_name, 

446 status=status, 

447 description=self.description.to_mongo(), 

448 tags=self.tags, 

449 type=self.problem_type, 

450 test_case_info=self.test_case.to_mongo(), 

451 can_view_stdout=self.can_view_stdout, 

452 allowed_language=self.allowed_language, 

453 quota=self.quota, 

454 default_code=self.default_code, 

455 ) 

456 ks.update(override) 

457 copy = self.add(**ks) 

458 return copy 

459 

460 @classmethod 

461 def release_problem(cls, problem_id): 

462 course = Course('Public').obj 

463 problem = Problem(problem_id).obj 

464 problem.courses = [course] 

465 problem.owner = 'first_admin' 

466 problem.save() 

467 

468 def is_test_case_ready(self) -> bool: 

469 return (self.test_case.case_zip.grid_id is not None 

470 or self.test_case.case_zip_minio_path is not None) 

471 

472 def get_test_case(self) -> BinaryIO: 

473 if self.test_case.case_zip_minio_path is not None: 

474 minio_client = MinioClient() 

475 try: 

476 resp = minio_client.client.get_object( 

477 minio_client.bucket, 

478 self.test_case.case_zip_minio_path, 

479 ) 

480 return BytesIO(resp.read()) 

481 finally: 

482 if 'resp' in locals(): 

483 resp.close() 

484 resp.release_conn() 

485 

486 # fallback to legacy GridFS storage 

487 return self.test_case.case_zip 

488 

489 # TODO: hope minio SDK to provide more high-level API 

490 def generate_urls_for_uploading_test_case( 

491 self, 

492 length: int, 

493 part_size: int, 

494 ) -> UploadInfo: 

495 # TODO: update url after uploading completed 

496 # TODO: handle failed uploading 

497 path = self._generate_test_case_obj_path() 

498 self.update(test_case__case_zip_minio_path=path) 

499 

500 minio_client = MinioClient() 

501 upload_id = minio_client.client._create_multipart_upload( 

502 minio_client.bucket, 

503 path, 

504 headers={'Content-Type': 'application/zip'}, 

505 ) 

506 part_count = (length + part_size - 1) // part_size 

507 

508 def get(i: int): 

509 return minio_client.client.get_presigned_url( 

510 'PUT', 

511 minio_client.bucket, 

512 path, 

513 expires=timedelta(minutes=30), 

514 extra_query_params={ 

515 'partNumber': str(i + 1), 

516 'uploadId': upload_id 

517 }, 

518 ) 

519 

520 return UploadInfo( 

521 urls=[get(i) for i in range(part_count)], 

522 upload_id=upload_id, 

523 ) 

524 

525 def complete_test_case_upload(self, upload_id: str, parts: list): 

526 minio_client = MinioClient() 

527 minio_client.client._complete_multipart_upload( 

528 minio_client.bucket, 

529 self.test_case.case_zip_minio_path, 

530 upload_id, 

531 parts, 

532 ) 

533 

534 try: 

535 test_case = self.get_test_case() 

536 self._validate_test_case(test_case) 

537 except BadTestCase: 

538 self.update(test_case__case_zip_minio_path=None) 

539 raise