본문 바로가기

개발 관련 기타/ML

[ML] 나는 솔로 테스트 - 출연하면 나의 이름은? AI 모델 teachable machine (reactjs, mui) - Part 1

과정:

- 모델 생성

- 화면 기획

- 화면 구현

- 모델 연결

- 결과 뿌리기

 

모델 생성:

- https://teachablemachine.withgoogle.com/ 사용

 

화면 기획:

- figma 사용

화면 구현:

- reactjs

- mui

모델 연결:

- https://teachablemachine.withgoogle.com/ 예제 코드 사용

 

결과 뿌리기:

- linear progress (mui)

 

소스코드:

import './App.css';

import * as React from 'react';
import {useState} from 'react';
import PropTypes from 'prop-types';
import Box from '@mui/material/Box';
import Grid from '@mui/material/Grid';
import Button from '@mui/material/Button';
import { ThemeProvider, createTheme } from '@mui/material/styles';
import { Typography, Container } from '@mui/material';
//import { AppBar, Toolbar, IconButton, Button } from '@mui/material';
//import MenuIcon from '@mui/icons-material/Menu';
import Radio from '@mui/material/Radio';
import RadioGroup from '@mui/material/RadioGroup';
import FormControlLabel from '@mui/material/FormControlLabel';
import FormControl from '@mui/material/FormControl';
import FormLabel from '@mui/material/FormLabel';
import Stack from '@mui/material/Stack';
import LinearProgress from '@mui/material/LinearProgress';

import * as tmImage from '@teachablemachine/image';

//const Navbar = () => {
//  return (
//    <AppBar position="static">
//      <Toolbar>
//        <IconButton edge="start" color="inherit" aria-label="menu">
//          <MenuIcon />
//        </IconButton>
//        <Typography variant="h6" sx={{ flexGrow: 1 }}>
//          My Website
//        </Typography>
//        <Button color="inherit">Login</Button>
//      </Toolbar>
//    </AppBar>
//  );
//}

//const Item = styled(Paper)(({ theme }) => ({
//  backgroundColor: theme.palette.mode === 'dark' ? '#1A2027' : '#fff',
//  ...theme.typography.body2,
//  padding: theme.spacing(1),
//  textAlign: 'center',
//  color: theme.palette.text.secondary,
//}));

const theme = createTheme({
  palette: {
  },
});

const MyTitle = () => {
  return (
    <Typography 
      variant="h4"
      color="secondary"
      align="center"
    >
      나는 솔로 테스트
    </Typography>
  );
};

const SubTitle = () => {
  return (
    <Typography 
      variant="h5"
      color="secondary"
      align="center"
    >
      출연하면 나의 이름은?
    </Typography>
  );
};

const Footer = () => {
  return (
    <Box sx={{
      position: 'fixed', 
      bottom: 0, 
      width: '100%', 
      backgroundColor: '#f0f0f0', 
      padding: '1rem 0',
    }}>
      <Container maxWidth="lg">
        <Typography variant="body1" align="center">
          Copyright © {new Date().getFullYear()} harang90 
        </Typography>
      </Container>
    </Box>
  );
};

const RadioButtonsGroup = () => {
  return (
    <FormControl>
      <FormLabel id="demo-radio-buttons-group-label">성별</FormLabel>
      <RadioGroup
        aria-labelledby="demo-radio-buttons-group-label"
        defaultValue="female"
        name="radio-buttons-group"
        row
      >
        <FormControlLabel value="female" control={<Radio />} label="여자" />
        <FormControlLabel value="male" control={<Radio />} label="남자" />
      </RadioGroup>
    </FormControl>
  );
};

const UploadButton = (props) => {
  const URL = "https://teachablemachine.withgoogle.com/models/TbtO4Hcl2/";

  let model;

  const modelURL = URL + "model.json";
  const metadataURL = URL + "metadata.json";

  // Load the image model and setup the webcam
  async function init() {

    // load the model and metadata
    // Refer to tmImage.loadFromFiles() in the API to support files from a file picker
    // or files from your local hard drive
    // Note: the pose library adds "tmImage" object to your window (window.tmImage)
    model = await tmImage.load(modelURL, metadataURL);

    let maxPredictions;
    maxPredictions = model.getTotalClasses();
  }

  // run the webcam image through the image model
  async function predict() {
    // predict can take in an image, video or canvas html element
    model = await tmImage.load(modelURL, metadataURL);
    const srcImg = document.getElementById('srcImg');
    const prediction = await model.predict(srcImg, false);
    prediction.sort((a, b) => parseFloat(b.probability) - parseFloat(a.probability));
    console.log("가장높은확률 : ", prediction[0].className)
    props.setResult(prediction)
  }

  const inputRef = React.useRef();

  const [imgBase64, setImgBase64] = useState("");

  const handleButtonClick = () => {
    inputRef.current.click();
  };

  const handleFileChange = (event) => {
    let reader = new FileReader();

    reader.onloadend = () => {
      // 2. 읽기가 완료되면 아래코드가 실행됩니다.
      const base64 = reader.result;
      if (base64) {
        setImgBase64(base64.toString()); // 파일 base64 상태 업데이트
      }
    }
    if (event.target.files[0]) {
      reader.readAsDataURL(event.target.files[0]);
      init().then(
        value => {
        console.log("init 모델");
        predict().then(
          value => {
            props.setHidden(true);
          })
        }
      )
    }
  };

  return (
    <div>
      <input
        type="file"
        style={{ display: 'none' }}
        ref={inputRef}
        onChange={handleFileChange}
      />
      <Button variant="contained" color="primary" onClick={handleButtonClick}>
        Upload File
      </Button>
      {imgBase64?
        <Box
          component="img"
          sx={{
            height: '100%',
            width: '100%'
          }}
          id="srcImg" src={imgBase64} alt="" />: 
        <>
        </>
      }
    </div>
  );
}

const MyStack = (props) => {
  console.log(props.result);

  const LinearProgressWithLabel = (props) => {
    return (
      <Box sx={{ display: 'flex', alignItems: 'center' }}>
        <Box sx={{ minWidth: 35 }}>
          <Typography variant="body2" color="text.secondary">{props.className}</Typography>
        </Box>
        <Box sx={{ width: '100%', mr: 1 }}>
          <LinearProgress variant="determinate" value={props.probability * 100} />
        </Box>
        <Box sx={{ minWidth: 35 }}>
          <Typography variant="body2" color="text.secondary">{`${Math.round(
            props.probability * 100,
          )}%`}</Typography>
        </Box>
      </Box>
    );
  }

  LinearProgressWithLabel.propTypes = {
    value: PropTypes.number.isRequired,
};

  return (
    <>
      {props.result[0] ?
        <Stack sx={{ width: '100%', color: 'grey.500' }} spacing={2}>
          <LinearProgressWithLabel className={props.result[0]["className"]} probability={props.result[0]["probability"]} color="primary" />
          <LinearProgressWithLabel className={props.result[1]["className"]} probability={props.result[1]["probability"]} color="secondary" />
          <LinearProgressWithLabel className={props.result[2]["className"]} probability={props.result[2]["probability"]} color="primary" />
          <LinearProgressWithLabel className={props.result[3]["className"]} probability={props.result[3]["probability"]} color="secondary" />
          <LinearProgressWithLabel className={props.result[4]["className"]} probability={props.result[4]["probability"]} color="primary" />
          <LinearProgressWithLabel className={props.result[5]["className"]} probability={props.result[5]["probability"]} color="secondary" />
        </Stack>
        :
        <Stack sx={{ width: '100%', color: 'grey.500' }} spacing={2}></Stack>
      }
    </>
  );
}

function App() {

  const [hidden, setHidden]  = useState(false);
  const [result, setResult] = useState([]);

  return (
    <ThemeProvider theme={theme}>
      <Box sx={{ flexGrow: 1 }}>
        <Grid container spacing={3}>
          <Grid item xs={12}>
            <Box m={2} mb={0}>
            </Box>
          </Grid>
          <Grid item xs={12}>
            <Box m={2} mb={0}>
              <MyTitle />
            </Box>
            <Box m={2} mb={0}>
							<SubTitle />
            </Box>
          </Grid>
          <Grid item xs={12}>
            <Box m={2} mb={0}>
              <RadioButtonsGroup></RadioButtonsGroup>
            </Box>
          </Grid>
          <Grid item xs={12}>
            <Box m={2} mb={0} component="button"
              hidden={hidden}
              sx={{
                width: 300,
                height: 300,
                backgroundColor: 'primary.dark',
                '&:hover': {
                  backgroundColor: 'primary.main',
                  opacity: [0.9, 0.8, 0.7],
                },
              }}
            >
              <UploadButton 
                setHidden={setHidden}
                setResult={setResult}
              />
            </Box>
            <Box m={2}
              hidden={!hidden}
              sx={{
                width: 300,
                height: 300,
                backgroundColor: 'primary.dark',
                '&:hover': {
                  backgroundColor: 'primary.main',
                  opacity: [0.9, 0.8, 0.7],
                },
              }}
            >
              <MyStack 
                result={result}
              />
            </Box>
          </Grid>
        </Grid>
      </Box>
    </ThemeProvider>
  );
}

export default App;